## Imports

In [1]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from fractions import gcd
from numbers import Number

## Utils

In [2]:
# CONFIG
config = dict()
"""Train"""
config["display_iters"] = 205942
config["val_iters"] = 205942 * 2
config["save_freq"] = 1.0
config["epoch"] = 0
config["horovod"] = True
config["opt"] = "adam"
config["num_epochs"] = 36
config["lr"] = [1e-3, 1e-4]
config["lr_epochs"] = [32]
#config["lr_func"] = StepLR(config["lr"], config["lr_epochs"])

config["batch_size"] = 32
config["val_batch_size"] = 32
config["workers"] = 0
config["val_workers"] = config["workers"]

"""Model"""
config["rot_aug"] = False
config["pred_range"] = [-100.0, 100.0, -100.0, 100.0]
config["num_scales"] = 6
config["n_actor"] = 128
config["n_map"] = 128
config["actor2map_dist"] = 7.0
config["map2actor_dist"] = 6.0
config["actor2actor_dist"] = 100.0
config["pred_size"] = 30
config["pred_step"] = 1
config["num_preds"] = config["pred_size"] // config["pred_step"]
config["num_mods"] = 6
config["cls_coef"] = 1.0
config["reg_coef"] = 1.0
config["mgn"] = 0.2
config["cls_th"] = 2.0
config["cls_ignore"] = 0.2

In [3]:
class Linear(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32, act=True):
        super(Linear, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear = nn.Linear(n_in, n_out, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')
        
        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out
    
    
class LinearRes(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32):
        super(LinearRes, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear1 = nn.Linear(n_in, n_out, bias=False)
        self.linear2 = nn.Linear(n_out, n_out, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if norm == 'GN':
            self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm1 = nn.BatchNorm1d(n_out)
            self.norm2 = nn.BatchNorm1d(n_out)
        else:   
            exit('SyncBN has not been added!')

        if n_in != n_out:
            if norm == 'GN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.transform = None

    def forward(self, x):
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.norm2(out)

        if self.transform is not None:
            out += self.transform(x)
        else:
            out += x

        out = self.relu(out)
        return out

In [4]:
M=10
K=6
a = torch.randn([10, 1, 2])
b = torch.randn([10, 6, 2])
print((a-b).view(M*K, -1).shape)

class AttDest(nn.Module):
    def __init__(self, n_agt):
        super(AttDest, self).__init__()
        norm = "GN"
        ng = 1
        
        # take position, get target-sized vector
        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )
        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)
        
    def forward(self, agts: Tensor, agt_ctrs: Tensor, dest_ctrs: Tensor) -> Tensor:
        'dest_ctrs: MxKx2?'
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)
        
        dist = (agt_ctrs.unsqueeze(1) - dest_ctrs).view(-1, 2) #
        dist = self.dist(dist)
        
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)
        
        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts) #[MxK, 2*agt_dim]
        return agts

torch.Size([60, 2])


## Prednet

In [5]:
pred = []
n_actor = config["n_actor"]
norm = "GN"
ng = 1

for i in range(config["num_mods"]):
    pred.append(
        nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng),
            nn.Linear(n_actor, 2 * config["num_preds"]),
        )
    )
print(nn.ModuleList(pred))

ModuleList(
  (0): Sequential(
    (0): LinearRes(
      (linear1): Linear(in_features=128, out_features=128, bias=False)
      (linear2): Linear(in_features=128, out_features=128, bias=False)
      (relu): ReLU(inplace=True)
      (norm1): GroupNorm(1, 128, eps=1e-05, affine=True)
      (norm2): GroupNorm(1, 128, eps=1e-05, affine=True)
    )
    (1): Linear(in_features=128, out_features=60, bias=True)
  )
  (1): Sequential(
    (0): LinearRes(
      (linear1): Linear(in_features=128, out_features=128, bias=False)
      (linear2): Linear(in_features=128, out_features=128, bias=False)
      (relu): ReLU(inplace=True)
      (norm1): GroupNorm(1, 128, eps=1e-05, affine=True)
      (norm2): GroupNorm(1, 128, eps=1e-05, affine=True)
    )
    (1): Linear(in_features=128, out_features=60, bias=True)
  )
  (2): Sequential(
    (0): LinearRes(
      (linear1): Linear(in_features=128, out_features=128, bias=False)
      (linear2): Linear(in_features=128, out_features=128, bias=False)
      (re

  self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)


In [6]:
class PredNet(nn.Module):
    """
    Final motion forecasting with Linear Residual block
    """
    def __init__(self, config):
        super(PredNet, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1
        
        n_actor = config["n_actor"]
        
        pred = []
        for i in range(config["num_mods"]):
            pred.append(
                nn.Sequential(
                    LinearRes(n_actor, n_actor, norm=norm, ng=ng),
                    nn.Linear(n_actor, 2 * config["num_preds"]),
                )
            )
        self.pred = nn.ModuleList(pred)
        
        # return a concatenated and linearly transformed vector of [agt,dist_from_final_pos]: [MxK, 2*agt_dim]
        self.att_dest = AttDest(n_actor)
        
        # classification
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), 
            nn.Linear(n_actor, 1)
        )
        
    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Dict[str, List[Tensor]]:
        preds = []
        
        for i in range(len(self.pred)): #pred:6
            preds.append(self.pred[i](actors))
        
        reg = torch.cat([x.unsqueeze(1) for x in preds], 1) #MxKx60
        reg = reg.view(reg.size(0), reg.size(1), -1, 2) #(MxKx30x2)
        
        for i in range(len(actor_idcs)):
            # what is happening here? Add ctrs to reg to get final locations
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            reg[idcs] = reg[idcs] + ctrs
            
        # what is happening here?
        # we have the regression - onto classification
        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs) # concatenate with end-points and transform to feed to classifier
        # feats: [MxK, 2*agt_dim]; feed to scoring function
        cls = self.cls(feats).view(-1, self.config["num_mods"]) #(MxK,1)->(M, K)
        
        # Sorting time
        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2) # sort the regression to be in order of confidence
        
        # time to ouput
        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out

## Convert to original coordinates as gt_preds

In [7]:
#for i in range(len(out["reg"])):
#    out["reg"][i] = torch.matmul(out["reg"][i], rot[i]) + orig[i].view(1, 1, 1, -1)

## Loss

In [8]:
dist = []
reg = torch.randn([10, 6, 30, 2])
gt_preds = torch.randn([10, 30, 2])
row_idcs = torch.arange(10)

for j in range(6):
    dist.append(
        torch.sqrt(((reg[row_idcs, j, -1] - gt_preds[row_idcs, -1])** 2).sum(1))
    )
dist = torch.cat([x.unsqueeze(1) for x in dist], 1)
min_dist, min_idcs = dist.min(1)
row_idcs = torch.arange(len(min_idcs)).long().to(min_idcs.device)
print(row_idcs, min_idcs)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([2, 2, 5, 1, 5, 3, 0, 1, 3, 2])


In [9]:
class PredLoss(nn.Module):
    def __init__(self, config):
        super(PredLoss, self).__init__()
        self.config = config
        self.reg_loss = nn.SmoothL1Loss(reduction="sum")
        
    def forward(self, out, gt_preds, has_preds):
        """
        Loss = Margin loss + regression loss
        Input:
        out: Dict[str, List[Tensor]], 
        gt_preds: List[Tensor], 
        has_preds: List[Tensor])
        Output:
        loss_out: Dict[str, Union[Tensor, int]]
        """
        cls, reg = out["cls"], out["reg"]
        
        #------from list to tensor
        cls = torch.cat([x for x in cls], 0) # convert from list[tensor]-> tensor with size(0)=len(list[tensor])
        reg = torch.cat([x for x in reg], 0)
        
        gt_preds = torch.cat([x for x in gt_preds], 0)
        has_preds = torch.cat([x for x in has_preds], 0)
        
        #-------init
        loss_out = dict()
        zero = 0.0 * (cls.sum() + reg.sum()) #K?
        loss_out["cls_loss"] = zero.clone()
        loss_out["num_cls"] = 0
        loss_out["reg_loss"] = zero.clone()
        loss_out["num_reg"] = 0
        
        num_mods, num_preds = self.config["num_mods"], self.config["num_preds"]
        # assert(has_preds.all())
        
        #-------masking based on has_preds
        last = has_preds.float() + 0.1 * torch.arange(num_preds).float().to(
            has_preds.device
        ) / float(num_preds)
        max_last, last_idcs = last.max(1)
        mask = max_last > 1.0

        cls = cls[mask]
        reg = reg[mask]
        gt_preds = gt_preds[mask]
        has_preds = has_preds[mask]
        last_idcs = last_idcs[mask]
        
        #-------finding what is k^ based on min_dist of mode to gt_pred
        row_idcs = torch.arange(len(last_idcs)).long().to(last_idcs.device)
        dist = []
        for j in range(num_mods):
            dist.append(torch.sqrt(((reg[row_idcs, j, last_idcs] - gt_preds[row_idcs, last_idcs])** 2).sum(1)))
        dist = torch.cat([x.unsqueeze(1) for x in dist], 1)
        # Returns a named tuple (values, indices) where values is the minimum value of each row of the input 
        # tensor in the given dimension dim. And indices is the index location of each minimum value found (argmin).
        min_dist, min_idcs = dist.min(1)
        row_idcs = torch.arange(len(min_idcs)).long().to(min_idcs.device)
        
        #-------classification loss (margin loss)
        mgn = cls[row_idcs, min_idcs].unsqueeze(1) - cls
        mask0 = (min_dist < self.config["cls_th"]).view(-1, 1)
        mask1 = dist - min_dist.view(-1, 1) > self.config["cls_ignore"]
        mgn = mgn[mask0 * mask1]
        mask = mgn < self.config["mgn"]
        coef = self.config["cls_coef"]
        loss_out["cls_loss"] += coef * (
            self.config["mgn"] * mask.sum() - mgn[mask].sum()
        )
        loss_out["num_cls"] += mask.sum().item()
        
        #--------regression loss: easier
        reg = reg[row_idcs, min_idcs]
        coef = self.config["reg_coef"]
        loss_out["reg_loss"] += coef * self.reg_loss(reg[has_preds], gt_preds[has_preds])
        loss_out["num_reg"] += has_preds.sum().item()
        
        return loss_out

In [10]:
class Loss(nn.Module):
    def __init__(self, config):
        super(Loss, self).__init__()
        self.config = config
        self.pred_loss = PredLoss(config)

    def forward(self, out: Dict, data: Dict) -> Dict:
        loss_out = self.pred_loss(out, 
                                  gpu(data["gt_preds"]), 
                                  gpu(data["has_preds"]))
        #cls_loss (divide by (num_traj-1)*num_actors) 
        #reg_loss (divide by num_actors*future_time_steps)
        loss_out["loss"] = loss_out["cls_loss"] / (loss_out["num_cls"] + 1e-10) + loss_out["reg_loss"] / (loss_out["num_reg"] + 1e-10)
        return loss_out