In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os
import torch

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from tqdm import tqdm
import torch.nn.functional as F


In [2]:
# if torch.backends.mps.is_available():
#     device = torch.device('mps')
#     print("Apple GPU")
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA GPU")
else:
    device = torch.device('cpu')

In [3]:
def getData(path):
    train_file = np.load(path+"/train.npz")
    train_data = train_file['data']
    test_file = np.load(path+"/test_input.npz")
    test_data = test_file['data']
    print(f"Training Data's shape is {train_data.shape} and Test Data's is {test_data.shape}")
    return train_data, test_data
trainData, testData = getData("./data/")

Training Data's shape is (10000, 50, 110, 6) and Test Data's is (2100, 50, 50, 6)


In [4]:
def compute_feature_ranges(data):
    # Feature indices:
    # 0: pos_x, 1: pos_y, 2: vel_x, 3: vel_y
    pos_x = data[..., 0]
    pos_y = data[..., 1]
    vel_x = data[..., 2]
    vel_y = data[..., 3]

    print(f"Position X range: {pos_x.min().item():.3f} to {pos_x.max().item():.3f}")
    print(f"Position Y range: {pos_y.min().item():.3f} to {pos_y.max().item():.3f}")
    print(f"Velocity X range: {vel_x.min().item():.3f} to {vel_x.max().item():.3f}")
    print(f"Velocity Y range: {vel_y.min().item():.3f} to {vel_y.max().item():.3f}")

compute_feature_ranges(trainData)

Position X range: -9305.083 to 13261.225
Position Y range: -4581.698 to 6655.741
Velocity X range: -40.043 to 42.910
Velocity Y range: -35.110 to 32.641


In [5]:
pos_x = trainData[..., 0]
pos_y = trainData[..., 1]
vel_x = trainData[..., 2]
vel_y = trainData[..., 3]

POS_X_MIN, POS_X_MAX = pos_x.min(), pos_x.max()
POS_Y_MIN, POS_Y_MAX = pos_y.min(), pos_y.max()
VEL_X_MIN, VEL_X_MAX = vel_x.min(), vel_x.max()
VEL_Y_MIN, VEL_Y_MAX = vel_y.min(), vel_y.max()

# print(f"Position X range: {pos_x.min().item():.3f} to {pos_x.max().item():.3f}")
# print(f"Position Y range: {pos_y.min().item():.3f} to {pos_y.max().item():.3f}")
# print(f"Velocity X range: {vel_x.min().item():.3f} to {vel_x.max().item():.3f}")
# print(f"Velocity Y range: {vel_y.min().item():.3f} to {vel_y.max().item():.3f}")

In [6]:
def normalize_batchX(data):
    data = data.clone()
    ref_t = 49

    pos_x = data[..., 0]
    pos_y = data[..., 1]
    vel_x = data[..., 2]
    vel_y = data[..., 3]
    heading = data[..., 4]

    x_ref = pos_x[:, :, ref_t].unsqueeze(-1)
    y_ref = pos_y[:, :, ref_t].unsqueeze(-1)
    theta_ref = heading[:, :, ref_t].unsqueeze(-1)

    cos_theta = torch.cos(-theta_ref)
    sin_theta = torch.sin(-theta_ref)

    dx = pos_x - x_ref
    dy = pos_y - y_ref

    new_x = dx * cos_theta - dy * sin_theta
    new_y = dx * sin_theta + dy * cos_theta
    new_vx = vel_x * cos_theta - vel_y * sin_theta
    new_vy = vel_x * sin_theta + vel_y * cos_theta
    new_heading = (heading - theta_ref + torch.pi) % (2 * torch.pi) - torch.pi

    def minmax_norm(t, t_min, t_max):
        return (t - t_min) / (t_max - t_min)

    new_x = minmax_norm(new_x, POS_X_MIN, POS_X_MAX)
    new_y = minmax_norm(new_y, POS_Y_MIN, POS_Y_MAX)
    new_vx = minmax_norm(new_vx, VEL_X_MIN, VEL_X_MAX)
    new_vy = minmax_norm(new_vy, VEL_Y_MIN, VEL_Y_MAX)

    # Mask for timestamps where all features are zero (B, 50 agents, 50 timestamps)
    all_zero_mask = torch.all(data == 0.0, dim=-1)
    # all_zero_mask = torch.all(torch.isclose(batchY, torch.zeros_like(batchY)), dim=-1)
    
    # Apply transformation only where not all zeros
    mask_inv = ~all_zero_mask

    data[..., 0] = torch.where(mask_inv, new_x, data[..., 0])
    data[..., 1] = torch.where(mask_inv, new_y, data[..., 1])
    data[..., 2] = torch.where(mask_inv, new_vx, data[..., 2])
    data[..., 3] = torch.where(mask_inv, new_vy, data[..., 3])
    data[..., 4] = torch.where(mask_inv, new_heading, data[..., 4])

    return data, x_ref, y_ref, theta_ref


In [7]:
def normalize_batchY(batchY, x_ref, y_ref, theta_ref):
    batchY = batchY.clone()

    cos_theta = torch.cos(-theta_ref).squeeze(-1)  # (B,50)
    sin_theta = torch.sin(-theta_ref).squeeze(-1)  # (B,50)

    cos_theta = cos_theta[:, 0].unsqueeze(-1)  # (B,1)
    sin_theta = sin_theta[:, 0].unsqueeze(-1)  # (B,1)
    x_ref = x_ref[:, 0, 0].unsqueeze(-1)       # (B,1)
    y_ref = y_ref[:, 0, 0].unsqueeze(-1)       # (B,1)
    theta_ref = theta_ref[:, 0, 0].unsqueeze(-1)  # (B,1)

    dx = batchY[..., 0] - x_ref
    dy = batchY[..., 1] - y_ref

    new_x = dx * cos_theta - dy * sin_theta
    new_y = dx * sin_theta + dy * cos_theta

    new_vx = batchY[..., 2] * cos_theta - batchY[..., 3] * sin_theta
    new_vy = batchY[..., 2] * sin_theta + batchY[..., 3] * cos_theta

    new_heading = (batchY[..., 4] - theta_ref + torch.pi) % (2 * torch.pi) - torch.pi

    def minmax_norm(t, t_min, t_max):
        return (t - t_min) / (t_max - t_min)

    new_x = minmax_norm(new_x, POS_X_MIN, POS_X_MAX)
    new_y = minmax_norm(new_y, POS_Y_MIN, POS_Y_MAX)
    new_vx = minmax_norm(new_vx, VEL_X_MIN, VEL_X_MAX)
    new_vy = minmax_norm(new_vy, VEL_Y_MIN, VEL_Y_MAX)

    all_zero_mask = torch.all(batchY == 0.0, dim=-1)  # (B,60)
    # all_zero_mask = torch.all(torch.isclose(batchY, torch.zeros_like(batchY)), dim=-1)


    mask_inv = ~all_zero_mask

    batchY[..., 0] = torch.where(mask_inv, new_x, batchY[..., 0])
    batchY[..., 1] = torch.where(mask_inv, new_y, batchY[..., 1])
    batchY[..., 2] = torch.where(mask_inv, new_vx, batchY[..., 2])
    batchY[..., 3] = torch.where(mask_inv, new_vy, batchY[..., 3])
    batchY[..., 4] = torch.where(mask_inv, new_heading, batchY[..., 4])

    return batchY


In [8]:
def unnormalize_batchY(batchY_pred, x_ref, y_ref, theta_ref):
    batchY_pred = batchY_pred.clone()

    cos_theta = torch.cos(theta_ref).squeeze(-1)
    sin_theta = torch.sin(theta_ref).squeeze(-1)

    cos_theta = cos_theta[:, 0].unsqueeze(-1)
    sin_theta = sin_theta[:, 0].unsqueeze(-1)
    x_ref = x_ref[:, 0, 0].unsqueeze(-1)
    y_ref = y_ref[:, 0, 0].unsqueeze(-1)
    theta_ref = theta_ref[:, 0, 0].unsqueeze(-1)

    x = batchY_pred[..., 0]
    y = batchY_pred[..., 1]

    def minmax_unnorm(t, t_min, t_max):
        return t * (t_max - t_min) + t_min

    xx = minmax_unnorm(x, POS_X_MIN, POS_X_MAX)
    yy = minmax_unnorm(y, POS_Y_MIN, POS_Y_MAX)

    vx = batchY_pred[..., 2]
    vy = batchY_pred[..., 3]

    vvx = minmax_unnorm(vx, VEL_X_MIN, VEL_X_MAX)
    vvy = minmax_unnorm(vy, VEL_Y_MIN, VEL_Y_MAX)
    

    global_x = xx * cos_theta - yy * sin_theta + x_ref
    global_y = xx * sin_theta + yy * cos_theta + y_ref

    vx = batchY_pred[..., 2]
    vy = batchY_pred[..., 3]

    global_vx = vvx * cos_theta - vvy * sin_theta
    global_vy = vvx * sin_theta + vvy * cos_theta

    global_heading = (batchY_pred[..., 4] + theta_ref + torch.pi) % (2 * torch.pi) - torch.pi

    all_zero_mask = torch.all(batchY_pred == 0.0, dim=-1)
    # all_zero_mask = torch.all(torch.isclose(batchY, torch.zeros_like(batchY)), dim=-1)


    mask_inv = ~all_zero_mask

    batchY_pred[..., 0] = torch.where(mask_inv, global_x, batchY_pred[..., 0])
    batchY_pred[..., 1] = torch.where(mask_inv, global_y, batchY_pred[..., 1])
    batchY_pred[..., 2] = torch.where(mask_inv, global_vx, batchY_pred[..., 2])
    batchY_pred[..., 3] = torch.where(mask_inv, global_vy, batchY_pred[..., 3])
    batchY_pred[..., 4] = torch.where(mask_inv, global_heading, batchY_pred[..., 4])

    return batchY_pred


In [9]:
tempX = 100 * torch.rand(10, 50, 50, 6).to(device)
tempY = 300 * torch.rand(10, 60, 6).to(device)

tempY[0, 0, :]

tensor([ 14.8431,  85.2577, 109.9855,  95.4430, 149.1436, 152.8894])

In [10]:
normX, x_ref, y_ref, theta_ref = normalize_batchX(tempX)
normY = normalize_batchY(tempY, x_ref, y_ref, theta_ref)
unNormY = unnormalize_batchY(normY, x_ref, y_ref, theta_ref)
tempY[0, 0, :], normY[0, 0, :], unNormY[0, 0, :]

(tensor([ 14.8431,  85.2577, 109.9855,  95.4430, 149.1436, 152.8894]),
 tensor([  0.4130,   0.4153,   1.8086,   1.9270,  -1.6529, 152.8894]),
 tensor([ 14.8430,  85.2578, 109.9855,  95.4430,  -1.6529, 152.8894]))

In [11]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (T, D)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (T, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # (1, T, D)

    def forward(self, x):
        # x: (B, T, D)
        return x + self.pe[:, :x.size(1)].to(x.device)

class TrajectoryTransformer(nn.Module):
    def __init__(self, input_dim = 6, output_dim = 6, model_dim=128, num_heads=4, num_layers=3, dropout=0.1, pred_len=60):
        super().__init__()
        self.model_dim = model_dim
        self.pred_len = pred_len

        self.input_fc = nn.Linear(input_dim, model_dim)
        self.temporal_encoding = PositionalEncoding(model_dim, max_len=128)

        self.time_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        # Agent-wise encoder: (50 agents) * (50 time) → Agent embeddings
        self.agent_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.query_embed = nn.Parameter(torch.randn(1, pred_len, model_dim))
        
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        # Final prediction: output 2D coordinates
        self.output_fc = nn.Linear(model_dim, output_dim)

    def forward(self, x):
        # x: (B, 50 agents, 50 timesteps, f features)
        B, N, T, F = x.shape

        x = x.view(B * N, T, F)
        x = self.input_fc(x)  
        x = self.temporal_encoding(x) 
        # x = x.mean(dim=1)

        # x = self.temporal_encoding(x)              # (B*N, T, D)
        x = self.time_encoder(x)                   # (B*N, T, D)
        x = x[:, -1, :]                            # Use last token representation (you could also try mean pooling)
        # x = x.view(B, N, self.model_dim)

        x = x.view(B, N, self.model_dim)
        # print(x.shape)

        encoded_agents = self.agent_encoder(x)  

        agent0_embed = encoded_agents[:, :, :]  # (B, D)
        # agent0_embed = agent0_embed.unsqueeze(1)  # (B, 1, D)

        agent0_embed = encoded_agents[:, 0:1, :]  # shape: (B, 1, D)
        query = agent0_embed.repeat(1, self.pred_len, 1)  # shape: (B, 60, D)
        out = self.decoder(query, encoded_agents)
        # query = self.query_embed.repeat(B, 1, 1)  # (B, 60, D)

        # out = self.decoder(query, agent0_embed)  # (B, 60, D)

        out = self.output_fc(out)
        return out



model = TrajectoryTransformer()
x = torch.randn(1, 50, 50, 6)
out = model(x)
print(out.shape)

torch.Size([1, 60, 6])


In [12]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (T, D)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (T, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # (1, T, D)

    def forward(self, x):
        # x: (B, T, D)
        return x + self.pe[:, :x.size(1)].to(x.device)

class TrajectoryTransformer(nn.Module):
    def __init__(self, input_dim=6, output_dim=6, model_dim=128, num_heads=4, num_layers=3, dropout=0.1, pred_len=60, num_agents=50):
        super().__init__()
        self.model_dim = model_dim
        self.pred_len = pred_len
        self.num_agents = num_agents

        self.input_fc = nn.Linear(input_dim, model_dim)
        self.temporal_encoding = PositionalEncoding(model_dim, max_len=128)

        self.time_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.agent_embedding = nn.Embedding(num_agents, model_dim)  # Option 2: agent ID embedding

        self.agent_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout, batch_first=True),
            num_layers=num_layers
        )

        self.output_fc = nn.Linear(model_dim, output_dim)

    def forward(self, x):
        # x: (B, N=50 agents, T=50 timesteps, F=6 features)
        B, N, T, F = x.shape

        x = x.view(B * N, T, F)               # (B*N, T, F)
        x = self.input_fc(x)                  # (B*N, T, D)
        x = self.temporal_encoding(x)         # (B*N, T, D)
        x = self.time_encoder(x)              # (B*N, T, D)
        x = x[:, -1, :]                       # (B*N, D) → use last time step
        x = x.view(B, N, self.model_dim)      # (B, N, D)

        # Add agent identity embeddings (Option 2)
        agent_ids = torch.arange(0, N, device=x.device).unsqueeze(0).repeat(B, 1)  # (B, N)
        agent_id_embeds = self.agent_embedding(agent_ids)                         # (B, N, D)
        x = x + agent_id_embeds                                                   # (B, N, D)

        encoded_agents = self.agent_encoder(x)  # (B, N, D)

        # Use 0th agent embedding as query (Option 1)
        agent0_embed = encoded_agents[:, 0:1, :]                 # (B, 1, D)
        query = agent0_embed.repeat(1, self.pred_len, 1)         # (B, 60, D)

        out = self.decoder(query, encoded_agents)                # (B, 60, D)
        out = self.output_fc(out)                                # (B, 60, output_dim)

        return out

# Test run
model = TrajectoryTransformer()
x = torch.randn(1, 50, 50, 6)  # (B=1, 50 agents, 50 time, 6 features)
out = model(x)
print(out.shape)  # Expected: (1, 60, 6)


torch.Size([1, 60, 6])


In [13]:
model = TrajectoryTransformer().to(device=device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 5544198


In [14]:
np.random.seed(42)
num_samples = trainData.shape[0]
indices = np.random.permutation(num_samples)
split_index = int(0.9 * num_samples)
train_idx, val_idx = indices[:split_index], indices[split_index:]

# Split the data
train_data = trainData[train_idx]
val_data = trainData[val_idx]

print("Train shape:", train_data.shape)
print("Validation shape:", val_data.shape)

Train shape: (9000, 50, 110, 6)
Validation shape: (1000, 50, 110, 6)


In [15]:
trainX, trainY, validX, validY = train_data[:, :, :50, :], train_data[:, 0, 50:, :], val_data[:, :, :50, :], val_data[:, 0, 50:, :]
trainX.shape, trainY.shape, validX.shape, validY.shape

((9000, 50, 50, 6), (9000, 60, 6), (1000, 50, 50, 6), (1000, 60, 6))

In [16]:
class TensorDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X_ = self.X[idx]
        Y_ = self.Y[idx]
        
        return (torch.tensor(X_, dtype=torch.float32), 
                torch.tensor(Y_, dtype=torch.float32))

In [17]:
trainTensor = TensorDataset(trainX, trainY)
testTensor = TensorDataset(validX, validY)

train_dataloader = DataLoader(trainTensor, batch_size=32, shuffle=True)
val_dataloader = DataLoader(testTensor, batch_size=32, shuffle=False)

In [18]:
import torch
import torch.nn as nn

class WeightedMSELoss(nn.Module):
    def __init__(self, weights):
        super(WeightedMSELoss, self).__init__()
        self.weights = weights

    def forward(self, input, target):
        loss = (input - target) ** 2
        weighted_loss = loss * self.weights
        return weighted_loss.mean()

feature_weights = torch.tensor([1.0, 1.0, 0.1, 0.1, 0.01, 0.01], device=device)


In [19]:
# export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0

torch.cuda.empty_cache()
epochs = 1000
# lossFn = nn.MSELoss()
lossFn = WeightedMSELoss(feature_weights)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25)
best_val_loss = float('inf')
best_train_loss = float('inf')
position_scale = 1.0
velocity_scale = 1.0

all_losses = {
    'training_mse_loss':[],
    'validation_mse_loss':[],
    'true_mse':[],
    'true_mae':[]
}

for each_epoch in range(epochs):
    model.train()
    runningLoss = 0.0
    loop = tqdm(train_dataloader, desc=f"Epoch [{each_epoch+1}/{epochs}]")
    for batchX, batchY in loop:
        normalizedX, x_ref, y_ref, theta_ref = normalize_batchX(batchX) 
        normalizedY = normalize_batchY(batchY, x_ref, y_ref, theta_ref)
        
        normalizedX = normalizedX.to(device, non_blocking=True)   
        normalizedY = normalizedY.to(device, non_blocking=True)  
            
        # compute_feature_ranges(normalizedX)
        # compute_feature_ranges(normalizedY)
        pred = model(normalizedX)
        loss = lossFn(pred[...,:], normalizedY[..., :]).to(device)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        runningLoss += loss.item()        
        # print(pred.shape)
        # break
    # break

    
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batchX, batchY in val_dataloader:   
            normalizedX, x_ref, y_ref, theta_ref = normalize_batchX(batchX) 
            x_ref = x_ref.to(device)
            y_ref = y_ref.to(device)
            theta_ref = theta_ref.to(device)
            batchX = batchX.to(device, non_blocking=True)   
            batchY = batchY.to(device, non_blocking=True) 
            normalizedY = normalize_batchY(batchY, x_ref, y_ref, theta_ref)
            
            normalizedX = normalizedX.to(device, non_blocking=True)   
            normalizedY = normalizedY.to(device, non_blocking=True)  

            pred = model(normalizedX)
            loss = lossFn(pred[...,:], normalizedY[..., :]).to(device)
            
            
            unnorm_pred = unnormalize_batchY(pred, x_ref, y_ref, theta_ref)
            unnorm_true = unnormalize_batchY(normalizedY, x_ref, y_ref, theta_ref)

            # print(batchY[0, 34, :], '\n', normalizedY[0, 34, :], '\n', unnorm_pred[0, 34, :], '\n',  unnorm_true[0, 34, :])
            # break
            val_loss += loss.item()
            val_mae += nn.L1Loss()(unnorm_pred[..., :2], unnorm_true[..., :2]).item()
            val_mse += nn.MSELoss()(unnorm_pred[..., :2], unnorm_true[..., :2]).item()
    #         break
    # break
            
    train_loss = runningLoss/len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    all_losses["training_mse_loss"].append(train_loss)
    all_losses["validation_mse_loss"].append(val_loss)
    all_losses["true_mse"].append(val_mse)
    all_losses["true_mae"].append(val_mae)

    loop.write(f" train MSE {train_loss:.10f} | train val MSE {val_loss:.10f} | val MAE {val_mae:.10f} | val MSE {val_mse:.10f}")
    scheduler.step()
    
    if train_loss < best_train_loss and val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        best_train_loss = train_loss
        no_improvement = 0
        torch.save(model.state_dict(), "./models/modelI/best_model.pt")
        loop.write(f" model Saved")
    torch.cuda.empty_cache()
    # else:
    #     no_improvement += 1
    # break
    
    

Epoch [1/1000]: 100%|██████████| 282/282 [25:57<00:00,  5.52s/it]


 train MSE 0.0191030771 | train val MSE 0.0142462896 | val MAE 2104.3061027527 | val MSE 7279939.1562500000
 model Saved


Epoch [2/1000]:  10%|▉         | 27/282 [02:31<23:51,  5.61s/it]


KeyboardInterrupt: 