# CSE251B Project Milestone Starter File

## Step 1: Import Dependencies:

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

## Step 2: Load the Dataset

#### You need to describe in your own words what the dataset is about, and use mathematical language and formulate your prediction task on the submitted PDF file for Question 1 Problem A.

#### Here we are loading the dataset from the local directory. And answer Question 1 Problem B

In [2]:
train_npz = np.load('cse-251-b-2025/train.npz')
train_data = train_npz['data']
test_npz  = np.load('cse-251-b-2025/test_input.npz')
test_data  = test_npz['data']

In [3]:
X_train = train_data[..., :50, :]

aug_X = []
max_swaps_per_scene = 8  #
for scene in train_data:
    vehicle_indices = []

    # Find vehicle agents (object_type == 0)
    for i in range(1, 50):  # skip agent 0
        agent = scene[i]
        if np.all(agent == 0):
            continue
        if agent[0, 5] == 0:  # 'vehicle'
            vehicle_indices.append(i)

    # Randomly select up to max_swaps_per_scene
    selected = np.random.choice(vehicle_indices, size=min(max_swaps_per_scene, len(vehicle_indices)), replace=False)

    for i in selected:
        swapped_scene = scene.copy()
        swapped_scene[[0, i]] = swapped_scene[[i, 0]]
        aug_X.append(swapped_scene[:, :50, :])  # (50, 50, 6)

# Stack and combine
aug_X = np.stack(aug_X)                          # (N_aug, 50, 50, 6)
X_train_aug = np.concatenate([X_train, aug_X], axis=0)

print("Original X_train shape:", X_train.shape)
print("Augmented X_train shape:", X_train_aug.shape)

X_train = X_train_aug
Y_train = train_data[:, 0, 50:, :2]

Original X_train shape: (10000, 50, 50, 6)
Augmented X_train shape: (89784, 50, 50, 6)


In [4]:
def plot_heatmap(data, title=None, bins=5):
    plt.figure(figsize=(6, 6))

    x_max = data[..., 0].max()
    x_min = data[..., 0].min()
    y_max = data[..., 1].max()
    y_min = data[..., 1].min()

    plt.hist2d(data[:, 0], data[:, 1], bins=bins, cmap='hot')
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.title(title)
    plt.colorbar(label='Density')
    plt.xlabel('X-axis')
    plt.ylabel('Y-axis')
    plt.show()

In [5]:
xy_in = train_data[:, :, :50, :2].reshape(-1, 2)
# only find the x, y != 0
xy_in_not_0 = xy_in[(xy_in[:, 0] != 0) & (xy_in[:, 1] != 0)]

In [8]:
plot_heatmap(xy_in, title='Heatmap of XY In', bins=5)
plot_heatmap(xy_in_not_0, title='Heatmap of XY In (non-zero)', bins=5)

NameError: name 'plot_heatmap' is not defined

In [None]:
plot_heatmap(xy_in, title='Heatmap of XY In', bins=50)
plot_heatmap(xy_in_not_0, title='Heatmap of XY In (non-zero)', bins=50)

#### Try to play around with dataset for training and testing, make exploratory analysis on the dataset for bonus points(up to 2)

## Step 3: Setting up the Training and Testing

### Example Code:

In [4]:
class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (60, 2)
        
        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

        return data_item
    

class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()
        
        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item

In [5]:
torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using CUDA GPU


# Use Transformer to Reconstruct

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader

In [27]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (B, T, D)
        return x + self.pe[:, :x.size(1), :]

class TransformerModel(nn.Module):
    def __init__(self, input_dim=6, d_model=512, nhead=16, num_layers=4, 
                 output_dim=60 * 2, seq_len=50):
        super(TransformerModel, self).__init__()
        self.seq_len = seq_len
        
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=seq_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=d_model*4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.predict_head = nn.Sequential(
            nn.Linear(d_model * seq_len, d_model*2),
            nn.ReLU(),
            nn.Linear(d_model*2, output_dim)
        )

    def forward(self, data):
        x = data.x.view(-1, 50, 50, 6)[:, 0, :, :]  # (B, T, F)
        B, T, F = x.shape
        
        x = self.input_proj(x)  # (B, T, d_model)
        x = self.pos_encoder(x)
        
        transformer_out = self.transformer(x)  # (B, T, d_model)

        flattened = transformer_out.reshape(B, -1)  # (B, T * d_model)
        
        out = self.predict_head(flattened)  # (B, 120)
        return out.view(-1, 60, 2)  # (B, 60, 2)

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerModel(d_model=512, nhead=64, num_layers=4).to(device)


optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)


scheduler = optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[
        optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=10),
        optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40, eta_min=1e-6)
    ],
    milestones=[10]
)


criterion = nn.SmoothL1Loss(beta=0.5)

early_stopping_patience = 30
best_val_loss = float('inf')
no_improvement = 0

In [40]:
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss = criterion(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
    
    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            val_loss += criterion(pred, y).item()

            pred_denorm = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y_denorm = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            
            val_mae += nn.L1Loss()(pred_denorm, y_denorm).item()
            val_mse += nn.MSELoss()(pred_denorm, y_denorm).item()
    
    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    
    scheduler.step()
    
    tqdm.tqdm.write(f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
                   f"Train Loss {train_loss:.4f} | Val Loss {val_loss:.4f} | "
                   f"Val MAE {val_mae:.4f} | Val MSE {val_mse:.4f}")
    
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            print("Early stopping triggered!")
            break

  future = future @ R
  future = future - origin
Epoch:   1%|▋                                                                       | 1/100 [00:09<15:06,  9.16s/epoch]

Epoch 000 | LR 0.000033 | Train Loss 0.6456 | Val Loss 0.4200 | Val MAE 4.0030 | Val MSE 51.1283


Epoch:   2%|█▍                                                                      | 2/100 [00:18<15:02,  9.21s/epoch]

Epoch 001 | LR 0.000062 | Train Loss 0.3597 | Val Loss 0.2927 | Val MAE 3.1312 | Val MSE 29.2511


Epoch:   3%|██▏                                                                     | 3/100 [00:27<14:50,  9.18s/epoch]

Epoch 002 | LR 0.000092 | Train Loss 0.3198 | Val Loss 0.2575 | Val MAE 2.9388 | Val MSE 23.5629


Epoch:   4%|██▉                                                                     | 4/100 [00:36<14:31,  9.08s/epoch]

Epoch 003 | LR 0.000122 | Train Loss 0.2684 | Val Loss 0.2854 | Val MAE 3.2189 | Val MSE 24.8803


Epoch:   5%|███▌                                                                    | 5/100 [00:45<14:21,  9.07s/epoch]

Epoch 004 | LR 0.000151 | Train Loss 0.2364 | Val Loss 0.2426 | Val MAE 2.9256 | Val MSE 19.0794


Epoch:   6%|████▎                                                                   | 6/100 [00:54<14:06,  9.00s/epoch]

Epoch 005 | LR 0.000181 | Train Loss 0.2243 | Val Loss 0.2116 | Val MAE 2.5824 | Val MSE 17.0857


Epoch:   7%|█████                                                                   | 7/100 [01:03<13:51,  8.95s/epoch]

Epoch 006 | LR 0.000211 | Train Loss 0.2220 | Val Loss 0.2007 | Val MAE 2.4291 | Val MSE 17.2817


Epoch:   8%|█████▊                                                                  | 8/100 [01:12<13:39,  8.90s/epoch]

Epoch 007 | LR 0.000241 | Train Loss 0.2194 | Val Loss 0.1751 | Val MAE 2.1897 | Val MSE 14.8366


Epoch:   9%|██████▍                                                                 | 9/100 [01:21<13:36,  8.97s/epoch]

Epoch 008 | LR 0.000270 | Train Loss 0.2194 | Val Loss 0.1971 | Val MAE 2.3819 | Val MSE 17.4312


Epoch:  10%|███████                                                                | 10/100 [01:30<13:37,  9.08s/epoch]

Epoch 009 | LR 0.000300 | Train Loss 0.2236 | Val Loss 0.2337 | Val MAE 2.7662 | Val MSE 19.8968


Epoch:  11%|███████▊                                                               | 11/100 [01:39<13:37,  9.18s/epoch]

Epoch 010 | LR 0.000300 | Train Loss 0.2160 | Val Loss 0.2021 | Val MAE 2.4606 | Val MSE 16.9119


Epoch:  12%|████████▌                                                              | 12/100 [01:49<13:33,  9.24s/epoch]

Epoch 011 | LR 0.000298 | Train Loss 0.2034 | Val Loss 0.2279 | Val MAE 2.7036 | Val MSE 19.4298


Epoch:  13%|█████████▏                                                             | 13/100 [01:58<13:30,  9.32s/epoch]

Epoch 012 | LR 0.000296 | Train Loss 0.2036 | Val Loss 0.1657 | Val MAE 2.0858 | Val MSE 13.9908


Epoch:  14%|█████████▉                                                             | 14/100 [02:08<13:26,  9.38s/epoch]

Epoch 013 | LR 0.000293 | Train Loss 0.1834 | Val Loss 0.1532 | Val MAE 1.9868 | Val MSE 12.6797


Epoch:  15%|██████████▋                                                            | 15/100 [02:17<13:17,  9.38s/epoch]

Epoch 014 | LR 0.000289 | Train Loss 0.1767 | Val Loss 0.1761 | Val MAE 2.1913 | Val MSE 14.2747


Epoch:  16%|███████████▎                                                           | 16/100 [02:26<13:06,  9.36s/epoch]

Epoch 015 | LR 0.000284 | Train Loss 0.1709 | Val Loss 0.1913 | Val MAE 2.4211 | Val MSE 15.5339


Epoch:  17%|████████████                                                           | 17/100 [02:36<12:48,  9.26s/epoch]

Epoch 016 | LR 0.000278 | Train Loss 0.1665 | Val Loss 0.1683 | Val MAE 2.1067 | Val MSE 13.5495


Epoch:  18%|████████████▊                                                          | 18/100 [02:45<12:35,  9.21s/epoch]

Epoch 017 | LR 0.000271 | Train Loss 0.1543 | Val Loss 0.1679 | Val MAE 2.0383 | Val MSE 14.5839


Epoch:  19%|█████████████▍                                                         | 19/100 [02:54<12:19,  9.13s/epoch]

Epoch 018 | LR 0.000264 | Train Loss 0.1543 | Val Loss 0.1419 | Val MAE 1.8256 | Val MSE 11.9551


Epoch:  20%|██████████████▏                                                        | 20/100 [03:02<12:02,  9.04s/epoch]

Epoch 019 | LR 0.000256 | Train Loss 0.1503 | Val Loss 0.1422 | Val MAE 1.8729 | Val MSE 11.5029


Epoch:  21%|██████████████▉                                                        | 21/100 [03:11<11:50,  9.00s/epoch]

Epoch 020 | LR 0.000248 | Train Loss 0.1455 | Val Loss 0.1263 | Val MAE 1.6150 | Val MSE 10.4687


Epoch:  22%|███████████████▌                                                       | 22/100 [03:20<11:36,  8.92s/epoch]

Epoch 021 | LR 0.000238 | Train Loss 0.1427 | Val Loss 0.1296 | Val MAE 1.6230 | Val MSE 10.7492


Epoch:  23%|████████████████▎                                                      | 23/100 [03:29<11:24,  8.89s/epoch]

Epoch 022 | LR 0.000229 | Train Loss 0.1412 | Val Loss 0.1452 | Val MAE 1.9348 | Val MSE 11.5500


Epoch:  24%|█████████████████                                                      | 24/100 [03:38<11:12,  8.85s/epoch]

Epoch 023 | LR 0.000218 | Train Loss 0.1398 | Val Loss 0.1270 | Val MAE 1.6626 | Val MSE 10.1630


Epoch:  25%|█████████████████▊                                                     | 25/100 [03:47<11:06,  8.88s/epoch]

Epoch 024 | LR 0.000208 | Train Loss 0.1376 | Val Loss 0.1204 | Val MAE 1.5978 | Val MSE 9.9441


Epoch:  26%|██████████████████▍                                                    | 26/100 [03:55<10:55,  8.86s/epoch]

Epoch 025 | LR 0.000197 | Train Loss 0.1339 | Val Loss 0.1315 | Val MAE 1.7308 | Val MSE 10.7383


Epoch:  27%|███████████████████▏                                                   | 27/100 [04:04<10:49,  8.89s/epoch]

Epoch 026 | LR 0.000185 | Train Loss 0.1330 | Val Loss 0.1178 | Val MAE 1.6024 | Val MSE 9.4706


Epoch:  28%|███████████████████▉                                                   | 28/100 [04:13<10:40,  8.90s/epoch]

Epoch 027 | LR 0.000174 | Train Loss 0.1297 | Val Loss 0.1210 | Val MAE 1.6502 | Val MSE 9.8423


Epoch:  29%|████████████████████▌                                                  | 29/100 [04:23<10:42,  9.05s/epoch]

Epoch 028 | LR 0.000162 | Train Loss 0.1287 | Val Loss 0.1132 | Val MAE 1.5335 | Val MSE 9.2464


Epoch:  30%|█████████████████████▎                                                 | 30/100 [04:32<10:38,  9.13s/epoch]

Epoch 029 | LR 0.000150 | Train Loss 0.1252 | Val Loss 0.1141 | Val MAE 1.5060 | Val MSE 9.1675


Epoch:  31%|██████████████████████                                                 | 31/100 [04:41<10:31,  9.16s/epoch]

Epoch 030 | LR 0.000139 | Train Loss 0.1253 | Val Loss 0.1173 | Val MAE 1.5587 | Val MSE 9.6117


Epoch:  32%|██████████████████████▋                                                | 32/100 [04:51<10:27,  9.22s/epoch]

Epoch 031 | LR 0.000127 | Train Loss 0.1233 | Val Loss 0.1113 | Val MAE 1.4886 | Val MSE 8.8747


Epoch:  33%|███████████████████████▍                                               | 33/100 [05:00<10:18,  9.23s/epoch]

Epoch 032 | LR 0.000116 | Train Loss 0.1219 | Val Loss 0.1123 | Val MAE 1.5398 | Val MSE 9.1236


Epoch:  34%|████████████████████████▏                                              | 34/100 [05:09<10:09,  9.23s/epoch]

Epoch 033 | LR 0.000104 | Train Loss 0.1203 | Val Loss 0.1164 | Val MAE 1.5084 | Val MSE 9.6655


Epoch:  35%|████████████████████████▊                                              | 35/100 [05:18<09:56,  9.18s/epoch]

Epoch 034 | LR 0.000093 | Train Loss 0.1177 | Val Loss 0.1183 | Val MAE 1.5807 | Val MSE 9.6715


Epoch:  36%|█████████████████████████▌                                             | 36/100 [05:27<09:38,  9.04s/epoch]

Epoch 035 | LR 0.000083 | Train Loss 0.1162 | Val Loss 0.1261 | Val MAE 1.6214 | Val MSE 10.3177


Epoch:  37%|██████████████████████████▎                                            | 37/100 [05:36<09:26,  9.00s/epoch]

Epoch 036 | LR 0.000072 | Train Loss 0.1141 | Val Loss 0.1110 | Val MAE 1.4823 | Val MSE 9.0102


Epoch:  38%|██████████████████████████▉                                            | 38/100 [05:45<09:20,  9.05s/epoch]

Epoch 037 | LR 0.000063 | Train Loss 0.1114 | Val Loss 0.1043 | Val MAE 1.3330 | Val MSE 8.9001


Epoch:  39%|███████████████████████████▋                                           | 39/100 [05:54<09:07,  8.98s/epoch]

Epoch 038 | LR 0.000053 | Train Loss 0.1098 | Val Loss 0.1074 | Val MAE 1.4603 | Val MSE 8.8453


Epoch:  40%|████████████████████████████▍                                          | 40/100 [06:03<08:56,  8.94s/epoch]

Epoch 039 | LR 0.000045 | Train Loss 0.1093 | Val Loss 0.1084 | Val MAE 1.4016 | Val MSE 8.8755


Epoch:  41%|█████████████████████████████                                          | 41/100 [06:11<08:45,  8.91s/epoch]

Epoch 040 | LR 0.000037 | Train Loss 0.1071 | Val Loss 0.1037 | Val MAE 1.3439 | Val MSE 8.6820


Epoch:  42%|█████████████████████████████▊                                         | 42/100 [06:20<08:37,  8.92s/epoch]

Epoch 041 | LR 0.000030 | Train Loss 0.1061 | Val Loss 0.1135 | Val MAE 1.4596 | Val MSE 9.4986


Epoch:  43%|██████████████████████████████▌                                        | 43/100 [06:30<08:36,  9.05s/epoch]

Epoch 042 | LR 0.000023 | Train Loss 0.1044 | Val Loss 0.1026 | Val MAE 1.3263 | Val MSE 8.3740


Epoch:  44%|███████████████████████████████▏                                       | 44/100 [06:39<08:31,  9.14s/epoch]

Epoch 043 | LR 0.000017 | Train Loss 0.1037 | Val Loss 0.1023 | Val MAE 1.3174 | Val MSE 8.5373


Epoch:  45%|███████████████████████████████▉                                       | 45/100 [06:48<08:26,  9.21s/epoch]

Epoch 044 | LR 0.000012 | Train Loss 0.1037 | Val Loss 0.1018 | Val MAE 1.3335 | Val MSE 8.4421


Epoch:  46%|████████████████████████████████▋                                      | 46/100 [06:58<08:20,  9.27s/epoch]

Epoch 045 | LR 0.000008 | Train Loss 0.1021 | Val Loss 0.1007 | Val MAE 1.2967 | Val MSE 8.3595


Epoch:  47%|█████████████████████████████████▎                                     | 47/100 [07:07<08:12,  9.30s/epoch]

Epoch 046 | LR 0.000005 | Train Loss 0.1010 | Val Loss 0.1002 | Val MAE 1.2816 | Val MSE 8.4338


Epoch:  48%|██████████████████████████████████                                     | 48/100 [07:16<08:02,  9.28s/epoch]

Epoch 047 | LR 0.000003 | Train Loss 0.1010 | Val Loss 0.1019 | Val MAE 1.3030 | Val MSE 8.4977


Epoch:  49%|██████████████████████████████████▊                                    | 49/100 [07:26<07:53,  9.28s/epoch]

Epoch 048 | LR 0.000001 | Train Loss 0.1012 | Val Loss 0.1005 | Val MAE 1.2841 | Val MSE 8.4091


Epoch:  50%|███████████████████████████████████▌                                   | 50/100 [07:35<07:44,  9.30s/epoch]

Epoch 049 | LR 0.000001 | Train Loss 0.1007 | Val Loss 0.1001 | Val MAE 1.2745 | Val MSE 8.3732


Epoch:  51%|████████████████████████████████████▏                                  | 51/100 [07:44<07:30,  9.20s/epoch]

Epoch 050 | LR 0.000001 | Train Loss 0.0999 | Val Loss 0.1003 | Val MAE 1.2807 | Val MSE 8.3952


Epoch:  52%|████████████████████████████████████▉                                  | 52/100 [07:53<07:22,  9.23s/epoch]

Epoch 051 | LR 0.000003 | Train Loss 0.1000 | Val Loss 0.0995 | Val MAE 1.2706 | Val MSE 8.3416


Epoch:  53%|█████████████████████████████████████▋                                 | 53/100 [08:02<07:10,  9.16s/epoch]

Epoch 052 | LR 0.000005 | Train Loss 0.1010 | Val Loss 0.0988 | Val MAE 1.2574 | Val MSE 8.2853


Epoch:  54%|██████████████████████████████████████▎                                | 54/100 [08:11<06:56,  9.06s/epoch]

Epoch 053 | LR 0.000008 | Train Loss 0.1006 | Val Loss 0.1009 | Val MAE 1.2897 | Val MSE 8.4123


Epoch:  55%|███████████████████████████████████████                                | 55/100 [08:20<06:43,  8.96s/epoch]

Epoch 054 | LR 0.000012 | Train Loss 0.1009 | Val Loss 0.0997 | Val MAE 1.2728 | Val MSE 8.3152


Epoch:  56%|███████████████████████████████████████▊                               | 56/100 [08:29<06:32,  8.91s/epoch]

Epoch 055 | LR 0.000017 | Train Loss 0.1011 | Val Loss 0.0991 | Val MAE 1.2846 | Val MSE 8.2249


Epoch:  57%|████████████████████████████████████████▍                              | 57/100 [08:38<06:27,  9.02s/epoch]

Epoch 056 | LR 0.000023 | Train Loss 0.1019 | Val Loss 0.1034 | Val MAE 1.3442 | Val MSE 8.7183


Epoch:  58%|█████████████████████████████████████████▏                             | 58/100 [08:47<06:21,  9.09s/epoch]

Epoch 057 | LR 0.000030 | Train Loss 0.1037 | Val Loss 0.1013 | Val MAE 1.3060 | Val MSE 8.4388


Epoch:  59%|█████████████████████████████████████████▉                             | 59/100 [08:56<06:14,  9.14s/epoch]

Epoch 058 | LR 0.000037 | Train Loss 0.1026 | Val Loss 0.1027 | Val MAE 1.3246 | Val MSE 8.6262


Epoch:  60%|██████████████████████████████████████████▌                            | 60/100 [09:06<06:07,  9.18s/epoch]

Epoch 059 | LR 0.000045 | Train Loss 0.1038 | Val Loss 0.1019 | Val MAE 1.3199 | Val MSE 8.5942


Epoch:  61%|███████████████████████████████████████████▎                           | 61/100 [09:15<05:58,  9.19s/epoch]

Epoch 060 | LR 0.000053 | Train Loss 0.1041 | Val Loss 0.1068 | Val MAE 1.3874 | Val MSE 8.7825


Epoch:  62%|████████████████████████████████████████████                           | 62/100 [09:24<05:50,  9.23s/epoch]

Epoch 061 | LR 0.000063 | Train Loss 0.1061 | Val Loss 0.1047 | Val MAE 1.3839 | Val MSE 8.6417


Epoch:  63%|████████████████████████████████████████████▋                          | 63/100 [09:33<05:41,  9.24s/epoch]

Epoch 062 | LR 0.000072 | Train Loss 0.1069 | Val Loss 0.1071 | Val MAE 1.4268 | Val MSE 8.8099


Epoch:  64%|█████████████████████████████████████████████▍                         | 64/100 [09:43<05:32,  9.24s/epoch]

Epoch 063 | LR 0.000083 | Train Loss 0.1089 | Val Loss 0.1027 | Val MAE 1.3461 | Val MSE 8.5386


Epoch:  65%|██████████████████████████████████████████████▏                        | 65/100 [09:52<05:23,  9.25s/epoch]

Epoch 064 | LR 0.000093 | Train Loss 0.1088 | Val Loss 0.1078 | Val MAE 1.4203 | Val MSE 8.9782


Epoch:  66%|██████████████████████████████████████████████▊                        | 66/100 [10:01<05:14,  9.25s/epoch]

Epoch 065 | LR 0.000104 | Train Loss 0.1110 | Val Loss 0.1178 | Val MAE 1.5389 | Val MSE 9.7043


Epoch:  67%|███████████████████████████████████████████████▌                       | 67/100 [10:10<05:03,  9.19s/epoch]

Epoch 066 | LR 0.000116 | Train Loss 0.1112 | Val Loss 0.1090 | Val MAE 1.5015 | Val MSE 8.7728


Epoch:  68%|████████████████████████████████████████████████▎                      | 68/100 [10:19<04:48,  9.01s/epoch]

Epoch 067 | LR 0.000127 | Train Loss 0.1141 | Val Loss 0.1110 | Val MAE 1.4911 | Val MSE 9.1882


Epoch:  69%|████████████████████████████████████████████████▉                      | 69/100 [10:28<04:37,  8.95s/epoch]

Epoch 068 | LR 0.000139 | Train Loss 0.1138 | Val Loss 0.1120 | Val MAE 1.4607 | Val MSE 9.1032


Epoch:  70%|█████████████████████████████████████████████████▋                     | 70/100 [10:37<04:31,  9.05s/epoch]

Epoch 069 | LR 0.000151 | Train Loss 0.1160 | Val Loss 0.1199 | Val MAE 1.5829 | Val MSE 9.7894


Epoch:  71%|██████████████████████████████████████████████████▍                    | 71/100 [10:46<04:24,  9.13s/epoch]

Epoch 070 | LR 0.000162 | Train Loss 0.1172 | Val Loss 0.1129 | Val MAE 1.4771 | Val MSE 9.2476


Epoch:  72%|███████████████████████████████████████████████████                    | 72/100 [10:56<04:16,  9.17s/epoch]

Epoch 071 | LR 0.000174 | Train Loss 0.1189 | Val Loss 0.1072 | Val MAE 1.4619 | Val MSE 8.5840


Epoch:  73%|███████████████████████████████████████████████████▊                   | 73/100 [11:05<04:07,  9.18s/epoch]

Epoch 072 | LR 0.000185 | Train Loss 0.1193 | Val Loss 0.1330 | Val MAE 1.7290 | Val MSE 10.5871


Epoch:  74%|████████████████████████████████████████████████████▌                  | 74/100 [11:14<03:57,  9.12s/epoch]

Epoch 073 | LR 0.000197 | Train Loss 0.1204 | Val Loss 0.1188 | Val MAE 1.5947 | Val MSE 9.8926


Epoch:  75%|█████████████████████████████████████████████████████▎                 | 75/100 [11:23<03:49,  9.17s/epoch]

Epoch 074 | LR 0.000208 | Train Loss 0.1217 | Val Loss 0.1103 | Val MAE 1.4708 | Val MSE 9.0680


Epoch:  76%|█████████████████████████████████████████████████████▉                 | 76/100 [11:32<03:41,  9.21s/epoch]

Epoch 075 | LR 0.000218 | Train Loss 0.1224 | Val Loss 0.1158 | Val MAE 1.5611 | Val MSE 9.4219


Epoch:  77%|██████████████████████████████████████████████████████▋                | 77/100 [11:42<03:32,  9.25s/epoch]

Epoch 076 | LR 0.000229 | Train Loss 0.1216 | Val Loss 0.1108 | Val MAE 1.4945 | Val MSE 8.8026


Epoch:  78%|███████████████████████████████████████████████████████▍               | 78/100 [11:51<03:23,  9.23s/epoch]

Epoch 077 | LR 0.000238 | Train Loss 0.1237 | Val Loss 0.1175 | Val MAE 1.5182 | Val MSE 9.7302


Epoch:  79%|████████████████████████████████████████████████████████               | 79/100 [12:00<03:14,  9.25s/epoch]

Epoch 078 | LR 0.000248 | Train Loss 0.1263 | Val Loss 0.1274 | Val MAE 1.6731 | Val MSE 10.2905


Epoch:  80%|████████████████████████████████████████████████████████▊              | 80/100 [12:10<03:05,  9.29s/epoch]

Epoch 079 | LR 0.000256 | Train Loss 0.1259 | Val Loss 0.1101 | Val MAE 1.4739 | Val MSE 8.7418


Epoch:  81%|█████████████████████████████████████████████████████████▌             | 81/100 [12:19<02:56,  9.29s/epoch]

Epoch 080 | LR 0.000264 | Train Loss 0.1261 | Val Loss 0.1241 | Val MAE 1.6682 | Val MSE 9.6419


Epoch:  82%|██████████████████████████████████████████████████████████▏            | 82/100 [12:28<02:47,  9.28s/epoch]

Epoch 081 | LR 0.000271 | Train Loss 0.1268 | Val Loss 0.1402 | Val MAE 1.8702 | Val MSE 11.2679


Epoch:  82%|██████████████████████████████████████████████████████████▏            | 82/100 [12:37<02:46,  9.24s/epoch]

Epoch 082 | LR 0.000278 | Train Loss 0.1256 | Val Loss 0.1277 | Val MAE 1.7105 | Val MSE 10.4434
Early stopping triggered!





In [43]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
model = TransformerModel(d_model=512, nhead=64, num_layers=4).to(device)
model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)
        
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())

pred_list = np.concatenate(pred_list, axis=0)
pred_output = pred_list.reshape(-1, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission.csv', index=True)