In [1]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import math


class SftLayer(nn.Module):
    def __init__(self, d_edge, d_model, d_ffn, n_head = 8, dropout = 0.1, update_edge = True):
        super().__init__()
        self.update_edge = update_edge

        self.proj_memory = nn.Sequential(
            nn.Linear(d_model + d_model + d_edge, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(inplace=True)
        )

        if self.update_edge:
            self.proj_edge = nn.Sequential(
                nn.Linear(d_model, d_edge),
                nn.LayerNorm(d_edge),
                nn.ReLU(inplace=True)
            )
            self.norm_edge = nn.LayerNorm(d_edge)

        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=n_head, dropout=dropout, batch_first=False)

        # Feedforward model
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)

        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, node, edge):
        # update node
        x, edge, memory = self._build_memory(node, edge)
        x_prime, _ = self._mha_block(x, memory, attn_mask=None, key_padding_mask=None)
        x = self.norm2(x + x_prime).squeeze()
        x = self.norm3(x + self._ff_block(x))
        return x, edge

    def _build_memory(self, node, edge):
        n_token = node.shape[0]

        # 1. build memory
        src_x = node.unsqueeze(dim=0).repeat([n_token, 1, 1])  # (N, N, d_model)
        tar_x = node.unsqueeze(dim=1).repeat([1, n_token, 1])  # (N, N, d_model)
        tmp = torch.cat([edge, src_x, tar_x], dim=-1)
        memory = self.proj_memory(torch.cat([edge, src_x, tar_x], dim=-1))  # (N, N, d_model)
        # 2. (optional) update edge (with residual)
        if self.update_edge:
            edge = self.norm_edge(edge + self.proj_edge(memory))  # (N, N, d_edge)

        return node.unsqueeze(dim=0), edge, memory

    # multihead attention block
    def _mha_block(self, x, mem, attn_mask, key_padding_mask):
        x, _ = self.multihead_attn(x, mem, mem,
                                   attn_mask=attn_mask,
                                   key_padding_mask=key_padding_mask,
                                   need_weights=False)  # return average attention weights
        return self.dropout2(x), None

    # feed forward block
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)



class FusionNet(nn.Module):
    def __init__(self, d_model, d_edge, n_head=8, n_layers=6, dropout=0.1, update_edge=True):
        super().__init__()
        fusion = []
        for i in range(n_layers):
            need_update_edge = False if i == n_layers - 1 else update_edge
            fusion.append(SftLayer(d_edge=d_edge,
                                   d_model=d_model,
                                   d_ffn=d_model*2,
                                   n_head=n_head,
                                   dropout=dropout,
                                   update_edge=need_update_edge))
        self.fusion = nn.ModuleList(fusion)

    def forward(self, agent_feats, agent_mask, map_feats, map_mask, rpe_feats, rpe_mask):
        """
        agent_feats: batch_size, agent_num, dim
        agent_mask: batch_size, agent_num
        map_feats: batch_size, map_num, dim
        map_mask: batch_size, map_num
        rpe_feats: batch_size, N, N, dim_1
        rpe_mask: batch_size, N, N
        """
        x = torch.cat((agent_feats, map_feats), dim=1) 
        x_mask = torch.cat((agent_mask, map_mask), dim=1)
        batch_size, all_num, dim = x.shape
        agent_num = agent_feats.shape[1]
        
        agents_new, maps_new = list(), list()
        for i in range(batch_size):
            x_frame = x[i]
            x_mask_frame = x_mask[i]
            x_frame = x_frame[x_mask_frame].view(-1, dim) # (valid_num,dim)
            valid_num = x_frame.shape[0]
            rpe_frame = rpe_feats[i][rpe_mask[i]].view(valid_num, valid_num, -1)
            for mod in self.fusion:
                x_frame, rpe_frame= mod(x_frame, rpe_frame)
            out = x_frame.new_zeros(all_num, x_frame.shape[-1])
            out[x_mask_frame] = x_frame
            agents_new.append(out[:agent_num])
            maps_new.append(out[agent_num:])
        agent_feats = torch.stack(agents_new)
        map_feats = torch.stack(maps_new)
        return agent_feats, map_feats

In [2]:
class ResMLP(nn.Module):
    def __init__(self, in_channel, out_channel, hidden=64, bias=True, activation="relu", norm='layer'):
        super().__init__()

        # define the activation function
        if activation == "relu":
            act_layer = nn.ReLU
        elif activation == "relu6":
            act_layer = nn.ReLU6
        elif activation == "leaky":
            act_layer = nn.LeakyReLU
        elif activation == "prelu":
            act_layer = nn.PReLU
        else:
            raise NotImplementedError

        # define the normalization function
        if norm == "layer":
            norm_layer = nn.LayerNorm
        elif norm == "batch":
            norm_layer = nn.BatchNorm1d
        else:
            raise NotImplementedError

        # insert the layers
        self.linear1 = nn.Linear(in_channel, hidden, bias=bias)
        self.linear2 = nn.Linear(hidden, out_channel, bias=bias)

        self.norm1 = norm_layer(hidden)
        self.norm2 = norm_layer(out_channel)

        self.act1 = act_layer(inplace=True)
        self.act2 = act_layer(inplace=True)

        self.shortcut = None
        if in_channel != out_channel:
            self.shortcut = nn.Sequential(
                nn.Linear(in_channel, out_channel, bias=bias),
                norm_layer(out_channel)
            )

    def forward(self, x):
        out = self.linear1(x)
        out = self.norm1(out)
        out = self.act1(out)
        out = self.linear2(out)
        out = self.norm2(out)

        if self.shortcut:
            out += self.shortcut(x)
        else:
            out += x
        return self.act2(out)


class TrajDecoder(nn.Module):
    def __init__(self, input_size, hidden_size, n_order=7, m=50):
        super().__init__()
        self.m = m
        self.taregt_prob_layer = nn.Sequential(
            ResMLP(input_size + 2, hidden_size, hidden_size),
            nn.Linear(hidden_size, 1)
        )
        self.target_offset_layer = nn.Sequential(
            ResMLP(input_size + 2, hidden_size, hidden_size),
            nn.Linear(hidden_size, 2)
        )
        self.motion_estimator_layer = nn.Sequential(
            ResMLP(input_size + 2, hidden_size, hidden_size),
            nn.Linear(hidden_size, n_order*2)
        )
        
        self.traj_prob_layer = nn.Sequential(
            ResMLP(input_size + (n_order+1)*2, hidden_size, hidden_size),
            nn.Linear(hidden_size, 1),
            nn.Softmax(dim=2)
        )
    def forward(self, feats, tar_candidate, target_gt, candidate_mask=None):
        """
        feats: B,N,D
        tar_candidate: B, N, M, 2
        target_gt:  B, N, 1, 2
        candidate_mask: B, N, M
        """
        
        B, N, M, _ = tar_candidate.shape
        feats_repeat = feats.unsqueeze(2).repeat(1, 1, M, 1)

        # stack the target candidates to the end of input feature
        feats_tar = torch.cat([feats_repeat, tar_candidate], dim=-1) # B, N, M, D+2
        # compute probability for each candidate
        prob_tensor = self.taregt_prob_layer(feats_tar).squeeze(-1) # B,N,M
        target_probs = self.masked_softmax(prob_tensor, candidate_mask, dim=-1) # B, N, M
        
        tar_offsets = self.target_offset_layer(feats_tar) # B, N, M, 2
        
        m = min(target_probs.shape[2], self.m)
        _, topk_indices = target_probs.topk(m, dim=2)
        tar_indices = topk_indices.unsqueeze(-1).expand(topk_indices.shape[0], 
                                                        topk_indices.shape[1], 
                                                        topk_indices.shape[2], 
                                                        tar_candidate.shape[-1])
        target_pred_se = torch.gather(tar_candidate, dim=2, index=tar_indices) # B, N, m, 2
        offset_pred_se = torch.gather(tar_offsets, dim=2, index=tar_indices) # B, N, m, 2
        
        target_pred = target_pred_se + offset_pred_se
        feat_indices = topk_indices.unsqueeze(-1).expand(topk_indices.shape[0], 
                                                        topk_indices.shape[1], 
                                                        topk_indices.shape[2], 
                                                        feats_repeat.shape[-1])
        feats_traj = torch.gather(feats_repeat, dim=2, index=feat_indices) # B, N, m, D
        feats_traj = torch.cat([feats_traj, target_pred], dim=-1) # B, N, m, D+2

        param = self.motion_estimator_layer(feats_traj) # B,N,m,n_order*2
        feats_traj = torch.cat([feats_traj, param], dim=-1)
        traj_probs = self.traj_prob_layer(feats_traj).squeeze(-1) # B, N, m
        
        # 预测轨迹(teacher_force)
        feat_traj_with_gt = torch.cat([feats.unsqueeze(2), target_gt], dim=-1) # B, N, 1, D+2
        param_with_gt = self.motion_estimator_layer(feat_traj_with_gt) # B,N,1,n_order*2
        
        return target_probs, target_pred, tar_offsets, param, param_with_gt, traj_probs
    
    
    def masked_softmax(self, vector, mask, dim=-1, memory_efficient=True, mask_fill_value=-1e32):
        if mask is None:
            result = F.softmax(vector, dim=dim)
        else:
            mask = mask.float()
            while mask.dim() < vector.dim():
                mask = mask.unsqueeze(-1)
            if not memory_efficient:
                # To limit numerical errors from large vector elements outside the mask, we zero these out.
                result = F.softmax(vector * mask, dim=dim)
                result = result * mask
                result = result / (result.sum(dim=dim, keepdim=True) + 1e-13)
                result = result.masked_fill((1 - mask).bool(), 0.0)
            else:
                masked_vector = vector.masked_fill((1 - mask).bool(), mask_fill_value)
                result = F.softmax(masked_vector, dim=dim)
                result = result.masked_fill((1 - mask).bool(), 0.0)
        return result



class PolylineNet(nn.Module):
    def __init__(self, input_size, hidden_size, out_size=None):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.LayerNorm(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size, bias=True),
            nn.LayerNorm(hidden_size),
            nn.ReLU(inplace=True)
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(hidden_size*2, hidden_size, bias=True),
            nn.LayerNorm(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, hidden_size, bias=True),
            nn.LayerNorm(hidden_size),
            nn.ReLU(inplace=True)
        )
        
        if out_size is not None:
            self.fc_out = nn.Sequential(
                nn.Linear(hidden_size, hidden_size, bias=True),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_size, out_size, bias=True)
            )
        else:
            self.fc_out = None 
        

    def forward(self, polylines, polylines_mask):
        """
        Args:
            polylines (batch_size, num_polylines, num_points_each_polylines, C):
            polylines_mask (batch_size, num_polylines, num_points_each_polylines):

        Returns:
        """
        bs, poly_num, point_num, C = polylines.shape
        poly_feat_valid = self.fc1(polylines[polylines_mask])  # (N, C)
        poly_feat = polylines.new_zeros(bs, poly_num, point_num, poly_feat_valid.shape[-1])
        poly_feat[polylines_mask] = poly_feat_valid
        
        # get global feature
        pooled_feat = poly_feat.max(dim=2)[0]
        poly_feat = torch.cat((poly_feat, pooled_feat[:, :, None, :].repeat(1, 1, point_num, 1)), dim=-1)
        # mlp
        poly_feat_valid = self.fc2(poly_feat[polylines_mask])
        feat_buffers = poly_feat.new_zeros(bs, poly_num, point_num, poly_feat_valid.shape[-1])
        feat_buffers[polylines_mask] = poly_feat_valid
        # max-pooling
        feat_buffers = feat_buffers.max(dim=2)[0]  # (batch_size, num_polylines, C)
        
        # out-mlp 
        if self.fc_out is not None:
            valid_mask = (polylines_mask.sum(dim=-1) > 0)
            feat_buffers_valid = self.fc_out(feat_buffers[valid_mask])  # (N, C)
            feat_buffers = feat_buffers.new_zeros(bs, poly_num, feat_buffers_valid.shape[-1])
            feat_buffers[valid_mask] = feat_buffers_valid
        return feat_buffers
    
class PlanNet(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.plan_mlp = PolylineNet(input_size, hidden_size)
        self.gate = nn.Sequential(
            nn.Linear(hidden_size*2, 64),
            nn.ReLU(inplace=True), 
            nn.Linear(64, 1), 
            nn.Sigmoid())

    def forward(self, agent_feats, plan_traj, plan_traj_mask):
        """
        plan_traj: B,N,d,2
        plan_traj_mask: B,N,d
        """
        batch_size, agent_num, _ = agent_feats.size() 
        ego_feat = agent_feats[:, 0].unsqueeze(1).expand(-1, agent_num, -1)
        gate_feat = torch.cat((agent_feats, ego_feat), dim=-1)
        gate = self.gate(gate_feat).squeeze(-1) # B*N
        
        if plan_traj.dim() == 3:
            plan_traj = plan_traj.unsqueeze(1)
            plan_traj_mask = plan_traj_mask.unsqueeze(1)
        plan_feat = self.plan_mlp(plan_traj, plan_traj_mask).expand(-1, agent_num, -1) # B*N*D
        plan_feat = torch.einsum('bnd,bn->bnd', plan_feat, gate)
        agent_feats = agent_feats + plan_feat
        return agent_feats, gate


class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = self.args.device
        self.agent_net = PolylineNet(  
            input_size=self.args.agent_input_size,
            hidden_size=self.args.agent_hidden_size,
            out_size=self.args.d_model)
        
        self.map_net = PolylineNet(
            input_size=self.args.map_input_size,
            hidden_size=self.args.map_hidden_size,
            out_size=self.args.d_model)
        
        self.rpe_net = nn.Sequential(
            nn.Linear(self.args.rpe_input_size, self.args.rpe_hidden_size),
            nn.LayerNorm(self.args.rpe_hidden_size),
            nn.ReLU(inplace=True)
        )
        
        self.fusion_net = FusionNet(
            d_model=self.args.d_model, 
            d_edge=self.args.rpe_hidden_size, 
            dropout=self.args.dropout,
            update_edge=self.args.update_edge)
        
        self.plan_net = PlanNet(
            input_size=self.args.plan_input_size, 
            hidden_size=self.args.d_model
        )
        
        self.traj_decoder = TrajDecoder(
            input_size=self.args.d_model,
            hidden_size=self.args.decoder_hidden_size, 
            n_order=self.args.bezier_order, 
            m=self.args.m)
        
        self.mat_T = self._get_T_matrix_bezier(n_order=self.args.bezier_order, n_step=50).to(self.device)
        
        self.relation_decoder = None
        if self.args.init_weights:
            self.apply(self._init_weights)
        
    def forward(self, batch_dict):
        agent_polylines, agent_polylines_mask = batch_dict['agent_polylines'], batch_dict['agent_polylines_mask'].bool() 
        map_polylines, map_polylines_mask = batch_dict['map_polylines'], batch_dict['map_polylines_mask'].bool() 
        agent_feats = self.agent_net(agent_polylines, agent_polylines_mask)
        map_feats = self.map_net(map_polylines, map_polylines_mask)
        
        rpe, rpe_mask = batch_dict['rpe'], batch_dict['rpe_mask'].bool()
        batch_size, N, _, _= rpe.shape
        rpe_feats_valid = self.rpe_net(rpe[rpe_mask])  # (N, C)
        rpe_feats = rpe_feats_valid.new_zeros(batch_size, N, N, rpe_feats_valid.shape[-1])
        rpe_feats[rpe_mask] = rpe_feats_valid
        
        agent_mask = (agent_polylines_mask.sum(dim=-1) > 0)  
        map_mask = (map_polylines_mask.sum(dim=-1) > 0)  
        agent_feats, map_feat = self.fusion_net(agent_feats, agent_mask, map_feats, map_mask, rpe_feats, rpe_mask)
        
        plan_traj, plan_traj_mask = batch_dict['plan_traj'], batch_dict['plan_traj_mask']
        agent_feats, gate = self.plan_net(agent_feats, plan_traj, plan_traj_mask)
        
        tar_candidate, candidate_mask = batch_dict['tar_candidate'], batch_dict['candidate_mask'].bool() 
        target_gt = batch_dict['gt_preds'][:, :, -1, :2]
        target_gt = target_gt.view(target_gt.shape[0], target_gt.shape[1], 1, 2) # B, N, 1, 2
        target_probs, pred_targets, pred_offsets, param, param_with_gt, traj_probs = self.traj_decoder(agent_feats, tar_candidate, target_gt, candidate_mask)
           
        # 由贝塞尔控制点反推轨迹
        bezier_param = torch.cat([param, pred_targets], dim=-1) # B, N, m, (n_order+1)*2
        bezier_control_points = bezier_param.view(bezier_param.shape[0],
                                                  bezier_param.shape[1],
                                                  bezier_param.shape[2], -1, 2) # B, N, m, n_order+1, 2
        trajs = torch.matmul(self.mat_T, bezier_control_points) # B,N,m,future_steps,2
        
        bezier_param_with_gt = torch.cat([param_with_gt, target_gt], dim=-1) # B, N, 1, (n_order+1)*2
        bezier_control_points_with_gt = bezier_param_with_gt.view(bezier_param_with_gt.shape[0],
                                                          bezier_param_with_gt.shape[1],
                                                          bezier_param_with_gt.shape[2], -1, 2) # B, N, 1, n_order+1, 2
        traj_with_gt = torch.matmul(self.mat_T, bezier_control_points_with_gt) # B,N,1,future_steps,2

        return {"target_probs": target_probs,
                "pred_offsets": pred_offsets,
                "traj_with_gt": traj_with_gt,
                "trajs": trajs,
                "traj_probs": traj_probs
               }  
        
    def _get_T_matrix_bezier(self, n_order, n_step):
        ts = np.linspace(0.0, 1.0, n_step, endpoint=True)
        T = []
        for i in range(n_order + 1):
            coeff = math.factorial(n_order) // (math.factorial(i) * math.factorial(n_order - i)) * (1.0 - ts)**(n_order - i) * ts**i
            # coeff = math.comb(n_order, i) * (1.0 - ts)**(n_order - i) * ts**i
            T.append(coeff)
        return torch.Tensor(np.array(T).T)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
        elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.MultiheadAttention):
            if m.in_proj_weight is not None:
                fan_in = m.embed_dim
                fan_out = m.embed_dim
                bound = (6.0 / (fan_in + fan_out)) ** 0.5
                nn.init.uniform_(m.in_proj_weight, -bound, bound)
            else:
                nn.init.xavier_uniform_(m.q_proj_weight)
                nn.init.xavier_uniform_(m.k_proj_weight)
                nn.init.xavier_uniform_(m.v_proj_weight)
            if m.in_proj_bias is not None:
                nn.init.zeros_(m.in_proj_bias)
            nn.init.xavier_uniform_(m.out_proj.weight)
            if m.out_proj.bias is not None:
                nn.init.zeros_(m.out_proj.bias)
            if m.bias_k is not None:
                nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
            if m.bias_v is not None:
                nn.init.normal_(m.bias_v, mean=0.0, std=0.02)

In [32]:
def parse_arguments():
    """Arguments for running the baseline.

    Returns:
        parsed arguments
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--agent_input_size', default=4, type=int)
    parser.add_argument('--agent_hidden_size', default=16, type=int)
    parser.add_argument('--map_input_size', default=4, type=int)
    parser.add_argument('--map_hidden_size', default=16, type=int)
    parser.add_argument('--d_model', default=128, type=int)

    parser.add_argument('--rpe_input_size', default=4, type=int)
    parser.add_argument('--rpe_hidden_size', default=32, type=int)
    parser.add_argument('--dropout', default=0.1, type=float)
    
    parser.add_argument('--plan_input_size', default=4, type=int)
    parser.add_argument('--decoder_hidden_size', default=128, type=int)
    parser.add_argument('--bezier_order', default=7, type=int)
    parser.add_argument('--m', default=50, type=int)
    
    parser.add_argument('--device', default="cpu", type=str) 
    parser.add_argument('--update_edge', default=True, type=bool)
    parser.add_argument('--init_weights', default=True, type=bool)
   
    return parser.parse_args([])
if __name__ == "__main__":
    args = parse_arguments()
    model = Model(args)
    agent_feats = torch.randn([2, 2, 5, 4])
    agent_mask =  torch.tensor([[[True,  True,  True, True, False],
                     [ True,  True,  True, True, False]],
                    [[True, True, False,  False, False],
                     [False,  False,  False, False, False]]])
    map_feats = torch.randn([2, 3, 5, 4])
    map_mask =  torch.tensor([[[True,  True,  True, True, False],
             [ True,  True,  True, True, False],
             [False,  False,  False, False, False]],
            [[True, True, False,  False, False],
             [False,  False,  False, False, False],
             [False,  False,  False, False, False]]])
    rpe_feats = torch.randn([2, 5, 5, 4])
    rpe_mask = torch.tensor([[[True,  True,  True, True, False],
             [ True,  True,  True, True, False],
             [ True,  True,  True, True, False],
             [ True,  True,  True, True, False],
             [False,  False,  False, False, False]],
            [[True, True, False,  False, False],
             [True, True, False,  False, False],
             [False,  False,  False, False, False],
             [False,  False,  False, False, False],
             [False,  False,  False, False, False]]])
    tar_candidate = torch.randn(2, 2, 300, 2)
    candidate_mask = torch.randn(2, 2, 300)>0.5
    gt_preds = torch.randn(2, 2, 50, 2)
    
    plan_traj = torch.randn(2, 2, 50, 4)
    plan_traj_mask = torch.randn(2, 2, 50)>0.5
    
    gt_candts = torch.randn(2, 2, 300)
    gt_tar_offset = torch.randn(2, 2, 2)
    
    batch_dict = {'agent_polylines':agent_feats, 'agent_polylines_mask':agent_mask,
                  'map_polylines':map_feats, 'map_polylines_mask':map_mask, 'rpe':rpe_feats,
                  'rpe_mask':rpe_mask, 'tar_candidate': tar_candidate, 'candidate_mask': candidate_mask, 'gt_preds':gt_preds,
                  'plan_traj':plan_traj, 'plan_traj_mask':plan_traj_mask, 'gt_candts':gt_candts, 'gt_tar_offset':gt_tar_offset, 
                  'gt_preds':gt_preds
                 }
    output_dict = model(batch_dict)

In [56]:

class Loss(nn.Module):
    def __init__(self,):
        """
        reduction: loss reduction, "sum" or "mean" (batch mean);
        """
        super(Loss, self).__init__()
        self.lambda1 = 1
        self.lambda2 = 0.1
        self.lambda3 = 1
        self.lambda4 = 0.5

        self.temper = 0.01
        self.d_safe = 3.0
    
    def forward(self, batch_dict, output_dict, epoch=1):
        loss = 0.0
        pred_mask = (batch_dict['candidate_mask'].sum(dim=-1) > 0) # B, N    
        # 1、target_loss
        gt_probs = batch_dict['gt_candts'].float() # B, N, M 

        gt_probs = gt_probs[pred_mask] # S, M
        pred_probs = output_dict['target_probs'][pred_mask]
        pred_num = pred_probs.shape[0]
        cls_loss = F.binary_cross_entropy(pred_probs, gt_probs, reduction='sum')/pred_num
        
        gt_tar_offset = batch_dict["gt_tar_offset"] # B, N ,2
        gt_tar_offset = gt_tar_offset[pred_mask] # S, 2
        gt_idx = gt_probs.nonzero()[:pred_num] 
        print(gt_idx.shape)
        pred_offsets = output_dict['pred_offsets'][pred_mask] # S, M, 2
        pred_offsets = pred_offsets[gt_idx[:, 0], gt_idx[:, 1]] # S, 2
        print(pred_offsets.shape)
        offset_loss = F.smooth_l1_loss(pred_offsets, gt_tar_offset, reduction='sum')/pred_num
        
        # 2、motion reg loss
        traj_with_gt = output_dict['traj_with_gt'].squeeze(2)[pred_mask] # S, 50, 2
        gt_trajs = batch_dict['gt_preds'][pred_mask] # S, 50, 2
        reg_loss = F.smooth_l1_loss(traj_with_gt, gt_trajs, reduction="sum")/pred_num
        
        # 3、score_loss
        pred_trajs = output_dict['trajs'][pred_mask] # S, m, 50, 2
        S, m, horizon, dim = pred_trajs.shape
        pred_trajs = pred_trajs.view(S, m , horizon*dim)
        gt_trajs = gt_trajs.view(S, horizon*dim)
        score_gt = F.softmax(-self.distance_metric(pred_trajs,  gt_trajs)/self.temper, dim=-1).detach()
        score_loss = F.binary_cross_entropy(output_dict['traj_probs'][pred_mask], score_gt, reduction='sum')/pred_num

        
        if epoch > 10:
            pred_trajs_t = pred_trajs.view(S, m, horizon, dim)
            plan_traj = batch_dict["plan_traj"][pred_mask][:, :, :2] # S, 50, 2
            plan_traj_mask = batch_dict['plan_traj_mask'][pred_mask] # S, 50
  
            distances = torch.sqrt(torch.sum((pred_trajs_t - plan_traj.unsqueeze(1))**2, dim=-1)) # S, m, 50
            masked_distances = distances.masked_fill(~plan_traj_mask.unsqueeze(1), 1000) # S, m, 50
            min_distances = torch.min(masked_distances, dim=2)[0] # S, m

            w_min_distances = output_dict['traj_probs'][pred_mask] * min_distances  # S, m
            min_distances_sum = torch.sum(w_min_distances, dim=-1)
            min_distances_sum = torch.clamp(min_distances_sum, max=self.d_safe)
            safety_loss = -torch.mean(min_distances_sum)

            loss = self.lambda1 * (cls_loss + offset_loss) + self.lambda2 * reg_loss + self.lambda3 * score_loss + self.lambda4 * safety_loss
            loss_dict = {"tar_cls_loss": self.lambda1*cls_loss,
                         "tar_offset_loss": self.lambda1*offset_loss,
                         "traj_loss": self.lambda2*reg_loss,
                         "score_loss": self.lambda3*score_loss,
                         "safety_loss": self.lambda4 * safety_loss
                        }
        else:
            loss = self.lambda1 * (cls_loss + offset_loss) + self.lambda2 * reg_loss + self.lambda3 * score_loss
            loss_dict = {"tar_cls_loss": self.lambda1*cls_loss,
                         "tar_offset_loss": self.lambda1*offset_loss,
                         "traj_loss": self.lambda2*reg_loss,
                         "score_loss": self.lambda3*score_loss}
        return loss, loss_dict
    
     
    def distance_metric(self, traj_candidate: torch.Tensor, traj_gt: torch.Tensor):
        """
        compute the distance between the candidate trajectories and gt trajectory
        :param traj_candidate: torch.Tensor, [batch_size, M, horizon * 2] or [M, horizon * 2]
        :param traj_gt: torch.Tensor, [batch_size, horizon * 2] or [1, horizon * 2]
        :return: distance, torch.Tensor, [batch_size, M] or [1, M]
        """
        assert traj_gt.dim() == 2, "Error dimension in ground truth trajectory"
        if traj_candidate.dim() == 3:
            # batch case
            pass

        elif traj_candidate.dim() == 2:
            traj_candidate = traj_candidate.unsqueeze(1)
        else:
            raise NotImplementedError

        assert traj_candidate.size()[2] == traj_gt.size()[1], "Miss match in prediction horizon!"

        _, M, horizon_2_times = traj_candidate.size()
        dis = torch.pow(traj_candidate - traj_gt.unsqueeze(1), 2).view(-1, M, int(horizon_2_times / 2), 2)

        dis, _ = torch.max(torch.sum(dis, dim=3), dim=2)

        return dis

In [57]:
loss = Loss()
loss, loss_dict = loss(batch_dict, output_dict, 11)

torch.Size([4, 2])
torch.Size([4, 2])


In [58]:
loss_dict

{'tar_cls_loss': tensor(-129.0979, grad_fn=<MulBackward0>),
 'tar_offset_loss': tensor(1.0893, grad_fn=<MulBackward0>),
 'traj_loss': tensor(5.4936, grad_fn=<MulBackward0>),
 'score_loss': tensor(4.9789, grad_fn=<MulBackward0>),
 'safety_loss': tensor(-0.1335, grad_fn=<MulBackward0>)}