## Input

In [1]:
import os
import csv
import torch
import argparse
import numpy as np
from tqdm import tqdm
from torch import nn, optim
from GameFormer.predictor import GameFormer
from torch.utils.data import DataLoader
from GameFormer.train_utils import *

num_neighbors = 20
batch_size = 32
# set up data loaders
train_path = 'nuplan/processed_data/train'
train_files = [f for d in os.listdir(train_path) for f in glob.glob(os.path.join(train_path, d, "*.npz"))]
train_set = DrivingData(train_files, num_neighbors)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())

In [2]:
with tqdm(train_loader, desc="Training", unit="batch") as data_epoch:
    for batch in data_epoch:
        # prepare data
        inputs = {
            'ego_agent_past': batch[0].to('cuda'),
            'neighbor_agents_past': batch[1].to('cuda'),
            'map_lanes': batch[2].to('cuda'),
            'map_crosswalks': batch[3].to('cuda'),
            'route_lanes': batch[4].to('cuda')
        }

        ego_future = batch[5].to('cuda')
        neighbors_future = batch[6].to('cuda')
        neighbors_future_valid = torch.ne(neighbors_future[..., :2], 0)
        
        break

Training:   0%|          | 0/1690 [00:26<?, ?batch/s]


## Encoder

### agent encoding

In [3]:
from GameFormer.predictor_modules import *

# shape = (32, T, 7) = (x, y, heading, vx, vy, ax, ay)
ego = inputs['ego_agent_past']   
# shape = (32, N, T, 11) = (x, y, heading, vx, vy, yaw, length, width, 1, 0, 0)   
neighbors = inputs['neighbor_agents_past']    
# shape = (32, 1+N, T, 5) = (x, y, heading, vx, vy)
actors = torch.cat([ego[:, None, :, :5], neighbors[..., :5]], dim=1)

ego_encoder = AgentEncoder(agent_dim=7).cuda()  # LSTM
# shape = (32, 256)
encoded_ego = ego_encoder(ego)

agent_encoder = AgentEncoder(agent_dim=11).cuda()  # LSTM
# shape = (N, 32, 256)
encoded_neighbors = [agent_encoder(neighbors[:, i]) for i in range(neighbors.shape[1])]

# shape = (32, N+1, 256)
encoded_actors = torch.stack([encoded_ego] + encoded_neighbors, dim=1)  
actors_mask = torch.eq(actors[:, :, -1].sum(-1), 0)

### map encoding

In [4]:
_lane_len = 50
_lane_feature = 7
_crosswalk_len = 30
_crosswalk_feature = 3

map_lanes = inputs['map_lanes']    # shape = (32, 40, 50, 7)
map_crosswalks = inputs['map_crosswalks']   # shape = (32, 5, 30, 3)


lane_encoder = VectorMapEncoder(_lane_feature, _lane_len).cuda()
# shape = (32, 200, 256)
encoded_map_lanes, lanes_mask = lane_encoder(map_lanes)

crosswalk_encoder = VectorMapEncoder(_crosswalk_feature, _crosswalk_len).cuda()
# shape = (32, 15, 256)
encoded_map_crosswalks, crosswalks_mask = crosswalk_encoder(map_crosswalks)

### attention fusion encoding

In [5]:
# shape = (32, 236, 256)
input = torch.cat([encoded_actors, encoded_map_lanes, encoded_map_crosswalks], dim=1)
mask = torch.cat([actors_mask, lanes_mask, crosswalks_mask], dim=1)

dim, layers, heads, dropout = 256, 6, 8, 0.1
attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
                                                activation='gelu', dropout=dropout, batch_first=True)
fusion_encoder = nn.TransformerEncoder(attention_layer, layers).cuda()
encoding = fusion_encoder(input, src_key_padding_mask=mask)

In [6]:
encoder_outputs = {
    'actors': actors,
    'encoding': encoding,
    'mask': mask,
    'route_lanes': inputs['route_lanes']
}

## Decoder

In [7]:
decoder_outputs = {}

# shape = (32, 1+N, 5)
current_states = encoder_outputs['actors'][:, :, -1]
# shape = (32, 200+15+21, 256)
encoding, mask = encoder_outputs['encoding'], encoder_outputs['mask']

In [8]:
neighbors = 10
modalities = 6
levels = 3

class GMMPredictor(nn.Module):
    def __init__(self, modalities=6):
        super(GMMPredictor, self).__init__()
        self.modalities = modalities
        self._future_len = 80
        self.gaussian = nn.Sequential(nn.Linear(256, 512), nn.ELU(), nn.Dropout(0.1), nn.Linear(512, self._future_len*4))
        self.score = nn.Sequential(nn.Linear(256, 64), nn.ELU(), nn.Linear(64, 1))
    
    def forward(self, input):
        B, N, M, _ = input.shape
        traj = self.gaussian(input).view(B, N, M, self._future_len, 4) # mu_x, mu_y, log_sig_x, log_sig_y
        score = self.score(input).squeeze(-1)

        return traj, score


class InitialPredictionDecoder(nn.Module):
    def __init__(self, modalities, neighbors, dim=256):
        super(InitialPredictionDecoder, self).__init__()
        self._modalities = modalities
        self._agents = neighbors + 1
        self.multi_modal_query_embedding = nn.Embedding(modalities, dim)
        self.agent_query_embedding = nn.Embedding(self._agents, dim)
        self.query_encoder = CrossTransformer()
        self.predictor = GMMPredictor()
        self.register_buffer('modal', torch.arange(modalities).long())
        self.register_buffer('agent', torch.arange(self._agents).long())

    def forward(self, current_states, encoding, mask):
        N = self._agents   # N = 1 + 10 = 11
        multi_modal_query = self.multi_modal_query_embedding(self.modal)   # 可学习的嵌入向量
        # self.modal.shape = (6)
        # multi_modal_query.shape = (6, 256)
        agent_query = self.agent_query_embedding(self.agent)
        # self.agent.shape = (N)
        # agent_query.shape = (N, 256)
        query = encoding[:, :N, None] + multi_modal_query[None, :, :] + agent_query[:, None, :]
        query_content = torch.stack([self.query_encoder(query[:, i], encoding, encoding, mask) for i in range(N)], dim=1)
        # query.shape = (32, N, M, 256)
        # query_content.shape = (32, N, M, 256)
        predictions, scores = self.predictor(query_content)
        # predictions.shape = (32, N, M, 4 * T, 4), mu_x, mu_y, log_sig_x, log_sig_y
        # scores.shape = (32, N, M)
        predictions[..., :2] += current_states[:, :N, None, None, :2]

        return query_content, predictions, scores
    
initial_predictor = InitialPredictionDecoder(modalities, neighbors).cuda()
last_content, last_level, last_score = initial_predictor(current_states, encoding, mask)


In [9]:
# level 0 decode
last_content, last_level, last_score = initial_predictor(current_states, encoding, mask)
decoder_outputs['level_0_interactions'] = last_level
decoder_outputs['level_0_scores'] = last_score

In [10]:
future_encoder = FutureEncoder().cuda()
# 把这个网络复制2份
interaction_stage = nn.ModuleList([InteractionDecoder(modalities, future_encoder).cuda() for _ in range(levels)])

In [11]:
# level k reasoning
for k in range(1, 3):
    interaction_decoder = interaction_stage[k-1]
    last_content, last_level, last_score = interaction_decoder(current_states, last_level, last_score, last_content, encoding, mask)
    decoder_outputs[f'level_{k}_interactions'] = last_level
    decoder_outputs[f'level_{k}_scores'] = last_score


In [12]:
env_encoding = last_content[:, 0]
env_encoding.shape

torch.Size([32, 6, 256])

loss