## DENSE-TNT

In [1]:
import math

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor

class LayerNorm(nn.Module):
    r"""
    Layer normalization.
    """

    def __init__(self, hidden_size, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias


class MLP(nn.Module):
    def __init__(self, hidden_size, out_features=None):
        super(MLP, self).__init__()
        if out_features is None:
            out_features = hidden_size
        self.linear = nn.Linear(hidden_size, out_features)
        self.layer_norm = LayerNorm(out_features)

    def forward(self, hidden_states):
        hidden_states = self.linear(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = torch.nn.functional.relu(hidden_states)
        return hidden_states


class GlobalGraph(nn.Module):
    r"""
    Global graph
    It's actually a self-attention.
    """

    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1):
        super(GlobalGraph, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.num_qkv = 1

        self.query = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
        self.key = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
        self.value = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
        

    def get_extended_attention_mask(self, attention_mask):
        """
        1 in attention_mask stands for doing attention, 0 for not doing attention.
        After this function, 1 turns to 0, 0 turns to -10000.0
        Because the -10000.0 will be fed into softmax and -10000.0 can be thought as 0 in softmax.
        """
        extended_attention_mask = attention_mask.unsqueeze(1)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask

    def transpose_for_scores(self, x):
        sz = x.size()[:-1] + (self.num_attention_heads,
                              self.attention_head_size)
        # (batch, max_vector_num, head, head_size)
        x = x.view(*sz)
        # (batch, head, max_vector_num, head_size)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask=None, mapping=None, return_scores=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = nn.functional.linear(hidden_states, self.key.weight)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(
            query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
        # print(attention_scores.shape, attention_mask.shape)
        if attention_mask is not None:
            attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask)
        # if utils.args.attention_decay and utils.second_span:
        #     attention_scores[:, 0, 0, 0] = attention_scores[:, 0, 0, 0] - self.attention_decay
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
                                  :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        if return_scores:
            assert attention_probs.shape[1] == 1
            attention_probs = torch.squeeze(attention_probs, dim=1)
            assert len(attention_probs.shape) == 3
            return context_layer, attention_probs
        return context_layer


class CrossAttention(GlobalGraph):
    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1, key_hidden_size=None,
                 query_hidden_size=None):
        super(CrossAttention, self).__init__(hidden_size, attention_head_size, num_attention_heads)
        if query_hidden_size is not None:
            self.query = nn.Linear(query_hidden_size, self.all_head_size * self.num_qkv)
        if key_hidden_size is not None:
            self.key = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)
            self.value = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)

    def forward(self, hidden_states_query, hidden_states_key=None, attention_mask=None, mapping=None,
                return_scores=False):
        mixed_query_layer = self.query(hidden_states_query)
        mixed_key_layer = self.key(hidden_states_key)
        mixed_value_layer = self.value(hidden_states_key)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(
            query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
        if attention_mask is not None:
            assert hidden_states_query.shape[1] == attention_mask.shape[1] \
                   and hidden_states_key.shape[1] == attention_mask.shape[2]
            attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
                                  :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        if return_scores:
            return context_layer, torch.squeeze(attention_probs, dim=1)
        return context_layer


class GlobalGraphRes(nn.Module):
    def __init__(self, hidden_size):
        super(GlobalGraphRes, self).__init__()
        self.global_graph = GlobalGraph(hidden_size, hidden_size // 2)
        self.global_graph2 = GlobalGraph(hidden_size, hidden_size // 2)

    def forward(self, hidden_states, attention_mask=None, mapping=None):
        # hidden_states = self.global_graph(hidden_states, attention_mask, mapping) \
        #                 + self.global_graph2(hidden_states, attention_mask, mapping)
        hidden_states = torch.cat([self.global_graph(hidden_states, attention_mask, mapping),
                                   self.global_graph2(hidden_states, attention_mask, mapping)], dim=-1)
        return hidden_states


class PointSubGraph(nn.Module):
    """
    Encode 2D goals conditioned on target agent
    """

    def __init__(self, hidden_size):
        super(PointSubGraph, self).__init__()
        self.hidden_size = hidden_size
        self.layers = nn.ModuleList([MLP(2, hidden_size // 2),
                                     MLP(hidden_size, hidden_size // 2),
                                     MLP(hidden_size, hidden_size)])

    def forward(self, hidden_states: Tensor, agent: Tensor):
        device = hidden_states.device
        predict_agent_num, point_num = hidden_states.shape[0], hidden_states.shape[1]
        hidden_size = self.hidden_size
        assert (agent.shape[0], agent.shape[1]) == (predict_agent_num, hidden_size)
        agent = agent[:, :hidden_size // 2].unsqueeze(1).expand([predict_agent_num, point_num, hidden_size // 2])
        for layer_index, layer in enumerate(self.layers):
            if layer_index == 0:
                hidden_states = layer(hidden_states)
            else:
                hidden_states = layer(torch.cat([hidden_states, agent], dim=-1))

        return hidden_states

In [2]:
import argparse
import inspect
import json
import math
import multiprocessing
import os
import pickle
import random
import subprocess
import sys
import time
import pdb
from collections import defaultdict
from multiprocessing import Process
from random import randint
from typing import Dict, List, Tuple, NamedTuple, Any, Union, Optional

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.path import Path
from matplotlib.pyplot import MultipleLocator
from torch import Tensor


def add_eval_param(param):
    if param not in args.eval_params:
        args.eval_params.append(param)


def get_name(name='', append_time=False):
    if name.endswith(time_begin):
        return name
    prefix = 'test.' if args.do_test else 'eval.' if args.do_eval and not args.do_train else ''
    prefix = 'debug.' + prefix if args.debug else prefix
    prefix = args.add_prefix + '.' + prefix if args.add_prefix is not None else prefix
    suffix = '.' + time_begin if append_time else ''
    return prefix + str(name) + suffix


eps = 1e-5

origin_point = None
origin_angle = None


def get_pad_vector(li):
    """
    Pad vector to length of args.hidden_size
    """
    assert len(li) <= args.hidden_size
    li.extend([0] * (args.hidden_size - len(li)))
    return li


def batch_list_to_batch_tensors(batch):
    return [each for each in batch]


def batch_list_to_batch_tensors_old(batch):
    batch_tensors = []
    for x in zip(*batch):
        batch_tensors.append(x)
    return batch_tensors


def round_value(v):
    return round(v / 100)


def get_dis(points: np.ndarray, point_label):
    return np.sqrt(np.square((points[:, 0] - point_label[0])) + np.square((points[:, 1] - point_label[1])))


def get_dis_point2point(point, point_=(0.0, 0.0)):
    return np.sqrt(np.square((point[0] - point_[0])) + np.square((point[1] - point_[1])))


def get_angle(x, y):
    return math.atan2(y, x)


def get_sub_matrix(traj, object_type, x=0, y=0, angle=None):
    res = []
    for i in range(0, len(traj), 2):
        if i > 0:
            vector = [traj[i - 2] - x, traj[i - 1] - y, traj[i] - x, traj[i + 1] - y]
            if angle is not None:
                vector[0], vector[1] = rotate(vector[0], vector[1], angle)
                vector[2], vector[3] = rotate(vector[2], vector[3], angle)
            res.append(vector)
    return res


def rotate(x, y, angle):
    res_x = x * math.cos(angle) - y * math.sin(angle)
    res_y = x * math.sin(angle) + y * math.cos(angle)
    return res_x, res_y


def rotate_(x, y, cos, sin):
    res_x = x * cos - y * sin
    res_y = x * sin + y * cos
    return res_x, res_y


index_file = 0

file2pred = {}


def __iter__(self):  # iterator to load data
    for __ in range(math.ceil(len(self.ex_list) / float(self.batch_size))):
        batch = []
        for __ in range(self.batch_size):
            idx = randint(0, len(self.ex_list) - 1)
            batch.append(self.__getitem__(idx))
        # To Tensor
        yield batch_list_to_batch_tensors(batch)


files_written = {}


def logging(*inputs, prob=1.0, type='1', is_json=False, affi=True, sep=' ', to_screen=False, append_time=False, as_pickle=False):
    """
    Print args into log file in a convenient style.
    """
    if to_screen:
        print(*inputs, sep=sep)
    if not random.random() <= prob or not hasattr(args, 'log_dir'):
        return

    file = os.path.join(args.log_dir, get_name(type, append_time))
    if as_pickle:
        with open(file, 'wb') as pickle_file:
            assert len(inputs) == 1
            pickle.dump(*inputs, pickle_file)
        return
    if file not in files_written:
        with open(file, "w", encoding='utf-8') as fout:
            files_written[file] = 1
    inputs = list(inputs)
    the_tensor = None
    for i, each in enumerate(inputs):
        if isinstance(each, torch.Tensor):
            # torch.Tensor(a), a must be Float tensor
            if each.is_cuda:
                each = each.cpu()
            inputs[i] = each.data.numpy()
            the_tensor = inputs[i]
    np.set_printoptions(threshold=np.inf)

    with open(file, "a", encoding='utf-8') as fout:
        if is_json:
            for each in inputs:
                print(json.dumps(each, indent=4), file=fout)
        elif affi:
            print(*tuple(inputs), file=fout, sep=sep)
            if the_tensor is not None:
                print(json.dumps(the_tensor.tolist()), file=fout)
            print(file=fout)
        else:
            print(*tuple(inputs), file=fout, sep=sep)
            print(file=fout)


mpl.use('Agg')


def larger(a, b):
    return a > b + eps


def equal(a, b):
    return True if abs(a - b) < eps else False


def get_valid_lens(matrix: np.ndarray):
    valid_lens = []
    for i in range(matrix.shape[0]):
        ok = False
        for j in range(2, matrix.shape[1], 2):
            if equal(matrix[i][j], 0) and equal(matrix[i][j + 1], 0):
                ok = True
                valid_lens.append(j)
                break

        assert ok
    return valid_lens


visualize_num = 0


def rot(verts, rad):
    rad = -rad
    verts = np.array(verts)
    rotMat = np.array([[np.cos(rad), -np.sin(rad)], [np.sin(rad), np.cos(rad)]])
    transVerts = verts.dot(rotMat)
    return transVerts


class CustomMarker(Path):
    def __init__(self, icon, az):
        import svgpath2mpl
        # if icon == "icon":
        #     verts = iconMat
        # svg = """<svg t="1624195118046" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="19465" xmlns:xlink="http://www.w3.org/1999/xlink" width="700" height="700"><defs><style type="text/css"></style></defs><path d="M812.875093 411.578027l-0.003413 0.01536-43.562667-11.671894V203.436373c0-102.367573-112.216747-185.35424-250.63936-185.35424s-250.641067 82.986667-250.641066 185.35424l-0.360107 10.238294v187.89376l-41.89696 11.226453-0.006827-0.01536c-26.519893 7.120213-44.946773 24.33536-41.166506 38.469973l47.930026-12.84096 0.003414 0.013654 35.136853-9.413974v484.061867l0.360107 7.022933c0 48.899413 112.218453 88.546987 250.641066 88.546987s250.63936-39.645867 250.63936-88.546987V427.36128l36.800854 9.86112 0.00512-0.01536 47.92832 12.84096c3.77856-14.134613-14.64832-31.34976-41.168214-38.469973zM658.152107 87.01952c13.34272-9.344 37.459627 2.075307 53.86752 25.506133s18.889387 49.998507 5.543253 59.342507c-13.34272 9.347413-37.46304-2.0736-53.86752-25.50272-16.406187-23.432533-18.891093-50.00192-5.543253-59.34592z m65.14176 231.66976l-42.922667 82.507093c-88.410453-28.182187-231.00416-29.134507-323.060053-2.84672l-41.96352-79.786666c92.73856-87.42912 315.33056-87.386453 407.94624 0.126293zM325.08416 111.418027c16.406187-23.430827 40.521387-34.850133 53.865813-25.506134 13.346133 9.344 10.862933 35.91168-5.543253 59.342507-16.402773 23.430827-40.521387 34.850133-53.865813 25.504427-13.34784-9.340587-10.86464-35.909973 5.543253-59.3408zM307.2 348.16c28.352853 17.481387 41.51808 150.084267 38.674773 276.48H307.2V348.16z m0 501.76V648.533333h37.94432c-3.43552 88.183467-14.849707 169.470293-34.530987 201.386667h-3.413333z m15.423147 21.143893l47.071573-118.454613 0.116053-0.114347c32.37888 32.37888 262.442667 30.34112 295.401814-2.618026l1.11104 0.269653 49.6896 117.439147c-43.892053 43.88864-350.266027 46.600533-393.39008 3.478186zM737.08032 846.506667h-3.413333c-19.679573-31.916373-31.095467-113.2032-34.52928-201.386667h37.942613v201.386667z m0-225.28h-38.673067c-2.843307-126.395733 10.320213-258.998613 38.673067-276.48v276.48z" fill="#1296db" p-id="19466"></path></svg>"""
        svg = "M812.875093 411.578027l-0.003413 0.01536-43.562667-11.671894V203.436373c0-102.367573-112.216747-185.35424-250.63936-185.35424s-250.641067 82.986667-250.641066 185.35424l-0.360107 10.238294v187.89376l-41.89696 11.226453-0.006827-0.01536c-26.519893 7.120213-44.946773 24.33536-41.166506 38.469973l47.930026-12.84096 0.003414 0.013654 35.136853-9.413974v484.061867l0.360107 7.022933c0 48.899413 112.218453 88.546987 250.641066 88.546987s250.63936-39.645867 250.63936-88.546987V427.36128l36.800854 9.86112 0.00512-0.01536 47.92832 12.84096c3.77856-14.134613-14.64832-31.34976-41.168214-38.469973zM658.152107 87.01952c13.34272-9.344 37.459627 2.075307 53.86752 25.506133s18.889387 49.998507 5.543253 59.342507c-13.34272 9.347413-37.46304-2.0736-53.86752-25.50272-16.406187-23.432533-18.891093-50.00192-5.543253-59.34592z m65.14176 231.66976l-42.922667 82.507093c-88.410453-28.182187-231.00416-29.134507-323.060053-2.84672l-41.96352-79.786666c92.73856-87.42912 315.33056-87.386453 407.94624 0.126293zM325.08416 111.418027c16.406187-23.430827 40.521387-34.850133 53.865813-25.506134 13.346133 9.344 10.862933 35.91168-5.543253 59.342507-16.402773 23.430827-40.521387 34.850133-53.865813 25.504427-13.34784-9.340587-10.86464-35.909973 5.543253-59.3408zM307.2 348.16c28.352853 17.481387 41.51808 150.084267 38.674773 276.48H307.2V348.16z m0 501.76V648.533333h37.94432c-3.43552 88.183467-14.849707 169.470293-34.530987 201.386667h-3.413333z m15.423147 21.143893l47.071573-118.454613 0.116053-0.114347c32.37888 32.37888 262.442667 30.34112 295.401814-2.618026l1.11104 0.269653 49.6896 117.439147c-43.892053 43.88864-350.266027 46.600533-393.39008 3.478186zM737.08032 846.506667h-3.413333c-19.679573-31.916373-31.095467-113.2032-34.52928-201.386667h37.942613v201.386667z m0-225.28h-38.673067c-2.843307-126.395733 10.320213-258.998613 38.673067-276.48v276.48z"
        # import xml.etree.ElementTree as etree
        # from six import StringIO
        # tree = etree.parse(StringIO(svg))
        # root = tree.getroot()
        az = az + math.radians(180)
        verts = svgpath2mpl.parse_path(svg).vertices
        verts[:, 0] -= (867 - 180) / 2 + 180
        verts[:, 1] -= (1008 - 18) / 2 + 18
        vertices = rot(verts, az)
        super().__init__(vertices, codes=svgpath2mpl.parse_path(svg).codes)


def visualize_goals_2D(mapping, goals_2D, scores: np.ndarray, future_frame_num, loss=None, labels: np.ndarray = None,
                       labels_is_valid=None, predict: np.ndarray = None):
    print('in visualize_goals_2D', mapping['file_name'])
    print('speed', mapping.get('seep', None))

    assert predict is not None
    predict = predict.reshape([6, future_frame_num, 2])
    assert labels.shape == (future_frame_num, 2)

    if 'eval_time' in mapping:
        assert labels.shape[0] == labels_is_valid.shape[0] == future_frame_num
        eval_time = mapping['eval_time']
        labels = labels[:eval_time]
        predict = predict[:, :eval_time, :]
        labels_is_valid = labels_is_valid[:eval_time]
        future_frame_num = eval_time

    if labels_is_valid is not None:
        assert labels.shape[0] == labels_is_valid.shape[0]
        labels = [labels[i] for i in range(future_frame_num) if labels_is_valid[i]]
        labels = np.array(labels)

    if 'time_offset' in mapping:
        time_offset = mapping['time_offset']
    else:
        time_offset = None

    assert labels is not None
    labels = labels.reshape([-1])

    fig_scale = 1.0
    marker_size_scale = 2
    # target_agent_color, target_agent_edge_color = '#0d79e7', '#bcd6ed' # blue
    target_agent_color, target_agent_edge_color = '#4bad34', '#c5dfb3'

    def get_scaled_int(a):
        return round(a * fig_scale)

    plt.cla()
    fig = plt.figure(0, figsize=(get_scaled_int(45), get_scaled_int(38)))

    if True:
        plt.xlim(-100, 100)
        plt.ylim(-30, 100)

    # plt.figure(0, dpi=300)
    cmap = plt.cm.get_cmap('Reds')
    vmin = np.log(0.0001)
    vmin = np.log(0.00001)
    scores = np.clip(scores.copy(), a_min=vmin, a_max=np.inf)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=np.max(scores)))
    plt.colorbar(sm)

    trajs = mapping['trajs']
    if args.argoverse:
        name = os.path.split(mapping['file_name'])[1].split('.')[0]
    name = name + '.FDE={}'.format(loss)

    add_end = True

    linewidth = 5

    for lane in mapping['vis_lanes']:
        lane = lane[:, :2]
        assert lane.shape == (len(lane), 2), lane.shape
        plt.plot(lane[:, 0], lane[:, 1], linestyle="-", color="black", marker=None,
                 markersize=0,
                 alpha=0.5,
                 linewidth=2,
                 zorder=0)
        # plt.fill(lane[:, 0], lane[:, 1], linestyle="-", color='#a5a5a5',
        #          linewidth=2,
        #          zorder=0)

    yaw_0 = None

    def draw_his_trajs():
        for i, traj in enumerate(trajs):
            assert isinstance(traj, np.ndarray)
            assert traj.ndim == 2 and traj.shape[1] == 2, traj.shape
            if i == 0:
                traj = np.array(traj).reshape([-1])
                t = np.zeros(len(traj) + 2)
                t[:len((traj))] = traj
                t[-2] = labels[0]
                t[-1] = labels[1]

                plt.plot(t[0::2], t[1::2], linestyle="-", color=target_agent_color, marker=None,
                         alpha=1,
                         linewidth=linewidth,
                         zorder=0)
                # if 'vis_video' in args.other_params:
                # plt.plot(0.0, 0.0, marker=CustomMarker("icon", 0), c=target_agent_color,
                #          markersize=20 * marker_size_scale, markeredgecolor=target_agent_edge_color, markeredgewidth=0.5)
            else:
                if True:
                    pass
                else:
                    if len(traj) >= 2:
                        color = "darkblue"
                        plt.plot(traj[:, 0], traj[:, 1], linestyle="-", color=color, marker=None,
                                 alpha=1,
                                 linewidth=linewidth,
                                 zorder=0)

    draw_his_trajs()

    if True:
        if goals_2D is not None:
            goals_2D = np.array(goals_2D)
            marker_size = 70
            plt.scatter(goals_2D[:, 0], goals_2D[:, 1], c=scores, cmap=cmap, norm=sm.norm, s=marker_size, alpha=0.5, marker=',')
        # s is size, default 20

        # if False:
        for each in predict:
            function2 = plt.plot(each[:, 0], each[:, 1], linestyle="-", color="darkorange", marker=None,
                                 linewidth=linewidth,
                                 zorder=0, label='Predicted trajectory')

            if add_end:
                plt.plot(each[-1, 0], each[-1, 1], markersize=15 * marker_size_scale, color="darkorange", marker="*",
                         markeredgecolor='black')

        if add_end:
            plt.plot(labels[-2], labels[-1], markersize=15 * marker_size_scale, color=target_agent_color, marker="*",
                     markeredgecolor='black')

        function1 = plt.plot(labels[0::2], labels[1::2], linestyle="-", color=target_agent_color, linewidth=linewidth,
                             zorder=0, label='Ground truth trajectory')

    functions = function1 + function2
    fun_labels = [f.get_label() for f in functions]
    plt.legend(functions, fun_labels, loc=0)

    plt.title('FDE={} file_name={}'.format(loss, mapping['file_name']))
    ax = plt.gca()
    ax.set_aspect(1)
    ax.xaxis.set_major_locator(MultipleLocator(4))
    ax.yaxis.set_major_locator(MultipleLocator(4))

    os.makedirs(os.path.join(args.log_dir, 'visualize_' + time_begin), exist_ok=True)
    plt.savefig(os.path.join(args.log_dir, 'visualize_' + time_begin,
                             get_name("visualize" + ("" if name == "" else "_" + name) + ".png")), bbox_inches='tight')
    plt.close()
    global visualize_num
    visualize_num += 1
    if visualize_num > 200 and 'vis_video' not in args.other_params and 'vis_all' not in args.other_params:
        print('press any key to continue')
        input()


def load_model(model, state_dict, prefix=''):
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(
            prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(model, prefix)

    if logger is None:
        return

    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, json.dumps(missing_keys, indent=4)))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, json.dumps(unexpected_keys, indent=4)))
    if len(error_msgs) > 0:
        print('\n'.join(error_msgs))


traj_last = None


def batch_init(mapping):
    global traj_last, origin_point, origin_angle
    batch_size = len(mapping)

    global origin_point, origin_angle
    origin_point = np.zeros([batch_size, 2])
    origin_angle = np.zeros([batch_size])
    for i in range(batch_size):
        origin_point[i][0], origin_point[i][1] = rotate(0 - mapping[i]['cent_x'], 0 - mapping[i]['cent_y'],
                                                        mapping[i]['angle'])
        origin_angle[i] = -mapping[i]['angle']

    def load_file2pred():
        global file2pred
        if len(file2pred) == 0:
            with open(args.other_params['set_predict_file2pred'], 'rb') as pickle_file:
                file2pred = pickle.load(pickle_file)


second_span = False

li_vector_num = None


def turn_traj(traj: np.ndarray, object_type='AGENT'):
    vectors = []
    traj = traj.reshape([-1, 2])
    for i, point in enumerate(traj):
        x, y = point[0], point[1]
        if i > 0:
            vector = [point_pre[0], point_pre[1], x, y, i * 0.1, object_type == 'AV',
                      object_type == 'AGENT', object_type == 'OTHERS', 0, i]
            vectors.append(get_pad_vector(vector))
        point_pre = point
    return vectors


def merge_tensors(tensors: List[torch.Tensor], device, hidden_size=None) -> Tuple[Tensor, List[int]]:
    """
    merge a list of tensors into a tensor
    """
    lengths = []
    hidden_size = args.hidden_size if hidden_size is None else hidden_size
    for tensor in tensors:
        lengths.append(tensor.shape[0] if tensor is not None else 0)
    res = torch.zeros([len(tensors), max(lengths), hidden_size], device=device)
    for i, tensor in enumerate(tensors):
        if tensor is not None:
            res[i][:tensor.shape[0]] = tensor
    return res, lengths


def de_merge_tensors(tensor: Tensor, lengths):
    return [tensor[i, :lengths[i]] for i in range(len(lengths))]


def gather_tensors(tensor: torch.Tensor, indices: List[list]):
    lengths = [len(each) for each in indices]
    assert tensor.shape[0] == len(indices)
    for each in indices:
        each.extend([0] * (tensor.shape[1] - len(each)))
    index = torch.tensor(indices, device=tensor.device)
    index = index.unsqueeze(2).expand(tensor.shape)
    tensor = torch.gather(tensor, 1, index)
    for i, length in enumerate(lengths):
        tensor[i, length:, :].fill_(0)
    # index = torch.zeros(tensor.shape, device=tensor.device, dtype=torch.int)
    return tensor, lengths


def get_closest_polygon(pred: np.ndarray, new_polygons) -> np.ndarray:
    dis = np.inf
    closest_polygon = None

    def get_dis(pred: np.ndarray, polygon):
        dis = 0
        pred = pred.reshape([30, 2])
        for point in pred:
            dis += np.min(np.abs(polygon[:, 0] - point[0]) + np.abs(polygon[:, 1] - point[1]))
        return dis

    for idx_polygon, polygon in enumerate(new_polygons):
        temp = get_dis(pred, polygon)
        if temp < dis:
            dis = temp
            closest_polygon = polygon
    return closest_polygon


NMS_LIST = [2.0, 1.7, 1.4, 2.3, 2.6] + [2.9, 3.2, 3.5, 3.8, 4.1] + [2.7, 2.8, 3.0, 3.1]

NMS_START = 6

DYNAMIC_NMS_START = 30

DYNAMIC_NMS_LIST = [3.2, 3.8, 4.8, 5.4, 6.0] + [6.6, 7.2, 7.8, 8.4, 0.0] + \
                   [2.0, 2.6, 1.5, 0.1]


def select_goals_by_NMS(mapping: Dict, goals_2D: np.ndarray, scores: np.ndarray, threshold, speed, gt_goal=None, mode_num=6):
    argsort = np.argsort(-scores)
    goals_2D = goals_2D[argsort]
    scores = scores[argsort]

    add_eval_param(f'DY_NMS={threshold}')

    speed_scale_factor = 1#utils_cython.speed_scale_factor(speed)
    threshold = threshold * speed_scale_factor

    pred_goals = []
    pred_probs = []

    def in_predict(pred_goals, point, threshold):
        return np.min(get_dis_point_2_points(point, pred_goals)) < threshold

    for i in range(len(goals_2D)):
        if len(pred_goals) > 0 and in_predict(np.array(pred_goals), goals_2D[i], threshold):
            continue
        else:
            pred_goals.append(goals_2D[i])
            pred_probs.append(scores[i])
            if len(pred_goals) == mode_num:
                break

    while len(pred_goals) < mode_num:
        i = np.random.randint(0, len(goals_2D))
        pred_goals.append(goals_2D[i])
        pred_probs.append(scores[i])

    pred_goals = np.array(pred_goals)
    pred_probs = np.array(pred_probs)

    FDE = np.inf
    if gt_goal is not None:
        for each in pred_goals:
            FDE = min(FDE, get_dis_point2point(each, gt_goal))

    mapping['pred_goals'] = pred_goals
    mapping['pred_probs'] = pred_probs


def select_goal_pairs_by_NMS(mapping: Dict, mapping_oppo: Dict, goals_4D: np.ndarray, scores_4D: np.ndarray, threshold, speed, speed_oppo,
                             mode_num=6):
    argsort = np.argsort(-scores_4D)

    goals_4D = goals_4D[argsort]
    scores_4D = scores_4D[argsort]

    def in_predict(pred_goal_pairs, goal_pair, thresholds):
        # pred_goal_pairs [..., 2, 2]
        return np.min(get_dis_point_2_points(goal_pair[0], pred_goal_pairs[:, 0, :])) < thresholds[0] \
               and np.min(get_dis_point_2_points(goal_pair[1], pred_goal_pairs[:, 1, :])) < thresholds[1]

    add_eval_param(f'DY_NMS={threshold}')

    thresholds = (threshold * utils_cython.speed_scale_factor(speed), threshold * utils_cython.speed_scale_factor(speed_oppo))

    pred_goal_pairs = []
    pred_probs = []

    for i in range(len(goals_4D)):
        if len(pred_goal_pairs) > 0 and in_predict(np.array(pred_goal_pairs), goals_4D[i].reshape((2, 2)), thresholds):
            continue
        else:
            pred_goal_pairs.append(goals_4D[i].reshape((2, 2)))
            pred_probs.append(scores_4D[i])
            if len(pred_goal_pairs) == mode_num:
                break

    while len(pred_goal_pairs) < mode_num:
        i = np.random.randint(0, len(pred_goal_pairs))
        pred_goal_pairs.append(goals_4D[i].reshape((2, 2)))
        pred_probs.append(scores_4D[i])

    pred_goal_pairs = np.array(pred_goal_pairs)
    pred_probs = np.array(pred_probs)

    mapping['pred_goals'] = pred_goal_pairs[:, 0, :]
    mapping['pred_probs'] = pred_probs
    mapping_oppo['pred_goals'] = pred_goal_pairs[:, 1, :]
    mapping_oppo['pred_probs'] = pred_probs


def get_FDE(points: np.ndarray, scores: np.ndarray, mapping, gt_goal=None, method=0, idx_in_batch=0, mode_num=6):
    points = points.copy()
    scores = scores.copy()
    polygons = mapping['polygons']

    li = sorted([(point, score) for (point, score) in zip(points, scores)], key=lambda x: x[1], reverse=True)
    points = np.array([each[0] for each in li])
    if 'scale' in mapping:
        scale = mapping['scale']
        points *= 1.0 / scale

    scores = np.array([each[1] for each in li])

    def get_hash(point):
        return round((point[0] + 500) * 100) * 1000000 + round((point[1] + 500) * 100)

    if True:
        scores = np.exp(scores)

        def get_scaled_scores(scores, sum=1.0):
            sum_cur = np.sum(scores)
            scores = scores / sum_cur * sum
            return scores

        if method == 1:
            idx = np.searchsorted(-scores, -0.001, side='right')
            scores, points = scores[:idx], points[:idx]

            def fn(a):
                return a + a ** 1.2

            scaled_scores = scores.copy()
            for i in range(len(scores)):
                scaled_scores[i] = fn(scaled_scores[i])
            scaled_scores = get_scaled_scores(scaled_scores, np.sum(scores))
            # for a, b in zip(scores, scaled_scores):
            #     print(a, b)
            scores = scaled_scores

        elif method == 2:
            idx = np.searchsorted(-scores, -0.001, side='right')
            scores, points = scores[:idx], points[:idx]

            def fn(a):
                return a + a ** 0.9

            scaled_scores = scores.copy()
            for i in range(len(scores)):
                scaled_scores[i] = fn(scaled_scores[i])
            scaled_scores = get_scaled_scores(scaled_scores, np.sum(scores))
            # for a, b in zip(scores, scaled_scores):
            #     print(a, b)
            scores = scaled_scores
        elif method == 3:
            idx = np.searchsorted(-scores, -0.001, side='right')
            scores, points = scores[:idx], points[:idx]

            def fn(a):
                return a + a ** 1.1

            scaled_scores = scores.copy()
            for i in range(len(scores)):
                scaled_scores[i] = fn(scaled_scores[i])
            scaled_scores = get_scaled_scores(scaled_scores, np.sum(scores))
            # for a, b in zip(scores, scaled_scores):
            #     print(a, b)
            scores = scaled_scores
        elif method == 4:
            idx = np.searchsorted(-scores, -0.001, side='right')
            scores, points = scores[:idx], points[:idx]

            def fn(a):
                return a + a ** 0.8

            scaled_scores = scores.copy()
            for i in range(len(scores)):
                scaled_scores[i] = fn(scaled_scores[i])
            scaled_scores = get_scaled_scores(scaled_scores, np.sum(scores))
            # for a, b in zip(scores, scaled_scores):
            #     print(a, b)
            scores = scaled_scores
        elif method == 5:
            idx = np.searchsorted(-scores, -0.001, side='right')
            scores, points = scores[:idx], points[:idx]

            def fn(a):
                return a + a ** 0.7

            scaled_scores = scores.copy()
            for i in range(len(scores)):
                scaled_scores[i] = fn(scaled_scores[i])
            scaled_scores = get_scaled_scores(scaled_scores, np.sum(scores))
            # print()
            # for a, b in zip(scores, scaled_scores):
            #     print(str(a)[:6], str(b)[:6])
            scores = scaled_scores
        elif NMS_START <= method < NMS_START + len(NMS_LIST):
            threshold = NMS_LIST[method - NMS_START]
            add_eval_param(f'NMS={threshold}')
            # print('threshold', threshold)
            predict = []
            ans_point_scores = []

            def in_predict(predict, point, threshold):
                return np.min(get_dis_point_2_points(point, predict)) < threshold

            for i in range(len(points)):
                if len(predict) > 0 and in_predict(np.array(predict), points[i], threshold):
                    continue
                else:
                    predict.append(points[i])
                    ans_point_scores.append(scores[i])
                    if len(predict) == mode_num:
                        break
            while len(predict) < mode_num:
                i = np.random.randint(0, len(points))
                predict.append(points[i])
                ans_point_scores.append(scores[i])

            idx_in_batch_2_ans_points[idx_in_batch] = np.array(predict)
            idx_in_batch_2_ans_point_scores[idx_in_batch] = np.array(ans_point_scores)
            FDE = np.inf
            if gt_goal is not None:
                for each in predict:
                    FDE = min(FDE, get_dis_point2point(each, gt_goal))
            method2FDEs[method].append(FDE)
        elif DYNAMIC_NMS_START <= method < DYNAMIC_NMS_START + len(DYNAMIC_NMS_LIST):
            threshold = DYNAMIC_NMS_LIST[method - DYNAMIC_NMS_START]
            add_eval_param(f'DY_NMS={threshold}')
            speed_scale_factor = utils_cython.speed_scale_factor(mapping['speed'])
            threshold = threshold * speed_scale_factor

            # print('threshold', threshold)
            predict = []
            ans_point_scores = []

            def in_predict(predict, point, threshold):
                return np.min(get_dis_point_2_points(point, predict)) < threshold

            for i in range(len(points)):
                if len(predict) > 0 and in_predict(np.array(predict), points[i], threshold):
                    continue
                else:
                    predict.append(points[i])
                    ans_point_scores.append(scores[i])
                    if len(predict) == mode_num:
                        break
            while len(predict) < mode_num:
                i = np.random.randint(0, len(points))
                predict.append(points[i])
                ans_point_scores.append(scores[i])

            idx_in_batch_2_ans_points[idx_in_batch] = np.array(predict)
            idx_in_batch_2_ans_point_scores[idx_in_batch] = np.array(ans_point_scores)
            FDE = np.inf
            if gt_goal is not None:
                for each in predict:
                    FDE = min(FDE, get_dis_point2point(each, gt_goal))
            method2FDEs[method].append(FDE)
            pass
        else:
            assert False

        if method < 6:
            # Note 'method > 0' in train.py
            with open(os.path.join(args.temp_file_dir, time_begin, "cpp_input" + str(idx_in_batch)), "w") as fout:
                print(len(points), file=fout)
                for point, score in zip(points, scores):
                    print(point[0], point[1], score, file=fout)


def get_subdivide_points(polygon, include_self=False, threshold=1.0, include_beside=False, return_unit_vectors=False):
    def get_dis(point_a, point_b):
        return np.sqrt((point_a[0] - point_b[0]) ** 2 + (point_a[1] - point_b[1]) ** 2)

    average_dis = 0
    for i, point in enumerate(polygon):
        if i > 0:
            average_dis += get_dis(point, point_pre)
        point_pre = point
    average_dis /= len(polygon) - 1

    points = []
    if return_unit_vectors:
        assert not include_self and not include_beside
        unit_vectors = []
    divide_num = 1
    while average_dis / divide_num > threshold:
        divide_num += 1
    for i, point in enumerate(polygon):
        if i > 0:
            for k in range(1, divide_num):
                def get_kth_point(point_a, point_b, ratio):
                    return (point_a[0] * (1 - ratio) + point_b[0] * ratio,
                            point_a[1] * (1 - ratio) + point_b[1] * ratio)

                points.append(get_kth_point(point_pre, point, k / divide_num))
                if return_unit_vectors:
                    unit_vectors.append(get_unit_vector(point_pre, point))
        if include_self or include_beside:
            points.append(point)
        point_pre = point
    if include_beside:
        points_ = []
        for i, point in enumerate(points):
            if i > 0:
                der_x = point[0] - point_pre[0]
                der_y = point[1] - point_pre[1]
                scale = 1 / math.sqrt(der_x ** 2 + der_y ** 2)
                der_x *= scale
                der_y *= scale
                der_x, der_y = rotate(der_x, der_y, math.pi / 2)
                for k in range(-2, 3):
                    if k != 0:
                        points_.append((point[0] + k * der_x, point[1] + k * der_y))
                        if i == 1:
                            points_.append((point_pre[0] + k * der_x, point_pre[1] + k * der_y))
            point_pre = point
        points.extend(points_)
    if return_unit_vectors:
        return points, unit_vectors
    return points
    # return points if not return_unit_vectors else points, unit_vectors


def get_one_subdivide_polygon(polygon):
    new_polygon = []
    for i, point in enumerate(polygon):
        if i > 0:
            new_polygon.append((polygon[i - 1] + polygon[i]) / 2)
        new_polygon.append(point)
    return new_polygon


def get_subdivide_polygons(polygon, threshold=2.0):
    if len(polygon) == 1:
        polygon = [polygon[0], polygon[0]]
    elif len(polygon) % 2 == 1:
        polygon = list(polygon)
        polygon = polygon[:len(polygon) // 2] + polygon[-(len(polygon) // 2):]
    assert_(len(polygon) >= 2)

    def get_dis(point_a, point_b):
        return np.sqrt((point_a[0] - point_b[0]) ** 2 + (point_a[1] - point_b[1]) ** 2)

    def get_average_dis(polygon):
        average_dis = 0
        for i, point in enumerate(polygon):
            if i > 0:
                average_dis += get_dis(point, point_pre)
            point_pre = point
        average_dis /= len(polygon) - 1
        return average_dis

    average_dis = get_average_dis(polygon)

    if average_dis > threshold:
        length = len(polygon)
        point_a = polygon[length // 2 - 1]
        point_b = polygon[length // 2]
        point_mid = (point_a + point_b) / 2
        polygon_a = polygon[:length // 2]
        polygon_a = get_one_subdivide_polygon(polygon_a)
        polygon_a = polygon_a + [point_mid]
        polygon_b = polygon[length // 2:]
        polygon_b = get_one_subdivide_polygon(polygon_b)
        polygon_b = [point_mid] + polygon_b
        assert_(len(polygon) == len(polygon_a))
        # print('polygon', np.array(polygon), 'polygon_a',np.array(polygon_a), average_dis, get_average_dis(polygon_a))
        return get_subdivide_polygons(polygon_a) + get_subdivide_polygons(polygon_b)
    else:
        return [polygon]


method2FDEs = defaultdict(list)


def get_neighbour_points(points, topk_ids=None, mapping=None, neighbour_dis=2):
    # grid = np.zeros([300, 300], dtype=int)
    grid = {}
    for fake_idx, point in enumerate(points):
        x, y = round(float(point[0])), round(float(point[1]))

        # not compatible argo
        for i in range(-neighbour_dis, neighbour_dis + 1):
            for j in range(-neighbour_dis, neighbour_dis + 1):
                grid[(x + i, y + j)] = 1
    points = list(grid.keys())
    return points


def get_neighbour_points_new(points, neighbour_dis=2, density=1.0):
    grid = {}

    for fake_idx, point in enumerate(points):
        x, y = round(float(point[0])), round(float(point[1]))
        if -100 <= x <= 100 and -100 <= y <= 100:
            i = x - neighbour_dis
            while i < x + neighbour_dis + eps:
                j = y - neighbour_dis
                while j < y + neighbour_dis + eps:
                    grid[(i, j)] = True
                    j += density
                i += density
    points = list(grid.keys())
    points = get_points_remove_repeated(points, density)
    return points


def get_neighbour_points_for_lanes(polygons):
    points = []
    for polygon in polygons:
        points.extend(polygon)
    return get_neighbour_points(points)


def calc_bitmap(bitmap, polygon):
    for point_idx, point in enumerate(polygon):
        if point_idx > 0:
            walk_bitmap(bitmap, point_pre, point, calc_bitmap=True)
        point_pre = point
    pass


def walk_bitmap(bitmap, point_a, point_b, calc_bitmap=False, check_bitmap=False):
    point_a = (round(float(point_a[0])) + 150, round(float(point_a[1])) + 150)
    point_b = (round(float(point_b[0])) + 150, round(float(point_b[1])) + 150)
    xs = [0, 0, 1, -1]
    ys = [1, -1, 0, 0]
    while True:
        if 0 <= point_a[0] < 300 and 0 <= point_a[1] < 300:
            if calc_bitmap:
                bitmap[point_a[0]][point_a[1]] = 1
            if check_bitmap:
                if bitmap[point_a[0]][point_a[1]]:
                    return True
        if point_a == point_b:
            break
        min_dis = np.inf
        arg_min = None
        for tx, ty in zip(xs, ys):
            x, y = point_a[0] + tx, point_a[1] + ty
            dis = np.sqrt((x - point_b[0]) ** 2 + (y - point_b[1]) ** 2)
            if dis < min_dis:
                min_dis = dis
                arg_min = (x, y)
        point_a = arg_min
    return False


def get_unit_vector(point_a, point_b):
    der_x = point_b[0] - point_a[0]
    der_y = point_b[1] - point_a[1]
    scale = 1 / math.sqrt(der_x ** 2 + der_y ** 2)
    der_x *= scale
    der_y *= scale
    return (der_x, der_y)


idx_in_batch_2_ans_points = {}
idx_in_batch_2_ans_point_scores = {}


def run_process_todo(queue, queue_res, speed=None, eval_time=None):
    id = np.random.randint(5)
    print('in run_process_todo', get_time(), id)


def run_process(queue, queue_res, args):
    id = np.random.randint(5)
    utils_cython.args = args
    objective = 'MR'
    if 'MRminFDE' in args.other_params:
        objective = 'MRminFDE'
    opti_time = float(args.other_params.get('opti_time', 10000.0))

    li = []
    while True:
        # print('a', round(time.time() - start_time, 2))
        value = queue.get()
        if value is None:
            break
        idx_in_batch, file_name, (goals_2D, scores), kwargs = value
        scores = np.exp(scores)
        if file_name == 'test_obs/data/33670.csv':
            print('aaa', len(scores), np.sum(scores), scores, goals_2D)

        if 'MRminFDE' in args.other_params:
            assert 'cnt_sample' in args.other_params
            MRratio = float(args.other_params['MRminFDE']) if args.other_params['MRminFDE'] is not True else 1.0

        start_time = time.time()

        if 'cnt_sample' in args.other_params:
            num_step = 1000
            kwargs.update(dict(
                num_step=num_step,
                cnt_sample=args.other_params['cnt_sample'],
                MRratio=MRratio,
            ))
            assert args.other_params['cnt_sample'] > 1

        results = utils_cython.get_optimal_targets(goals_2D, scores, file_name, objective, opti_time, kwargs=kwargs)

        li.append(round(time.time() - start_time, 2))

        expectation, ans_points, pred_probs = results
        queue_res.put((idx_in_batch, expectation, ans_points, pred_probs))
    pass

    print('out run_process', get_time(), id)


def select_goals_by_optimization(batch_gt_points, mapping, close=False):
    this = select_goals_by_optimization
    if not hasattr(this, 'processes'):
        # if end:
        #     return
        queue = multiprocessing.Queue(args.core_num)
        queue_res = multiprocessing.Queue()
        processes = [
            Process(target=run_process, args=(queue, queue_res, args,))
            for _ in range(args.core_num)]
        for each in processes:
            each.start()
        this.processes = processes
        this.queue = queue
        this.queue_res = queue_res

    queue = this.queue
    queue_res = this.queue_res

    if close:
        for i in range(args.core_num):
            queue.put(None)
        for each in select_goals_by_optimization.processes:
            each.join()
        return

    start_time = time.time()
    batch_size, future_frame_num, _ = batch_gt_points.shape

    batch_file_name = get_from_mapping(mapping, 'file_name')

    assert args.core_num >= 2

    run_times = 8
    for _ in range(run_times):
        for i in range(batch_size):
            kwargs = {}
            pass

            queue.put((i, batch_file_name[i], mapping[i]['goals_2D_scores'], kwargs))

    while not queue.empty():
        pass

    expectations = np.ones(batch_size) * 10000.0
    batch_ans_points = np.zeros([batch_size, 6, 2])
    batch_pred_probs = np.zeros([batch_size, 6])
    for _ in range(run_times * batch_size):
        i, expectation, ans_points, pred_probs = queue_res.get()
        if expectation < expectations[i]:
            expectations[i] = expectation
            batch_ans_points[i] = ans_points
            batch_pred_probs[i] = pred_probs

    # print('here', round(time.time() - start_time, 2))

    for i in range(batch_size):
        FDE = np.inf
        if not args.do_test:
            FDE = np.min(get_dis_point_2_points(batch_gt_points[i][-1], batch_ans_points[i]))
        method2FDEs[0].append(FDE)

        ans_points = batch_ans_points[i].copy()
        if args.argoverse:
            to_origin_coordinate(ans_points, i)

    return batch_ans_points, batch_pred_probs


def to_origin_coordinate(points, idx_in_batch, scale=None):
    for point in points:
        point[0], point[1] = rotate(point[0] - origin_point[idx_in_batch][0],
                                    point[1] - origin_point[idx_in_batch][1], origin_angle[idx_in_batch])
        if scale is not None:
            point[0] *= scale
            point[1] *= scale


def to_relative_coordinate(points, x, y, angle):
    for point in points:
        point[0], point[1] = rotate(point[0] - x, point[1] - y, angle)


def get_time():
    return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())


time_begin = get_time()


def assert_(satisfied, info=None):
    if not satisfied:
        if info is not None:
            print(info)
        print(sys._getframe().f_code.co_filename, sys._getframe().f_back.f_lineno)
    assert satisfied


def get_miss_rate(li_FDE, dis=2.0):
    return np.sum(np.array(li_FDE) > dis) / len(li_FDE) if len(li_FDE) > 0 else None


def ids_to_matrix(ids_list: List[List[int]], size, device):
    tensor = torch.zeros([len(ids_list), size], device=device)
    for idx, each_list in enumerate(ids_list):
        if len(each_list) == 0:
            continue
        tensor[idx].scatter_(0, torch.tensor(each_list, device=device), 1.0)
    return tensor


def get_max_hidden(hidden_states: Tensor, pooling_mask: Tensor):
    num_query = pooling_mask.shape[0]
    num_key = pooling_mask.shape[1]
    assert num_key == hidden_states.shape[0]
    hidden_size = hidden_states.shape[1]
    pooling_mask = (1.0 - pooling_mask) * -10000.0
    hidden_states = hidden_states.unsqueeze(0).expand([num_query, num_key, hidden_size])
    hidden_states = hidden_states + pooling_mask.unsqueeze(2)
    return torch.max(hidden_states, dim=1)[0]


return_values = None


def model_return(*inputs):
    if args.distributed_training:
        global return_values
        return_values = inputs[1:]
        return inputs[0]
    else:
        return inputs


def get_color_text(text, color='red'):
    if color == 'red':
        return "\033[31m" + text + "\033[0m"
    else:
        assert False


other_errors_dict = defaultdict(list)


def other_errors_put(error_type, error):
    other_errors_dict[error_type].append(error)


def other_errors_to_string():
    res = {}
    for each, value in other_errors_dict.items():
        res[each] = np.mean(value)
    return str(res)


def get_points_remove_repeated(points, threshold=1.0):
    grid = {}

    def get_hash_point(point):
        return round(point[0] / threshold), round(point[1] / threshold)

    def get_de_hash_point(point):
        return float(point[0] * threshold), float(point[1] * threshold)

    for each in points:
        grid[get_hash_point(each)] = True
    return [get_de_hash_point(each) for each in list(grid.keys())]


def get_dis_point_2_points(point, points):
    assert points.ndim == 2
    return np.sqrt(np.square(points[:, 0] - point[0]) + np.square(points[:, 1] - point[1]))


def get_dis_point_2_polygons(point, polygons):
    dis = np.zeros(len(polygons))
    for i, each in enumerate(polygons):
        dis[i] = np.min(np.sqrt(np.square(each[:, 0] - point[0]) + np.square(each[:, 1] - point[1])))
    # dis = np.square(polygons[:, :, 0] - point[0]) + np.square(polygons[:, :, 1] - point[1])
    # dis = np.min(dis, axis=-1)
    # dis = np.sqrt(dis)
    return dis


_zip = zip


def zip(*inputs):
    for each in inputs:
        assert len(each) == len(inputs[0])
    return _zip(*inputs)


def zip_enum(*inputs):
    for each in inputs:
        assert len(each) == len(inputs[0])
    return zip(range(len(inputs[0])), *inputs)


def point_in_points(point, points):
    points = np.array(points)
    if points.ndim != 2:
        return False
    dis = get_dis_point_2_points(point, points)
    return np.min(dis) < 1.0 + eps


def get_pseudo_label(predicts, labels, self_cost=None, kwargs={}):
    if self_cost is None:
        self_cost = np.zeros(len(predicts))
    if isinstance(labels, list):
        cost_list = []
        pseudo_label_list = []
        for each in labels:
            pseudo_label, cost, _ = \
                utils_cython.get_pseudo_label(predicts.astype(np.float32), each.astype(np.float32), self_cost.astype(np.float32), kwargs)
            pseudo_label_list.append(pseudo_label)
            cost_list.append(cost)

        argmin = np.argmin(np.array(cost_list))
        return pseudo_label_list[argmin], cost_list[argmin], None
    else:
        return utils_cython.get_pseudo_label(predicts.astype(np.float32), labels.astype(np.float32), self_cost.astype(np.float32), kwargs)


def get_file_name_int(file_name):
    return int(os.path.split(file_name)[1][:-4])


def assign(a, b, n=2):
    if n == 2:
        a[0], a[1] = b[0], b[1]
    else:
        assert False


def my_print(*args):
    print(*args)


i_epoch = None


def get_from_mapping(mapping: List[Dict], key=None):
    if key is None:
        line_context = inspect.getframeinfo(inspect.currentframe().f_back).code_context[0]
        key = line_context.split('=')[0].strip()
    return [each[key] for each in mapping]


ap_list = None


def metric_values_to_string(metric_values, metric_names, metric=None, index=None, append=False):
    if metric_values == None:
        print('metric_values is None')
        return
    lines = []
    for i, m in enumerate(
            ['min_ade', 'min_fde', 'miss_rate', 'overlap_rate', 'map']):
        if metric is None or metric == m:
            for j, n in enumerate(metric_names):
                if index is None or index == j:
                    if append and metric_values[i][j] > 0.0:
                        ap_list.append(float(metric_values[i][j]))
                    lines.append('{}/{}: {}'.format(m, n, metric_values[i][j]))
    return '\n'.join(lines)


def pool_forward(rank, queue, result_queue, run):
    while True:
        file = queue.get()
        if file is None:
            break
        result = run(*file)
        result_queue.put(result)


class Pool:
    def __init__(self, core_num, files, run):
        self.core_num = core_num
        self.queue = multiprocessing.Queue(core_num)
        self.result_queue = multiprocessing.Queue(core_num)
        self.processes = [multiprocessing.Process(target=pool_forward, args=(rank, self.queue, self.result_queue, run,)) for rank in
                          range(self.core_num)]
        self.files = files
        for each in self.processes:
            each.start()
        for file in files:
            assert file is not None
            self.queue.put(file)

    def join(self):
        results = []
        for i in range(len(self.files)):
            results.append(self.result_queue.get())

        while not self.queue.empty():
            pass

        for i in range(self.core_num):
            self.queue.put(None)

        for each in self.processes:
            each.join()

        return results


motion_metrics = None
metric_names = None

trajectory_type_2_motion_metrics = {}


def get_trajectory_upsample(inputs: np.ndarray, future_frame_num, future_test_frame_num):
    stride = future_frame_num // future_test_frame_num
    shape_prefix = list(inputs.shape[:-2])
    assert len(shape_prefix) > 0
    inputs = inputs.reshape(-1, future_test_frame_num, future_frame_num)
    outputs = np.zeros(len(inputs), future_frame_num, 2)
    outputs[:, stride - 1::stride, :] = inputs
    outputs = outputs.reshape(*shape_prefix, future_frame_num, 2)
    return outputs


def get_eval_identifier():
    eval_identifier = args.model_recover_path.split('/')[-1]
    for each in args.eval_params:
        each = str(each)
        if len(each) > 15 and '=' in each:
            each = each.split('=')[0]
        if len(each) > 15:
            each = 'long'
        eval_identifier += '.' + str(each)
    eval_identifier = get_name(eval_identifier, append_time=True)
    return eval_identifier


def get_wait5_rank(rank):
    rank = rank + 1
    return rank // 2


# def shape_equal(shape, shape_):
#     if len(shape) != len(shape_):
#         return False

class Normalizer:
    def __init__(self, x, y, yaw):
        self.x = x
        self.y = y
        self.yaw = yaw
        self.origin = rotate(0.0 - x, 0.0 - y, yaw)

    def __call__(self, points, reverse=False):
        points = np.array(points)
        if points.shape == (2,):
            points.shape = (1, 2)
        assert len(points.shape) <= 3
        if len(points.shape) == 3:
            for each in points:
                each[:] = self.__call__(each, reverse)
        else:
            assert len(points.shape) == 2
            for point in points:
                if reverse:
                    point[0], point[1] = rotate(point[0] - self.origin[0],
                                                point[1] - self.origin[1], -self.yaw)
                else:
                    point[0], point[1] = rotate(point[0] - self.x,
                                                point[1] - self.y, self.yaw)

        return points


def satisfy_one_of(conds, other_params):
    for each in conds:
        if each in other_params:
            return True
    return False


def get_static_var(obj, name, default=None, path=None):
    if not hasattr(obj, name):
        if default is not None:
            value = default
        elif path is not None:
            value = structs.load(path)
        else:
            assert False
        setattr(obj, name, value)
    return getattr(obj, name)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy()

In [3]:
from typing import Dict, List, Tuple, NamedTuple, Any

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor

#from modeling.decoder import Decoder, DecoderResCat
#from modeling.lib import MLP, GlobalGraph, LayerNorm, CrossAttention, GlobalGraphRes
#import utils
global other_params
other_params = []
other_params.append('goals_2D')
other_params.append('stage_one_dynamic')

class NewSubGraph(nn.Module):

    def __init__(self, hidden_size, depth=None):
        super(NewSubGraph, self).__init__()
        if depth is None:
            depth = 3 #args.sub_graph_depth
        self.layers = nn.ModuleList([MLP(hidden_size, hidden_size // 2) for _ in range(depth)])

        self.layer_0 = MLP(hidden_size)
        self.layers = nn.ModuleList([GlobalGraph(hidden_size, num_attention_heads=2) for _ in range(depth)])
        self.layers_2 = nn.ModuleList([LayerNorm(hidden_size) for _ in range(depth)])
        self.layers_3 = nn.ModuleList([LayerNorm(hidden_size) for _ in range(depth)])
        self.layers_4 = nn.ModuleList([GlobalGraph(hidden_size) for _ in range(depth)])
        #if 'point_level-4-3' in args.other_params:
        self.layer_0_again = MLP(hidden_size)

    def forward(self, input_list: list):
        batch_size = len(input_list)
        device = input_list[0].device
        hidden_states, lengths = merge_tensors(input_list, device)
        hidden_size = hidden_states.shape[2]
        max_vector_num = hidden_states.shape[1]

        attention_mask = torch.zeros([batch_size, max_vector_num, max_vector_num], device=device)
        hidden_states = self.layer_0(hidden_states)

        #if 'point_level-4-3' in args.other_params:
        hidden_states = self.layer_0_again(hidden_states)
        for i in range(batch_size):
            assert lengths[i] > 0
            attention_mask[i, :lengths[i], :lengths[i]].fill_(1)

        for layer_index, layer in enumerate(self.layers):
            temp = hidden_states
            # hidden_states = layer(hidden_states, attention_mask)
            # hidden_states = self.layers_2[layer_index](hidden_states)
            # hidden_states = F.relu(hidden_states) + temp
            hidden_states = layer(hidden_states, attention_mask)
            hidden_states = F.relu(hidden_states)
            hidden_states = hidden_states + temp
            hidden_states = self.layers_2[layer_index](hidden_states)

        return torch.max(hidden_states, dim=1)[0], torch.cat(de_merge_tensors(hidden_states, lengths))


class VectorNet(nn.Module):
    r"""
    VectorNet
    It has two main components, sub graph and global graph.
    Sub graph encodes a polyline as a single vector.
    """

    def __init__(self, args_: None):
        super(VectorNet, self).__init__()
        #global args
        #args = args_
        hidden_size = 128

        self.point_level_sub_graph = NewSubGraph(hidden_size)
        self.point_level_cross_attention = CrossAttention(hidden_size)

        self.global_graph = GlobalGraph(hidden_size)
        #if 'enhance_global_graph' in args.other_params:
        self.global_graph = GlobalGraphRes(hidden_size)
        #if 'laneGCN' in args.other_params:
        self.laneGCN_A2L = CrossAttention(hidden_size)
        self.laneGCN_L2L = GlobalGraphRes(hidden_size)
        self.laneGCN_L2A = CrossAttention(hidden_size)

        self.decoder = Decoder(None, self)

        #if 'complete_traj' in args.other_params:
        self.decoder.complete_traj_cross_attention = CrossAttention(hidden_size)
        self.decoder.complete_traj_decoder = DecoderResCat(hidden_size, hidden_size * 3, out_features=self.decoder.future_frame_num * 2)

    def forward_encode_sub_graph(self, mapping: List[Dict], matrix: List[np.ndarray], polyline_spans: List[List[slice]],
                                 device, batch_size) -> Tuple[List[Tensor], List[Tensor]]:
        """
        :param matrix: each value in list is vectors of all element (shape [-1, 128])
        :param polyline_spans: vectors of i_th element is matrix[polyline_spans[i]]
        :return: hidden states of all elements and hidden states of lanes
        """
        input_list_list = []
        # TODO(cyrushx): This is not used? Is it because input_list_list includes map data as well?
        # Yes, input_list_list includes map data, this will be used in the future release.
        map_input_list_list = []
        lane_states_batch = None
        for i in range(batch_size):
            input_list = []
            map_input_list = []
            map_start_polyline_idx = mapping[i]['map_start_polyline_idx']
            for j, polyline_span in enumerate(polyline_spans[i]):
                tensor = torch.tensor(matrix[i][polyline_span], device=device)
                input_list.append(tensor)
                if j >= map_start_polyline_idx:
                    map_input_list.append(tensor)

            input_list_list.append(input_list)
            map_input_list_list.append(map_input_list)

        if True:
            element_states_batch = []
            for i in range(batch_size):
                a, b = self.point_level_sub_graph(input_list_list[i])
                element_states_batch.append(a)

        #if 'stage_one' in args.other_params:
        lane_states_batch = []
        for i in range(batch_size):
            a, b = self.point_level_sub_graph(map_input_list_list[i])
            lane_states_batch.append(a)

        #if 'laneGCN' in args.other_params:
        inputs_before_laneGCN, inputs_lengths_before_laneGCN = merge_tensors(element_states_batch, device=device)
        for i in range(batch_size):
            map_start_polyline_idx = mapping[i]['map_start_polyline_idx']
            agents = element_states_batch[i][:map_start_polyline_idx]
            lanes = element_states_batch[i][map_start_polyline_idx:]
            #if 'laneGCN-4' in args.other_params:
            lanes = lanes + self.laneGCN_A2L(lanes.unsqueeze(0), torch.cat([lanes, agents[0:1]]).unsqueeze(0)).squeeze(0)
            #else:
            #    lanes = lanes + self.laneGCN_A2L(lanes.unsqueeze(0), agents.unsqueeze(0)).squeeze(0)
            #    lanes = lanes + self.laneGCN_L2L(lanes.unsqueeze(0)).squeeze(0)
            #    agents = agents + self.laneGCN_L2A(agents.unsqueeze(0), lanes.unsqueeze(0)).squeeze(0)
            element_states_batch[i] = torch.cat([agents, lanes])

        return element_states_batch, lane_states_batch

    # @profile
    def forward(self, mapping: List[Dict], device):
        import time
        global starttime
        starttime = time.time()

        matrix = get_from_mapping(mapping, 'matrix')
        # TODO(cyrushx): Can you explain the structure of polyline spans?
        # vectors of i_th element is matrix[polyline_spans[i]]
        polyline_spans = get_from_mapping(mapping, 'polyline_spans')

        batch_size = len(matrix)
        # for i in range(batch_size):
        # polyline_spans[i] = [slice(polyline_span[0], polyline_span[1]) for polyline_span in polyline_spans[i]]

        if True:#args.argoverse:
            batch_init(mapping)

        element_states_batch, lane_states_batch = self.forward_encode_sub_graph(mapping, matrix, polyline_spans, device, batch_size)

        inputs, inputs_lengths = merge_tensors(element_states_batch, device=device)
        max_poly_num = max(inputs_lengths)
        attention_mask = torch.zeros([batch_size, max_poly_num, max_poly_num], device=device)
        for i, length in enumerate(inputs_lengths):
            attention_mask[i][:length][:length].fill_(1)

        hidden_states = self.global_graph(inputs, attention_mask, mapping)

        #utils.logging('time3', round(time.time() - starttime, 2), 'secs')

        return self.decoder(mapping, batch_size, lane_states_batch, inputs, inputs_lengths, hidden_states, device)

In [4]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, Tensor

#import structs
#import utils_cython
#from modeling.lib import PointSubGraph, GlobalGraphRes, CrossAttention, GlobalGraph, MLP
#import utils


class DecoderRes(nn.Module):
    def __init__(self, hidden_size, out_features=60):
        super(DecoderRes, self).__init__()
        self.mlp = MLP(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, out_features)

    def forward(self, hidden_states):
        hidden_states = hidden_states + self.mlp(hidden_states)
        hidden_states = self.fc(hidden_states)
        return hidden_states


class DecoderResCat(nn.Module):
    def __init__(self, hidden_size, in_features, out_features=60):
        super(DecoderResCat, self).__init__()
        self.mlp = MLP(in_features, hidden_size)
        self.fc = nn.Linear(hidden_size + in_features, out_features)

    def forward(self, hidden_states):
        hidden_states = torch.cat([hidden_states, self.mlp(hidden_states)], dim=-1)
        hidden_states = self.fc(hidden_states)
        return hidden_states



class Decoder(nn.Module):

    def __init__(self, _args:None, vectornet):
        super(Decoder, self).__init__()
        #global args
        #args = args_
        hidden_size = 128 
        self.future_frame_num = 30 
        self.mode_num = 6 

        self.decoder = DecoderRes(hidden_size, out_features=2)

        if 'variety_loss' in other_params:
            self.variety_loss_decoder = DecoderResCat(hidden_size, hidden_size, out_features=6 * self.future_frame_num * 2)

            if 'variety_loss-prob' in other_params:
                self.variety_loss_decoder = DecoderResCat(hidden_size, hidden_size, out_features=6 * self.future_frame_num * 2 + 6)
        elif 'goals_2D' in other_params:
            #self.decoder = DecoderResCat(hidden_size, hidden_size, out_features=self.future_frame_num * 2)
            self.goals_2D_mlps = nn.Sequential(
                MLP(2, hidden_size),
                MLP(hidden_size),
                MLP(hidden_size)
            )
            #self.goals_2D_decoder = DecoderRes(hidden_size * 3, out_features=1)
            self.goals_2D_decoder = DecoderResCat(hidden_size, hidden_size * 3, out_features=1)
            self.goals_2D_cross_attention = CrossAttention(hidden_size)
            #if 'point_sub_graph' in other_params:
            self.goals_2D_point_sub_graph = PointSubGraph(hidden_size)
        
        #if 'stage_one' in other_params:
        self.stage_one_cross_attention = CrossAttention(hidden_size)
        self.stage_one_decoder = DecoderResCat(hidden_size, hidden_size * 3, out_features=1)
        self.stage_one_goals_2D_decoder = DecoderResCat(hidden_size, hidden_size * 4, out_features=1)

        if 'set_predict' in other_params:
            if True:
                if 'set_predict-train_recover' in other_params:
                    model_recover = torch.load(other_params['set_predict-train_recover'])
                else:
                    model_recover = torch.load(model_recover_path)
                vectornet.decoder = self
                utils.load_model(vectornet, model_recover)
                # self must be vectornet
                for p in vectornet.parameters():
                    p.requires_grad = False

            self.set_predict_point_feature = nn.Sequential(MLP(3, hidden_size), MLP(hidden_size, hidden_size))

            self.set_predict_encoders = nn.ModuleList(
                [GlobalGraphRes(hidden_size) for _ in range(other_params['set_predict'])])

            self.set_predict_decoders = nn.ModuleList(
                [DecoderResCat(hidden_size, hidden_size * 2, out_features=13) for _ in range(other_params['set_predict'])])

    def goals_2D_per_example_stage_one(self, i, mapping, lane_states_batch, inputs, inputs_lengths,
                                       hidden_states, device, loss):
        def get_stage_one_scores():
            stage_one_hidden = lane_states_batch[i]
            stage_one_hidden_attention = self.stage_one_cross_attention(
                stage_one_hidden.unsqueeze(0), inputs[i][:inputs_lengths[i]].unsqueeze(0)).squeeze(0)
            stage_one_scores = self.stage_one_decoder(torch.cat([hidden_states[i, 0, :].unsqueeze(0).expand(
                stage_one_hidden.shape), stage_one_hidden, stage_one_hidden_attention], dim=-1))
            stage_one_scores = stage_one_scores.squeeze(-1)
            stage_one_scores = F.log_softmax(stage_one_scores, dim=-1)
            return stage_one_scores

        stage_one_scores = get_stage_one_scores()
        assert len(stage_one_scores) == len(mapping[i]['polygons'])
        mapping[i]['stage_one_scores'] = stage_one_scores
        # print('stage_one_scores', stage_one_scores.requires_grad)
        loss[i] += F.nll_loss(stage_one_scores.unsqueeze(0),
                              torch.tensor([mapping[i]['stage_one_label']], device=device))
        # print('stage_one_scores-2', loss[i].requires_grad)
        if 'stage_one_dynamic' in other_params:
            _, stage_one_topk_ids = torch.topk(stage_one_scores, k=len(stage_one_scores))
            threshold = float(other_params['stage_one_dynamic'])
            sum = 0.0
            for idx, each in enumerate(torch.exp(stage_one_scores[stage_one_topk_ids])):
                sum += each
                if sum > threshold:
                    stage_one_topk_ids = stage_one_topk_ids[:idx + 1]
                    break
            utils.other_errors_put('stage_one_k', len(stage_one_topk_ids))
        else:
            _, stage_one_topk_ids = torch.topk(stage_one_scores, k=min(stage_one_K, len(stage_one_scores)))

        if mapping[i]['stage_one_label'] in stage_one_topk_ids.tolist():
            utils.other_errors_put('stage_one_recall', 1.0)
        else:
            utils.other_errors_put('stage_one_recall', 0.0)

        stage_one_topk = lane_states_batch[i][stage_one_topk_ids]
        mapping[i]['stage_one_topk'] = stage_one_topk

        return stage_one_topk_ids

    def goals_2D_per_example_lazy_points(self, i, goals_2D, mapping, labels, device, scores,
                                         get_scores_inputs, stage_one_topk_ids=None, gt_points=None):
        #if args.argoverse:
        k = 150
        #else:
        #    k = 40
        _, topk_ids = torch.topk(scores, k=min(k, len(scores)))
        topk_ids = topk_ids.tolist()

        goals_2D_new = utils.get_neighbour_points(goals_2D[topk_ids], topk_ids=topk_ids, mapping=mapping[i])

        goals_2D_new = torch.cat([torch.tensor(goals_2D_new, device=device, dtype=torch.float),
                                  torch.tensor(goals_2D, device=device, dtype=torch.float)], dim=0)

        old_vector_num = len(goals_2D)

        goals_2D = np.array(goals_2D_new.tolist())
        # print('len', len(goals_2D))

        scores = self.get_scores(goals_2D_new, *get_scores_inputs)

        index = torch.argmax(scores).item()
        point = np.array(goals_2D_new[index].tolist())

        if not args.do_test:
            label = np.array(labels[i]).reshape([self.future_frame_num, 2])
            final_idx = mapping[i].get('final_idx', -1)
            mapping[i]['goals_2D_labels'] = np.argmin(utils.get_dis(goals_2D, label[final_idx]))

        return scores, point, goals_2D

    def goals_2D_per_example_calc_loss(self, i: int, goals_2D: np.ndarray, mapping: List[Dict], inputs: Tensor,
                                       inputs_lengths: List[int], hidden_states: Tensor, device, loss: Tensor,
                                       DE: np.ndarray, gt_points: np.ndarray, scores: Tensor, highest_goal: np.ndarray,
                                       labels_is_valid: List[np.ndarray]):
        """
        Calculate loss for a training example
        """
        final_idx = mapping[i].get('final_idx', -1)
        gt_goal = gt_points[final_idx]
        DE[i][final_idx] = np.sqrt((highest_goal[0] - gt_points[final_idx][0]) ** 2 + (highest_goal[1] - gt_points[final_idx][1]) ** 2)
        #if 'complete_traj' in args.other_params:
        target_feature = self.goals_2D_mlps(torch.tensor(gt_points[final_idx], dtype=torch.float, device=device))
        pass
        if True:
            target_feature.detach_()
            hidden_attention = self.complete_traj_cross_attention(
                target_feature.unsqueeze(0).unsqueeze(0), inputs[i][:inputs_lengths[i]].detach().unsqueeze(0)).squeeze(
                0).squeeze(0)
            predict_traj = self.complete_traj_decoder(
                torch.cat([hidden_states[i, 0, :].detach(), target_feature, hidden_attention], dim=-1)).view(
                [self.future_frame_num, 2])
        loss[i] += (F.smooth_l1_loss(predict_traj, torch.tensor(gt_points, dtype=torch.float, device=device), reduction='none') * \
                    torch.tensor(labels_is_valid[i], dtype=torch.float, device=device).view(self.future_frame_num, 1)).mean()

        loss[i] += F.nll_loss(scores.unsqueeze(0),
                              torch.tensor([mapping[i]['goals_2D_labels']], device=device))

    def goals_2D_per_example(self, i: int, goals_2D: np.ndarray, mapping: List[Dict], lane_states_batch: List[Tensor],
                             inputs: Tensor, inputs_lengths: List[int], hidden_states: Tensor, labels: List[np.ndarray],
                             labels_is_valid: List[np.ndarray], device, loss: Tensor, DE: np.ndarray):
        """
        :param i: example index in batch
        :param goals_2D: candidate goals sampled from map (shape ['goal num', 2])
        :param lane_states_batch: each value in list is hidden states of lanes (value shape ['lane num', hidden_size])
        :param inputs: hidden states of all elements before encoding by global graph (shape [batch_size, 'element num', hidden_size])
        :param inputs_lengths: valid element number of each example
        :param hidden_states: hidden states of all elements after encoding by global graph (shape [batch_size, -1, hidden_size])
        :param loss: (shape [batch_size])
        :param DE: displacement error (shape [batch_size, self.future_frame_num])
        """
        if args.do_train:
            final_idx = mapping[i].get('final_idx', -1)
            assert labels_is_valid[i][final_idx]

        gt_points = labels[i].reshape([self.future_frame_num, 2])

        stage_one_topk_ids = None
        if 'stage_one' in args.other_params:
            stage_one_topk_ids = self.goals_2D_per_example_stage_one(i, mapping, lane_states_batch, inputs, inputs_lengths,
                                                                     hidden_states, device, loss)

        goals_2D_tensor = torch.tensor(goals_2D, device=device, dtype=torch.float)
        get_scores_inputs = (inputs, hidden_states, inputs_lengths, i, mapping, device)

        scores = self.get_scores(goals_2D_tensor, *get_scores_inputs)
        index = torch.argmax(scores).item()
        highest_goal = goals_2D[index]

        #if 'lazy_points' in args.other_params:
        scores, highest_goal, goals_2D = \
            self.goals_2D_per_example_lazy_points(i, goals_2D, mapping, labels, device, scores,
                                                  get_scores_inputs, stage_one_topk_ids, gt_points)
        index = None

        if args.do_train:
            self.goals_2D_per_example_calc_loss(i, goals_2D, mapping, inputs, inputs_lengths,
                                                hidden_states, device, loss, DE, gt_points, scores, highest_goal, labels_is_valid)

        if args.visualize:
            mapping[i]['vis.goals_2D'] = goals_2D
            mapping[i]['vis.scores'] = np.array(scores.tolist())
            mapping[i]['vis.labels'] = gt_points
            mapping[i]['vis.labels_is_valid'] = labels_is_valid[i]

        if 'set_predict' in args.other_params:
            self.run_set_predict(goals_2D, scores, mapping, device, loss, i)
            if args.visualize:
                set_predict_ans_points = mapping[i]['set_predict_ans_points']
                predict_trajs = np.zeros((6, self.future_frame_num, 2))
                predict_trajs[:, -1, :] = set_predict_ans_points

        else:
            if args.do_eval:
                if args.nms_threshold is not None:
                    utils.select_goals_by_NMS(mapping[i], goals_2D, np.array(scores.tolist()), args.nms_threshold, mapping[i]['speed'])
                elif 'optimization' in args.other_params:
                    mapping[i]['goals_2D_scores'] = goals_2D.astype(np.float32), np.array(scores.tolist(), dtype=np.float32)
                else:
                    assert False

    def goals_2D_eval(self, batch_size, mapping, labels, hidden_states, inputs, inputs_lengths, device):
        if 'set_predict' in args.other_params:
            pred_goals_batch = [mapping[i]['set_predict_ans_points'] for i in range(batch_size)]
            pred_probs_batch = np.zeros((batch_size, 6))
        elif 'optimization' in args.other_params:
            pred_goals_batch, pred_probs_batch = utils.select_goals_by_optimization(
                np.array(labels).reshape([batch_size, self.future_frame_num, 2]), mapping)
        elif args.nms_threshold is not None:
            pred_goals_batch = [mapping[i]['pred_goals'] for i in range(batch_size)]
            pred_probs_batch = [mapping[i]['pred_probs'] for i in range(batch_size)]
        else:
            assert False

        pred_goals_batch = np.array(pred_goals_batch)
        pred_probs_batch = np.array(pred_probs_batch)
        assert pred_goals_batch.shape == (batch_size, self.mode_num, 2)
        assert pred_probs_batch.shape == (batch_size, self.mode_num)

        if 'complete_traj' in args.other_params:
            pred_trajs_batch = []
            for i in range(batch_size):
                targets_feature = self.goals_2D_mlps(torch.tensor(pred_goals_batch[i], dtype=torch.float, device=device))
                hidden_attention = self.complete_traj_cross_attention(
                    targets_feature.unsqueeze(0), inputs[i][:inputs_lengths[i]].unsqueeze(0)).squeeze(0)
                predict_trajs = self.complete_traj_decoder(
                    torch.cat([hidden_states[i, 0, :].unsqueeze(0).expand(len(targets_feature), -1), targets_feature,
                               hidden_attention], dim=-1)).view([self.mode_num, self.future_frame_num, 2])
                predict_trajs = np.array(predict_trajs.tolist())
                final_idx = mapping[i].get('final_idx', -1)
                predict_trajs[:, final_idx, :] = pred_goals_batch[i]
                mapping[i]['vis.predict_trajs'] = predict_trajs.copy()

                if args.argoverse:
                    for each in predict_trajs:
                        utils.to_origin_coordinate(each, i)
                pred_trajs_batch.append(predict_trajs)
            pred_trajs_batch = np.array(pred_trajs_batch)
        else:
            pass
        if args.visualize:
            for i in range(batch_size):
                utils.visualize_goals_2D(mapping[i], mapping[i]['vis.goals_2D'], mapping[i]['vis.scores'], self.future_frame_num,
                                         labels=mapping[i]['vis.labels'],
                                         labels_is_valid=mapping[i]['vis.labels_is_valid'],
                                         predict=mapping[i]['vis.predict_trajs'])

        return pred_trajs_batch, pred_probs_batch, None

    def variety_loss(self, mapping: List[Dict], hidden_states: Tensor, batch_size, inputs: Tensor,
                     inputs_lengths: List[int], labels_is_valid: List[np.ndarray], loss: Tensor,
                     DE: np.ndarray, device, labels: List[np.ndarray]):
        """
        :param hidden_states: hidden states of all elements after encoding by global graph (shape [batch_size, -1, hidden_size])
        :param inputs: hidden states of all elements before encoding by global graph (shape [batch_size, 'element num', hidden_size])
        :param inputs_lengths: valid element number of each example
        :param DE: displacement error (shape [batch_size, self.future_frame_num])
        """
        outputs = self.variety_loss_decoder(hidden_states[:, 0, :])
        pred_probs = None
        if 'variety_loss-prob' in args.other_params:
            pred_probs = F.log_softmax(outputs[:, -6:], dim=-1)
            outputs = outputs[:, :-6].view([batch_size, 6, self.future_frame_num, 2])
        else:
            outputs = outputs.view([batch_size, 6, self.future_frame_num, 2])

        for i in range(batch_size):
            if args.do_train:
                assert labels_is_valid[i][-1]
            gt_points = np.array(labels[i]).reshape([self.future_frame_num, 2])
            argmin = np.argmin(utils.get_dis_point_2_points(gt_points[-1], np.array(outputs[i, :, -1, :].tolist())))

            loss_ = F.smooth_l1_loss(outputs[i, argmin],
                                     torch.tensor(gt_points, device=device, dtype=torch.float), reduction='none')
            loss_ = loss_ * torch.tensor(labels_is_valid[i], device=device, dtype=torch.float).view(self.future_frame_num, 1)
            if labels_is_valid[i].sum() > utils.eps:
                loss[i] += loss_.sum() / labels_is_valid[i].sum()

            if 'variety_loss-prob' in args.other_params:
                loss[i] += F.nll_loss(pred_probs[i].unsqueeze(0), torch.tensor([argmin], device=device))
        if args.do_eval:
            outputs = np.array(outputs.tolist())
            pred_probs = np.array(pred_probs.tolist(), dtype=np.float32) if pred_probs is not None else pred_probs
            for i in range(batch_size):
                for each in outputs[i]:
                    utils.to_origin_coordinate(each, i)

            return outputs, pred_probs, None
        return loss.mean(), DE, None

    def forward(self, mapping: List[Dict], batch_size, lane_states_batch: List[Tensor], inputs: Tensor,
                inputs_lengths: List[int], hidden_states: Tensor, device):
        """
        :param lane_states_batch: each value in list is hidden states of lanes (value shape ['lane num', hidden_size])
        :param inputs: hidden states of all elements before encoding by global graph (shape [batch_size, 'element num', hidden_size])
        :param inputs_lengths: valid element number of each example
        :param hidden_states: hidden states of all elements after encoding by global graph (shape [batch_size, 'element num', hidden_size])
        """
        labels = utils.get_from_mapping(mapping, 'labels')
        labels_is_valid = utils.get_from_mapping(mapping, 'labels_is_valid')
        loss = torch.zeros(batch_size, device=device)
        DE = np.zeros([batch_size, self.future_frame_num])

        if 'variety_loss' in args.other_params:
            return self.variety_loss(mapping, hidden_states, batch_size, inputs, inputs_lengths, labels_is_valid, loss, DE, device, labels)
        elif 'goals_2D' in args.other_params:
            for i in range(batch_size):
                goals_2D = mapping[i]['goals_2D']

                self.goals_2D_per_example(i, goals_2D, mapping, lane_states_batch, inputs, inputs_lengths,
                                          hidden_states, labels, labels_is_valid, device, loss, DE)

            if 'set_predict' in args.other_params:
                pass
                # if args.do_eval:
                #     pred_trajs_batch = np.zeros([batch_size, 6, self.future_frame_num, 2])
                #     for i in range(batch_size):
                #         if 'set_predict_trajs' in mapping[i]:
                #             pred_trajs_batch[i] = mapping[i]['set_predict_trajs']
                #             for each in pred_trajs_batch[i]:
                #                 utils.to_origin_coordinate(each, i)
                #     return pred_trajs_batch

            if args.do_eval:
                return self.goals_2D_eval(batch_size, mapping, labels, hidden_states, inputs, inputs_lengths, device)
            else:
                if args.visualize:
                    for i in range(batch_size):
                        predict = np.zeros((self.mode_num, self.future_frame_num, 2))
                        utils.visualize_goals_2D(mapping[i], mapping[i]['vis.goals_2D'], mapping[i]['vis.scores'],
                                                 self.future_frame_num,
                                                 labels=mapping[i]['vis.labels'],
                                                 labels_is_valid=mapping[i]['vis.labels_is_valid'],
                                                 predict=predict)
                return loss.mean(), DE, None
        else:
            assert False

    def get_scores(self, goals_2D_tensor: Tensor, inputs, hidden_states, inputs_lengths, i, mapping, device):
        """
        :param goals_2D_tensor: candidate goals sampled from map (shape ['goal num', 2])
        :return: log scores of goals (shape ['goal num'])
        """
        if 'point_sub_graph' in args.other_params:
            goals_2D_hidden = self.goals_2D_point_sub_graph(goals_2D_tensor.unsqueeze(0), hidden_states[i, 0:1, :]).squeeze(0)
        else:
            goals_2D_hidden = self.goals_2D_mlps(goals_2D_tensor)

        goals_2D_hidden_attention = self.goals_2D_cross_attention(
            goals_2D_hidden.unsqueeze(0), inputs[i][:inputs_lengths[i]].unsqueeze(0)).squeeze(0)

        if 'stage_one' in args.other_params:
            stage_one_topk = mapping[i]['stage_one_topk']
            stage_one_scores = mapping[i]['stage_one_scores']
            stage_one_topk_here = stage_one_topk
            stage_one_goals_2D_hidden_attention = self.goals_2D_cross_attention(
                goals_2D_hidden.unsqueeze(0), stage_one_topk_here.unsqueeze(0)).squeeze(0)
            li = [hidden_states[i, 0, :].unsqueeze(0).expand(goals_2D_hidden.shape),
                  goals_2D_hidden, goals_2D_hidden_attention, stage_one_goals_2D_hidden_attention]

            scores = self.stage_one_goals_2D_decoder(torch.cat(li, dim=-1))
        else:
            scores = self.goals_2D_decoder(torch.cat([hidden_states[i, 0, :].unsqueeze(0).expand(
                goals_2D_hidden.shape), goals_2D_hidden, goals_2D_hidden_attention], dim=-1))

        scores = scores.squeeze(-1)
        scores = F.log_softmax(scores, dim=-1)
        return scores

    def run_set_predict(self, goals_2D, scores, mapping, device, loss, i):
        gt_points = mapping[i]['labels'].reshape((self.future_frame_num, 2))

        if args.argoverse:
            if 'set_predict-topk' in args.other_params:
                topk_num = args.other_params['set_predict-topk']

                if topk_num == 0:
                    topk_num = torch.sum(scores > np.log(0.00001)).item()

                _, topk_ids = torch.topk(scores, k=min(topk_num, len(scores)))
                goals_2D = goals_2D[topk_ids.cpu().numpy()]
                scores = scores[topk_ids]

        scores_positive_np = np.exp(np.array(scores.tolist(), dtype=np.float32))
        goals_2D = goals_2D.astype(np.float32)

        max_point_idx = torch.argmax(scores)
        vectors_3D = torch.cat([torch.tensor(goals_2D, device=device, dtype=torch.float), scores.unsqueeze(1)], dim=-1)
        vectors_3D = torch.tensor(vectors_3D.tolist(), device=device, dtype=torch.float)

        vectors_3D[:, 0] -= goals_2D[max_point_idx, 0]
        vectors_3D[:, 1] -= goals_2D[max_point_idx, 1]

        points_feature = self.set_predict_point_feature(vectors_3D)
        costs = np.zeros(args.other_params['set_predict'])
        pseudo_labels = []
        predicts = []

        set_predict_trajs_list = []
        group_scores = torch.zeros([len(self.set_predict_encoders)], device=device)

        start_time = utils.time.time()

        if True:
            for k, (encoder, decoder) in enumerate(zip(self.set_predict_encoders, self.set_predict_decoders)):
                if 'set_predict-one_encoder' in args.other_params:
                    encoder = self.set_predict_encoders[0]

                if True:
                    if 'set_predict-one_encoder' in args.other_params and k > 0:
                        pass
                    else:
                        encoding = encoder(points_feature.unsqueeze(0)).squeeze(0)

                    decoding = decoder(torch.cat([torch.max(encoding, dim=0)[0], torch.mean(encoding, dim=0)], dim=-1)).view([13])
                    group_scores[k] = decoding[0]
                    predict = decoding[1:].view([6, 2])

                    predict[:, 0] += goals_2D[max_point_idx, 0]
                    predict[:, 1] += goals_2D[max_point_idx, 1]

                predicts.append(predict)

                if args.do_eval:
                    pass
                else:
                    selected_points = np.array(predict.tolist(), dtype=np.float32)
                    temp = None
                    assert goals_2D.dtype == np.float32, goals_2D.dtype
                    kwargs = None
                    if 'set_predict-MRratio' in args.other_params:
                        kwargs = {}
                        kwargs['set_predict-MRratio'] = args.other_params['set_predict-MRratio']
                    costs[k] = utils_cython.set_predict_get_value(goals_2D, scores_positive_np, selected_points, kwargs=kwargs)

                    pseudo_labels.append(temp)

        argmin = torch.argmax(group_scores).item()

        if args.do_train:
            utils.other_errors_put('set_hungary', np.min(costs))
            group_scores = F.log_softmax(group_scores, dim=-1)
            min_cost_idx = np.argmin(costs)
            loss[i] = 0

            if True:
                selected_points = np.array(predicts[min_cost_idx].tolist(), dtype=np.float32)
                kwargs = None
                if 'set_predict-MRratio' in args.other_params:
                    kwargs = {}
                    kwargs['set_predict-MRratio'] = args.other_params['set_predict-MRratio']
                _, dynamic_label = utils_cython.set_predict_next_step(goals_2D, scores_positive_np, selected_points,
                                                                      lr=args.set_predict_lr, kwargs=kwargs)
                # loss[i] += 2.0 / globals.set_predict_lr * \
                #            F.l1_loss(predicts[min_cost_idx], torch.tensor(dynamic_label, device=device, dtype=torch.float))
                loss[i] += 2.0 * F.l1_loss(predicts[min_cost_idx], torch.tensor(dynamic_label, device=device, dtype=torch.float))

            loss[i] += F.nll_loss(group_scores.unsqueeze(0), torch.tensor([min_cost_idx], device=device))

            t = np.array(predicts[min_cost_idx].tolist())

            utils.other_errors_put('set_MR_mincost', np.min(utils.get_dis_point_2_points(gt_points[-1], t)) > 2.0)
            utils.other_errors_put('set_minFDE_mincost', np.min(utils.get_dis_point_2_points(gt_points[-1], t)))

        predict = np.array(predicts[argmin].tolist())

        set_predict_ans_points = predict.copy()
        li = []
        for point in set_predict_ans_points:
            li.append((point, scores[np.argmin(utils.get_dis_point_2_points(point, goals_2D))]))
        li = sorted(li, key=lambda x: -x[1])
        set_predict_ans_points = np.array([each[0] for each in li])
        mapping[i]['set_predict_ans_points'] = set_predict_ans_points

        if args.argoverse:
            utils.other_errors_put('set_MR_pred', np.min(utils.get_dis_point_2_points(gt_points[-1], predict)) > 2.0)
            utils.other_errors_put('set_minFDE_pred', np.min(utils.get_dis_point_2_points(gt_points[-1], predict)))


In [5]:
 model_densetnt = VectorNet(None)

## LANE-GCN

In [6]:
import numpy as np
from fractions import gcd
from numbers import Number

import torch
from torch import nn
from torch.nn import functional as F


# Conv layer with norm (gn or bn) and relu. 
class Conv(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Conv, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.conv = nn.Conv2d(n_in, n_out, kernel_size=kernel_size, padding=(int(kernel_size) - 1) // 2, stride=stride, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm2d(n_out)
        else:
            exit('SyncBN has not been added!')

        self.relu = nn.ReLU(inplace=True)
        self.act = act    

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out


class Conv1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Conv1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.conv = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, padding=(int(kernel_size) - 1) // 2, stride=stride, bias=False)

        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out


class Linear(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32, act=True):
        super(Linear, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear = nn.Linear(n_in, n_out, bias=False)
        
        if norm == 'GN':
            self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')
        
        self.relu = nn.ReLU(inplace=True)
        self.act = act

    def forward(self, x):
        out = self.linear(x)
        out = self.norm(out)
        if self.act:
            out = self.relu(out)
        return out


# Post residual layer
class PostRes(nn.Module):
    def __init__(self, n_in, n_out, stride=1, norm='GN', ng=32, act=True):
        super(PostRes, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])
        
        self.conv1 = nn.Conv2d(n_in, n_out, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(n_out, n_out, kernel_size=3, padding=1, bias=False)
        self.relu = nn.ReLU(inplace = True)
        
        # All use name bn1 and bn2 to load imagenet pretrained weights
        if norm == 'GN':
            self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.bn1 = nn.BatchNorm2d(n_out)
            self.bn2 = nn.BatchNorm2d(n_out)
        else:
            exit('SyncBN has not been added!')

        if stride != 1 or n_out != n_in:
            if norm == 'GN':
                self.downsample = nn.Sequential(
                        nn.Conv2d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.downsample = nn.Sequential(
                        nn.Conv2d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm2d(n_out))
            else:
                exit('SyncBN has not been added!')    
        else:
            self.downsample = None

        self.act = act

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            x = self.downsample(x)

        out += x
        if self.act:
            out = self.relu(out)
        return out


class Res1d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size=3, stride=1, norm='GN', ng=32, act=True):
        super(Res1d, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])
        padding = (int(kernel_size) - 1) // 2
        self.conv1 = nn.Conv1d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.conv2 = nn.Conv1d(n_out, n_out, kernel_size=kernel_size, padding=padding, bias=False)
        self.relu = nn.ReLU(inplace = True)

        # All use name bn1 and bn2 to load imagenet pretrained weights
        if norm == 'GN':
            self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.bn1 = nn.BatchNorm1d(n_out)
            self.bn2 = nn.BatchNorm1d(n_out)
        else:
            exit('SyncBN has not been added!')

        if stride != 1 or n_out != n_in:
            if norm == 'GN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.downsample = nn.Sequential(
                        nn.Conv1d(n_in, n_out, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.downsample = None

        self.act = act

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            x = self.downsample(x)

        out += x
        if self.act:
            out = self.relu(out)
        return out


class LinearRes(nn.Module):
    def __init__(self, n_in, n_out, norm='GN', ng=32):
        super(LinearRes, self).__init__()
        assert(norm in ['GN', 'BN', 'SyncBN'])

        self.linear1 = nn.Linear(n_in, n_out, bias=False)
        self.linear2 = nn.Linear(n_out, n_out, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if norm == 'GN':
            self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
            self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)
        elif norm == 'BN':
            self.norm1 = nn.BatchNorm1d(n_out)
            self.norm2 = nn.BatchNorm1d(n_out)
        else:   
            exit('SyncBN has not been added!')

        if n_in != n_out:
            if norm == 'GN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.GroupNorm(gcd(ng, n_out), n_out))
            elif norm == 'BN':
                self.transform = nn.Sequential(
                    nn.Linear(n_in, n_out, bias=False),
                    nn.BatchNorm1d(n_out))
            else:
                exit('SyncBN has not been added!')
        else:
            self.transform = None

    def forward(self, x):
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.linear2(out)
        out = self.norm2(out)

        if self.transform is not None:
            out += self.transform(x)
        else:
            out += x

        out = self.relu(out)
        return out


class Null(nn.Module):
    def __init__(self):
        super(Null, self).__init__()
        
    def forward(self, x):
        return x


def linear_interp(x, n_max):
    """Given a Tensor of normed positions, returns linear interplotion weights and indices.
    Example: For position 1.2, its neighboring pixels have indices 0 and 1, corresponding
    to coordinates 0.5 and 1.5 (center of the pixel), and linear weights are 0.3 and 0.7.
    Args:
        x: Normalizzed positions, ranges from 0 to 1, float Tensor.
        n_max: Size of the dimension (pixels), multiply x to get absolution positions.
    Returns: Weights and indices of left side and right side.
    """
    x = x * n_max - 0.5

    mask = x < 0
    x[mask] = 0
    mask = x > n_max - 1
    x[mask] = n_max - 1
    n = torch.floor(x)

    rw = x - n
    lw = 1.0 - rw
    li = n.long()
    ri = li + 1
    mask = ri > n_max - 1
    ri[mask] = n_max - 1

    return lw, li, rw, ri


def get_pixel_feat(fm, bboxes, pts_range):
    x, y = bboxes[:, 0], bboxes[:, 1]
    x_min, x_max, y_min, y_max = pts_range[:4]
    x = (x - x_min) / (x_max - x_min)
    y = (y_max - y) / (y_max - y_min)

    _, fm_h, fm_w = fm.size()
    xlw, xli, xhw, xhi = linear_interp(x, fm_w)
    ylw, yli, yhw, yhi = linear_interp(y, fm_h)
    feat = \
        (xlw * ylw).unsqueeze(1) * fm[:, yli, xli].transpose(0, 1) +\
        (xlw * yhw).unsqueeze(1) * fm[:, yhi, xli].transpose(0, 1) +\
        (xhw * ylw).unsqueeze(1) * fm[:, yli, xhi].transpose(0, 1) +\
        (xhw * yhw).unsqueeze(1) * fm[:, yhi, xhi].transpose(0, 1)
    return feat


def get_roi_feat(fm, bboxes, roi_size, pts_range):
    """Given a set of BEV bboxes get their BEV ROI features.
    Args:
        fm: Feature map, float tensor, chw
        bboxes: BEV bboxes, n x 5 float tensor (cx, cy, wid, hgt, theta)
        roi_size: ROI size (number of bins), [int] or int
        pts_range: Range of points, tuple of ints, (x_min, x_max, y_min, y_max, z_min, z_max)
    Returns: Extracted features of size (num_roi, c, roi_size, roi_size).
    """
    if isinstance(roi_size, Number):
        roi_size = [roi_size, roi_size]

    cx, cy, wid, hgt, theta = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3], bboxes[:, 4]
    st = torch.sin(theta)
    ct = torch.cos(theta)
    num_bboxes = len(bboxes)

    rot_mat = bboxes.new().resize_(num_bboxes, 2, 2)
    rot_mat[:, 0, 0] = ct
    rot_mat[:, 0, 1] = -st
    rot_mat[:, 1, 0] = st
    rot_mat[:, 1, 1] = ct

    offset = bboxes.new().resize_(len(bboxes), roi_size[0], roi_size[1], 2)
    x_bin = (torch.arange(roi_size[1]).float().to(bboxes.device) + 0.5) / roi_size[1] - 0.5
    offset[:, :, :, 0] = x_bin.view(1, 1, -1) * wid.view(-1, 1, 1)
    y_bin = (torch.arange(roi_size[0] - 1, -1, -1).float().to(bboxes.device) + 0.5) / roi_size[0] - 0.5
    offset[:, :, :, 1] = y_bin.view(1, -1, 1) * hgt.view(-1, 1, 1)

    rot_mat = rot_mat.view(num_bboxes, 1, 1, 2, 2)
    offset = offset.view(num_bboxes, roi_size[0], roi_size[1], 2, 1)
    offset = torch.matmul(rot_mat, offset).view(num_bboxes, roi_size[0], roi_size[1], 2)

    x = cx.view(-1, 1, 1) + offset[:, :, :, 0]
    y = cy.view(-1, 1, 1) + offset[:, :, :, 1]
    x = x.view(-1)
    y = y.view(-1)

    x_min, x_max, y_min, y_max = pts_range[:4]
    x = (x - x_min) / (x_max - x_min)
    y = (y_max - y) / (y_max - y_min)

    fm_c, fm_h, fm_w = fm.size()
    feat = fm.new().float().resize_(num_bboxes * roi_size[0] * roi_size[1], fm_c)
    mask = (x > 0) * (x < 1) * (y > 0) * (y < 1)
    x = x[mask]
    y = y[mask]

    xlw, xli, xhw, xhi = linear_interp(x, fm_w)
    ylw, yli, yhw, yhi = linear_interp(y, fm_h)
    feat[mask] = \
        (xlw * ylw).unsqueeze(1) * fm[:, yli, xli].transpose(0, 1) +\
        (xlw * yhw).unsqueeze(1) * fm[:, yhi, xli].transpose(0, 1) +\
        (xhw * ylw).unsqueeze(1) * fm[:, yli, xhi].transpose(0, 1) +\
        (xhw * yhw).unsqueeze(1) * fm[:, yhi, xhi].transpose(0, 1)
    feat[torch.logical_not(mask)] = 0
    feat = feat.view(num_bboxes, roi_size[0] * roi_size[1], fm_c)
    feat = feat.transpose(1, 2).contiguous().view(num_bboxes, -1, roi_size[0], roi_size[1])
    return feat


In [7]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union


### config ###
config = dict()
"""Train"""
config["display_iters"] = 205942
config["val_iters"] = 205942 * 2
config["save_freq"] = 1.0
config["epoch"] = 0
config["horovod"] = True
config["opt"] = "adam"
config["num_epochs"] = 36
config["lr"] = [1e-3, 1e-4]
config["lr_epochs"] = [32]
#config["lr_func"] = StepLR(config["lr"], config["lr_epochs"])


config["batch_size"] = 32
config["val_batch_size"] = 32
config["workers"] = 0
config["val_workers"] = config["workers"]


"""Model"""
config["rot_aug"] = False
config["pred_range"] = [-100.0, 100.0, -100.0, 100.0]
config["num_scales"] = 6
config["n_actor"] = 128
config["n_map"] = 128
config["actor2map_dist"] = 7.0
config["map2actor_dist"] = 6.0
config["actor2actor_dist"] = 100.0
config["pred_size"] = 30
config["pred_step"] = 1
config["num_preds"] = config["pred_size"] // config["pred_step"]
config["num_mods"] = 6
config["cls_coef"] = 1.0
config["reg_coef"] = 1.0
config["mgn"] = 0.2
config["cls_th"] = 2.0
config["cls_ignore"] = 0.2
### end of config ###

class Net(nn.Module):
    """
    Lane Graph Network contains following components:
        1. ActorNet: a 1D CNN to process the trajectory input
        2. MapNet: LaneGraphCNN to learn structured map representations 
           from vectorized map data
        3. Actor-Map Fusion Cycle: fuse the information between actor nodes 
           and lane nodes:
            a. A2M: introduces real-time traffic information to 
                lane nodes, such as blockage or usage of the lanes
            b. M2M:  updates lane node features by propagating the 
                traffic information over lane graphs
            c. M2A: fuses updated map features with real-time traffic 
                information back to actors
            d. A2A: handles the interaction between actors and produces
                the output actor features
        4. PredNet: prediction header for motion forecasting using 
           feature from A2A
    """
    def __init__(self, config):
        super(Net, self).__init__()
        self.config = config

        self.actor_net = ActorNet(config)
        self.map_net = MapNet(config)

        #self.a2m = A2M(config)
        #self.m2m = M2M(config)
        self.m2a = M2A(config)
        self.a2a = A2A(config)

        self.pred_net = PredNet(config)

    def forward(self, data: Dict) -> Dict[str, List[Tensor]]:
        # construct actor feature
        actors, actor_idcs = actor_gather(gpu(data["feats"]))
        actor_ctrs = gpu(data["ctrs"])
        actors = self.actor_net(actors)

        # construct map features
        graph = graph_gather(to_long(gpu(data["graph"])))
        nodes, node_idcs, node_ctrs = self.map_net(graph)

        # actor-map fusion cycle 
        #nodes = self.a2m(nodes, graph, actors, actor_idcs, actor_ctrs)
        #nodes = self.m2m(nodes, graph)
        actors = self.m2a(actors, actor_idcs, actor_ctrs, nodes, node_idcs, node_ctrs)
        actors = self.a2a(actors, actor_idcs, actor_ctrs)

        # prediction
        out = self.pred_net(actors, actor_idcs, actor_ctrs)
        rot, orig = gpu(data["rot"]), gpu(data["orig"])
        # transform prediction to world coordinates
        for i in range(len(out["reg"])):
            out["reg"][i] = torch.matmul(out["reg"][i], rot[i]) + orig[i].view(
                1, 1, 1, -1
            )
        return out



def actor_gather(actors: List[Tensor]) -> Tuple[Tensor, List[Tensor]]:
    batch_size = len(actors)
    num_actors = [len(x) for x in actors]

    actors = [x.transpose(1, 2) for x in actors]
    actors = torch.cat(actors, 0)

    actor_idcs = []
    count = 0
    for i in range(batch_size):
        idcs = torch.arange(count, count + num_actors[i]).to(actors.device)
        actor_idcs.append(idcs)
        count += num_actors[i]
    return actors, actor_idcs


def graph_gather(graphs):
    batch_size = len(graphs)
    node_idcs = []
    count = 0
    counts = []
    for i in range(batch_size):
        counts.append(count)
        idcs = torch.arange(count, count + graphs[i]["num_nodes"]).to(
            graphs[i]["feats"].device
        )
        node_idcs.append(idcs)
        count = count + graphs[i]["num_nodes"]

    graph = dict()
    graph["idcs"] = node_idcs
    graph["ctrs"] = [x["ctrs"] for x in graphs]

    for key in ["feats", "turn", "control", "intersect"]:
        graph[key] = torch.cat([x[key] for x in graphs], 0)

    for k1 in ["pre", "suc"]:
        graph[k1] = []
        for i in range(len(graphs[0]["pre"])):
            graph[k1].append(dict())
            for k2 in ["u", "v"]:
                graph[k1][i][k2] = torch.cat(
                    [graphs[j][k1][i][k2] + counts[j] for j in range(batch_size)], 0
                )

    for k1 in ["left", "right"]:
        graph[k1] = dict()
        for k2 in ["u", "v"]:
            temp = [graphs[i][k1][k2] + counts[i] for i in range(batch_size)]
            temp = [
                x if x.dim() > 0 else graph["pre"][0]["u"].new().resize_(0)
                for x in temp
            ]
            graph[k1][k2] = torch.cat(temp)
    return graph


class ActorNet(nn.Module):
    """
    Actor feature extractor with Conv1D
    """
    def __init__(self, config):
        super(ActorNet, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_in = 3
        n_out = [32, 64, 128]
        blocks = [Res1d, Res1d, Res1d]
        num_blocks = [2, 2, 2]

        groups = []
        for i in range(len(num_blocks)):
            group = []
            if i == 0:
                group.append(blocks[i](n_in, n_out[i], norm=norm, ng=ng))
            else:
                group.append(blocks[i](n_in, n_out[i], stride=2, norm=norm, ng=ng))

            for j in range(1, num_blocks[i]):
                group.append(blocks[i](n_out[i], n_out[i], norm=norm, ng=ng))
            groups.append(nn.Sequential(*group))
            n_in = n_out[i]
        self.groups = nn.ModuleList(groups)

        n = config["n_actor"]
        lateral = []
        for i in range(len(n_out)):
            lateral.append(Conv1d(n_out[i], n, norm=norm, ng=ng, act=False))
        self.lateral = nn.ModuleList(lateral)

        self.output = Res1d(n, n, norm=norm, ng=ng)

    def forward(self, actors: Tensor) -> Tensor:
        out = actors

        outputs = []
        for i in range(len(self.groups)):
            out = self.groups[i](out)
            outputs.append(out)

        out = self.lateral[-1](outputs[-1])
        for i in range(len(outputs) - 2, -1, -1):
            out = F.interpolate(out, scale_factor=2, mode="linear", align_corners=False)
            out += self.lateral[i](outputs[i])

        out = self.output(out)[:, :, -1]
        return out


class MapNet(nn.Module):
    """
    Map Graph feature extractor with LaneGraphCNN
    """
    def __init__(self, config):
        super(MapNet, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        self.input = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )
        self.seg = nn.Sequential(
            nn.Linear(2, n_map),
            nn.ReLU(inplace=True),
            Linear(n_map, n_map, norm=norm, ng=ng, act=False),
        )

        keys = ["ctr", "norm", "ctr2", "left", "right"]
        for i in range(config["num_scales"]):
            keys.append("pre" + str(i))
            keys.append("suc" + str(i))

        fuse = dict()
        for key in keys:
            fuse[key] = []

        for i in range(2):
            for key in fuse:
                if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
                elif key in ["ctr2"]:
                    fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
                else:
                    fuse[key].append(nn.Linear(n_map, n_map, bias=False))

        for key in fuse:
            fuse[key] = nn.ModuleList(fuse[key])
        self.fuse = nn.ModuleDict(fuse)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, graph):
        if (
            len(graph["feats"]) == 0
            or len(graph["pre"][-1]["u"]) == 0
            or len(graph["suc"][-1]["u"]) == 0
        ):
            temp = graph["feats"]
            return (
                temp.new().resize_(0),
                [temp.new().long().resize_(0) for x in graph["node_idcs"]],
                temp.new().resize_(0),
            )

        ctrs = torch.cat(graph["ctrs"], 0)
        feat = self.input(ctrs)
        feat += self.seg(graph["feats"])
        feat = self.relu(feat)

        """fuse map"""
        res = feat
        for i in range(len(self.fuse["ctr"])):
            temp = self.fuse["ctr"][i](feat)
            for key in self.fuse:
                if key.startswith("pre") or key.startswith("suc"):
                    k1 = key[:3]
                    k2 = int(key[3:])
                    temp.index_add_(
                        0,
                        graph[k1][k2]["u"],
                        self.fuse[key][i](feat[graph[k1][k2]["v"]]),
                    )

            if len(graph["left"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["left"]["u"],
                    self.fuse["left"][i](feat[graph["left"]["v"]]),
                )
            if len(graph["right"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["right"]["u"],
                    self.fuse["right"][i](feat[graph["right"]["v"]]),
                )

            feat = self.fuse["norm"][i](temp)
            feat = self.relu(feat)

            feat = self.fuse["ctr2"][i](feat)
            feat += res
            feat = self.relu(feat)
            res = feat
        return feat, graph["idcs"], graph["ctrs"]


class A2M(nn.Module):
    """
    Actor to Map Fusion:  fuses real-time traffic information from
    actor nodes to lane nodes
    """
    def __init__(self, config):
        super(A2M, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        """fuse meta, static, dyn"""
        self.meta = Linear(n_map + 4, n_map, norm=norm, ng=ng)
        att = []
        for i in range(2):
            att.append(Att(n_map, config["n_actor"]))
        self.att = nn.ModuleList(att)

    def forward(self, feat: Tensor, graph: Dict[str, Union[List[Tensor], Tensor, List[Dict[str, Tensor]], Dict[str, Tensor]]], actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Tensor:
        """meta, static and dyn fuse using attention"""
        meta = torch.cat(
            (
                graph["turn"],
                graph["control"].unsqueeze(1),
                graph["intersect"].unsqueeze(1),
            ),
            1,
        )
        feat = self.meta(torch.cat((feat, meta), 1))

        for i in range(len(self.att)):
            feat = self.att[i](
                feat,
                graph["idcs"],
                graph["ctrs"],
                actors,
                actor_idcs,
                actor_ctrs,
                self.config["actor2map_dist"],
            )
        return feat


class M2M(nn.Module):
    """
    The lane to lane block: propagates information over lane
            graphs and updates the features of lane nodes
    """
    def __init__(self, config):
        super(M2M, self).__init__()
        self.config = config
        n_map = config["n_map"]
        norm = "GN"
        ng = 1

        keys = ["ctr", "norm", "ctr2", "left", "right"]
        for i in range(config["num_scales"]):
            keys.append("pre" + str(i))
            keys.append("suc" + str(i))

        fuse = dict()
        for key in keys:
            fuse[key] = []

        for i in range(4):
            for key in fuse:
                if key in ["norm"]:
                    fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
                elif key in ["ctr2"]:
                    fuse[key].append(Linear(n_map, n_map, norm=norm, ng=ng, act=False))
                else:
                    fuse[key].append(nn.Linear(n_map, n_map, bias=False))

        for key in fuse:
            fuse[key] = nn.ModuleList(fuse[key])
        self.fuse = nn.ModuleDict(fuse)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, feat: Tensor, graph: Dict) -> Tensor:
        """fuse map"""
        res = feat
        for i in range(len(self.fuse["ctr"])):
            temp = self.fuse["ctr"][i](feat)
            for key in self.fuse:
                if key.startswith("pre") or key.startswith("suc"):
                    k1 = key[:3]
                    k2 = int(key[3:])
                    temp.index_add_(
                        0,
                        graph[k1][k2]["u"],
                        self.fuse[key][i](feat[graph[k1][k2]["v"]]),
                    )

            if len(graph["left"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["left"]["u"],
                    self.fuse["left"][i](feat[graph["left"]["v"]]),
                )
            if len(graph["right"]["u"] > 0):
                temp.index_add_(
                    0,
                    graph["right"]["u"],
                    self.fuse["right"][i](feat[graph["right"]["v"]]),
                )

            feat = self.fuse["norm"][i](temp)
            feat = self.relu(feat)

            feat = self.fuse["ctr2"][i](feat)
            feat += res
            feat = self.relu(feat)
            res = feat
        return feat


class M2A(nn.Module):
    """
    The lane to actor block fuses updated
        map information from lane nodes to actor nodes
    """
    def __init__(self, config):
        super(M2A, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]
        n_map = config["n_map"]

        att = []
        for i in range(2):
            att.append(Att(n_actor, n_map))
        self.att = nn.ModuleList(att)

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor], nodes: Tensor, node_idcs: List[Tensor], node_ctrs: List[Tensor]) -> Tensor:
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                nodes,
                node_idcs,
                node_ctrs,
                self.config["map2actor_dist"],
            )
        return actors


class A2A(nn.Module):
    """
    The actor to actor block performs interactions among actors.
    """
    def __init__(self, config):
        super(A2A, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]
        n_map = config["n_map"]

        att = []
        for i in range(2):
            att.append(Att(n_actor, n_actor))
        self.att = nn.ModuleList(att)

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Tensor:
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                actors,
                actor_idcs,
                actor_ctrs,
                self.config["actor2actor_dist"],
            )
        return actors


class EncodeDist(nn.Module):
    def __init__(self, n, linear=True):
        super(EncodeDist, self).__init__()
        norm = "GN"
        ng = 1

        block = [nn.Linear(2, n), nn.ReLU(inplace=True)]

        if linear:
            block.append(nn.Linear(n, n))

        self.block = nn.Sequential(*block)

    def forward(self, dist):
        x, y = dist[:, :1], dist[:, 1:]
        dist = torch.cat(
            (
                torch.sign(x) * torch.log(torch.abs(x) + 1.0),
                torch.sign(y) * torch.log(torch.abs(y) + 1.0),
            ),
            1,
        )

        dist = self.block(dist)
        return dist


class PredNet(nn.Module):
    """
    Final motion forecasting with Linear Residual block
    """
    def __init__(self, config):
        super(PredNet, self).__init__()
        self.config = config
        norm = "GN"
        ng = 1

        n_actor = config["n_actor"]

        pred = []
        for i in range(config["num_mods"]):
            pred.append(
                nn.Sequential(
                    LinearRes(n_actor, n_actor, norm=norm, ng=ng),
                    nn.Linear(n_actor, 2 * config["num_preds"]),
                )
            )
        self.pred = nn.ModuleList(pred)

        self.att_dest = AttDest(n_actor)
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), nn.Linear(n_actor, 1)
        )

    def forward(self, actors: Tensor, actor_idcs: List[Tensor], actor_ctrs: List[Tensor]) -> Dict[str, List[Tensor]]:
        preds = []
        for i in range(len(self.pred)):
            preds.append(self.pred[i](actors))
        reg = torch.cat([x.unsqueeze(1) for x in preds], 1)
        reg = reg.view(reg.size(0), reg.size(1), -1, 2)

        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            reg[idcs] = reg[idcs] + ctrs

        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs)
        cls = self.cls(feats).view(-1, self.config["num_mods"])

        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2)

        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out


class Att(nn.Module):
    """
    Attention block to pass context nodes information to target nodes
    This is used in Actor2Map, Actor2Actor, Map2Actor and Map2Map
    """
    def __init__(self, n_agt: int, n_ctx: int) -> None:
        super(Att, self).__init__()
        norm = "GN"
        ng = 1

        self.dist = nn.Sequential(
            nn.Linear(2, n_ctx),
            nn.ReLU(inplace=True),
            Linear(n_ctx, n_ctx, norm=norm, ng=ng),
        )

        self.query = Linear(n_agt, n_ctx, norm=norm, ng=ng)

        self.ctx = nn.Sequential(
            Linear(3 * n_ctx, n_agt, norm=norm, ng=ng),
            nn.Linear(n_agt, n_agt, bias=False),
        )

        self.agt = nn.Linear(n_agt, n_agt, bias=False)
        self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
        self.linear = Linear(n_agt, n_agt, norm=norm, ng=ng, act=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, agts: Tensor, agt_idcs: List[Tensor], agt_ctrs: List[Tensor], ctx: Tensor, ctx_idcs: List[Tensor], ctx_ctrs: List[Tensor], dist_th: float) -> Tensor:
        res = agts
        if len(ctx) == 0:
            agts = self.agt(agts)
            agts = self.relu(agts)
            agts = self.linear(agts)
            agts += res
            agts = self.relu(agts)
            return agts

        batch_size = len(agt_idcs)
        hi, wi = [], []
        hi_count, wi_count = 0, 0
        for i in range(batch_size):
            dist = agt_ctrs[i].view(-1, 1, 2) - ctx_ctrs[i].view(1, -1, 2)
            dist = torch.sqrt((dist ** 2).sum(2))
            mask = dist <= dist_th

            idcs = torch.nonzero(mask, as_tuple=False)
            if len(idcs) == 0:
                continue

            hi.append(idcs[:, 0] + hi_count)
            wi.append(idcs[:, 1] + wi_count)
            hi_count += len(agt_idcs[i])
            wi_count += len(ctx_idcs[i])
        hi = torch.cat(hi, 0)
        wi = torch.cat(wi, 0)

        agt_ctrs = torch.cat(agt_ctrs, 0)
        ctx_ctrs = torch.cat(ctx_ctrs, 0)
        dist = agt_ctrs[hi] - ctx_ctrs[wi]
        dist = self.dist(dist)

        query = self.query(agts[hi])

        ctx = ctx[wi]
        ctx = torch.cat((dist, query, ctx), 1)
        ctx = self.ctx(ctx)

        agts = self.agt(agts)
        agts.index_add_(0, hi, ctx)
        agts = self.norm(agts)
        agts = self.relu(agts)

        agts = self.linear(agts)
        agts += res
        agts = self.relu(agts)
        return agts


class AttDest(nn.Module):
    def __init__(self, n_agt: int):
        super(AttDest, self).__init__()
        norm = "GN"
        ng = 1

        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )

        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)

    def forward(self, agts: Tensor, agt_ctrs: Tensor, dest_ctrs: Tensor) -> Tensor:
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)

        dist = (agt_ctrs.unsqueeze(1) - dest_ctrs).view(-1, 2)
        dist = self.dist(dist)
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)

        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts)
        return agts

## COUNT PARAMETERS

In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [9]:
model_lanegcn_ours = Net(config)
print("LANEGCN--", count_parameters(model_lanegcn_ours)) #Original-LaneGCN:3701161
print("DENSE-TNT--", count_parameters(model_densetnt)) 

LANEGCN-- 1842601
DENSE-TNT-- 1052225


  self.bn1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.bn2 = nn.GroupNorm(gcd(ng, n_out), n_out)
  nn.GroupNorm(gcd(ng, n_out), n_out))
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.norm = nn.GroupNorm(gcd(ng, n_out), n_out)
  fuse[key].append(nn.GroupNorm(gcd(ng, n_map), n_map))
  self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
  self.norm1 = nn.GroupNorm(gcd(ng, n_out), n_out)
  self.norm2 = nn.GroupNorm(gcd(ng, n_out), n_out)
