In [2]:
#to figure out M6 post and to_world.
import os
import numpy as np
from numpy import ndarray
from fractions import gcd
from numbers import Number
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
import torch.optim as optim
import random
from utils import collate_fn,gpu,to_long
import logging
from memory_profiler import profile
import gc
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

seed = 33

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

config = dict()
config['n_actornet'] = 128
config['num_epochs'] = 50
config['lr'] = 1e-2
config['train_split'] = '/home/avt/prediction/Waymo/data_processed/train'
config['val_split'] = '/home/avt/prediction/Waymo/data_processed/validation'
config["num_scales"] = 6
config["n_map"] = 128
config["n_actor"] = 128
config["actor2map_dist"] = 7.0
config["map2actor_dist"] = 6.0
config["actor2actor_dist"] = 100.0
config["num_mods"] = 6
config["pred_size"] = 80
config["pred_step"] = 1
config["num_preds"] = config["pred_size"] // config["pred_step"]
config["cls_th"] = 2.0 #5.0
config["cls_ignore"] = 0.2
config["mgn"] = 0.2
config["cls_coef"] = 1.0
config["reg_coef"] = 1.0
#test

In [24]:

class Linear(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32, act=True):
        super(Linear, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear = nn.Linear(n_in, n_out, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')
        
        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out


class LinearRes(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32):
        super(LinearRes, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear1 = nn.Linear(n_in, n_out, bias=False)
        self.linear2 = nn.Linear(n_out, n_out, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if norm == 'GN':
            self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm1 = nn.BatchNorm1d(n_out)
            self.norm2 = nn.BatchNorm1d(n_out)
        else:   
            exit('SyncBN has not been added!')

        if n_in != n_out:
            if norm == 'GN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.transform = None

    def forward(self, x):
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.norm2(out)

        if self.transform is not None:
            out += self.transform(x)
            
        else:
            out += x
        out = self.relu(out) 

        return out


class Conv1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Conv1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.conv = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, padding=(int(kernel_size) - 1) // 2, stride=stride, bias=False)

        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out
 

class Res1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Res1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])
        padding = (int(kernel_size) - 1) // 2
        self.conv1 = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.conv2 = nn.Conv1d(n_out, n_out, kernel_size=kernel_size, padding=padding, bias=False)
        self.relu = nn.ReLU(inplace = True)

        # All use name bn1 and bn2 to load imagenet pretrained weights
        if norm == 'GN':
            self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.bn1 = nn.BatchNorm1d(n_out)
            self.bn2 = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        if stride != 1 or n_out != n_in:
            if norm == 'GN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.downsample = None

        self.act = act

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            x = self.downsample(x)

        out += x
        if self.act:
            out = self.relu(out)
        return out


def actor_gather(actors: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
    """
    actors is data['feat']
     
    """

    batch_size = len(actors)
    num_actors = [len(x) for x in actors]

    actors = [torch.stack(x).transpose(1, 2) for x in actors]
    actors = torch.cat(actors, 0)

    actor_idcs = []
    count = 0
    for i in range(batch_size):
        idcs = torch.arange(count, count + num_actors[i]).to(actors.device)
        actor_idcs.append(idcs)
        count += num_actors[i]
    return actors, actor_idcs


def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []

    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"])
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]

    graph = dict()
    graph["idcs"] = node_idcs
    graph["ctrs"] = [x["ctrs"] for x in graphs]

    graph['feats'] = torch.cat([x['feats'] for x in graphs], 0)

    for k1 in ["pre", "suc"]:
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ["u", "v"]:
                graph[k1][i][k2] = torch.cat(
                    [graphs[j][k1][i][k2] + counts[j] for j in range(batch_size)], 0
                )

    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
    
    return graph


class ActorNet(nn.Module):
    def __init__(self,config) -> None:
        super(ActorNet,self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_in = 4
        n_out = [32, 64, 128]
        blocks = [Res1d, Res1d, Res1d]
        num_blocks = [2, 2, 2]

        groups = []

        for i in range(len(num_blocks)):

            group = []

            if i == 0:
                group.append(blocks[i](n_in, n_out[i], norm=norm, ng=ng))
            else:
                group.append(blocks[i](n_in, n_out[i], stride=2, norm=norm, ng=ng))

            for j in range(1, num_blocks[i]):
                group.append(blocks[i](n_out[i], n_out[i], norm=norm, ng=ng))
            
            groups.append(nn.Sequential(*group))
            
            n_in = n_out[i]

        self.groups = nn.ModuleList(groups)

        n = config['n_actornet']#128
        
        lateral = []
        for i in range(len(n_out)):
            lateral.append(Conv1d(n_out[i], n, norm=norm, ng=ng, act=False))
        self.lateral = nn.ModuleList(lateral)

        self.outlayer = Res1d(n, n, norm=norm, ng=ng)

    def forward(self, actors: Tensor) -> Tensor:
        #actors [batch_size,feature_dim(4),time_step(11)]
        
        out = actors

        outputs = []
        for i in range(len(self.groups)):
            out = self.groups[i](out)
            outputs.append(out)
        
        out = self.lateral[-1](outputs[-1])

        for i in range(len(outputs) - 2, -1, -1):

            out = F.interpolate(out, scale_factor=2, mode="linear", align_corners=False)
            tmp = self.lateral[i](outputs[i])

            if out.shape != tmp.shape:
                out = out[:,:,:tmp.shape[2]]

            out += tmp
        
        out = self.outlayer(out)[:,:,-1]

        return out


class MapNet(nn.Module):
  
    def __init__(self, config):
        super(MapNet, self).__init__()
        self.config = config
        n_map = 128
        norm = "GN"
        ng = 1

        self.input = nn.Sequential(
            nn.Linear(3, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )
        self.seg = nn.Sequential(
            nn.Linear(3, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )

        keys = ["ctr", "norm", "ctr2", "left", "right"]
        for i in range(config["num_scales"]):
            keys.append("pre" + str(i))
            keys.append("suc" + str(i))

        fuse = dict()
        for key in keys:
            fuse[key] = []

        for i in range(4):
            for key in fuse:
                if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
                elif key in ["ctr2"]:
                    fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
                else:
                    fuse[key].append(nn.Linear(n_map, n_map, bias=False))

        for key in fuse:
            fuse[key] = nn.ModuleList(fuse[key])
        self.fuse = nn.ModuleDict(fuse)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, graph):
        if (
            len(graph["feats"]) == 0
            or len(graph["pre"][-1]["u"]) == 0
            or len(graph["suc"][-1]["u"]) == 0
        ):
            temp = graph["feats"]
            return (
                temp.new().resize_(0),
                [temp.new().long().resize_(0) for x in graph["node_idcs"]],
                temp.new().resize_(0),
            )

        ctrs = torch.cat(graph["ctrs"], 0)
        feat = self.input(ctrs)
        feat += self.seg(graph["feats"])
        feat = self.relu(feat)

        """fuse map"""
        res = feat
        for i in range(len(self.fuse["ctr"])):
            temp = self.fuse["ctr"][i](feat)
            for key in self.fuse:
                if key.startswith("pre") or key.startswith("suc"):
                    k1 = key[:3]
                    k2 = int(key[3:])
                    temp.index_add_(
                        0,
                        graph[k1][k2]["u"],
                        self.fuse[key][i](feat[graph[k1][k2]["v"]]),
                    )

            if len(graph["left"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["left"]["u"],
                    self.fuse["left"][i](feat[graph["left"]["v"]]),
                )
            if len(graph["right"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["right"]["u"],
                    self.fuse["right"][i](feat[graph["right"]["v"]]),
                )

            feat = self.fuse["norm"][i](temp)
            feat = self.relu(feat)

            feat = self.fuse["ctr2"][i](feat)
            feat += res
            feat = self.relu(feat)
            res = feat

        return feat , graph["idcs"], graph["ctrs"]


class Att(nn.Module):
    def __init__(self, n_agt: int, n_ctx: int) -> None:
        super(Att, self).__init__()
        norm = "GN"
        ng = 1

        self.dist = nn.Sequential(
            nn.Linear(3, n_ctx),
            nn.ReLU(inplace=True),
            Linear(n_ctx, n_ctx, norm=norm, ng=ng),
        )

        self.query = Linear(n_agt, n_ctx, norm=norm, ng=ng)

        self.ctx = nn.Sequential(
            Linear(3 * n_ctx, n_agt, norm=norm, ng=ng),
            nn.Linear(n_agt, n_agt, bias=False),
        )

        self.agt = nn.Linear(n_agt, n_agt, bias=False)
        self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
        self.linear = Linear(n_agt, n_agt, norm=norm, ng=ng, act=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, agts: Tensor, agt_idcs: List[Tensor], agt_ctrs: List[Tensor], ctx: Tensor, ctx_idcs: List[Tensor], ctx_ctrs: List[Tensor], dist_th: float) -> Tensor:
        # feat, graph["idcs"], graph["ctrs"], actors, actor_idcs, actor_ctrs, config["actor2map_dist"]      
        res = agts
        if len(ctx) == 0:
            agts = self.agt(agts)
            agts = self.relu(agts)
            agts = self.linear(agts)
            agts += res
            agts = self.relu(agts)
            return agts

        batch_size = len(agt_idcs)
        hi, wi = [], []
        hi_count, wi_count = 0, 0

        
        for i in range(batch_size):
        
            dist = agt_ctrs[i].view(-1, 1, 3) - ctx_ctrs[i].view(1, -1, 3)
            dist = torch.sqrt((dist ** 2).sum(2))
            mask = dist <= dist_th

            idcs = torch.nonzero(mask, as_tuple=False)
            if len(idcs) == 0:
                continue

            hi.append(idcs[:, 0] + hi_count)
            wi.append(idcs[:, 1] + wi_count)
            hi_count += len(agt_idcs[i])
            wi_count += len(ctx_idcs[i])

        if hi == []:
            print('WARNING!!! - Attention')

        hi = torch.cat(hi, 0)
        wi = torch.cat(wi, 0)

        agt_ctrs = torch.cat(agt_ctrs, 0)
        ctx_ctrs = torch.cat(ctx_ctrs, 0)
        dist = agt_ctrs[hi] - ctx_ctrs[wi]
        dist = self.dist(dist)

        query = self.query(agts[hi])

        ctx = ctx[wi]
        ctx = torch.cat((dist, query, ctx), 1)
        ctx = self.ctx(ctx)

        agts = self.agt(agts)
        agts.index_add_(0, hi, ctx)
        agts = self.norm(agts)
        agts = self.relu(agts)

        agts = self.linear(agts)
        agts += res
        agts = self.relu(agts)

        return agts


class A2M(nn.Module):
    """
    Actor to Map Fusion:  fuses real-time traffic information from
    actor nodes to lane nodes
    """
    def __init__(self, config):
        super(A2M, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        """fuse meta, static, dyn"""
        self.meta = Linear(n_map, n_map, norm=norm, ng=ng)
        att = []
        for i in range(2):
            att.append(Att(n_map, config["n_actor"]))
        self.att = nn.ModuleList(att)

    def forward(self, feat: Tensor, graph: Dict[str, Union[List[Tensor], Tensor, List[Dict[str, Tensor]], Dict[str, Tensor]]], actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Tensor:
        """meta, static and dyn fuse using attention"""
        
        feat = self.meta(feat)

        for i in range(len(self.att)):
            feat = self.att[i](
                feat,
                graph["idcs"],
                graph["ctrs"],
                actors,
                actor_idcs,
                actor_ctrs,
                self.config["actor2map_dist"],
            )
        return feat


class M2M(nn.Module):
 
    def __init__(self, config):
        super(M2M, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        keys = ["ctr", "norm", "ctr2", "left", "right"]
        for i in range(config["num_scales"]):
            keys.append("pre" + str(i))
            keys.append("suc" + str(i))

        fuse = dict()
        for key in keys:
            fuse[key] = []

        for i in range(4):
            for key in fuse:
                if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
                elif key in ["ctr2"]:
                    fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
                else:
                    fuse[key].append(nn.Linear(n_map, n_map, bias=False))

        for key in fuse:
            fuse[key] = nn.ModuleList(fuse[key])
        self.fuse = nn.ModuleDict(fuse)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, feat: Tensor, graph: Dict) -> Tensor:
        """fuse map"""
        res = feat
        for i in range(len(self.fuse["ctr"])):
            temp = self.fuse["ctr"][i](feat)
            for key in self.fuse:
                if key.startswith("pre") or key.startswith("suc"):
                    k1 = key[:3]
                    k2 = int(key[3:])
                    temp.index_add_(
                        0,
                        graph[k1][k2]["u"],
                        self.fuse[key][i](feat[graph[k1][k2]["v"]]),
                    )

            if len(graph["left"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["left"]["u"],
                    self.fuse["left"][i](feat[graph["left"]["v"]]),
                )
            if len(graph["right"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["right"]["u"],
                    self.fuse["right"][i](feat[graph["right"]["v"]]),
                )

            feat = self.fuse["norm"][i](temp)
            feat = self.relu(feat)

            feat = self.fuse["ctr2"][i](feat)
            feat += res
            feat = self.relu(feat)
            res = feat
        return feat


class M2A(nn.Module):
    """
    The lane to actor block fuses updated
        map information from lane nodes to actor nodes
    """
    def __init__(self, config):
        super(M2A, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]
        n_map = config["n_map"]

        att = []
        for i in range(2):
            att.append(Att(n_actor, n_map))
        self.att = nn.ModuleList(att)

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor], nodes: Tensor, node_idcs: List[Tensor], node_ctrs: List[Tensor]) -> Tensor:
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                nodes,
                node_idcs,
                node_ctrs,
                self.config["map2actor_dist"],
            )
        return actors


class A2A(nn.Module):
    """
    The actor to actor block performs interactions among actors.
    """
    def __init__(self, config):
        super(A2A, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]
        n_map = config["n_map"]

        att = []
        for i in range(2):
            att.append(Att(n_actor, n_actor))
        self.att = nn.ModuleList(att)

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Tensor:
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                actors,
                actor_idcs,
                actor_ctrs,
                self.config["actor2actor_dist"],
            )
        return actors


class PredNet(nn.Module):
    def __init__(self, config):
        super(PredNet, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]

        pred = []
        for i in range(config["num_mods"]):
            pred.append(
                nn.Sequential(
                    LinearRes(n_actor, n_actor, norm=norm, ng=ng),
                    nn.Linear(n_actor, 2 * config["num_preds"]),
                )
            )
        self.pred = nn.ModuleList(pred)

        self.att_dest = AttDest(n_actor)
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), nn.Linear(n_actor, 1)
        )

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Dict[str, List[Tensor]]:
        preds = []
        for i in range(len(self.pred)):
            preds.append(self.pred[i](actors))
        reg = torch.cat([x.unsqueeze(1) for x in preds], 1)
        reg = reg.view(reg.size(0), reg.size(1), -1, 2)

        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 3)
            reg[idcs] = reg[idcs] + ctrs[:,:,:,:2]

        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs)
        cls = self.cls(feats).view(-1, self.config["num_mods"])

        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2)

        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out


class AttDest(nn.Module):
    def __init__(self, n_agt: int):
        super(AttDest, self).__init__()
        norm = "GN"
        ng = 1

        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )

        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)

    def forward(self, agts: Tensor, agt_ctrs: Tensor, dest_ctrs: Tensor) -> Tensor:
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)

        dist = (agt_ctrs[:,:2].unsqueeze(1) - dest_ctrs).view(-1, 2)
        dist = self.dist(dist)
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)

        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts)
        return agts


class GreatNet(nn.Module):
    def __init__(self,config) -> None:
        super().__init__()

        self.config = config

        self.actor_net = ActorNet(config)
        self.map_net = MapNet(config)

        self.a2m = A2M(config)
        self.m2m = M2M(config)
        self.m2a = M2A(config)
        self.a2a = A2A(config)
        
        self.pred_net = PredNet(config)
    
    def forward(self, data: Dict) -> Tensor:

        actors, actor_idcs = actor_gather(data["feats"])
        actor_ctrs = [torch.stack(i,0) for i in data["ctrs"]]

        actors = gpu(actors)
        actor_idcs = gpu(actor_idcs)
        actor_ctrs = gpu(actor_ctrs)

        actors = self.actor_net(actors)

        #------------------------------------------------------------#

        graph = to_long(data['graph'])
        graph = graph_gather(graph)

        graph = gpu(graph)

        nodes, node_idcs, node_ctrs = self.map_net(graph)

        #------------------------------------------------------------#
        
        nodes = self.a2m(nodes, graph, actors, actor_idcs, actor_ctrs)
        nodes = self.m2m(nodes, graph)
        actors = self.m2a(actors, actor_idcs, actor_ctrs, nodes, node_idcs, node_ctrs)
        actors = self.a2a(actors, actor_idcs, actor_ctrs)
        
        out = self.pred_net(actors, actor_idcs, actor_ctrs)
        rot, orig = gpu(data["rot"]), gpu(data["orig"])

        # to_global
        for i in range(len(out["reg"])):
            out["reg"][i] = torch.matmul(out["reg"][i], rot[i]) + orig[i][:2].view(1, 1, 1, -1)

        return out


class PredLoss(nn.Module):
    def __init__(self, config):
        super(PredLoss, self).__init__()
        self.config = config
        self.reg_loss = nn.SmoothL1Loss(reduction="sum")

    def forward(self, out: Dict[str, List[Tensor]], data) -> Dict[str, Union[Tensor, int]]:
        cls, reg = out["cls"], out["reg"]
        cls = torch.cat([x for x in cls], 0)
        reg = torch.cat([x for x in reg], 0)
        has_preds = pre_gather(data['has_preds']).cuda()
        gt_preds = pre_gather(data['gt2_preds']).float()[:,:,:2].cuda()

        loss_out = dict()
        zero = 0.0 * (cls.sum() + reg.sum())
        loss_out["cls_loss"] = zero.clone()
        loss_out["num_cls"] = 0
        loss_out["reg_loss"] = zero.clone()
        loss_out["num_reg"] = 0

        num_mods, num_preds = self.config["num_mods"], self.config["num_preds"]
        # assert(has_preds.all())

        last = has_preds.float() + 0.1 * torch.arange(num_preds).float().to(
            has_preds.device
        ) / float(num_preds)
        max_last, last_idcs = last.max(1)
        mask = max_last > 1.0

        cls = cls[mask]
        reg = reg[mask]
        gt_preds = gt_preds[mask]
        has_preds = has_preds[mask]
        last_idcs = last_idcs[mask]

        row_idcs = torch.arange(len(last_idcs)).long().to(last_idcs.device)
        dist = []
        for j in range(num_mods):
            dist.append(
                torch.sqrt(
                    (
                        (reg[row_idcs, j, last_idcs] - gt_preds[row_idcs, last_idcs])
                        ** 2
                    ).sum(1)
                )
            )
        dist = torch.cat([x.unsqueeze(1) for x in dist], 1)
        min_dist, min_idcs = dist.min(1)
        row_idcs = torch.arange(len(min_idcs)).long().to(min_idcs.device)

        mgn = cls[row_idcs, min_idcs].unsqueeze(1) - cls
        mask0 = (min_dist < self.config["cls_th"]).view(-1, 1)
        mask1 = dist - min_dist.view(-1, 1) > self.config["cls_ignore"]
        mgn = mgn[mask0 * mask1]
        mask = mgn < self.config["mgn"]
        coef = self.config["cls_coef"]
        loss_out["cls_loss"] += coef * (
            self.config["mgn"] * mask.sum() - mgn[mask].sum()
        )
        loss_out["num_cls"] += mask.sum().item()

        reg = reg[row_idcs, min_idcs]
        coef = self.config["reg_coef"]
        loss_out["reg_loss"] += coef * self.reg_loss(
            reg[has_preds], gt_preds[has_preds]
        )
        loss_out["num_reg"] += has_preds.sum().item()

        return loss_out


class Loss(nn.Module):
    def __init__(self, config):
        super(Loss, self).__init__()
        self.config = config
        self.pred_loss = PredLoss(config)

    def forward(self, out: Dict, data: Dict) -> Dict:
        loss_out = self.pred_loss(out, data)
        loss_out["loss"] = loss_out["cls_loss"] / (
            loss_out["num_cls"] + 1e-10
        ) + loss_out["reg_loss"] / (loss_out["num_reg"] + 1e-10)
        return loss_out
    


class PostProcess(nn.Module):
    def __init__(self, config):
        super(PostProcess, self).__init__()
        self.config = config

    def forward(self, out,data):
        post_out = dict()
        post_out["preds"] = [x[0:1].detach().cpu().numpy() for x in out["reg"]]
        
        post_out["gt_preds"] = [x[0:1].numpy() for x in gather(data["gt_preds"])]
        post_out["has_preds"] = [x[0:1].numpy() for x in gather(data["has_preds"])]
        return post_out

    def append(self, metrics: Dict, loss_out: Dict, post_out: Optional[Dict[str, List[ndarray]]]=None) -> Dict:
        if len(metrics.keys()) == 0:
            for key in loss_out:
                if key != "loss":
                    metrics[key] = 0.0

            for key in post_out:
                metrics[key] = []

        for key in loss_out:
            if key == "loss":
                continue
            if isinstance(loss_out[key], torch.Tensor):
                metrics[key] += loss_out[key].item()
            else:
                metrics[key] += loss_out[key]

        for key in post_out:
            metrics[key] += post_out[key]
        return metrics

    def display(self, metrics, dt, epoch, lr=None):
        """Every display-iters print training/val information"""
        if lr is not None:
            print("Epoch %3.3f, lr %.5f, time %3.2f" % (epoch, lr, dt))
        else:
            print(
                "************************* Validation, time %3.2f *************************"
                % dt
            )

        cls = metrics["cls_loss"] / (metrics["num_cls"] + 1e-10)
        reg = metrics["reg_loss"] / (metrics["num_reg"] + 1e-10)
        loss = cls + reg

        preds = np.concatenate(metrics["preds"], 0)
        gt_preds = np.concatenate(metrics["gt_preds"], 0)
        has_preds = np.concatenate(metrics["has_preds"], 0)
        ade1, fde1, ade, fde, min_idcs = pred_metrics(preds, gt_preds, has_preds)

        print(
            "loss %2.4f %2.4f %2.4f, ade1 %2.4f, fde1 %2.4f, ade %2.4f, fde %2.4f"
            % (loss, cls, reg, ade1, fde1, ade, fde)
        )
        print()



def pred_metrics(preds, gt_preds, has_preds):
    assert has_preds.all()
    preds = np.asarray(preds, np.float32)
    gt_preds = np.asarray(gt_preds, np.float32)

    """batch_size x num_mods x num_preds"""
    err = np.sqrt(((preds - np.expand_dims(gt_preds, 1)) ** 2).sum(3))

    ade1 = err[:, 0].mean()
    fde1 = err[:, 0, -1].mean()

    min_idcs = err[:, :, -1].argmin(1)
    row_idcs = np.arange(len(min_idcs)).astype(np.int64)
    err = err[row_idcs, min_idcs]
    ade = err.mean()
    fde = err[:, -1].mean()
    return ade1, fde1, ade, fde, min_idcs



def pre_gather(gts: List) -> Tensor:
    tmp = list()
    for g in gts:
        tmp += g
    
    tmp = torch.stack(tmp)

    return tmp

def gather(gts) -> list:

    tmp = list()
    for i,g in enumerate(gts):
        zz = torch.stack(g, dim=0)
        tmp.append(zz)
    
    return tmp





In [3]:

class W_Dataset(Dataset):
    def __init__(self,path) -> None:

        self.path = path
        self.files = os.listdir(path)
    
    def __getitem__(self, index) -> dict:

        data_path = os.path.join(self.path,self.files[index])
        data = torch.load(data_path)

        return data
    
    def __len__(self) -> int:

        return len(self.files)

In [4]:
batch_size = 4
dataset_train = W_Dataset(config['train_split'])
train_loader = DataLoader(dataset_train, 
                        batch_size = batch_size ,
                        collate_fn = collate_fn, 
                        shuffle = True, 
                        drop_last=True)

In [26]:
net = GreatNet(config).cuda()
loss_f = Loss(config).cuda()
post_process = PostProcess(config).cuda()

  self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
  nn.GroupNorm(gcd(ng, n_out), n_out))
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
  self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
  fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
  self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)


In [6]:
for i, data in enumerate(train_loader):
        if i > 0:
            break

In [25]:
out = net(data)
loss_out = loss_f(out,data)

In [54]:
[out['reg'][i].shape for i in range(4)]

[torch.Size([47, 6, 80, 2]),
 torch.Size([19, 6, 80, 2]),
 torch.Size([43, 6, 80, 2]),
 torch.Size([47, 6, 80, 2])]

In [55]:
[gather(data['gt_preds'])[i].shape for i in range(4)]

[torch.Size([47, 80, 3]),
 torch.Size([19, 80, 3]),
 torch.Size([43, 80, 3]),
 torch.Size([47, 80, 3])]

In [56]:
[gather(data['has_preds'])[i].shape for i in range(4)]

[torch.Size([47, 80]),
 torch.Size([19, 80]),
 torch.Size([43, 80]),
 torch.Size([47, 80])]

In [60]:
xx = out['reg'][0]
mask = gather(data['has_preds'])[0]
xx.shape, mask.shape

(torch.Size([47, 6, 80, 2]), torch.Size([47, 80]))

In [163]:
cls, reg = out["cls"], out["reg"]
gt_preds, has_preds = gather(data['gt_preds']), gather(data['has_preds'])

In [75]:
reg[0].shape

torch.Size([47, 6, 80, 2])

In [173]:
cls, reg = out["cls"], out["reg"]
gt_preds, has_preds = gather(data['gt_preds']), gather(data['has_preds'])

reg = torch.cat([x for x in reg], 0)
gt_preds = torch.cat([x for x in gt_preds], 0)
has_preds = torch.cat([x for x in has_preds], 0)

last = has_preds.float() + 0.1 * torch.arange(config["num_preds"]).float().to(
            has_preds.device
        ) / float(config["num_preds"])

max_last, last_idcs = last.max(1)
mask = max_last >1.0

reg = reg[mask]
gt_preds = gt_preds[mask][:,:,:2]
has_preds = has_preds[mask]
last_idcs = last_idcs[mask]

row_idcs = torch.arange(len(last_idcs)).long().to(last_idcs.device)

reg.shape, gt_preds.shape, has_preds.shape, last_idcs.shape


(torch.Size([151, 6, 80, 2]),
 torch.Size([151, 80, 2]),
 torch.Size([151, 80]),
 torch.Size([151]))

In [1]:
#fde
dist_6m = []
for i in range(config["num_mods"]):

    rr = reg[row_idcs,i,last_idcs]
    gg = gt_preds[row_idcs,last_idcs]
    (rr.shape - gg.shape)

print('hello')
        

#         dd = torch.sqrt(((rr[last_idcs] - gg[last_idcs])**2).sum(1))
#         dist.append(dd.mean().item())
    
#     dist_6m.append(torch.tensor(dist).view(-1,1))

# zz = torch.cat(dist_6m,1)
# min_dist, min_idcs = zz.min(1)
# # mask = ~torch.isnan(min_dist)
# # mask
# ade = min_dist.mean().item()
# ade

NameError: name 'config' is not defined

In [171]:
#ade
dist_6m = []
for i in range(config["num_mods"]):
    reg_m = reg[:,i]
    
    dist = []
    for j in range(len(reg_m)):
    
        rr = reg_m[j]
        gg = gt_preds[j].cuda()
        hh = has_preds[j].cuda()

        dd = torch.sqrt(((rr[hh] - gg[hh])**2).sum(1))
        dist.append(dd.mean().item())
    
    dist_6m.append(torch.tensor(dist).view(-1,1))

zz = torch.cat(dist_6m,1)
min_dist, min_idcs = zz.min(1)
# mask = ~torch.isnan(min_dist)
# mask
ade = min_dist.mean().item()
ade

15.718441009521484

In [76]:
reg = torch.cat([x for x in reg], 0)
gt_preds = torch.cat([x for x in gt_preds], 0)
has_preds = torch.cat([x for x in has_preds], 0)

reg.shape, gt_preds.shape, has_preds.shape

(torch.Size([156, 6, 80, 2]), torch.Size([156, 80, 3]), torch.Size([156, 80]))

In [77]:
last = has_preds.float() + 0.1 * torch.arange(80).float().to(
            has_preds.device
        ) / float(80)

In [83]:
last.shape

torch.Size([156, 80])

In [81]:
last

tensor([[1.0000, 1.0013, 1.0025,  ..., 0.0963, 0.0975, 0.0988],
        [1.0000, 1.0013, 1.0025,  ..., 1.0963, 1.0975, 1.0987],
        [1.0000, 1.0013, 1.0025,  ..., 1.0963, 1.0975, 1.0987],
        ...,
        [1.0000, 1.0013, 1.0025,  ..., 0.0963, 0.0975, 0.0988],
        [1.0000, 1.0013, 1.0025,  ..., 0.0963, 0.0975, 0.0988],
        [1.0000, 1.0013, 1.0025,  ..., 1.0963, 1.0975, 1.0987]])

In [82]:
max_last, last_idcs = last.max(1)
max_last, last_idcs

(tensor([1.0612, 1.0987, 1.0987, 1.0987, 1.0612, 1.0987, 1.0987, 1.0987, 1.0987,
         1.0013, 1.0275, 1.0987, 1.0662, 1.0063, 1.0137, 1.0263, 1.0275, 1.0125,
         1.0987, 1.0488, 1.0362, 1.0987, 1.0987, 1.0300, 1.0462, 1.0512, 1.0013,
         1.0163, 1.0287, 1.0987, 1.0987, 1.0375, 1.0987, 1.0438, 1.0975, 1.0013,
         1.0987, 1.0638, 1.0987, 1.0925, 1.0187, 1.0037, 1.0987, 1.0987, 0.0988,
         1.0200, 1.0987, 1.0325, 1.0987, 1.0987, 1.0987, 1.0150, 1.0213, 1.0987,
         1.0987, 1.0987, 1.0987, 1.0550, 1.0000, 1.0125, 1.0063, 1.0225, 1.0175,
         1.0987, 1.0287, 1.0987, 1.0987, 1.0987, 1.0000, 1.0538, 1.0987, 1.0987,
         1.0000, 1.0987, 1.0987, 1.0987, 1.0987, 1.0987, 1.0987, 1.0987, 1.0987,
         1.0987, 1.0987, 1.0987, 1.0987, 1.0987, 1.0987, 1.0750, 1.0987, 1.0987,
         1.0987, 1.0987, 1.0987, 1.0838, 1.0987, 1.0987, 1.0987, 1.0475, 1.0987,
         1.0987, 1.0100, 1.0300, 1.0000, 1.0987, 1.0987, 1.0137, 1.0987, 1.0475,
         1.0987, 1.0987, 1.0

In [84]:
mask = max_last >1.0

In [96]:
mask.shape

torch.Size([156])

In [88]:
has_preds[mask].shape

torch.Size([151, 80])

In [89]:
row_idcs = torch.arange(len(last_idcs[mask])).long().to(last_idcs.device)

In [90]:
row_idcs

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150])

In [91]:
reg.shape

torch.Size([156, 6, 80, 2])

In [107]:
first = reg[:,0]
first.shape

torch.Size([156, 80, 2])

In [109]:
has_preds.shape

torch.Size([156, 80])

In [117]:
a = first[0]

In [118]:
b = has_preds[0]

In [119]:
a[b]

tensor([[-501.2417, -937.2134],
        [-500.6860, -940.0228],
        [-499.3760, -938.0051],
        [-499.7210, -939.3006],
        [-498.9965, -938.2569],
        [-501.8896, -936.3184],
        [-501.8258, -938.5580],
        [-499.9292, -938.9136],
        [-502.3118, -940.4543],
        [-502.0541, -937.3940],
        [-500.5880, -940.0506],
        [-501.8033, -937.0641],
        [-503.0638, -937.3415],
        [-500.6946, -935.3120],
        [-502.0236, -936.0708],
        [-499.6173, -937.1816],
        [-501.1579, -938.4704],
        [-500.0850, -938.2658],
        [-498.2083, -937.1018],
        [-501.6024, -936.8612],
        [-500.5796, -936.5544],
        [-500.5533, -938.7656],
        [-502.1509, -938.9677],
        [-500.3987, -938.3334],
        [-498.9292, -937.1847],
        [-499.7605, -941.2291],
        [-502.0973, -937.3139],
        [-502.5291, -938.0525],
        [-500.0315, -937.5228],
        [-500.6879, -936.5007],
        [-500.8666, -935.1537],
        

In [126]:
dist = []
for i in range(len(first)):
    a = first[i]
    b = has_preds[i]
    c = a[b].sum(1)
    dist.append(c.mean().item())

len(dist)

156

In [140]:
gt_preds.shape

torch.Size([156, 80, 3])

tensor([[ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [159]:

dist_6m = []
for i in range(6):
    reg_m = reg[:,i]
    has_m = has_preds[:,i]

    dist = []
    for j in range(len(reg_m)):
    
        rr = reg_m[j]
        gg = gt_preds[j][:,:2].cuda()
        hh = has_m[j].cuda()

        dd = torch.sqrt(((rr[hh] - gg[hh])**2).sum(1))
        dist.append(dd.mean().item())
    
    dist_6m.append(torch.tensor(dist).view(-1,1))

zz = torch.cat(dist_6m,1)
min_dist, min_idcs = zz.min(1)
mask = ~torch.isnan(min_dist)
mask
    

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True, False,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True, False,
         True, False,  True,  True, False, False,  True,  True,  True,  True,
        False, False,  True,  True, False,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False, False,
         True,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, 

In [93]:
reg[row_idcs,0].shape

torch.Size([151, 80, 2])

In [95]:
last_idcs.shape

torch.Size([156])

In [106]:
reg[row_idcs,0,last_idcs[mask]].sum(1)

tensor([ -1437.0986,  -1427.7796,  -1440.4379,  -1435.2506,  -1447.6534,
         -1422.9018,  -1427.7830,  -1417.6063,  -1443.2016,  -1479.2214,
         -1393.2641,  -1440.1887,  -1450.4697,  -1479.7942,  -1437.1132,
         -1440.9873,  -1454.4685,  -1444.5574,  -1412.6508,  -1458.4474,
         -1446.1409,  -1422.2487,  -1461.3977,  -1448.3022,  -1453.1905,
         -1456.3722,  -1436.1091,  -1518.7122,  -1509.2657,  -1466.4433,
         -1444.7665,  -1510.5346,  -1467.4415,  -1516.9710,  -1469.5888,
         -1475.9035,  -1469.8074,  -1463.1251,  -1471.2118,  -1469.7432,
         -1476.3749,  -1419.9907,  -1462.9936,  -1420.2063,  -1408.9870,
         -1475.7822,  -1427.5165,   -639.1908,   -688.2452,   -710.3316,
          -716.0579,   -689.0257,   -674.8936,   -734.8500,   -764.3897,
          -705.5470,   -759.7037,   -710.8603,   -744.6027,   -704.9582,
          -719.4709,   -675.8939,   -690.5137,   -761.2961,   -653.4922,
          -730.5037,   -959.1937,   -957.5865,   -9

In [27]:
post_out = post_process(out, data)

In [42]:
post_process.append(metrics, loss_out, post_out)

{'cls_loss': 0.0,
 'num_cls': 0.0,
 'reg_loss': 56878543.04597525,
 'num_reg': 9248.0,
 'preds': [array([[[[-501.24170139, -937.21344195],
           [-500.68597057, -940.02283154],
           [-499.37603405, -938.0050603 ],
           [-499.72096279, -939.30063973],
           [-498.99648312, -938.25689421],
           [-501.88964108, -936.31839953],
           [-501.82576206, -938.55797205],
           [-499.92919758, -938.91361104],
           [-502.31182697, -940.45430551],
           [-502.05405262, -937.39397583],
           [-500.58799008, -940.05062161],
           [-501.80330113, -937.06411181],
           [-503.06379345, -937.34145937],
           [-500.69461086, -935.31204854],
           [-502.02356938, -936.07083807],
           [-499.61726978, -937.18156777],
           [-501.15785626, -938.47036562],
           [-500.08502987, -938.26584921],
           [-498.20830944, -937.10184584],
           [-501.60241726, -936.86121521],
           [-500.57956913, -936.55443917],
 

In [46]:
metrics.keys()

dict_keys(['cls_loss', 'num_cls', 'reg_loss', 'num_reg', 'preds', 'gt_preds', 'has_preds'])

In [51]:
preds = np.concatenate(metrics["preds"], 0)

In [53]:
preds.shape

(4, 6, 80, 2)

In [50]:
metrics['gt_preds'][0].shape

(1, 80, 3)

In [28]:
post_out.keys()

dict_keys(['preds', 'gt_preds', 'has_preds'])

In [23]:
gather(data["gt_preds"])

torch.Size([47, 80, 3])

In [8]:
out.keys()

dict_keys(['cls', 'reg'])

In [13]:
out['reg'][0].shape

torch.Size([47, 6, 80, 2])

In [9]:
loss_out.keys()

dict_keys(['cls_loss', 'num_cls', 'reg_loss', 'num_reg', 'loss'])

In [None]:
[out["reg"][i].size() for i in range(len(out['reg']))]

In [None]:
post_out = dict()
post_out["preds"] = [x[0:1].detach().cpu().numpy() for x in out["reg"]]

In [None]:
out['cls'][0].size()

In [None]:
out['reg'][0].size()

In [None]:
for i in range(len(out["reg"])):
    1

In [None]:
orig = torch.ones(3)

In [None]:
orig

In [None]:
orig.view(1,-1)

In [None]:
orig.view(1,1,1,-1)