In [None]:
import numpy as np 
import pandas as pd
import copy
import time

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Any,Dict,List

padding_keys = ['HISTORY','FUTURE','NORM_CENTER','LANE_VECTORS','NEW_LANES','CLASS_LIST']
stacking_keys = ['VALID_LEN']
listing_keys = ['TARGET_MASK','LANE_ID','HEADINGS','RAW_HISTORY','RAW_FUTURE']

def collate_single_cpu(batch):

    keys = batch[0].keys()

    out = {k: [] for k in keys}

    for data in batch:
        for k, v in data.items():
            out[k].append(v)
    
    # stacking
    for k in stacking_keys:
        out[k] = torch.stack(out[k], dim=0)
    
    # padding
    for k in padding_keys:
        out[k] = pad_sequence(out[k], batch_first=True)
    
    return out

class inDDataset(Dataset):
    def __init__(self, data):
        super(inDDataset, self).__init__()

        self.data = data

    def __len__(self):
        
        return len(self.data)

    def __getitem__(self, idx):

        data_dict = self.get_data(idx)

        return data_dict

    def get_data(self, idx):
        
        out_dict = {}
        
        datadict = self.data[idx]
        out_dict.update(datadict)

        for k, v in out_dict.items():
            if isinstance(v, np.ndarray):
                v = torch.from_numpy(v).to(device)
            
                if v.dtype == torch.double:
                    v = v.type(torch.float32).to(device)
            
                out_dict[k] = v

        return out_dict

def pad_track(
        track_df: pd.DataFrame,
        seq_timestamps: np.ndarray,
        base: int,
        track_len: int,
        raw_data_format: Dict[str, int],
) -> np.ndarray:
    
    track_vals = track_df.values
    track_timestamps = track_df['frame'].values
    seq_timestamps = seq_timestamps[base:base+track_len]

    start_idx = np.where(seq_timestamps == track_timestamps[0])[0][0]
    end_idx = np.where(seq_timestamps == track_timestamps[-1])[0][0]

    padded_track_array = np.pad(track_vals,
                                ((start_idx, track_len - end_idx - 1),
                                    (0, 0)), "edge")

    mask = np.ones((end_idx+1-start_idx))
    mask = np.pad(mask, (start_idx, track_len - end_idx - 1), 'constant')
    if padded_track_array.shape[0] < track_len:
        return None, None, False

    for i in range(padded_track_array.shape[0]):
        padded_track_array[i, 0] = seq_timestamps[i]
    assert mask.shape[0] == padded_track_array.shape[0]
    return padded_track_array, mask, True



In [None]:
import pickle
processed_data_path = r'intersectionA_data_with_map.pkl'
with open(processed_data_path, 'rb') as f:
    data_location1 = pickle.load(f)

processed_data_path = r'intersectionB_data_with_map.pkl'
with open(processed_data_path, 'rb') as f:
    data_location2 = pickle.load(f)

processed_data_path = r'intersectionC_data_with_map.pkl'
with open(processed_data_path, 'rb') as f:
    data_location3 = pickle.load(f)
    
processed_data_path = r'intersectionD_data_with_map.pkl'
with open(processed_data_path, 'rb') as f:
    data_location4 = pickle.load(f)

In [None]:
#intersection A is the target domain，other three intersections are source domains

def series_impute(small_train,small_val,small_test,size_train,size_val,size_test):
    if len(small_train) == size_train and len(small_val) == size_val and len(small_test) == size_test:
        return small_train,small_val,small_test
    train_tmp = small_train
    val_tmp = small_val
    test_tmp = small_test
    for i in range (math.ceil(size_train/len(small_train))):
        #print(math.ceil(size_train/len(small_train)))
        train_tmp = np.concatenate((train_tmp,small_train), axis = 0)
    if math.ceil(size_train/len(small_train))==0:
        train_tmp = small_train
    train_tmp = train_tmp[:size_train]
    
    for i in range (math.ceil(size_val/len(small_val))):
        val_tmp = np.concatenate((val_tmp,small_val), axis = 0)
    if math.ceil(size_val/len(small_val))==0:
        val_tmp = small_val
    val_tmp = val_tmp[:size_val]
    
    for i in range (math.ceil(size_test/len(small_test))):
        test_tmp = np.concatenate((test_tmp,small_test), axis = 0)
    if math.ceil(size_test/len(small_test))==0:
        test_tmp = small_test
    test_tmp = test_tmp[:size_test]

    return train_tmp, val_tmp, test_tmp

def random_split_data(total_data,mr):
    shuffled_indices = np.random.permutation(len(total_data))
    train_indices = shuffled_indices[:int(0.8*mr*len(total_data))]
    val_indices = shuffled_indices[int(len(total_data)*0.8):int(len(total_data)*0.9)]
    test_indices = shuffled_indices[int(len(total_data)*0.9):]

    train_data = []
    val_data = []
    test_data = []
    for ind in train_indices:
        train_data.append(total_data[ind])
    for ind in val_indices:
        val_data.append(total_data[ind])
    for ind in test_indices:
        test_data.append(total_data[ind])
        
    return train_data,val_data,test_data

train_data1,val_data1,test_data1 = random_split_data(data_location1,0.0001)
train_data2,val_data2,test_data2 = random_split_data(data_location2,1)
train_data3,val_data3,test_data3 = random_split_data(data_location3,1)
train_data4,val_data4,test_data4 = random_split_data(data_location4,1)

print(len(train_data1),len(val_data1),len(test_data1))
print(len(train_data2),len(val_data2),len(test_data2))
print(len(train_data4),len(val_data4),len(test_data4))

import math

size_train = int(max(len(train_data1),len(train_data2),len(train_data3),len(train_data4))*(3/3))
size_val = int(max(len(val_data1),len(val_data2),len(val_data3),len(val_data4))*(3/3))
size_test = int(max(len(test_data1),len(test_data2),len(test_data3),len(test_data4))*(3/3))

#train_data1,val_data1,test_data1 = series_impute(train_data1,val_data1,test_data1,size_train,size_val,size_test)
train_data2,val_data2,test_data2 = series_impute(train_data2,val_data2,test_data2,size_train,size_val,size_test)
train_data3,val_data3,test_data3 = series_impute(train_data3,val_data3,test_data3,size_train,size_val,size_test)
train_data4,val_data4,test_data4 = series_impute(train_data4,val_data4,test_data4,size_train,size_val,size_test)

print(len(train_data1),len(val_data1),len(test_data1))
print(len(train_data2),len(val_data2),len(test_data2))
print(len(train_data4),len(val_data4),len(test_data4))

train_s1_Dataset = inDDataset(train_data4)
train_s1_dataloader = DataLoader(train_s1_Dataset,shuffle=True,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
val_s1_Dataset = inDDataset(val_data4)
val_s1_dataloader = DataLoader(val_s1_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
test_s1_Dataset = inDDataset(test_data4)
test_s1_dataloader = DataLoader(test_s1_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)

train_s2_Dataset = inDDataset(train_data2)
train_s2_dataloader = DataLoader(train_s2_Dataset,shuffle=True,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
val_s2_Dataset = inDDataset(val_data2)
val_s2_dataloader = DataLoader(val_s2_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
test_s2_Dataset = inDDataset(test_data2)
test_s2_dataloader = DataLoader(test_s2_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)

train_s3_Dataset = inDDataset(train_data3)
train_s3_dataloader = DataLoader(train_s3_Dataset,shuffle=True,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
val_s3_Dataset = inDDataset(val_data3)
val_s3_dataloader = DataLoader(val_s3_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
test_s3_Dataset = inDDataset(test_data3)
test_s3_dataloader = DataLoader(test_s3_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)

train_t_Dataset = inDDataset(train_data1)
train_t_dataloader = DataLoader(train_t_Dataset,shuffle=True,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
val_t_Dataset = inDDataset(val_data1)
val_t_dataloader = DataLoader(val_t_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
test_t_Dataset = inDDataset(test_data1)
test_t_dataloader = DataLoader(test_t_Dataset,shuffle=False,batch_size=16,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)

dataloaders = {}
dataloaders['train_s1'] = train_s1_dataloader
dataloaders['train_s2'] = train_s2_dataloader
dataloaders['train_s3'] = train_s3_dataloader
dataloaders['val_s1'] = val_s1_dataloader
dataloaders['val_s2'] = val_s2_dataloader
dataloaders['val_s3'] = val_s3_dataloader
dataloaders['test_s1'] = test_s1_dataloader
dataloaders['test_s2'] = test_s2_dataloader
dataloaders['test_s3'] = test_s3_dataloader

dataloaders['train_t'] = train_t_dataloader
dataloaders['val_t'] = val_t_dataloader
dataloaders['test_t'] = test_t_dataloader

dataset_sizes = {}
dataset_sizes['train_s'] = len(train_data1)
dataset_sizes['val_s'] = len(val_data1)
dataset_sizes['test_s'] = len(test_data1)
dataset_sizes['train_t'] = len(train_data2)
dataset_sizes['val_t'] = len(val_data2)
dataset_sizes['test_t'] = len(test_data2)

In [None]:
train_data1,val_data1,test_data1 = random_split_data(data_location1,0.0001)

train_t_Dataset = inDDataset(train_data1)
train_t_dataloader = DataLoader(train_t_Dataset,shuffle=True,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
val_t_Dataset = inDDataset(val_data1)
val_t_dataloader = DataLoader(val_t_Dataset,shuffle=False,batch_size=32,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)
test_t_Dataset = inDDataset(test_data1)
test_t_dataloader = DataLoader(test_t_Dataset,shuffle=False,batch_size=16,num_workers=0,collate_fn=collate_single_cpu,drop_last=True)

dataloaders = {}
dataloaders['train_t'] = train_t_dataloader
dataloaders['val_t'] = val_t_dataloader
dataloaders['test_t'] = test_t_dataloader

dataset_sizes = {}
dataset_sizes['train_t'] = len(train_data1)
dataset_sizes['val_t'] = len(val_data1)
dataset_sizes['test_t'] = len(test_data1)
print(dataset_sizes['test_t'])

In [None]:
import copy
import torch
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
from multi_attention_forward import multi_head_attention_forward

class LaneNet(nn.Module):
    def __init__(self, in_channels, hidden_unit, num_subgraph_layers):
        super(LaneNet, self).__init__()
        self.num_subgraph_layers = num_subgraph_layers
        self.layer_seq = nn.Sequential()
        for i in range(num_subgraph_layers):
            self.layer_seq.add_module(
                f'lmlp_{i}', MLP(in_channels, hidden_unit))
            in_channels = hidden_unit*2

    def forward(self, lane):
        
        x = lane
        for name, layer in self.layer_seq.named_modules():
            if isinstance(layer, MLP):
                x = layer(x)
                x_max = torch.max(x, -2)[0]
                x_max = x_max.unsqueeze(2).repeat(1, 1, x.shape[2], 1)
                x = torch.cat([x, x_max], dim=-1)
        x_max = torch.max(x, -2)[0]
        return x_max

class MLP(nn.Module):
    def __init__(self, in_channels, hidden_unit, verbose=False):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_unit),
            nn.LayerNorm(hidden_unit),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.mlp(x)
        return x

def preprocess_lane(lane_subgraph,lane,B, lane_valid_len, max_lane_num):
    
    lane_v = torch.cat(
        [lane[:, :, :-1, :2],
            lane[:, :, 1:, :2]], dim=-1).to(device)
    
    lane_mask = torch.zeros(
        (B, 1, int(max_lane_num))).to(device)
    for i in range(lane_valid_len.shape[0]):
        lane_mask[i, 0, :lane_valid_len[i]] = 1
    
    lane_feature = lane_subgraph(lane_v)

    return lane_feature, lane_mask

def preprocess_traj(traj, B, traj_valid_len, max_agent_num):
    
    social_valid_len = traj_valid_len
    social_mask = torch.zeros(
        (B, 1, int(max_agent_num))).to(device)
    for i in range(B):
        social_mask[i, 0, :social_valid_len[i]] = 1

    return social_mask


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise RuntimeError("activation should be relu/gelu, not %s." % activation)

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def subsequent_mask(size):
    
    attn_shape = (1, size, size)
    mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(mask) == 0

def _generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).to(device)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

class MultiheadAttention(nn.Module):
    
    __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']

    def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None,
                 vdim=None):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            nn.init.xavier_uniform_(self.in_proj_weight)
        else:
            nn.init.xavier_uniform_(self.q_proj_weight)
            nn.init.xavier_uniform_(self.k_proj_weight)
            nn.init.xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super(MultiheadAttention, self).__setstate__(state)

    def forward(self, query, key, value, key_padding_mask=None,
                need_weights=True, attn_mask=None):
        
        if not self._qkv_same_embed_dim:
            return multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask, use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight)
        else:
            return multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                self.in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, self.out_proj.weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask)


class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        
        src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
                                    key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        if hasattr(self, "activation"):
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        else:
            src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))

        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src, attn


class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        
        output = src

        atts = []

        for i in range(self.num_layers):
            output, attn = self.layers[i](output, src_mask=mask,
                                          src_key_padding_mask=src_key_padding_mask)
            atts.append(attn)
        if self.norm:
            output = self.norm(output)

        return output, atts

class TransformerDecoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0, activation="relu"):
        super(TransformerDecoderLayer, self).__init__()
        self.tgt_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.src_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, src, target, src_mask=None, target_mask=None, src_key_padding_mask=None):
        
        target2, attn_tgt = self.tgt_attn(target, target, target, attn_mask=target_mask,
                                    key_padding_mask=src_key_padding_mask)
        target = target+self.dropout1(target2)
        target = self.norm1(target)
        
        target2, attn_src = self.src_attn(target, src, src, attn_mask=src_mask,
                                    key_padding_mask=src_key_padding_mask)
        target = target + self.dropout2(target2)
        target = self.norm2(target)

        if hasattr(self, "activation"):
            target2 = self.linear2(self.dropout(self.activation(self.linear1(target))))
        else:
            target2 = self.linear2(self.dropout(F.relu(self.linear1(target))))

        target = target + self.dropout3(target2)
        target = self.norm3(target)
        return target, attn_tgt, attn_src
    
class TransformerDecoder(nn.Module):

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, target, src_mask=None, target_mask=None, src_key_padding_mask=None):

        atts_tgt = []
        atts_src = []

        for i in range(self.num_layers):
            target, attn_tgt, attn_src = self.layers[i](src, target, src_mask=src_mask,target_mask = target_mask,
                                          src_key_padding_mask=src_key_padding_mask)
            atts_tgt.append(attn_tgt)
            atts_src.append(attn_src)
        if self.norm:
            target = self.norm(target)

        return target, atts_tgt, atts_src

import os
import sys
import random
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import argparse

import time
import math
import scipy.io as scp
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageOps
from scipy.ndimage.interpolation import rotate

import io

from multiprocessing import Pool
import torch.multiprocessing

def construct_target(phy_tgt, traj, num_queries):
    for k in range(phy_tgt.shape[1]):
        traj_input = traj[:, k, :, :2]
        da_x = (traj_input[:, -1, 0] - traj_input[:, -2, 0]) / 1
        da_y = (traj_input[:, -1, 1] - traj_input[:, -2, 1]) / 1
        new_da_x = da_x-da_x
        new_da_y = da_y
    
        hist_outputs = torch.zeros([traj.shape[0], 12, 2]).to(device)
        for i in range(hist_outputs.shape[0]):
            hist_outputs[i, :, 0] = torch.linspace(traj_input[i, -1, 0].item(),
                                                       traj_input[i, -1, 0].item() + new_da_x[i].item() * 12, 13)[1:]
            hist_outputs[i, :, 1] = torch.linspace(traj_input[i, -1, 1].item(),
                                                       traj_input[i, -1, 1].item() + new_da_y[i].item() * 12, 13)[1:]
        
        phy_tgt[:, k, 0, :, :] = hist_outputs
        phy_tgt = phy_tgt.to(device)
    return phy_tgt


In [None]:
#global interactor: refer to Z. Zhou, L. Ye, J. Wang, K. Wu, and K. Lu. "HiVT: Hierarchical vector Transformer for multi-agent motion prediction," In Proc. IEEE Conf. Comput. Vis. Pattern Recognit. (CVPR), New Orleans, Louisiana, USA, pp. 8823-8833, Jun. 2022

from typing import List, Optional

import torch
import torch.nn as nn

def init_weights(m: nn.Module) -> None:
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        fan_in = m.in_channels / m.groups
        fan_out = m.out_channels / m.groups
        bound = (6.0 / (fan_in + fan_out)) ** 0.5
        nn.init.uniform_(m.weight, -bound, bound)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.MultiheadAttention):
        if m.in_proj_weight is not None:
            fan_in = m.embed_dim
            fan_out = m.embed_dim
            bound = (6.0 / (fan_in + fan_out)) ** 0.5
            nn.init.uniform_(m.in_proj_weight, -bound, bound)
        else:
            nn.init.xavier_uniform_(m.q_proj_weight)
            nn.init.xavier_uniform_(m.k_proj_weight)
            nn.init.xavier_uniform_(m.v_proj_weight)
        if m.in_proj_bias is not None:
            nn.init.zeros_(m.in_proj_bias)
        nn.init.xavier_uniform_(m.out_proj.weight)
        if m.out_proj.bias is not None:
            nn.init.zeros_(m.out_proj.bias)
        if m.bias_k is not None:
            nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
        if m.bias_v is not None:
            nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
    elif isinstance(m, nn.LSTM):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                for ih in param.chunk(4, 0):
                    nn.init.xavier_uniform_(ih)
            elif 'weight_hh' in name:
                for hh in param.chunk(4, 0):
                    nn.init.orthogonal_(hh)
            elif 'weight_hr' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias_ih' in name:
                nn.init.zeros_(param)
            elif 'bias_hh' in name:
                nn.init.zeros_(param)
                nn.init.ones_(param.chunk(4, 0)[1])
    elif isinstance(m, nn.GRU):
        for name, param in m.named_parameters():
            if 'weight_ih' in name:
                for ih in param.chunk(3, 0):
                    nn.init.xavier_uniform_(ih)
            elif 'weight_hh' in name:
                for hh in param.chunk(3, 0):
                    nn.init.orthogonal_(hh)
            elif 'bias_ih' in name:
                nn.init.zeros_(param)
            elif 'bias_hh' in name:
                nn.init.zeros_(param)

class MultipleInputEmbedding(nn.Module):

    def __init__(self,
                 in_channels: List[int],
                 out_channel: int) -> None:
        super(MultipleInputEmbedding, self).__init__()
        self.module_list = nn.ModuleList(
            [nn.Sequential(nn.Linear(in_channel, out_channel),
                           nn.LayerNorm(out_channel),
                           nn.ReLU(inplace=True),
                           nn.Linear(out_channel, out_channel))
             for in_channel in in_channels])
        self.aggr_embed = nn.Sequential(
            nn.LayerNorm(out_channel),
            nn.ReLU(inplace=True),
            nn.Linear(out_channel, out_channel),
            nn.LayerNorm(out_channel))
        self.apply(init_weights)

    def forward(self,
                continuous_inputs: List[torch.Tensor],
                categorical_inputs: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        for i in range(len(self.module_list)):
            continuous_inputs[i] = self.module_list[i](continuous_inputs[i])
        output = torch.stack(continuous_inputs).sum(dim=0)
        if categorical_inputs is not None:
            output += torch.stack(categorical_inputs).sum(dim=0)
        return self.aggr_embed(output)

import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj
from torch_geometric.typing import OptTensor
from torch_geometric.typing import Size
from torch_geometric.utils import softmax
from torch_geometric.utils import subgraph

from itertools import permutations

class GlobalInteractorLayer(MessagePassing):

    def __init__(self,
                 embed_dim: int,
                 num_heads: int = 8,
                 dropout: float = 0.1,
                 **kwargs) -> None:
        super(GlobalInteractorLayer, self).__init__(aggr='add', node_dim=0, **kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.lin_q_node = nn.Linear(embed_dim, embed_dim)
        self.lin_k_node = nn.Linear(embed_dim, embed_dim)
        self.lin_k_edge = nn.Linear(embed_dim, embed_dim)
        self.lin_v_node = nn.Linear(embed_dim, embed_dim)
        self.lin_v_edge = nn.Linear(embed_dim, embed_dim)
        self.lin_self = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.lin_ih = nn.Linear(embed_dim, embed_dim)
        self.lin_hh = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout))

    def forward(self,
                x: torch.Tensor,
                edge_index: Adj,
                edge_attr: torch.Tensor,
                size: Size = None) -> torch.Tensor:
        x = x + self._mha_block(self.norm1(x), edge_index, edge_attr, size)
        x = x + self._ff_block(self.norm2(x))
        return x

    def message(self,
                x_i: torch.Tensor,
                x_j: torch.Tensor,
                edge_attr: torch.Tensor,
                index: torch.Tensor,
                ptr: OptTensor,
                size_i: Optional[int]) -> torch.Tensor:
        query = self.lin_q_node(x_i).view(-1, self.num_heads, self.embed_dim // self.num_heads)
        key_node = self.lin_k_node(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
        key_edge = self.lin_k_edge(edge_attr).view(-1, self.num_heads, self.embed_dim // self.num_heads)
        value_node = self.lin_v_node(x_j).view(-1, self.num_heads, self.embed_dim // self.num_heads)
        value_edge = self.lin_v_edge(edge_attr).view(-1, self.num_heads, self.embed_dim // self.num_heads)
        scale = (self.embed_dim // self.num_heads) ** 0.5
        alpha = (query * (key_node + key_edge)).sum(dim=-1) / scale
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = self.attn_drop(alpha)
        return (value_node + value_edge) * alpha.unsqueeze(-1)

    def update(self,
               inputs: torch.Tensor,
               x: torch.Tensor) -> torch.Tensor:
        inputs = inputs.view(-1, self.embed_dim)
        gate = torch.sigmoid(self.lin_ih(inputs) + self.lin_hh(x))
        return inputs + gate * (self.lin_self(x) - inputs)

    def _mha_block(self,
                   x: torch.Tensor,
                   edge_index: Adj,
                   edge_attr: torch.Tensor,
                   size: Size) -> torch.Tensor:
        x = self.out_proj(self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, size=size))
        return self.proj_drop(x)

    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

class GlobalInteractor(nn.Module):
    def __init__(self,
                 historical_steps: int,
                 embed_dim: int,
                 edge_dim: int,
                 num_modes: int = 1,
                 num_heads: int = 8,
                 num_layers: int = 3,
                 dropout: float = 0.1) -> None:
        super(GlobalInteractor, self).__init__()
        self.historical_steps = historical_steps
        self.embed_dim = embed_dim
        self.num_modes = num_modes

        self.rel_embed = MultipleInputEmbedding(in_channels=[edge_dim, edge_dim], out_channel=embed_dim)
        self.global_interactor_layers = nn.ModuleList(
            [GlobalInteractorLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
             for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_dim)
        self.multihead_proj = nn.Linear(embed_dim, num_modes * embed_dim)
        self.apply(init_weights)
        
    def forward(self,data,social_inp) -> torch.Tensor:
        
        social_out = torch.zeros([0,social_inp.shape[1],social_inp.shape[2]]).to(device)
        for m in range(social_inp.shape[0]):
            temp_inp = social_inp[m,:,:]
            num_nodes = data['VALID_LEN'][m,0].item()
            if num_nodes==1:
                #social_out = torch.cat([social_out,torch.zeros([1,social_inp.shape[1],social_inp.shape[2]])],dim=0)
                social_out = torch.cat([social_out,social_inp[[m],:,:]],dim=0)
                continue
            temp_edge_index = torch.LongTensor(list(permutations(range(num_nodes), 2))).t().contiguous().to(device)
            temp_norm_center = data['NORM_CENTER'][m,:]
            temp_heading = data['HEADINGS'][m]
            
            temp_rel_pos = temp_norm_center[temp_edge_index[0]] - temp_norm_center[temp_edge_index[1]]
            temp_rel_theta = temp_heading[temp_edge_index[0]] - temp_heading[temp_edge_index[1]]
            temp_rel_theta_cos = torch.cos(temp_rel_theta)
            temp_rel_theta_sin = torch.sin(temp_rel_theta)
            temp_rel_embed = self.rel_embed([temp_rel_pos, torch.cat((temp_rel_theta_cos, temp_rel_theta_sin), dim=-1)])
            
            x = temp_inp
            for layer in self.global_interactor_layers:
                x = layer(x,temp_edge_index,temp_rel_embed)
            x = self.norm(x)  
            x = self.multihead_proj(x).view(-1,self.num_modes,self.embed_dim)  
            x = x.transpose(0,1)  
            social_out = torch.cat([social_out,x],dim=0)
        
        return social_out


In [None]:
class MultiPredictionHeader(nn.Module):
    def __init__(self, d_model, out_size, dropout, reg_h_dim=128, dis_h_dim=128, cls_h_dim=128):
        super(MultiPredictionHeader, self).__init__()
        self.out_size = out_size
        self.reg_mlp_veh = nn.Sequential(
            nn.Linear(d_model, reg_h_dim * 2, bias=True),
            nn.LayerNorm(reg_h_dim * 2),
            nn.ReLU(),
            nn.Linear(reg_h_dim * 2, reg_h_dim, bias=True),
            nn.Linear(reg_h_dim, out_size, bias=True))
        self.reg_mlp_bic = nn.Sequential(
            nn.Linear(d_model, reg_h_dim * 2, bias=True),
            nn.LayerNorm(reg_h_dim * 2),
            nn.ReLU(),
            nn.Linear(reg_h_dim * 2, reg_h_dim, bias=True),
            nn.Linear(reg_h_dim, out_size, bias=True))
        self.reg_mlp_ped = nn.Sequential(
            nn.Linear(d_model, reg_h_dim * 2, bias=True),
            nn.LayerNorm(reg_h_dim * 2),
            nn.ReLU(),
            nn.Linear(reg_h_dim * 2, reg_h_dim, bias=True),
            nn.Linear(reg_h_dim, out_size, bias=True))
        self.fusion3 = nn.Linear(4, 2, bias=True)
        self.fusion4 = nn.Linear(4, 2, bias=True)
        self.num_modes = 1
        
    def forward(self, feature_out,data,traj):
        
        pred = torch.zeros([*feature_out.shape[:2],self.out_size]).to(device)

        pred[data['CLASS_LIST']==1] = self.reg_mlp_veh(feature_out[data['CLASS_LIST']==1])
        pred[data['CLASS_LIST']==2] = self.reg_mlp_bic(feature_out[data['CLASS_LIST']==2])
        pred[data['CLASS_LIST']==3] = self.reg_mlp_ped(feature_out[data['CLASS_LIST']==3])
        
        ori_pred = pred.view(*pred.shape[:-1], -1, 2)
        pred = ori_pred.cumsum(dim=-2)
        
        fusion_phy = torch.zeros([pred.shape[0],pred.shape[1],1,12,2]).to(device)
        fusion_phy = construct_target(fusion_phy,traj,self.num_modes)
        
        hist_outputs = fusion_phy.squeeze(2)
        hist = hist_outputs.cumsum(axis=-2)
        
        final_cum = self.fusion3(torch.cat([pred, hist], -1))
        final_ori = self.fusion4(torch.cat([ori_pred, hist_outputs], -1))
        
        return final_cum, final_ori
    

In [None]:
# Transformer_utils refer to Y. Liu, J. Zhang, L. Fang, Q. Jiang, and B. Zhou, "Multimodal motion prediction with stacked transformers," In Proc. IEEE Conf. Comput. Vis. Pattern Recognit. (CVPR), virtually, pp. 7577-7586, Jun. 2021

from Transformer_utils import (Decoder, DecoderLayer, Encoder, EncoderDecoder,
                                 EncoderLayer, LinearEmbedding, MultiHeadAttention,
                                 PointerwiseFeedforward, PositionalEncoding,
                                 SublayerConnection)

class MA_STTN_MAP(nn.Module):
    def __init__(self, hist_inp_size, num_queries, dec_inp_size, dec_out_size, N, N_social,
                d_model, d_ff, pos_dim, dist_dim, h, dropout):
        super(MA_STTN_MAP, self).__init__()
        self.num_queries = num_queries
        c = copy.deepcopy
        dropout_atten = dropout
        attn = MultiHeadAttention(h, d_model, dropout=dropout_atten)
        ff = PointerwiseFeedforward(d_model, d_ff, dropout)
        position = PositionalEncoding(d_model, dropout)
        
        self.hist_tf = EncoderDecoder(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
            Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
            nn.Sequential(LinearEmbedding(hist_inp_size, d_model), c(position)))
        self.lane_enc = Encoder(EncoderLayer(
            d_model, c(attn), c(ff), dropout), N_lane)
        self.lane_dec = Decoder(DecoderLayer(
            d_model, c(attn), c(attn), c(ff), dropout), N_lane)
        self.lane_emb = LinearEmbedding(lane_inp_size, d_model)
        self.phy_emb = nn.Sequential(
            nn.Flatten(start_dim=-2,end_dim=-1),
            nn.Linear(12*2, d_model, bias=True),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model, bias=True))
        self.pos_emb = nn.Sequential(
            nn.Linear(2, pos_dim, bias=True),
            nn.LayerNorm(pos_dim),
            nn.ReLU(),
            nn.Linear(pos_dim, pos_dim, bias=True))
        self.dist_emb = nn.Sequential(
            nn.Linear(num_queries*d_model, dist_dim, bias=True),
            nn.LayerNorm(dist_dim),
            nn.ReLU(),
            nn.Linear(dist_dim, dist_dim, bias=True))
        
        self.fusion1 = nn.Sequential(
            nn.Linear(d_model+pos_dim, d_model, bias=True),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model, bias=True))
        self.fusion2 = nn.Sequential(
            nn.Linear(dist_dim+pos_dim, d_model, bias=True),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model, bias=True))
        self.fusion3 = nn.Linear(4, 2, bias=True)
        self.fusion4 = nn.Linear(4, 2, bias=True)
        self.social_enc = Encoder(EncoderLayer(
            d_model, c(attn), c(ff), dropout), N_social)
        self.social_dec = Decoder(DecoderLayer(
            d_model, c(attn), c(attn), c(ff), dropout), N_social)
        
        self.historical_steps = 8
        self.edge_dim = 2
        self.embed_dim = d_model
        self.num_modes = num_queries
        self.num_heads = 8
        self.num_global_layers = 3
        self.global_interactor = GlobalInteractor(historical_steps=self.historical_steps,
                                    embed_dim=self.embed_dim,
                                    edge_dim=self.edge_dim,
                                    num_modes=self.num_modes,
                                    num_heads=self.num_heads,
                                    num_layers=self.num_global_layers,
                                    dropout=dropout).to(device)
        
        self.w1 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
        self.w2 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
        self.w3 = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
        self.theta = nn.Parameter(torch.FloatTensor(1), requires_grad=False)

        self.w1.data.fill_(3**0.5)
        self.w2.data.fill_(3**0.5)
        self.w3.data.fill_(3**0.5)
        self.theta.data.fill_(0.9)
        
        for name, param in self.named_parameters():
            # print(name)
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

        self.query_embed = nn.Embedding(self.num_queries, d_model)
        self.query_embed.weight.requires_grad == True
        nn.init.orthogonal_(self.query_embed.weight)
        
        
    def forward(self, traj, pos, max_agent_num, social_mask,lane_vec, lane_mask,data):
        self.query_batches = self.query_embed.weight.view(
            1, 1, *self.query_embed.weight.shape).repeat(*traj.shape[:2], 1, 1)
        # Physics target construction
        phy_tgt = torch.zeros([self.query_batches.shape[0],self.query_batches.shape[1],self.query_batches.shape[2],12,2]).to(device)
        phy_tgt = construct_target(phy_tgt,traj,self.num_queries)
        phy_tgt = self.phy_emb(phy_tgt)
        
        #historical information
        hist_out = self.hist_tf(traj, phy_tgt, None, None, self.query_batches)
        pos = self.pos_emb(pos)
        hist_out = torch.cat([pos.unsqueeze(dim=2).repeat(
                    1, 1, self.num_queries, 1), hist_out], dim=-1)
        hist_out = self.fusion1(hist_out)
        
        #lane encoder
        social_num = max_agent_num
        lane_mem = self.lane_enc(self.lane_emb(lane_vec), lane_mask)
        lane_mem = lane_mem.unsqueeze(1).repeat(1, social_num, 1, 1)
        lane_mask = lane_mask.unsqueeze(1).repeat(1, social_num, 1, 1)
        
        # Lane decoder
        lane_out = self.lane_dec(hist_out, lane_mem, lane_mask, None)
        
        # Fuse position information
        dist = lane_out.view(*traj.shape[0:2], -1)
        dist = self.dist_emb(dist)
        
        #Global Interactor
        social_inp = self.fusion2(torch.cat([pos, dist], -1))
        social_out = self.global_interactor(data,social_inp)
        
        feature_out = torch.cat([social_out,lane_out.squeeze(2),phy_tgt.squeeze(2)], -1)
        
        return feature_out
    
    def cal_traj_loss_dann(self,loss_traj_s1,loss_traj_s2,loss_traj_s3):
        return (1/(self.w1)**2)*loss_traj_s1+(1/(self.w2)**2)*loss_traj_s2+(1/(self.w3)**2)*loss_traj_s3+\
                torch.log(self.w1+torch.Tensor([1.0]).to(device))+torch.log(self.w2+torch.Tensor([1.0]).to(device))+torch.log(self.w3+torch.Tensor([1.0]).to(device))
    
    def cal_total_loss_dann(self,loss_traj_s1,loss_traj_s2,loss_traj_s3,
                          loss_domain_s1,loss_domain_s2,loss_domain_s3):
        traj_loss = (1/(self.w1)**2)*loss_traj_s1+(1/(self.w2)**2)*loss_traj_s2+(1/(self.w3)**2)*loss_traj_s3+\
                    torch.log(self.w1+torch.Tensor([1.0]).to(device))+torch.log(self.w2+torch.Tensor([1.0]).to(device))+torch.log(self.w3+torch.Tensor([1.0]).to(device))
        
        domain_loss = (1/(self.w1)**2)*loss_domain_s1+(1/(self.w2)**2)*loss_domain_s2+(1/(self.w3)**2)*loss_domain_s3+\
                    torch.log(self.w1+torch.Tensor([1.0]).to(device))+torch.log(self.w2+torch.Tensor([1.0]).to(device))+torch.log(self.w3+torch.Tensor([1.0]).to(device))
        
        total_loss = self.theta*traj_loss + (torch.Tensor([1.0]).to(device)-self.theta)*domain_loss
        return total_loss,traj_loss,domain_loss
    
    def cal_total_loss_target(self,loss_traj_t,loss_domain_t):
        
        total_loss = self.theta*loss_traj_t + (torch.Tensor([1.0]).to(device)-self.theta)*loss_domain_t
        return total_loss

In [None]:
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.init as init
from torch.nn import Parameter

class GradReverse(torch.autograd.Function):

    def forward(ctx, x, constant):
        ctx.constant = constant
        return x.view_as(x)

    def backward(ctx, grad_output):
        grad_output = grad_output.neg() * ctx.constant
        return grad_output, None

    def grad_reverse(x, constant):
        return GradReverse.apply(x, constant)

class Domain_D(nn.Module):

    def __init__(self,d_model):
        super(Domain_D, self).__init__()
        self.classify = nn.Sequential(
                nn.Linear(d_model*3, 512, bias=True),
                #nn.LayerNorm(reg_h_dim * 2),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(512, 256, bias=True),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Linear(256, 1, bias=True))

    def forward(self,inputs):
        
        logits = self.classify(inputs)

        return logits
    
class TrajectoryAttentionMechanism(nn.Module):
    def __init__(self,fin,d_model,fout,N_domain,h_domain,d_ff,dropout_domain):
        
        super(TrajectoryAttentionMechanism,self).__init__()
        c = copy.deepcopy
        domain_attn = MultiHeadAttention(h_domain, d_model*3, dropout=dropout_domain)
        domain_ff = PointerwiseFeedforward(d_model*3, d_ff, dropout_domain)
        self.domain_enc = Encoder(EncoderLayer(d_model*3, c(domain_attn), c(domain_ff), dropout_domain), N_domain)
        
        
    def forward(self,feature_out,social_mask):
        D_inputs = self.domain_enc(feature_out, social_mask)
        
        return D_inputs[:,0,:]

In [None]:
#parameter
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#HTN feature extractor
hist_inp_size = 5
num_queries = 1
lane_inp_size = 64
dec_inp_size = 64
# output steps*2: 12*2
dec_out_size = 24
N = 2
N_lane = 2
N_social = 2
d_model = 128
d_ff = 256
pos_dim = 64
dist_dim = 128
h = 4
dropout = 0.1
PredictionModel = MA_STTN_MAP(hist_inp_size, num_queries, dec_inp_size, dec_out_size, N, N_social, d_model, d_ff, pos_dim, dist_dim, h, dropout).to(device)

#LaneNet for trajectory preprocessing
lane_channels= 4
subgraph_width = 32
num_subgraph_layres =2
lane_subgraph = LaneNet(lane_channels, subgraph_width, num_subgraph_layres).to(device)

#Trajectory prediction header
prediction_header = MultiPredictionHeader(d_model*3, dec_out_size, dropout).to(device)

#domain discriminator
dropout_domain = 0.1
N_domain = 2
h_domain = 4
d_model = 128
d_ff = 256
fin = d_model*3
fout = 1
TraAttM = TrajectoryAttentionMechanism(fin,d_model,fout,N_domain,h_domain,d_ff,dropout_domain).to(device)

domain_dis = Domain_D(d_model).to(device)

In [None]:
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):

    n_samples = int(source.size()[0])+int(target.size()[0])
    total = torch.cat([source, target], dim=0)
    total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
    L2_distance = ((total0-total1)**2).sum(2) 
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
    bandwidth /= kernel_mul ** (kernel_num // 2)
    bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
    return sum(kernel_val)

def cal_mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    
    batch_size = int(source.size()[0])
    kernels = guassian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
    XX = kernels[:batch_size, :batch_size]
    YY = kernels[batch_size:, batch_size:]
    XY = kernels[:batch_size, batch_size:]
    YX = kernels[batch_size:, :batch_size]
    loss = torch.mean(XX + YY - XY -YX)
    return loss

import torch.autograd as autograd

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
lambda_gp = 10

def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

def cal_train_loss(traj_lab,outputs_coord,labels_mask):
    muX = outputs_coord[:,:,:,[0]].to(device)
    muY = outputs_coord[:,:,:,[1]].to(device)
    x = traj_lab[:,:,:,[0]].to(device)
    y = traj_lab[:,:,:,[1]].to(device)
    each_err = torch.sqrt(torch.pow(x - muX, 2) + torch.pow(y - muY, 2))
    sum_target_err = torch.sum(each_err,dim=2)
    error_mask = torch.floor(torch.sum(labels_mask,dim=2)/12)
    if torch.sum(error_mask)==0:
        return torch.tensor([0.])
    train_error = torch.sum(torch.mul(sum_target_err,error_mask))/(torch.sum(error_mask)*12)
    return train_error

def cal_metric(traj_lab,outputs_coord,labels_mask):
    muX = outputs_coord[:,:,:,[0]].to(device)
    muY = outputs_coord[:,:,:,[1]].to(device)
    x = traj_lab[:,:,:,[0]].to(device)
    y = traj_lab[:,:,:,[1]].to(device)
    each_err = torch.sqrt(torch.pow(x - muX, 2) + torch.pow(y - muY, 2))
    mean_target_err = torch.mean(each_err,dim=2)
    final_target_err = each_err[:,:,-1,:]
    error_mask = torch.floor(torch.sum(labels_mask,dim=2)/12)
    if torch.sum(error_mask)==0:
        return torch.tensor([0.]),torch.tensor([0.])
    mean_error = torch.sum(torch.mul(mean_target_err,error_mask))/(torch.sum(error_mask))
    final_error = torch.sum(torch.mul(final_target_err,error_mask))/(torch.sum(error_mask))
    return mean_error,final_error

In [None]:
train_loss=[]
val_loss=[]
validation_loss=[]

from torch.nn import functional as f

def get_input_data(data,lane_subgraph):
    B = data['HISTORY'].shape[0]
    pos = data['NORM_CENTER']
    traj = data['HISTORY']
    lane = data['NEW_LANES']
    traj_valid_len = data['VALID_LEN'][:,0]
    max_agent_num = torch.max(traj_valid_len)
    lane_valid_len = data['VALID_LEN'][:,1]
    max_lane_num = torch.max(lane_valid_len)
    
    social_mask = preprocess_traj(data['HISTORY'],B,traj_valid_len,max_agent_num)
    lane_vec, lane_mask = preprocess_lane(lane_subgraph,lane,B,lane_valid_len,max_lane_num)
    
    traj_lab = torch.zeros((0,data['FUTURE'].shape[1],12,2)).to(device)
    labels = torch.zeros((0,data['FUTURE'].shape[1],12,2)).to(device)
    for i in range(len(data['FUTURE'])):
        gt = data['FUTURE'][i,:,:,:2].unsqueeze(0)
        labels = torch.cat([labels,gt],dim=0)
        gt = gt.cumsum(axis=-2)
        traj_lab = torch.cat([traj_lab,gt],dim=0)
    labels_mask = data['FUTURE'][:,:,:,[-1]]
    
    traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,traj_lab,labels_mask = \
            Variable(traj.to(device)),Variable(pos.to(device)),Variable(max_agent_num.to(device)),\
            Variable(social_mask.to(device)),Variable(lane_vec.to(device)),Variable(lane_mask.to(device)),\
            Variable(traj_lab.to(device)),Variable(labels_mask.to(device))
    
    return traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,traj_lab,labels_mask

def train_model(PredictionModel,prediction_header,lane_subgraph,TraAttM,domain_dis,p_optimizer,d_optimizer,num_epochs,log_interval=25,scheduler=None):
    since = time.time()
    best_prediction_model_wts = copy.deepcopy(PredictionModel.state_dict())
    best_prediction_header_wts = copy.deepcopy(prediction_header.state_dict())
    best_lane_subgraph_wts = copy.deepcopy(lane_subgraph.state_dict())
    best_TraAttM_wts = copy.deepcopy(TraAttM.state_dict())
    best_domain_dis_wts = copy.deepcopy(domain_dis.state_dict())
    
    best_loss = 100
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        iteration = 0
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                PredictionModel.train()
                lane_subgraph.train()
                prediction_header.train()
                TraAttM.train()
                domain_dis.train()

                start_steps = epoch * dataset_sizes['train_s']
                total_steps = 50 * dataset_sizes['train_s']
                
                running_loss = 0.0
                
                for batch_idx, (sdata1, sdata2, sdata3) in enumerate(zip(dataloaders['train_s1'],dataloaders['train_s2'],dataloaders['train_s3'])):
                    traj_s1,pos_s1,max_agent_num_s1,social_mask_s1,lane_vec_s1,lane_mask_s1,traj_lab_s1,labels_mask_s1 = get_input_data(sdata1,lane_subgraph)
                    traj_s2,pos_s2,max_agent_num_s2,social_mask_s2,lane_vec_s2,lane_mask_s2,traj_lab_s2,labels_mask_s2 = get_input_data(sdata2,lane_subgraph)
                    traj_s3,pos_s3,max_agent_num_s3,social_mask_s3,lane_vec_s3,lane_mask_s3,traj_lab_s3,labels_mask_s3 = get_input_data(sdata3,lane_subgraph)
                    #traj_t,pos_t,max_agent_num_t,social_mask_t,lane_vec_t,lane_mask_t,traj_lab_t,labels_mask_t = get_input_data(tdata,lane_subgraph)
                        
                    B = traj_s1.shape[0]

                    # setup optimizer
                    d_optimizer.zero_grad()
                    p_optimizer.zero_grad()
                        
                    feature_out_s1 = PredictionModel(traj_s1,pos_s1,max_agent_num_s1,social_mask_s1,lane_vec_s1,lane_mask_s1,sdata1)
                    feature_out_s2 = PredictionModel(traj_s2,pos_s2,max_agent_num_s2,social_mask_s2,lane_vec_s2,lane_mask_s2,sdata2)
                    feature_out_s3 = PredictionModel(traj_s3,pos_s3,max_agent_num_s3,social_mask_s3,lane_vec_s3,lane_mask_s3,sdata3)
                    #feature_out_t = PredictionModel(traj_t,pos_t,max_agent_num_t,social_mask_t,lane_vec_t,lane_mask_t,tdata)

                    pred_cum_s1,pred_ori_s1 = prediction_header(feature_out_s1,sdata1,traj_s1)
                    pred_cum_s2,pred_ori_s2 = prediction_header(feature_out_s2,sdata2,traj_s2)
                    pred_cum_s3,pred_ori_s3 = prediction_header(feature_out_s3,sdata3,traj_s3)
                    #pred_cum_t,pred_ori_t = prediction_header(feature_out_t,tdata,traj_t)
                        
                    D_inputs_s1 = TraAttM(feature_out_s1,social_mask_s1)
                    D_inputs_s2 = TraAttM(feature_out_s2,social_mask_s2)
                    D_inputs_s3 = TraAttM(feature_out_s3,social_mask_s3)
                    #D_inputs_t = TraAttM(feature_out_t,social_mask_t)
                    
                    fake_inputs_t = np.random.laplace(0, 1, size=D_inputs_s1.shape).astype('float32')
                    fake_inputs_t  = torch.tensor(fake_inputs_t).to(device)
                    
                    D_pred_s1 = domain_dis(D_inputs_s1)
                    D_pred_s2 = domain_dis(D_inputs_s2)
                    D_pred_s3 = domain_dis(D_inputs_s3)
                    #D_pred_t = domain_dis(D_inputs_t)
                    D_pred_t = domain_dis(fake_inputs_t)
                    
                    #mmd loss calculation
                    mmd_s12 = cal_mmd(D_inputs_s1,D_inputs_s2)
                    mmd_s13 = cal_mmd(D_inputs_s1,D_inputs_s3)
                    mmd_s23 = cal_mmd(D_inputs_s2,D_inputs_s3)
                    
                    gradient_penalty_s1 = compute_gradient_penalty(domain_dis, fake_inputs_t, D_inputs_s1)
                    gradient_penalty_s2 = compute_gradient_penalty(domain_dis, fake_inputs_t, D_inputs_s2)
                    gradient_penalty_s3 = compute_gradient_penalty(domain_dis, fake_inputs_t, D_inputs_s3)
                    
                    D_loss_s1 = torch.mean(D_pred_t)-torch.mean(D_pred_s1)+lambda_gp*gradient_penalty_s1
                    D_loss_s2 = torch.mean(D_pred_t)-torch.mean(D_pred_s2)+lambda_gp*gradient_penalty_s2
                    D_loss_s3 = torch.mean(D_pred_t)-torch.mean(D_pred_s3)+lambda_gp*gradient_penalty_s3
                    #discriminator loss
                    train_D_loss = (D_loss_s1+D_loss_s2+D_loss_s3)/12
                    train_D_loss.backward(retain_graph=True)
                    
                    G_loss_s1 = (-torch.mean(D_pred_t)+torch.mean(D_pred_s1))/(12/3)
                    G_loss_s2 = (-torch.mean(D_pred_t)+torch.mean(D_pred_s2))/(12/3)
                    G_loss_s3 = (-torch.mean(D_pred_t)+torch.mean(D_pred_s3))/(12/3)
                    
                    loss_traj_s1 = cal_train_loss(traj_lab_s1,pred_cum_s1,labels_mask_s1)
                    loss_traj_s2 = cal_train_loss(traj_lab_s2,pred_cum_s2,labels_mask_s2)
                    loss_traj_s3 = cal_train_loss(traj_lab_s3,pred_cum_s3,labels_mask_s3)
                    
                    traj_loss = PredictionModel.cal_traj_loss_dann(loss_traj_s1,loss_traj_s2,loss_traj_s3)
                    train_P_loss = traj_loss+(G_loss_s1+G_loss_s2+G_loss_s3)/3+mmd_s12+mmd_s13+mmd_s23
                    train_P_loss.backward()
                    d_optimizer.step()
                    p_optimizer.step()
                    
                    running_loss += train_P_loss.item() * B
                    
                    if batch_idx % log_interval == 0:
                            print('Train Epoch: {} [{}/{} ({:.1f}%)] P Loss: {:.6f} Traj Loss: {:.6f} D Loss: {:.6f} mmd: {:.6f} w1:{:.6f} w2:{:.6f} w3:{:.6f}'.format(
                    epoch, batch_idx * B, dataset_sizes['train_s'],batch_idx * B/dataset_sizes['train_s']*100,train_P_loss.item(),traj_loss.item(),train_D_loss.item()\
                            ,(mmd_s12.item()+mmd_s13.item()+mmd_s23.item())/3,PredictionModel.w1[0].item(),PredictionModel.w2[0].item(),PredictionModel.w3[0].item()))
                    
                epoch_loss = running_loss / dataset_sizes['train_s']
                
            elif phase == 'val':
                PredictionModel.eval()
                lane_subgraph.eval()
                prediction_header.eval()
                TraAttM.eval()
                domain_dis.eval()
                
                running_loss = 0.0
                show_model_error_mean = 0.0
                show_model_error_final = 0.0
                show_model_error_mean_s1 = 0.0
                show_model_error_final_s1 = 0.0
                show_model_error_mean_s2 = 0.0
                show_model_error_final_s2 = 0.0
                show_model_error_mean_s3 = 0.0
                show_model_error_final_s3 = 0.0
                
                for batch_idx, (sdata1, sdata2, sdata3, tdata) in enumerate(zip(dataloaders['val_s1'],dataloaders['val_s2'],
                    dataloaders['val_s3'],dataloaders['val_t'])):
                    
                    traj_s1,pos_s1,max_agent_num_s1,social_mask_s1,lane_vec_s1,lane_mask_s1,traj_lab_s1,labels_mask_s1 = get_input_data(sdata1,lane_subgraph)
                    traj_s2,pos_s2,max_agent_num_s2,social_mask_s2,lane_vec_s2,lane_mask_s2,traj_lab_s2,labels_mask_s2 = get_input_data(sdata2,lane_subgraph)
                    traj_s3,pos_s3,max_agent_num_s3,social_mask_s3,lane_vec_s3,lane_mask_s3,traj_lab_s3,labels_mask_s3 = get_input_data(sdata3,lane_subgraph)
                    traj_t,pos_t,max_agent_num_t,social_mask_t,lane_vec_t,lane_mask_t,traj_lab_t,labels_mask_t = get_input_data(tdata,lane_subgraph)
                    
                    B = traj_s1.shape[0]
                    
                    feature_out_s1 = PredictionModel(traj_s1,pos_s1,max_agent_num_s1,social_mask_s1,lane_vec_s1,lane_mask_s1,sdata1)
                    feature_out_s2 = PredictionModel(traj_s2,pos_s2,max_agent_num_s2,social_mask_s2,lane_vec_s2,lane_mask_s2,sdata2)
                    feature_out_s3 = PredictionModel(traj_s3,pos_s3,max_agent_num_s3,social_mask_s3,lane_vec_s3,lane_mask_s3,sdata3)
                    feature_out_t = PredictionModel(traj_t,pos_t,max_agent_num_t,social_mask_t,lane_vec_t,lane_mask_t,tdata)
                    
                    pred_cum_s1,pred_ori_s1 = prediction_header(feature_out_s1,sdata1,traj_s1)
                    pred_cum_s2,pred_ori_s2 = prediction_header(feature_out_s2,sdata2,traj_s2)
                    pred_cum_s3,pred_ori_s3 = prediction_header(feature_out_s3,sdata3,traj_s3)
                    pred_cum_t,pred_ori_t = prediction_header(feature_out_t,tdata,traj_t)
                    
                    loss_traj_s1 = cal_train_loss(traj_lab_s1,pred_cum_s1,labels_mask_s1)
                    vali_mean_s1, vali_final_s1 = cal_metric(traj_lab_s1,pred_cum_s1,labels_mask_s1)
                    loss_traj_s2 = cal_train_loss(traj_lab_s2,pred_cum_s2,labels_mask_s2)
                    vali_mean_s2, vali_final_s2 = cal_metric(traj_lab_s2,pred_cum_s2,labels_mask_s2)
                    loss_traj_s3 = cal_train_loss(traj_lab_s3,pred_cum_s3,labels_mask_s3)
                    vali_mean_s3, vali_final_s3 = cal_metric(traj_lab_s3,pred_cum_s3,labels_mask_s3)
                        
                    traj_loss = PredictionModel.cal_traj_loss_dann(loss_traj_s1,loss_traj_s2,loss_traj_s3)
                    
                    vali_mean, vali_final = cal_metric(traj_lab_t,pred_cum_t,labels_mask_t)
                    
                    show_model_error_mean += vali_mean.item() * B
                    show_model_error_final += vali_final.item() * B
                    
                    running_loss += traj_loss.item() * B
                    show_model_error_mean_s1 += vali_mean_s1.item() * B
                    show_model_error_final_s1 += vali_final_s1.item() * B
                    show_model_error_mean_s2 += vali_mean_s2.item() * B
                    show_model_error_final_s2 += vali_final_s2.item() * B
                    show_model_error_mean_s3 += vali_mean_s3.item() * B
                    show_model_error_final_s3 += vali_final_s3.item() * B
                #domain_correct = tgt_correct + src_correct1 + src_correct2 + src_correct3
    
                epoch_loss = running_loss / dataset_sizes['val_s']
                epoch_show_loss_mean_s1 = show_model_error_mean_s1 / dataset_sizes['val_s']
                epoch_show_loss_final_s1 = show_model_error_final_s1 / dataset_sizes['val_s']
                epoch_show_loss_mean_s2 = show_model_error_mean_s2 / dataset_sizes['val_s']
                epoch_show_loss_final_s2 = show_model_error_final_s2 / dataset_sizes['val_s']
                epoch_show_loss_mean_s3 = show_model_error_mean_s3 / dataset_sizes['val_s']
                epoch_show_loss_final_s3 = show_model_error_final_s3 / dataset_sizes['val_s']
                epoch_show_loss_mean = show_model_error_mean / dataset_sizes['val_s']
                epoch_show_loss_final = show_model_error_final / dataset_sizes['val_s']
                val_loss.append(epoch_loss)
                print('{} Loss: {:.6f} Source 1 Mean_Error: {:.6f} Source 1 Final_Error {:.6f} Source 2 Mean_Error: {:.6f} Source 2 Final_Error {:.6f} Source 3 Mean_Error: {:.6f} Source 3 Final_Error {:.6f}'\
                      .format('val',epoch_loss,epoch_show_loss_mean_s1,epoch_show_loss_final_s1,epoch_show_loss_mean_s2,epoch_show_loss_final_s2,epoch_show_loss_mean_s3,epoch_show_loss_final_s3))
                print('{} Mean_Error: {:.6f} Final_Error {:.6f}'.format('Target Val Error :',epoch_show_loss_mean,epoch_show_loss_final))
                
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_prediction_model_wts = copy.deepcopy(PredictionModel.state_dict())
                    best_prediction_header_wts = copy.deepcopy(prediction_header.state_dict())
                    best_lane_subgraph_wts = copy.deepcopy(lane_subgraph.state_dict())
                    best_TraAttM_wts = copy.deepcopy(TraAttM.state_dict())
                    best_domain_dis_wts = copy.deepcopy(domain_dis.state_dict())
                    torch.save(PredictionModel.state_dict(), r'PredictionModel_DG_intersectionA(%d).tar' %(epoch+1))
                    torch.save(prediction_header.state_dict(), r'prediction_header_DG_intersectionA(%d).tar' %(epoch+1))
                    torch.save(lane_subgraph.state_dict(), r'lane_subgraph_DG_intersectionA(%d).tar' %(epoch+1))
                    torch.save(TraAttM.state_dict(), r'TraAttM_DG_intersectionA(%d).tar' %(epoch+1))
                    torch.save(domain_dis.state_dict(), r'domain_dis_DG_intersectionA(%d).tar' %(epoch+1))
    
    if(scheduler):
        scheduler.step()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:8f}'.format(best_loss))

    # load best model weights
    torch.save(PredictionModel.state_dict(), r'PredictionModel_DG_intersectionA(Final).tar')
    torch.save(prediction_header.state_dict(), r'prediction_header_DG_intersectionA(Final).tar')
    torch.save(lane_subgraph.state_dict(), r'lane_subgraph_DG_intersectionA(Final).tar')
    torch.save(TraAttM.state_dict(), r'TraAttM_DG_intersectionA(Final).tar')
    torch.save(domain_dis.state_dict(), r'domain_dis_DG_intersectionA(Final).tar')
    
    PredictionModel.load_state_dict(best_prediction_model_wts)
    prediction_header.load_state_dict(best_prediction_header_wts)
    lane_subgraph.load_state_dict(best_lane_subgraph_wts)
    TraAttM.load_state_dict(best_TraAttM_wts)
    domain_dis.load_state_dict(best_domain_dis_wts)
    torch.save(PredictionModel.state_dict(), r'PredictionModel_DG_intersectionA.tar')
    torch.save(prediction_header.state_dict(), r'prediction_header_DG_intersectionA.tar')
    torch.save(lane_subgraph.state_dict(), r'lane_subgraph_DG_intersectionA.tar')
    torch.save(TraAttM.state_dict(), r'TraAttM_DG_intersectionA.tar')
    torch.save(domain_dis.state_dict(), r'domain_dis_DG_intersectionA.tar')


In [None]:
import torch.optim as optim
import copy
from typing import Any,Dict,List

p_optimizer = torch.optim.Adam([
        {'params': PredictionModel.parameters()},
        {'params': lane_subgraph.parameters()},
        {'params': prediction_header.parameters()}], lr=0.0005)
d_optimizer = torch.optim.Adam([
        {'params': TraAttM.parameters()},
        {'params': domain_dis.parameters()}], lr=0.0005)

scheduler = torch.optim.lr_scheduler.StepLR(p_optimizer, 15, gamma=0.1)

train_model(PredictionModel,prediction_header,lane_subgraph,TraAttM,domain_dis,p_optimizer,d_optimizer,num_epochs=50,log_interval=100,scheduler=scheduler)


In [None]:
def get_input_data(data,lane_subgraph):
    B = data['HISTORY'].shape[0]
    pos = data['NORM_CENTER']
    traj = data['HISTORY']
    lane = data['LANE_VECTORS']
    traj_valid_len = data['VALID_LEN'][:,0]
    max_agent_num = torch.max(traj_valid_len)
    lane_valid_len = data['VALID_LEN'][:,1]
    max_lane_num = torch.max(lane_valid_len)
    
    social_mask = preprocess_traj(data['HISTORY'],B,traj_valid_len,max_agent_num)
    lane_vec, lane_mask = preprocess_lane(lane_subgraph,lane,B,lane_valid_len,max_lane_num)
    
    traj_lab = torch.zeros((0,data['FUTURE'].shape[1],12,2)).to(device)
    labels = torch.zeros((0,data['FUTURE'].shape[1],12,2)).to(device)
    for i in range(len(data['FUTURE'])):
        gt = data['FUTURE'][i,:,:,:2].unsqueeze(0)
        labels = torch.cat([labels,gt],dim=0)
        gt = gt.cumsum(axis=-2)
        traj_lab = torch.cat([traj_lab,gt],dim=0)
    labels_mask = data['FUTURE'][:,:,:,[-1]]
    
    traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,traj_lab,labels_mask = \
            Variable(traj.to(device)),Variable(pos.to(device)),Variable(max_agent_num.to(device)),\
            Variable(social_mask.to(device)),Variable(lane_vec.to(device)),Variable(lane_mask.to(device)),\
            Variable(traj_lab.to(device)),Variable(labels_mask.to(device))
    
    return traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,traj_lab,labels_mask

def cal_test_metric(traj_lab,outputs_coord,labels_mask):
    muX = outputs_coord[:,:,:,[0]].to(device)
    muY = outputs_coord[:,:,:,[1]].to(device)
    x = traj_lab[:,:,:,[0]].to(device)
    y = traj_lab[:,:,:,[1]].to(device)
    each_err = torch.sqrt(torch.pow(x - muX, 2) + torch.pow(y - muY, 2))
    #mean_target_err = torch.mean(each_err,dim=2)
    #final_target_err = each_err[:,:,-1,:]
    error_mask = torch.floor(torch.sum(labels_mask,dim=2)/12)
    err_mean = torch.zeros([4])
    err_final = torch.zeros([4])
    if torch.sum(error_mask)==0:
        return err_mean,err_final
    for i in range(4):
        mean_target_err = torch.mean(each_err[:,:,:3*(i+1),:],dim=2)
        final_target_err = each_err[:,:,3*(i+1)-1,:]
        mean_error = torch.sum(torch.mul(mean_target_err,error_mask))/(torch.sum(error_mask))
        final_error = torch.sum(torch.mul(final_target_err,error_mask))/(torch.sum(error_mask))
        err_mean[i] = mean_error
        err_final[i] = final_error
    return err_mean,err_final

def cal_RFDE(outputs,labels):
    muX = outputs[:,:,[0]]
    muY = outputs[:,:,[1]]
    x = labels[:,:,[0]]
    y = labels[:,:,[1]]
    out =  torch.sqrt(torch.pow(x - muX, 2) + torch.pow(y - muY, 2))
    lossSum = torch.sum(out[:,:,0],dim=0)
    err_final_rela_son = torch.zeros([4])
    err_final_rela_mom = torch.zeros([4])
    for i in range(4):
        err_final_rela_son[i] = lossSum[3*(i+1)-1]
        err_final_rela_mom[i] = torch.sum(torch.sqrt(torch.pow(labels[:,3*(i+1)-1,[0]], 2) + torch.pow(labels[:,3*(i+1)-1,[1]], 2)))
        
    return err_final_rela_son,err_final_rela_mom

PredictionModel.eval()   # Set model to evaluate mode
prediction_header.eval()
lane_subgraph.eval()

#running_loss = 0.0
errors_mean = torch.zeros([4])
errors_final = torch.zeros([4])
errors_final_son = torch.zeros([4])
errors_final_mom = torch.zeros([4])

for batch_idx,data in enumerate(dataloaders['test_t']):
    
    traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,traj_lab,labels_mask = get_input_data(data,lane_subgraph)
    B = traj.shape[0]
    
    feature_out = PredictionModel(traj,pos,max_agent_num,social_mask,lane_vec,lane_mask,data)
    pred_cum,pred_ori = prediction_header(feature_out,data,traj)
    #train_P_loss = cal_train_loss(traj_lab_t,pred_cum_t,labels_mask_t)
    #vali_mean, vali_final = cal_metric(traj_lab_t,pred_cum_t,labels_mask_t)
    err_mean, err_final = cal_test_metric(traj_lab,pred_cum,labels_mask)
    
    #running_loss += train_P_loss.item() * B
    errors_mean += (err_mean*B).detach().numpy()
    errors_final += (err_final*B).detach().numpy()
    #domain_correct = tgt_correct + src_correct1 + src_correct2 + src_correct3
    
    error_mask = torch.floor(torch.sum(labels_mask,dim=2)/12)
    remain_cum = pred_cum[error_mask[:,:,0]==1,:,:]
    remain_traj_lab = traj_lab[error_mask[:,:,0]==1,:,:]
    
    err_final_rela_son,err_final_rela_mom = cal_RFDE(remain_cum,remain_traj_lab)
    errors_final_son += (err_final_rela_son).detach().numpy()
    errors_final_mom += (err_final_rela_mom).detach().numpy()
    
#epoch_loss = running_loss / dataset_sizes['test_t']
errors_mean = errors_mean / dataset_sizes['test_t']
errors_final = errors_final / dataset_sizes['test_t']
print('Target Mean Error(m) :',errors_mean)
print('Target Final Error(m) :',errors_final)
print('Target RFDE :',errors_final_son/errors_final_mom)