In [53]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from data import ArgoDataset, collate_fn
from utils import gpu, to_long,  Optimizer, StepLR

from layers import Conv1d, Res1d, Linear, LinearRes, Null
from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from lanegcn import PredNet, get_model
import torch
from torch.utils.data import Sampler, DataLoader

import matplotlib.pyplot as plt

config, Dataset, collate_fn, net, loss, post_process, opt = get_model()
import os

import argparse
import numpy as np
import random
import sys
import time
import shutil
from importlib import import_module
from numbers import Number

import torch
from torch.utils.data import Sampler, DataLoader


from utils import Logger, load_pretrain
def worker_init_fn(pid):
    np_seed = int(pid)
    np.random.seed(np_seed)
    random_seed = np.random.randint(2 ** 32 - 1)
    random.seed(random_seed)


dataset = Dataset(config["train_split"], config, train=True)
train_loader = DataLoader(
    dataset,
    batch_size=config["batch_size"],
    num_workers=config["workers"],
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True,
    worker_init_fn=worker_init_fn,
    drop_last=True,
)


for i, data in enumerate(train_loader):
    break


In [54]:
data.keys()

dict_keys(['city', 'orig', 'gt_preds', 'has_preds', 'theta', 'rot', 'feats', 'ctrs', 'graph', 'trajs2', 'traj1'])

In [55]:
from lanegcn import ActorNet, MapNet, actor_gather, graph_gather
actor_net = ActorNet(config)
map_net = MapNet(config)

# construct actor feature
actors, actor_idcs = actor_gather(gpu(data["feats"]))
actor_ctrs = gpu(data["ctrs"])
actors = actor_net(actors)

# construct map features
graph = graph_gather(to_long(gpu(data["graph"])))
nodes, node_idcs, node_ctrs = map_net(graph)

In [56]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, d_k, attn_dropout=0.1):
        super().__init__()
        self.d_k = d_k
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q, k.transpose(2, 3)) / self.d_k

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)
        return output, attn

In [57]:

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_x, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_x, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_x, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_x, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_x, bias=False)

        self.attention = ScaledDotProductAttention(d_k=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_x, eps=1e-6)

    def forward(self, x, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        batch, len_x = x.size(0), x.size(1)

        residual = x

        # Pass through the pre-attention projection: b x len_x x (n*d_v)
        # Separate different heads: b x len_x x n x d_v
        q = self.w_qs(x).view(batch, len_x, n_head, d_k)
        k = self.w_ks(x).view(batch, len_x, n_head, d_k)
        v = self.w_vs(x).view(batch, len_x, n_head, d_v)

        # Transpose for attention dot product: b x n x len_x x d_v
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        out, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x len_x x n x d_v
        # Combine the last two dimensions to concatenate all the heads together: b x len_x x (n*d_v)
        out = out.transpose(1, 2).contiguous().view(batch, len_x, -1)
        out = self.dropout(self.fc(out))
        out += residual

        out = self.layer_norm(out)

        return out, attn

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)
        return x


class MultiHeadAttnEncoderLayer(nn.Module):
    def __init__(self, d_x, d_k, d_v, n_head, d_inner, dropout=0.1):
        super(MultiHeadAttnEncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(
            n_head, d_x, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_x, d_inner, dropout=dropout)

    def forward(self, enc_input, self_attn_mask=None):
        enc_output, enc_self_attn = self.self_attn(
            enc_input, mask=self_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_self_attn



In [58]:
class Conv1dAggreBlock(nn.Module):
    """
        Aaggregation block using max-pooling
    """

    def __init__(self, n_feat: int, dropout: float = 0.0) -> None:
        super(Conv1dAggreBlock, self).__init__()
        norm = "GN"
        ng = 1
        self.n_feat = n_feat

        self.conv_1 = Conv1d(n_feat, n_feat, kernel_size=1, norm=norm, ng=ng)
        self.conv_2 = Conv1d(n_feat*2, n_feat, kernel_size=1, norm=norm, ng=ng)

        self.aggre_func = F.adaptive_avg_pool1d

        self.conv_3 = Conv1d(n_feat, n_feat, kernel_size=1, norm=norm, ng=ng, act=False)
        self.relu = nn.ReLU(inplace=True)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, feats):
        '''
            feats: (batch, c, N)
        '''
        res = feats
        feats = self.conv_1(feats)
        feats_mp, _ = feats.max(dim=1)  # global max-pooling

        feats_mp = feats_mp.unsqueeze(1).repeat((1, self.n_feat, 1))
        feats = torch.cat([feats, feats_mp], dim=1)
        feats = self.conv_2(feats)
        feats = self.dropout(feats)

        feats = self.conv_3(feats)
        feats += res
        feats = self.relu(feats)

        return feats

In [59]:
class GoalDecoder(nn.Module):
    def __init__(self, config, n_feat=32, n_pts=200):
        super(GoalDecoder, self).__init__()
        norm = "GN"
        ng = 1

        self.aggre_1 = Conv1dAggreBlock(n_feat=n_feat, dropout=0.1)
        self.conv_1 = Conv1d(n_feat, 8, kernel_size=1, norm=norm, ng=ng)

        self.aggre_2 = Conv1dAggreBlock(n_feat=8, dropout=0.1)
        self.conv_2 = Conv1d(8, 4, kernel_size=1, norm=norm, ng=ng)

        self.conv_3 = Conv1d(4, 1, kernel_size=1, norm=norm, ng=ng, act=False)

        self.dropout = nn.Dropout(p=0.1)

    def forward(self, feat, coord):
        '''
            feat:   (batch, N, n_feat)
            coord:  (batch, N, 2)
        '''
        feat = feat.transpose(1, 2)

        feat = self.aggre_1(feat)
        feat = self.conv_1(feat)
        feat = self.dropout(feat)

        feat = self.aggre_2(feat)
        feat = self.conv_2(feat)
        feat = self.dropout(feat)

        feat = self.conv_3(feat)

        weights = F.softmax(feat, dim=-1).transpose(1, 2)  # weights, (batch, N, 1)
        goal = torch.sum(coord * weights, dim=1)

        return goal.unsqueeze(1), weights  # (batch, 1, 2)


In [60]:
def prob_traj_output(x):
    # e.g., [batch, seq, 1] or # [batch, n_dec, seq, 1]
    muX = x[..., 0:1]  # [..., 1]
    muY = x[..., 1:2]  # [..., 1]
    sigX = x[..., 2:3]  # [..., 1]
    sigY = x[..., 3:4]  # [..., 1]
    rho = x[..., 4:5]  # [..., 1]
    sigX = torch.exp(sigX)
    sigY = torch.exp(sigY)
    rho = torch.tanh(rho)

    out = torch.cat([muX, muY, sigX, sigY, rho], dim=-1)  # [..., 5]
    return out

In [79]:
class GodaClassifierDaOnly(nn.Module):
    def __init__(self, config):
        super(GodaClassifierDaOnly, self).__init__()
        n_map = 128
        norm = "GN"
        ng = 1

        self.fc_1 = Linear(n_map, 16, norm=norm, ng=ng, act=True)
        self.fc_2 = Linear(16, 8, norm=norm, ng=ng, act=True)
        self.fc_3 = nn.Linear(8, 1, bias=False)

        # self.dropout = nn.Dropout(p=config["p_dropout"])

    def forward(self, feat):
        feat = self.fc_1(feat)
        feat = self.fc_2(feat)
        # feat = self.dropout(feat)

        feat = self.fc_3(feat)
        out = torch.sigmoid_(feat).view(-1)
        return out

class GoalGenerator(nn.Module):
    def __init__(self, config, n_blk=2):
        super(GoalGenerator, self).__init__()
        n_mode = 6
        n_feat = 128
        norm = "GN"
        ng = 1
        self.n_blk = n_blk

        self.conv_1 = Conv1d(35, n_feat, kernel_size=1, norm=norm, ng=ng)

        self.aggre = nn.ModuleList([
            MultiHeadAttnEncoderLayer(d_x=n_feat, d_k=n_feat, d_v=n_feat, n_head=2,
                                      d_inner=n_feat, dropout=0.1)
            for _ in range(self.n_blk)])

        self.multihead_decoder = nn.ModuleList([
            GoalDecoder(config=config, n_feat=n_feat, n_pts=200) for _ in range(n_mode)
        ])

    def forward(self, score, coord, goda_feat):
        feat = torch.cat([coord, score.unsqueeze(2), goda_feat], dim=2).transpose(1, 2)  # (batch, 35, N)
        feat = self.conv_1(feat)
        feat = feat.transpose(1, 2)  # (batch, N, n_feat)

        for enc_layer in self.aggre:
            feat, _ = enc_layer(feat, self_attn_mask=None)  # (batch, N, n_feat)

        goals = []
        weights = []
        for decoder in self.multihead_decoder:
            goals_mode, weights_mode = decoder(feat, coord)
            goals.append(goals_mode)
            weights.append(weights_mode)
        goals = torch.cat(goals, dim=1)  # (batch, n_mode, 2)

        return goals

class TrajCompletor(nn.Module):
    def __init__(self, config, prob_output=True):
        super(TrajCompletor, self).__init__()
        self.prob_output = prob_output
        norm = "GN"
        ng = 1

        self.fc_1 = LinearRes(130, 128, norm=norm, ng=ng)
        # self.fc_2 = LinearRes(128, 128, norm=norm, ng=ng)
        self.dropout = nn.Dropout(p=0.1)

        if self.prob_output:
            self.fc_d = nn.Linear(128, 30*5, bias=False)
        else:
            self.fc_d = nn.Linear(128, 30*2, bias=False)

    def forward(self, traj_enc, goal):
        '''
            traj_enc:   (batch, 128)
            goal:       (batch, n_mode, 2)
        '''
        n_batch = goal.shape[0]
        n_mode = goal.shape[1]
        x = torch.cat([traj_enc.unsqueeze(1).repeat((1, n_mode, 1)), goal], dim=2)

        x = x.reshape(-1, 130)
        x = self.fc_1(x)
        # x = self.fc_2(x)
        x = self.dropout(x)

        if self.prob_output:
            traj_pred = self.fc_d(x).reshape(n_batch, n_mode, 30, 5)
            traj_pred = prob_traj_output(traj_pred)
        else:
            traj_pred = self.fc_d(x).reshape(n_batch, n_mode, 30, 2)

        return traj_pred

In [80]:
# ~ decoders
goda_classifier = GodaClassifierDaOnly(config)
traj_completor = TrajCompletor(config)

# ~ final goal generation
goal_generator = GoalGenerator(config, n_blk=2)


Decoding

In [81]:
from torch_scatter import scatter_max, scatter_mean, scatter_add
device = torch.device('cpu')
cfg = config
training = True

def decode_goal_and_traj(feat_da, graph_da, actors, actor_idcs, goal_gt):
    batch_size = len(graph_da['ctrs'])
    agent_idcs = torch.LongTensor([idcs[0] for idcs in actor_idcs])
    agent_feat = actors[agent_idcs]

    goda_cls = goda_classifier(feat_da)

    DEFAULT_GODA_NUM = 200
    goda_score = []
    goda_coord = []
    goda_feat = []
    for i in range(batch_size):
        scores = goda_cls[graph_da['idcs'][i]]
        _, idcs = torch.sort(scores, descending=True)
        assert len(idcs) > DEFAULT_GODA_NUM, 'Invalid goda number'

        idcs = idcs[:DEFAULT_GODA_NUM]
        goda_score.append(scores[idcs])
        goda_coord.append(graph_da['ctrs'][i][idcs])
        goda_feat.append(feat_da[graph_da['idcs'][i]][idcs])

    goda_score = torch.stack(goda_score, dim=0)
    goda_coord = torch.stack(goda_coord, dim=0)
    goda_feat = torch.stack(goda_feat, dim=0)

    goal_pred = goal_generator(goda_score, goda_coord, goda_feat)  # (batch, n_mode, 2)

    '''
        train: use GT goal
        test/val: use generated goals
    '''
    if training:
        goal_mock = gpu(torch.stack(goal_gt), device=device).unsqueeze(1)
        traj_pred = traj_completor(agent_feat, goal_mock)
        traj_pred = traj_pred.repeat((1, cfg['n_mode'], 1, 1))  # [batch, n_mode, len_pred, 2]
        score_pred = torch.ones(traj_pred.shape[0], traj_pred.shape[1])
    else:
        # sum/mean before sampling
        edges = graph_da['ms_edges'][0]
        goda_cls_tmp = goda_cls.clone()
        goda_cls_tmp = scatter_add(goda_cls_tmp.index_select(0, edges['u']), edges['v'], out=goda_cls_tmp, dim=0)

        traj_pred = traj_completor(agent_feat, goal_pred)
        # ! tmp, overwrite
        traj_pred[:, :, -1, 0:2] = goal_pred

        # score_pred = torch.ones(traj_pred.shape[0], traj_pred.shape[1])  # uniform distribution
        # ~ score from heatmap
        score_pred = []
        for b in range(batch_size):
            goals = graph_da['ctrs'][b]
            goal_scores = goda_cls_tmp[graph_da['idcs'][b]]
            score_tmp = []
            for g in goal_pred[b]:
                dist = torch.sqrt(torch.sum((goals - g.unsqueeze(0))**2, dim=1))
                min_idx = torch.argmin(dist)
                score_tmp.append(goal_scores[min_idx])
            score_pred.append(score_tmp)
        score_pred = torch.Tensor(score_pred).to(device)

    return goda_cls, traj_pred, goal_pred, score_pred

In [82]:
goda_cls, traj_pred, goal_pred, score_pred = decode_goal_and_traj(actors, graph, actors, actor_idcs, data['ctrs'])
print("decoding: goda_cls:", goda_cls)
print("decoding: traj_pred:", traj_pred)
print("decoding: goal_pred:", goal_pred)
print("decoding: score_pred:", score_pred)

IndexError: index 528 is out of bounds for dimension 0 with size 528