## Imports

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy import sparse
import os
import copy
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
from argoverse.map_representation.map_api import ArgoverseMap
from skimage.transform import rotate

In [2]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

## Load from Data storage path into data_argo

In [3]:
data_path = os.path.join('../LaneGCN/', "dataset","preprocess", "val_crs_dist6_angle90.p")
data_argo = np.load(data_path, allow_pickle=True)

In [4]:
## CHOOSE SCENE!
idx = 3
for keys,_ in data_argo[idx].items():
    print(keys)
    
data = data_argo[idx]

idx
city
feats
ctrs
orig
theta
rot
gt_preds
has_preds
graph


In [5]:
# CONFIG
config = dict()
"""Train"""
config["display_iters"] = 205942
config["val_iters"] = 205942 * 2
config["save_freq"] = 1.0
config["epoch"] = 0
config["horovod"] = True
config["opt"] = "adam"
config["num_epochs"] = 36
config["lr"] = [1e-3, 1e-4]
config["lr_epochs"] = [32]
#config["lr_func"] = StepLR(config["lr"], config["lr_epochs"])

config["batch_size"] = 32
config["val_batch_size"] = 32
config["workers"] = 0
config["val_workers"] = config["workers"]

"""Model"""
config["rot_aug"] = False
config["pred_range"] = [-100.0, 100.0, -100.0, 100.0]
config["num_scales"] = 6
config["n_actor"] = 128
config["n_map"] = 128
config["actor2map_dist"] = 7.0
config["map2actor_dist"] = 6.0
config["actor2actor_dist"] = 100.0
config["pred_size"] = 30
config["pred_step"] = 1
config["num_preds"] = config["pred_size"] // config["pred_step"]
config["num_mods"] = 6
config["cls_coef"] = 1.0
config["reg_coef"] = 1.0
config["mgn"] = 0.2
config["cls_th"] = 2.0
config["cls_ignore"] = 0.2

## Layers

In [6]:
class Conv1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Conv1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.conv = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, padding=(int(kernel_size) - 1) // 2, stride=stride, bias=False)

        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out

In [7]:
class Res1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Res1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])
        padding = (int(kernel_size) - 1) // 2
        self.conv1 = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.conv2 = nn.Conv1d(n_out, n_out, kernel_size=kernel_size, padding=padding, bias=False)
        self.relu = nn.ReLU(inplace = True)

        # All use name bn1 and bn2 to load imagenet pretrained weights
        if norm == 'GN':
            self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.bn1 = nn.BatchNorm1d(n_out)
            self.bn2 = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        if stride != 1 or n_out != n_in:
            if norm == 'GN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.downsample = None

        self.act = act

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            x = self.downsample(x)

        out += x
        if self.act:
            out = self.relu(out)
        return out

In [8]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32, act=True):
        super(Linear, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear = nn.Linear(n_in, n_out, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')
        
        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out

## Actor-Net

In [9]:
# actor-gather function
# Basic function: convert actor shape from [num, 20, 3] ->[num, 3, 20];; Also generate unique idx for each actor across batches
def actor_gather(actors):
    batch_size = len(actors)
    num_actors = [len(x) for x in actors]
    actors = [torch.tensor(x).transpose(1, 2) for x in actors] #(num_actors, 3, 20)
    actors = torch.cat(actors, 0)
    
    actor_idcs = []
    count = 0
    
    for i in range(batch_size):
        #---(how to take care of the batch-dimension)
        idcs = torch.arange(count, count + num_actors[i]) # say [0,1,2,...10] for batch-1; [11,12,13,...21] for batch-2
        actor_idcs.append(idcs) # give arange-ID to actors across batches
        count += num_actors[i]
    return actors, actor_idcs

# input
actors = [data['feats']]    
actors, actor_idcs = actor_gather(actors)
print("Num actors: ", len(actors), ";\tActors shape: ", actors[0].shape)

Num actors:  24 ;	Actors shape:  torch.Size([3, 20])


In [10]:
# define actor-net: essentially a bunch of 1D-Convs with FPN
# basic components: Res1D, blocks, groups, lateral, interpolates, outputs
# 1 group/1blocks = many Res1Ds; 1 groups = many group; laterals = no of side branch outputs

class ActorNet(nn.Module):
    """
    Actor feature extractor with Conv1D
    """
    def __init__(self, config):
        super(ActorNet, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1
        
        n_in = 3
        n_out = [32, 64, 128]
        blocks = [Res1d, Res1d, Res1d]
        num_blocks = [2, 2, 2]
        
        groups = []
        # first loop is across no of blocks (here 3!)
        for i in range(len(num_blocks)):
            group = []
            if i == 0:
                group.append(blocks[i](n_in, n_out[i], norm=norm, ng=ng))
            else:
                group.append(blocks[i](n_in, n_out[i], stride=2, norm=norm, ng=ng)) # have a lever to set the stride
            
            for j in range(1, num_blocks[i]):
                group.append(blocks[i](n_out[i], n_out[i], norm=norm, ng=ng))
            
            groups.append(nn.Sequential(*group))
            
            n_in = n_out[i]    
        
        self.groups = nn.ModuleList(groups) # moduleLIST: groups
            
        # lateral
        n = config["n_actor"]
        lateral = []
        for i in range(len(n_out)):
            lateral.append(Conv1d(n_out[i], n, norm=norm, ng=ng, act=False)) # convert every middle representation across len(n_out) layers to 128 dim 
        self.lateral = nn.ModuleList(lateral) # moduleLIST: lateral
        
        # output: just a transformation
        self.output = Res1d(n, n, norm=norm, ng=ng)
        
        
    def forward(self, actors):
        out = actors
        
        outputs = []
        ## do convs with groups
        for i in range(len(self.groups)):
            out = self.groups[i](out)
            outputs.append(out)
            
        ## interpolate and lateral add
        out = self.lateral[-1](outputs[-1])
        for i in range(len(outputs) - 2, -1, -1):
            out = F.interpolate(out, scale_factor=2, mode="linear", align_corners=False)
            out += self.lateral[i](outputs[i])
        
        ## output
        out = self.output(out)[:, :, -1] # take only the last time-step's representation
        return out

In [11]:
actor_net = ActorNet(config)
actors = actor_net(actors)

  self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
  nn.GroupNorm(gcd(ng, n_out), n_out))
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)


In [12]:
## Input: actor features [num_actors per scene*batch, 3, 20]
## Output: actor features at the last time-step [num_actors per scene*batch, 128]
print("Shape of actors after actor-net: ", actors.shape)

Shape of actors after actor-net:  torch.Size([24, 128])


## Map-Gather (Organize Map Data for Map-Net)

In [13]:
# test on 2 graphs
graphs = [data_argo[idx]['graph'] #]
          , data_argo[idx+1]['graph']]

In [14]:
def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []
    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"])
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]
    
    graph = dict()
    graph["idcs"] = node_idcs
    
    #-----------Features and Attributes Order for Lane-Conv (how to take care of the batch-dimension)
    graph["ctrs"] = [torch.tensor(x["ctrs"]) for x in graphs]
    print("Shape of graph[ctrs]: G1 ", graph['ctrs'][0].shape, "G2: ", graph['ctrs'][1].shape)
    
    for key in ["feats", "turn", "control", "intersect"]:
        graph[key] = torch.cat([torch.tensor(x[key]) for x in graphs], 0)
        print("Shape of ", key, ':\t', graph[key].shape)
        
    #-----------Edges Order for Lane-Conv (how to take care of the batch-dimension)
    #------k1: pre or suc or left or right keys; i: 6(because 6 scales/dilations of lane-conv); k2:u/v keys
    for k1 in ["pre", "suc"]: # go through graph keys predecessor and successor
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ['u', 'v']:
                 graph[k1][i][k2] = torch.cat([torch.LongTensor(graphs[j][k1][i][k2]) + counts[j] 
                                               for j in range(batch_size)], 0) 
                    # add proper index counters so each node is unique and concatenate
                    
    print("At scale 6, what does the predecessor u-v's look like: ?", "u: ", graph['pre'][5]['u'].shape, '\tv:', 
          graph['pre'][5]['v'].shape)
    
    #------For left and right, there is no dilation! SO loop through only k1 and k2
    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [torch.LongTensor(graphs[i][k1][k2]) + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
            
            
    print("Shape of individual left-right graphs?", len(graphs[0][k1][k2]), ';\t', len(graphs[1][k1][k2]))
    print("What does left-right look like ?", "u: ", graph['left']['u'].shape, '\tv:', 
          graph['left']['v'].shape)
    
    return graph

In [15]:
graph = graph_gather(graphs)

Shape of graph[ctrs]: G1  torch.Size([1017, 2]) G2:  torch.Size([1476, 2])
Shape of  feats :	 torch.Size([2493, 2])
Shape of  turn :	 torch.Size([2493, 2])
Shape of  control :	 torch.Size([2493])
Shape of  intersect :	 torch.Size([2493])
At scale 6, what does the predecessor u-v's look like: ? u:  torch.Size([3137]) 	v: torch.Size([3137])
Shape of individual left-right graphs? 558 ;	 270
What does left-right look like ? u:  torch.Size([828]) 	v: torch.Size([828])


## Map-Net

In [16]:
# check which layer is for what!
# self.input (Linear/RELU/Linear), self.seg (Linear/RELU/Linear), self.fuse(4 blocks of lane-conv layers)
n_map = config["n_map"]
norm = "GN"
ng = 1

keys = ["ctr", "norm", "ctr2", "left", "right"]
for i in range(config["num_scales"]):
    keys.append("pre" + str(i))
    keys.append("suc" + str(i))
print("6 scales so Keys are: ", keys)

fuse = dict()
for key in keys:
    fuse[key] = []
print("\nInit fuse dict: ", fuse)

for i in range(4):
    for key in fuse:
        if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
        elif key in ["ctr2"]:
            fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
        else:
            fuse[key].append(nn.Linear(n_map, n_map, bias=False))

print("\nThe fuse module is: \n",fuse)

6 scales so Keys are:  ['ctr', 'norm', 'ctr2', 'left', 'right', 'pre0', 'suc0', 'pre1', 'suc1', 'pre2', 'suc2', 'pre3', 'suc3', 'pre4', 'suc4', 'pre5', 'suc5']

Init fuse dict:  {'ctr': [], 'norm': [], 'ctr2': [], 'left': [], 'right': [], 'pre0': [], 'suc0': [], 'pre1': [], 'suc1': [], 'pre2': [], 'suc2': [], 'pre3': [], 'suc3': [], 'pre4': [], 'suc4': [], 'pre5': [], 'suc5': []}

The fuse module is: 
 {'ctr': [Linear(in_features=128, out_features=128, bias=False), Linear(in_features=128, out_features=128, bias=False), Linear(in_features=128, out_features=128, bias=False), Linear(in_features=128, out_features=128, bias=False)], 'norm': [GroupNorm(1, 128, eps=1e-05, affine=True), GroupNorm(1, 128, eps=1e-05, affine=True), GroupNorm(1, 128, eps=1e-05, affine=True), GroupNorm(1, 128, eps=1e-05, affine=True)], 'ctr2': [Linear(
  (linear): Linear(in_features=128, out_features=128, bias=False)
  (norm): GroupNorm(1, 128, eps=1e-05, affine=True)
  (relu): ReLU(inplace=True)
), Linear(
  (line

  fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)


In [17]:
# Implement eqn 1 of LaneGCN: output should be [[number of nodes x 128]]
inputnn = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )
seg = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )
    
ctrs = torch.cat(graph["ctrs"], 0)
feat = inputnn(ctrs)
feat += seg(graph["feats"])
feat = nn.ReLU()(feat)

print(feat.shape)

## Next implement Eqn-3 without the XW0
"""fuse map"""
res = feat
for i in range(len(fuse["ctr"])): # essentially looping 4 times
    temp = fuse["ctr"][i](feat) # first linear
    
    if i==0:
        k1 = 'pre'
        k2 = 0
        print("Print an example of graph: ", graph['pre'][0]['u'])
    
        index = torch.tensor(graph[k1][k2]["u"].clone(), dtype=torch.long)
        print(index)
    
    for key in fuse:
        if key.startswith("pre") or key.startswith("suc"):
            k1 = key[:3]
            k2 = int(key[3:])
            temp.index_add_(
                0,
                graph[k1][k2]["u"],
                fuse[key][i](feat[graph[k1][k2]["v"]]),
            )


torch.Size([2493, 128])
Print an example of graph:  tensor([   1,    2,    3,  ..., 2490, 2491, 2492])
tensor([   1,    2,    3,  ..., 2490, 2491, 2492])


  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  index = torch.tensor(graph[k1][k2]["u"].clone(), dtype=torch.long)


In [18]:
class MapNet(nn.Module):
    """
    Map Graph feature extractor with LaneGraphCNN
    """
    def __init__(self, config):
        super(MapNet, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        self.input = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )
        self.seg = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )

        keys = ["ctr", "norm", "ctr2", "left", "right"]
        for i in range(config["num_scales"]):
            keys.append("pre" + str(i))
            keys.append("suc" + str(i))

        fuse = dict()
        for key in keys:
            fuse[key] = []

        for i in range(4):
            for key in fuse:
                if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
                elif key in ["ctr2"]:
                    fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
                else:
                    fuse[key].append(nn.Linear(n_map, n_map, bias=False))

        for key in fuse:
            fuse[key] = nn.ModuleList(fuse[key])
        self.fuse = nn.ModuleDict(fuse)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, graph):
        if (
            len(graph["feats"]) == 0
            or len(graph["pre"][-1]["u"]) == 0
            or len(graph["suc"][-1]["u"]) == 0
        ):
            temp = graph["feats"]
            return (
                temp.new().resize_(0),
                [temp.new().long().resize_(0) for x in graph["node_idcs"]],
                temp.new().resize_(0),
            )

        ctrs = torch.cat(graph["ctrs"], 0)
        feat = self.input(ctrs)
        feat += self.seg(graph["feats"])
        feat = self.relu(feat)

        """fuse map"""
        res = feat
        for i in range(len(self.fuse["ctr"])):
            temp = self.fuse["ctr"][i](feat)
            for key in self.fuse:
                # Eqn-3, third component
                if key.startswith("pre") or key.startswith("suc"):
                    k1 = key[:3]
                    k2 = int(key[3:])
                    temp.index_add_(
                        0,
                        graph[k1][k2]["u"],
                        self.fuse[key][i](feat[graph[k1][k2]["v"]]),
                    )
            # Eqn-3, second component
            if len(graph["left"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["left"]["u"],
                    self.fuse["left"][i](feat[graph["left"]["v"]]),
                )
            if len(graph["right"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["right"]["u"],
                    self.fuse["right"][i](feat[graph["right"]["v"]]),
                )

            feat = self.fuse["norm"][i](temp)
            feat = self.relu(feat)

            feat = self.fuse["ctr2"][i](feat)
            # Eqn-3, first component
            feat += res
            feat = self.relu(feat)
            res = feat
        return feat, graph["idcs"], graph["ctrs"]

In [19]:
map_net = MapNet(config)
nodes, node_idcs, node_ctrs = map_net(graph)

  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))


In [20]:
print("What is nodes, node_idcs, node_ctrs?")
print("\nFeat: ", nodes.shape)
print(node_idcs)
print("\nNode_ctrs: ", node_ctrs[0].shape, node_ctrs[1].shape)

What is nodes, node_idcs, node_ctrs?

Feat:  torch.Size([2493, 128])
[tensor([   0,    1,    2,  ..., 1014, 1015, 1016]), tensor([1017, 1018, 1019,  ..., 2490, 2491, 2492])]

Node_ctrs:  torch.Size([1017, 2]) torch.Size([1476, 2])
