### Dataset downloading

Before starting the experiments, you need to download the dataset generated by SUMO, which is accessible at [[training set](https://drive.google.com/file/d/1GM3CMnkQcRPQNsgBjqHzDqwB8zLJoNaC/view?usp=share_link)] and [[validation set](https://drive.google.com/file/d/1ar6zcAqJJJCe18O6XlJ6ja9dFKwcgf52/view?usp=share_link)]. Then please run the following command to create a folder named `csv` and move the downloaded files `train_pre.zip` and `val_pre.zip` into this folder.

In [1]:
!mkdir csv

To extract the dataset, you can use the following commands:

In [5]:
!unzip -q csv/train_pre.zip -d csv
!unzip -q csv/val_pre.zip -d csv

### Modules loading

In [25]:
import argparse
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GraphConv as GNNConv
from tqdm import tqdm

from dataset import CarDataset

### Parameters setting

In [26]:
batch_size = 200
train_folder = "csv/train_pre"
val_folder = "csv/val_pre"
exp_id = "sumo_test"
model_path = f"trained_params/{exp_id}"
mlp = False
collision_penalty = False
lr = 1e-3
n_epoch = 150

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(model_path, exist_ok=True)

In [28]:
train_dataset = CarDataset(preprocess_folder=train_folder, mlp=False, mpc_aug=True)
val_dataset = CarDataset(preprocess_folder=val_folder, mlp=False, mpc_aug=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)

### Model definition & instantiation

In [29]:
class GNN_mtl_gnn(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        torch.manual_seed(21)
        self.conv1 = GNNConv(hidden_channels, hidden_channels)
        self.conv2 = GNNConv(hidden_channels, hidden_channels)
        self.linear1 = nn.Linear(6, 64)
        self.linear2 = nn.Linear(64, hidden_channels)
        self.linear3 = nn.Linear(hidden_channels, hidden_channels)
        self.linear4 = nn.Linear(hidden_channels, hidden_channels)
        self.linear5 = nn.Linear(hidden_channels, 30*2)

    def forward(self, x, edge_index):
        x = self.linear1(x).relu()
        x = self.linear2(x).relu()
        x = self.linear3(x).relu() + x
        x = self.linear4(x).relu() + x
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.linear5(x)
        return x

In [30]:
model = GNN_mtl_gnn(hidden_channels=128)
print(model)
model = model.to(device)

GNN_mtl_gnn(
  (conv1): GraphConv(128, 128)
  (conv2): GraphConv(128, 128)
  (linear1): Linear(in_features=6, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=128, bias=True)
  (linear3): Linear(in_features=128, out_features=128, bias=True)
  (linear4): Linear(in_features=128, out_features=128, bias=True)
  (linear5): Linear(in_features=128, out_features=60, bias=True)
)


### Functions definition

In [31]:
def rotation_matrix_back(yaw):
    """
    Rotate back from the local coordinate system to the global coordinate system. 
    """
    rotation = np.array([[np.cos(-np.pi/2+yaw), -np.sin(-np.pi/2+yaw)],[np.sin(-np.pi/2+yaw), np.cos(-np.pi/2+yaw)]])
    rotation = torch.tensor(rotation).float()
    return rotation

In [32]:
def train(model, device, data_loader, optimizer, collision_penalty=False):
    """ Performs an epoch of model training.

    Parameters:
    model (nn.Module): Model to be trained.
    device (torch.Device): Device used for training.
    data_loader (torch.utils.data.DataLoader): Data loader containing all batches.
    optimizer (torch.optim.Optimizer): Optimizer used to update model.
    collision_penalty (bool): set it to True if you want to use collision penalty.

    Returns:
    float: Total loss for epoch.
    """
    model.train()
    total_loss = 0

    dist_threshold = 4

    for batch in data_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x[:,[0, 1, 3, 4, 5, 6]], batch.edge_index)   # [x, y, yaw, 3-bit intention]
        out = out.reshape(-1, 30, 2)
        out = out.permute(0, 2, 1)    # [v, 2, pred]
        yaw = batch.x[:, 3].detach().cpu().numpy()
        rotations = torch.stack([rotation_matrix_back(yaw[i])  for i in range(batch.x.shape[0])]).to(out.device)
        out = torch.bmm(rotations, out).permute(0, 2, 1)       # [v, pred, 2]
        out += batch.x[:,[0, 1]].unsqueeze(1)
        gt = batch.y.reshape(-1, 30, 6)[:,:,[0, 1]]
        error = ((gt-out).square().sum(-1)).sum(-1)
        loss = (batch.weights * error).nanmean()
        
        if collision_penalty:
            mask = (batch.edge_index[0, :] < batch.edge_index[1, :])
            _edge = batch.edge_index[:, mask].T   # [edge',2]
            dist = torch.linalg.norm(out[_edge[:, 0]] - out[_edge[:, 1]], dim=-1)
            dist = dist_threshold - dist[dist < dist_threshold]
            _collision_penalty = dist.square().mean()
            loss += _collision_penalty * 20

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data_loader)

In [33]:
def evaluate(model, device, data_loader):
    """ Performs the evaluation.

    Parameters:
    model (nn.Module): Model to be trained.
    device (torch.Device): Device used for training.
    data_loader (torch.utils.data.DataLoader): Data loader containing all batches.

    Returns:
    list of evaluation metrics (including ADE, FDE, etc.).
    """

    dist_threshold = 4
    mr_threshold = 4
    model.eval()
    ade, fde = [], []
    n_edge, n_collision = [], []
    val_losses, collision_penalties = [], []
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            out = model(batch.x[:,[0, 1, 3, 4, 5, 6]], batch.edge_index)
            out = out.reshape(-1, 30, 2)
            out = out.permute(0, 2, 1)
            yaw = batch.x[:, 3].detach().cpu().numpy()
            rotations = torch.stack([rotation_matrix_back(yaw[i])  for i in range(batch.x.shape[0])]).to(out.device)
            out = torch.bmm(rotations, out).permute(0, 2, 1)       # [v, pred, 2]
            out += batch.x[:,[0, 1]].unsqueeze(1)
            
            gt = batch.y.reshape(-1, 30, 6)[:, :, [0, 1]]
            _error = (gt-out).square().sum(-1)
            error = _error.clone() ** 0.5
            _error = _error.sum(-1)
            val_loss = (batch.weights * _error).nanmean()
            val_losses.append(val_loss)
            fde.append(error[:,-1])
            ade.append(error.mean(dim=-1))

            mask = (batch.edge_index[0,:] < batch.edge_index[1,:])
            _edge = batch.edge_index[:, mask].T   # [edge',2]
            dist = torch.linalg.norm(out[_edge[:,0]] - out[_edge[:,1]], dim=-1)
            collision_penalty = dist_threshold - dist[dist < dist_threshold]
            collision_penalty = collision_penalty.square().mean() * 20
            collision_penalties.append(collision_penalty)

            dist = torch.min(dist, dim=-1)[0]
            n_edge.append(len(dist))
            n_collision.append((dist < 2).sum().item())
    
    ade = torch.cat(ade).mean()
    fde = torch.cat(fde)
    mr = ((fde > mr_threshold).sum() / len(fde)).item()
    fde = fde.mean()
    collision_rate = sum(n_collision) / sum(n_edge)
    collision_penalties = torch.tensor(collision_penalties).mean()
    val_losses = torch.tensor(val_losses).mean()
    
    return ade.item(), fde.item(), mr, collision_rate, val_losses.item(), collision_penalties.item()

### Training & evaluation

In [34]:
min_ade = 1e6
min_fde = 1e6
best_epoch = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
record = []

for epoch in tqdm(range(0, n_epoch)):
    loss = train(model, device, train_loader, optimizer)
    if epoch % 10 == 0:
        ade, fde, mr, collision_rate, val_losses, collision_penalties = evaluate(model, device, val_loader)
        record.append([ade, fde, mr, collision_rate, val_losses, collision_penalties])
        print(f"Epoch {epoch}: Train Loss: {loss}, ADE: {ade}, FDE: {fde}, MR: {mr}, CR:{collision_rate}, \
            Val_loss: {val_losses}, CP: {collision_penalties}, lr: {optimizer.param_groups[0]['lr']}.")
        torch.save(model.state_dict(), model_path + \
            f"/model_{'mlp' if mlp else 'gnn'}_{'wp' if collision_penalty else 'np'}_{exp_id}_e3_{str(epoch).zfill(4)}.pth")
        if fde < min_fde:
            min_ade, min_fde = ade, fde
            best_epoch = epoch
            print("New smallest FDE!!")

pkl_file = f"model_gnn_{'wp' if collision_penalty else 'np'}_{exp_id}.pkl"
with open(f'{model_path}/{pkl_file}', 'wb') as handle:
    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)

  1%|          | 1/150 [00:01<03:44,  1.50s/it]

Epoch 0: Train Loss: 204191.59598214287, ADE: 17.99057388305664, FDE: 25.204317092895508, MR: 0.9461342096328735, CR:0.033681462140992165,             Val_loss: 210975.859375, CP: nan, lr: 0.001.
New smallest FDE!!


  7%|▋         | 11/150 [00:09<01:51,  1.25it/s]

Epoch 10: Train Loss: 14771.683837890625, ADE: 4.887154579162598, FDE: 8.52444839477539, MR: 0.7570093274116516, CR:0.4651436031331593,             Val_loss: 19282.283203125, CP: 93.7400894165039, lr: 0.001.
New smallest FDE!!


 14%|█▍        | 21/150 [00:16<01:38,  1.31it/s]

Epoch 20: Train Loss: 5518.316162109375, ADE: 3.237579584121704, FDE: 5.9307074546813965, MR: 0.5901443958282471, CR:0.574934725848564,             Val_loss: 9008.125, CP: 102.0328369140625, lr: 0.001.
New smallest FDE!!


 21%|██        | 31/150 [00:24<01:37,  1.22it/s]

Epoch 30: Train Loss: 3577.1335013253347, ADE: 2.703300952911377, FDE: 4.948358058929443, MR: 0.46745961904525757, CR:0.6518276762402089,             Val_loss: 6970.52294921875, CP: 114.27143096923828, lr: 0.001.
New smallest FDE!!


 27%|██▋       | 41/150 [00:30<01:22,  1.32it/s]

Epoch 40: Train Loss: 2800.9671456473216, ADE: 2.601576805114746, FDE: 4.953203201293945, MR: 0.4700084924697876, CR:0.7093994778067885,             Val_loss: 6769.8330078125, CP: 120.59954071044922, lr: 0.001.


 34%|███▍      | 51/150 [00:38<01:18,  1.26it/s]

Epoch 50: Train Loss: 2409.840105329241, ADE: 2.693805694580078, FDE: 5.066558361053467, MR: 0.46661001443862915, CR:0.7515665796344647,             Val_loss: 7231.583984375, CP: 115.84384155273438, lr: 0.001.


 41%|████      | 61/150 [00:45<01:10,  1.27it/s]

Epoch 60: Train Loss: 1868.4940708705358, ADE: 2.571370840072632, FDE: 4.882014274597168, MR: 0.43857261538505554, CR:0.7580939947780679,             Val_loss: 6846.5830078125, CP: 119.41574096679688, lr: 0.001.
New smallest FDE!!


 47%|████▋     | 71/150 [00:53<01:01,  1.28it/s]

Epoch 70: Train Loss: 1610.8249773297991, ADE: 2.5238358974456787, FDE: 4.843796253204346, MR: 0.44621917605400085, CR:0.7941253263707572,             Val_loss: 7047.064453125, CP: 128.0289764404297, lr: 0.001.
New smallest FDE!!


 54%|█████▍    | 81/150 [01:01<00:58,  1.17it/s]

Epoch 80: Train Loss: 1448.5355224609375, ADE: 2.4165570735931396, FDE: 4.650563716888428, MR: 0.4088360071182251, CR:0.7787206266318538,             Val_loss: 6962.2314453125, CP: 129.2513427734375, lr: 0.001.
New smallest FDE!!


 61%|██████    | 91/150 [01:08<00:50,  1.18it/s]

Epoch 90: Train Loss: 1310.5351867675781, ADE: 2.413524866104126, FDE: 4.625000476837158, MR: 0.4056074619293213, CR:0.8039164490861619,             Val_loss: 7094.41259765625, CP: 128.8695068359375, lr: 0.001.
New smallest FDE!!


 67%|██████▋   | 101/150 [01:16<00:38,  1.27it/s]

Epoch 100: Train Loss: 1067.4303131103516, ADE: 2.3872785568237305, FDE: 4.655385494232178, MR: 0.4093457758426666, CR:0.7913838120104438,             Val_loss: 7174.46044921875, CP: 134.4886474609375, lr: 0.001.


 74%|███████▍  | 111/150 [01:24<00:32,  1.20it/s]

Epoch 110: Train Loss: 1110.0423496791295, ADE: 2.421121835708618, FDE: 4.702868461608887, MR: 0.4280373752117157, CR:0.8006527415143603,             Val_loss: 7294.501953125, CP: 130.607421875, lr: 0.001.


 81%|████████  | 121/150 [01:32<00:22,  1.27it/s]

Epoch 120: Train Loss: 910.7719116210938, ADE: 2.330911159515381, FDE: 4.541915416717529, MR: 0.4079864025115967, CR:0.8169712793733681,             Val_loss: 6962.04736328125, CP: 136.26275634765625, lr: 0.001.
New smallest FDE!!


 87%|████████▋ | 131/150 [01:39<00:15,  1.26it/s]

Epoch 130: Train Loss: 848.9992937360491, ADE: 2.3191120624542236, FDE: 4.5340576171875, MR: 0.3972811996936798, CR:0.8325065274151436,             Val_loss: 6909.578125, CP: 133.60231018066406, lr: 0.001.
New smallest FDE!!


 94%|█████████▍| 141/150 [01:47<00:07,  1.27it/s]

Epoch 140: Train Loss: 769.4565734863281, ADE: 2.387826919555664, FDE: 4.717695713043213, MR: 0.4276975095272064, CR:0.8134464751958225,             Val_loss: 7048.095703125, CP: 135.67608642578125, lr: 0.001.


100%|██████████| 150/150 [01:54<00:00,  1.32it/s]
