## Input

In [71]:
import os
import csv
import torch
import argparse
import numpy as np
from tqdm import tqdm
from torch import nn, optim
from GameFormer.predictor import GameFormer
from torch.utils.data import DataLoader
from GameFormer.train_utils import *

num_neighbors = 20
batch_size = 32
# set up data loaders
train_path = '/data/fyy/GameFormer-Planner/nuplan/processed_data/train'
train_files = [f for d in os.listdir(train_path) for f in glob.glob(os.path.join(train_path, d, "*.npz"))]
train_set = DrivingData(train_files, num_neighbors)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count())

In [72]:
count = 0
with tqdm(train_loader, desc="Training", unit="batch") as data_epoch:
    for batch in data_epoch:
        count += 1
        if count == 2:
            break
        # prepare data
        inputs = {
            'ego_agent_past': batch[0].to('cuda'),
            'neighbor_agents_past': batch[1].to('cuda'),
            'map_lanes': batch[2].to('cuda'),
            'map_crosswalks': batch[3].to('cuda'),
            'route_lanes': batch[4].to('cuda')
        }

        ego_future = batch[5].to('cuda')
        neighbors_future = batch[6].to('cuda')
        neighbors_future_valid = torch.ne(neighbors_future[..., :2], 0)
        
        break

Training:   0%|          | 0/1690 [00:02<?, ?batch/s]


## Encoder

### agent encoding

In [73]:
from GameFormer.predictor_modules import *

# shape = (32, T, 7) = (x, y, heading, vx, vy, ax, ay)
ego = inputs['ego_agent_past']   
# shape = (32, N, T, 11) = (x, y, heading, vx, vy, yaw, length, width, 1, 0, 0)   
neighbors = inputs['neighbor_agents_past']    
# shape = (32, 1+N, T, 5) = (x, y, heading, vx, vy)
actors = torch.cat([ego[:, None, :, :5], neighbors[..., :5]], dim=1)

ego_encoder = AgentEncoder(agent_dim=7).cuda()  # LSTM
# shape = (32, 256)
encoded_ego = ego_encoder(ego)

agent_encoder = AgentEncoder(agent_dim=11).cuda()  # LSTM
# shape = (N, 32, 256)
encoded_neighbors = [agent_encoder(neighbors[:, i]) for i in range(neighbors.shape[1])]

# shape = (32, N+1, 256)
encoded_actors = torch.stack([encoded_ego] + encoded_neighbors, dim=1)  
actors_mask = torch.eq(actors[:, :, -1].sum(-1), 0)

### map encoding

In [74]:
_lane_len = 50
_lane_feature = 7
_crosswalk_len = 30
_crosswalk_feature = 3

map_lanes = inputs['map_lanes']    # shape = (32, 40, 50, 7)
map_crosswalks = inputs['map_crosswalks']   # shape = (32, 5, 30, 3)

lane_encoder = VectorMapEncoder(_lane_feature, _lane_len).cuda()
# shape = (32, 200, 256)
encoded_map_lanes, lanes_mask = lane_encoder(map_lanes)

crosswalk_encoder = VectorMapEncoder(_crosswalk_feature, _crosswalk_len).cuda()
# shape = (32, 15, 256)
encoded_map_crosswalks, crosswalks_mask = crosswalk_encoder(map_crosswalks)

### attention fusion encoding

In [75]:
# shape = (32, 236, 256)
input = torch.cat([encoded_actors, encoded_map_lanes, encoded_map_crosswalks], dim=1)
mask = torch.cat([actors_mask, lanes_mask, crosswalks_mask], dim=1)

dim, layers, heads, dropout = 256, 6, 8, 0.1
attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
                                                activation='gelu', dropout=dropout, batch_first=True)
fusion_encoder = nn.TransformerEncoder(attention_layer, layers).cuda()
encoding = fusion_encoder(input, src_key_padding_mask=mask)

In [76]:
encoder_outputs = {
    'actors': actors,
    'encoding': encoding,
    'mask': mask,
    'route_lanes': inputs['route_lanes']
}
# torch.Size([32, 21, 21, 5]) torch.Size([32, 236, 256]) torch.Size([32, 236]) torch.Size([32, 10, 50, 3])
# torch.Size([32, 21, 21, 5]) torch.Size([32, 236, 256]) torch.Size([32, 236]) torch.Size([32, 10, 50, 3])

In [77]:
print(actors.shape, encoding.shape, mask.shape, inputs['route_lanes'].shape)

torch.Size([32, 21, 21, 5]) torch.Size([32, 236, 256]) torch.Size([32, 236]) torch.Size([32, 10, 50, 3])


## Decoder

In [78]:
decoder_outputs = {}

# shape = (32, 1+N, 5)
current_states = encoder_outputs['actors'][:, :, -1]
# shape = (32, 200+15+21, 256)
encoding, mask = encoder_outputs['encoding'], encoder_outputs['mask']

In [79]:
neighbors = 10
modalities = 6
levels = 3

class GMMPredictor(nn.Module):
    def __init__(self, modalities=6):
        super(GMMPredictor, self).__init__()
        self.modalities = modalities
        self._future_len = 80
        self.gaussian = nn.Sequential(nn.Linear(256, 512), nn.ELU(), nn.Dropout(0.1), nn.Linear(512, self._future_len*4))
        self.score = nn.Sequential(nn.Linear(256, 64), nn.ELU(), nn.Linear(64, 1))
    
    def forward(self, input):
        B, N, M, _ = input.shape
        traj = self.gaussian(input).view(B, N, M, self._future_len, 4) # mu_x, mu_y, log_sig_x, log_sig_y
        score = self.score(input).squeeze(-1)

        return traj, score


class InitialPredictionDecoder(nn.Module):
    def __init__(self, modalities, neighbors, dim=256):
        super(InitialPredictionDecoder, self).__init__()
        self._modalities = modalities
        self._agents = neighbors + 1
        self.multi_modal_query_embedding = nn.Embedding(modalities, dim)
        self.agent_query_embedding = nn.Embedding(self._agents, dim)
        self.query_encoder = CrossTransformer()
        self.predictor = GMMPredictor()
        self.register_buffer('modal', torch.arange(modalities).long())
        self.register_buffer('agent', torch.arange(self._agents).long())

    def forward(self, current_states, encoding, mask):
        N = self._agents   # N = 1 + 10 = 11
        multi_modal_query = self.multi_modal_query_embedding(self.modal)   # 可学习的嵌入向量
        # self.modal.shape = (6)
        # multi_modal_query.shape = (6, 256)
        agent_query = self.agent_query_embedding(self.agent)
        # self.agent.shape = (N)
        # agent_query.shape = (N, 256)
        query = encoding[:, :N, None] + multi_modal_query[None, :, :] + agent_query[:, None, :]
        query_content = torch.stack([self.query_encoder(query[:, i], encoding, encoding, mask) for i in range(N)], dim=1)
        # query.shape = (32, N, M, 256)
        # query_content.shape = (32, N, M, 256)
        predictions, scores = self.predictor(query_content)
        # predictions.shape = (32, N, M, 4 * T, 4), mu_x, mu_y, log_sig_x, log_sig_y
        # scores.shape = (32, N, M)
        predictions[..., :2] += current_states[:, :N, None, None, :2]

        return query_content, predictions, scores
    
initial_predictor = InitialPredictionDecoder(modalities, neighbors).cuda()
last_content, last_level, last_score = initial_predictor(current_states, encoding, mask)
content_list = [last_content]


In [80]:
# level 0 decode
last_content, last_level, last_score = initial_predictor(current_states, encoding, mask)
decoder_outputs['level_0_interactions'] = last_level
decoder_outputs['level_0_scores'] = last_score

In [81]:
last_content[:, 0].shape

torch.Size([32, 6, 256])

In [82]:
future_encoder = FutureEncoder().cuda()
# 把这个网络复制2份
interaction_stage = nn.ModuleList([InteractionDecoder(modalities, future_encoder).cuda() for _ in range(levels)])

In [83]:

# level k reasoning
for k in range(1, 3):
    interaction_decoder = interaction_stage[k-1]
    last_content, last_level, last_score = interaction_decoder(current_states, last_level, last_score, last_content, encoding, mask)
    decoder_outputs[f'level_{k}_interactions'] = last_level
    decoder_outputs[f'level_{k}_scores'] = last_score
    content_list.append(last_content)


In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ContentFeatureExtractor(nn.Module):
    def __init__(self, input_dim, out_dim):
        super(ContentFeatureExtractor, self).__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.fc2 = nn.Linear(512, out_dim)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x

class FeatureFusionAttention(nn.Module):
    def __init__(self, feature_size, num_features):
        super(FeatureFusionAttention, self).__init__()
        self.query = nn.Linear(feature_size, feature_size)
        self.key = nn.Linear(feature_size, feature_size)
        self.value = nn.Linear(feature_size, feature_size)
        self.num_features = num_features

    def forward(self, x):
        # x 的形状是 [batch_size, num_tensors, seq_len, num_features, feature_size]
        batch_size, num_tensors, seq_len, num_features, feature_size = x.shape
        
        # 将x变形为 [batch_size * seq_len * num_features, num_tensors, feature_size]
        x_reshaped = x.view(-1, num_tensors, feature_size)
        
        # 计算查询（Q）、键（K）和值（V）
        Q = self.query(x_reshaped)
        K = self.key(x_reshaped)
        V = self.value(x_reshaped)
        # 计算注意力分数
        attention_scores = Q @ K.transpose(-2, -1) / (feature_size ** 0.5)
        
        # 应用softmax函数获取注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # 应用注意力权重到V（值）
        attended = attention_weights @ V
        # 将结果变形回 [batch_size, seq_len, num_features, num_tensors, feature_size]
        attended = attended.view(batch_size, seq_len, num_features, num_tensors, feature_size)
        print(attended.shape)
        # 在num_tensors维度上求和或者平均
        fused_features = attended.sum(dim=3)  # 使用 mean(dim=3) 来取平均
        print(fused_features.shape)
        return fused_features
    
content_feature_extractor = ContentFeatureExtractor(256, 256).cuda()
feature_fusion_attention = FeatureFusionAttention(feature_size=256, num_features=6).cuda()

content_dim = []
for input in content_list:
    batch_size, num_v, modalities, feature_dim = input.shape
    out = content_feature_extractor(input.view(-1, feature_dim)).view(batch_size, num_v, modalities, -1)
    content_dim.append(out)
    
combined_feature_tensors = torch.stack((content_dim[0], content_dim[1], content_dim[2]), dim=1) 
attention_output = feature_fusion_attention(combined_feature_tensors)

torch.Size([32, 11, 6, 3, 256])
torch.Size([32, 11, 6, 256])


In [85]:
(content_dim[0] + content_dim[1] + content_dim[2]).shape

torch.Size([32, 11, 6, 256])

In [86]:
input.shape

torch.Size([32, 11, 6, 256])

In [87]:
combined_feature_tensors.shape

torch.Size([32, 3, 11, 6, 256])

In [88]:
class AttPredictionDecoder(nn.Module):
    def __init__(self, neighbors):
        super(AttPredictionDecoder, self).__init__()
        self._agents = neighbors + 1
        self.predictor = GMMPredictor()

    def forward(self, current_states, att_content):
        N = self._agents
        predictions, scores = self.predictor(att_content)
        predictions[..., :2] += current_states[:, :N, None, None, :2]
        return att_content, predictions, scores

att_predictor = AttPredictionDecoder(neighbors).cuda()
att_content, predictions, scores = att_predictor(current_states, attention_output)

In [89]:
att_content[:, 0].shape

torch.Size([32, 6, 256])

In [90]:
predictions.shape

torch.Size([32, 11, 6, 80, 4])