## Imports

In [1]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

## Load from Data storage path into data_argo

In [2]:
data_path = os.path.join('../LaneGCN/', "dataset","preprocess", "val_crs_dist6_angle90.p")
data_argo = np.load(data_path, allow_pickle=True)

In [3]:
## CHOOSE SCENE!
idx = 3
for keys,_ in data_argo[idx].items():
    print(keys)
    
data = data_argo[idx]

idx
city
feats
ctrs
orig
theta
rot
gt_preds
has_preds
graph


In [4]:
# CONFIG
config = dict()
"""Train"""
config["display_iters"] = 205942
config["val_iters"] = 205942 * 2
config["save_freq"] = 1.0
config["epoch"] = 0
config["horovod"] = True
config["opt"] = "adam"
config["num_epochs"] = 36
config["lr"] = [1e-3, 1e-4]
config["lr_epochs"] = [32]
#config["lr_func"] = StepLR(config["lr"], config["lr_epochs"])

config["batch_size"] = 32
config["val_batch_size"] = 32
config["workers"] = 0
config["val_workers"] = config["workers"]

"""Model"""
config["rot_aug"] = False
config["pred_range"] = [-100.0, 100.0, -100.0, 100.0]
config["num_scales"] = 6
config["n_actor"] = 128
config["n_map"] = 128
config["actor2map_dist"] = 7.0
config["map2actor_dist"] = 6.0
config["actor2actor_dist"] = 100.0
config["pred_size"] = 30
config["pred_step"] = 1
config["num_preds"] = config["pred_size"] // config["pred_step"]
config["num_mods"] = 6
config["cls_coef"] = 1.0
config["reg_coef"] = 1.0
config["mgn"] = 0.2
config["cls_th"] = 2.0
config["cls_ignore"] = 0.2

## Utils

In [5]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32, act=True):
        super(Linear, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear = nn.Linear(n_in, n_out, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')
        
        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out

## Attention layer to be used for 3 kinds of fusion (A2A, L2A, A2L)

In [6]:
agt_ctrs = torch.randn([3, 5, 2])
ctx_ctrs = torch.randn([3, 100, 2])

hi, wi = [], []
hi_count, wi_count = 0, 0
        
for i in range(3):
    dist = agt_ctrs[i].view(-1, 1, 2) - ctx_ctrs[i].view(1, -1, 2)
    
    if i==0:
        print("Shape of dist vector: ", dist.shape)
    dist = torch.sqrt((dist ** 2).sum(2))
    mask = dist <= 1
    idcs = torch.nonzero(mask, as_tuple=False)   
    if i==0:
        print("Shape of idcs: ", idcs.shape)
    # hi-agt; wi-ctx
    hi.append(idcs[:, 0] + hi_count)
    wi.append(idcs[:, 1] + wi_count)
    hi_count += 5
    wi_count += 100
        
print("Agent idxs: ", hi[0])
print("Map idxs: ", wi[0])
hi = torch.cat(hi, 0)
wi = torch.cat(wi, 0)
agt_ctrs = agt_ctrs.view(-1, 2)
ctx_ctrs = ctx_ctrs.view(-1, 2)

dist = agt_ctrs[hi] - ctx_ctrs[wi]
print("Dist shape: ", dist.shape)

agt = torch.randn([15, 12])
ctx = torch.randn([300, 128])

print("\nVector be concat shapes: ", dist.shape, agt[hi].shape, ctx[wi].shape)

Shape of dist vector:  torch.Size([5, 100, 2])
Shape of idcs:  torch.Size([97, 2])
Agent idxs:  tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4])
Map idxs:  tensor([29, 38, 49, 50, 59, 72, 74, 75, 79, 85, 88,  7, 10, 16, 18, 66,  2, 11,
        27, 32, 51, 58, 60, 68, 70, 73, 76, 77, 78, 81, 82, 84, 92,  5,  8, 14,
        29, 38, 49, 50, 52, 59, 63, 72, 74, 75, 79, 83, 85, 88, 93, 97, 98,  1,
         4,  9, 12, 13, 15, 21, 24, 25, 26, 30, 31, 33, 34, 36, 37, 39, 42, 45,
        47, 51, 52, 53, 54, 57, 61, 62, 63, 65, 67, 70, 77, 78, 80, 81, 82, 83,
        84, 89, 90, 91, 92, 95, 97])
Dist shape:  torch.Size([317, 2])

Vector be concat shapes:  torch.Size([317, 2]) torch.Size([317, 12]) torch.Size([317, 128

## Summary

Say we have actor x_i. The map nodes are denoted by y_j.

First get dist of all relevant i's from all relevant j's.

Concat this with [dist_ij, x_i, y_j]. 

Sum all [dist_ij, x_i, y_j] vectors for a particular x_i. This is the new vector for x_i.

In [7]:
class Att(nn.Module):
    """
    Attention block to pass context nodes information to target nodes
    This is used in Actor2Map, Actor2Actor, Map2Actor
    """
    def __init__(self, n_agt, n_ctx):
        # n_agt: dim of target ; n_ctx: dimension of src
        super(Att, self).__init__()
        norm = "GN"
        ng = 1
        
        # to convert dist between a src and target to a src-shaped tensor
        self.dist = nn.Sequential(
                    nn.Linear(2, n_ctx),
                    nn.ReLU(inplace=True),
                    Linear(n_ctx, n_ctx, norm, n_g)
                    )
        
        # Convert target vectors to src dim for concat (not sure why this is needed)
        self.query = Linear(n_agt, n_ctx, norm=norm, ng=ng)
        
        # Convert the concat-ed tensor back to target-dim
        self.ctx = nn.Sequential(
            Linear(3 * n_ctx, n_agt, norm=norm, ng=ng),
            nn.Linear(n_agt, n_agt, bias=False),
        )
        
        # linear transform on target
        self.agt = nn.Linear(n_agt, n_agt, bias=False)
        
        self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
        self.relu = nn.ReLU(inplace=True)
        
        # final transform to be added to res
        self.linear = Linear(n_agt, n_agt, norm=norm, ng=ng, act=False)
       
    
    def forward(self, 
                agts, agt_idcs, agt_ctrs, 
                ctx, ctx_idcs, ctx_ctrs, 
                dist_th):
        res = agts
        
        if len(ctx) == 0:
            agts = self.agt(agts)
            agts = self.relu(agts)
            agts = self.linear(agts)
            agts += res
            agts = self.relu(agts)
            return agts

        batch_size = len(agt_idcs)
        
        hi, wi = [], []
        hi_count, wi_count = 0, 0
        
        # go over every batch separately
        for i in range(batch_size):
            # distance thresholding
            dist = agt_ctrs[i].view(-1, 1, 2) - ctx_ctrs[i].view(1, -1, 2)
            dist = torch.sqrt((dist ** 2).sum(2))
            mask = dist <= dist_th
            
            idcs = torch.nonzero(mask, as_tuple=False)
            if len(idcs) == 0:
                continue
                
            hi.append(idcs[:, 0] + hi_count)
            wi.append(idcs[:, 1] + wi_count)
            hi_count += len(agt_idcs[i])
            wi_count += len(ctx_idcs[i])
            
        # src: wi, target: hi
        hi = torch.cat(hi, 0)
        wi = torch.cat(wi, 0)
        
        # distance encoding between src and target less than a threshold
        agt_ctrs = torch.cat(agt_ctrs, 0)
        ctx_ctrs = torch.cat(ctx_ctrs, 0)
        dist = agt_ctrs[hi] - ctx_ctrs[wi]
        dist = self.dist(dist)
        
        # linear transform for appropriate targets
        query = self.query(agts[hi])
        
        # concat and transform
        ctx = ctx[wi]
        ctx = torch.cat((dist, query, ctx), 1)
        ctx = self.ctx(ctx)
        
        # combine transformed src with targets
        agts = self.agt(agts)
        agts.index_add_(0, hi, ctx)
        
        # normalize, activate
        agts = self.norm(agts)
        agts = self.relu(agts)
        
        agts = self.linear(agts)
        
        agts += res
        agts = self.relu(agts)
        return agts