In [1]:
pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
import os
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F

In [3]:
train_npz = np.load('/kaggle/input/argoverse-modified/train.npz')
train_data = train_npz['data']
test_npz  = np.load('/kaggle/input/argoverse-modified/test_input.npz')
test_data  = test_npz['data']

In [4]:
print(train_data.shape, test_data.shape)

# Split once for later use
X_train = train_data[..., :50, :]
Y_train = train_data[:, 0, 50:, :2]

(10000, 50, 110, 6) (2100, 50, 50, 6)


In [5]:
def rotate(x, y, heading):
    cos_theta = np.cos(-heading)
    sin_theta = np.sin(-heading)
    x_rot = x * cos_theta - y * sin_theta
    y_rot = x * sin_theta + y * cos_theta
    return x_rot, y_rot

In [6]:
class MotionEmbeddingMLP(nn.Module):
    def __init__(self, input_dim = 12, hidden_dim=32, output_dim=16):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        # x: (num_agents, seq_len, input_dim)
        return self.mlp(x)

In [7]:
class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (60, 2)
        
        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1

            # Gaussian position jitter
            if np.random.rand() < 0.5:
                noise = np.random.normal(0, 0.1, size=hist[..., :2].shape)
                hist[..., :2] += noise

            # Time warping
            if np.random.rand() < 0.3:
                scale = np.random.uniform(0.9, 1.1)
                hist[..., :2] *= scale
                hist[..., 2:4] *= scale
                future *= scale

            # Velocity perturbation
            if np.random.rand() < 0.4:
                perturb = np.random.normal(0, 0.05, size=hist[..., 2:4].shape)
                hist[..., 2:4] += perturb
                hist[..., :2] += perturb * 0.1  # assuming 0.1s per timestep

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale

        vx = hist[..., 2]
        vy = hist[..., 3]
        ego_heading = np.arctan2(vy[0], vx[0])
        ego_heading = ego_heading[None, :] 
        dvx = np.diff(vx, axis=1, prepend=vx[:, 0:1])
        dvy = np.diff(vy, axis=1, prepend=vy[:, 0:1])
        acceleration = np.sqrt(dvx**2 + dvy**2)

        acc_x_ego, acc_y_ego = rotate(dvx, dvy, ego_heading)

        vel_x_ego, vel_y_ego = rotate(vx, vy, ego_heading)

        theta = np.arctan2(vy, vx)
        dtheta = np.diff(theta, axis=1, prepend=theta[:, 0:1])

        ego_pos = hist[0:1, :, :2]  # (1, 50, 2)
        rel_pos = hist[:, :, :2] - ego_pos  # (50, 50, 2)
        dist_to_ego = np.linalg.norm(rel_pos, axis=-1, keepdims=True)  # (50, 50, 1)
        
        rel_x = rel_pos[:, :, 0]
        rel_y = rel_pos[:, :, 1]

        rot_x, rot_y = rotate(rel_x, rel_y, ego_heading)

        rotated_rel = np.stack([rot_x, rot_y], axis=-1)
        rotated_vel = np.stack([vel_x_ego, vel_y_ego], axis=-1) # (50, 50, 2)
        rotated_acc = np.stack([acc_x_ego, acc_y_ego], axis=-1)


        hist = np.concatenate([hist, dvx[..., None], dvy[..., None], acceleration[..., None], dtheta[..., None],
        rel_pos, rotated_rel, rotated_vel, rotated_acc, dist_to_ego], axis=-1)
        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

        return data_item
    

class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()
        
        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale

        vx = hist[..., 2]
        vy = hist[..., 3]
        ego_heading = np.arctan2(vy[0], vx[0])
        ego_heading = ego_heading[None, :] 
        dvx = np.diff(vx, axis=1, prepend=vx[:, 0:1])
        dvy = np.diff(vy, axis=1, prepend=vy[:, 0:1])
        acceleration = np.sqrt(dvx**2 + dvy**2)

        acc_x_ego, acc_y_ego = rotate(dvx, dvy, ego_heading)

        vel_x_ego, vel_y_ego = rotate(vx, vy, ego_heading)

        theta = np.arctan2(vy, vx)
        dtheta = np.diff(theta, axis=1, prepend=theta[:, 0:1])

        ego_pos = hist[0:1, :, :2]  # (1, 50, 2)
        rel_pos = hist[:, :, :2] - ego_pos
        dist_to_ego = np.linalg.norm(rel_pos, axis=-1, keepdims=True)  # (50, 50, 1)

        rel_x = rel_pos[:, :, 0]
        rel_y = rel_pos[:, :, 1]

        rot_x, rot_y = rotate(rel_x, rel_y, ego_heading)

        rotated_rel = np.stack([rot_x, rot_y], axis=-1)
        rotated_vel = np.stack([vel_x_ego, vel_y_ego], axis=-1) # (50, 50, 2)
        rotated_acc = np.stack([acc_x_ego, acc_y_ego], axis=-1)

        hist = np.concatenate([hist, dvx[..., None], dvy[..., None], acceleration[..., None], dtheta[..., None],
        rel_pos, rotated_rel, rotated_vel, rotated_acc, dist_to_ego], axis=-1)

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item

In [8]:
torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using CUDA GPU


In [9]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 50):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model/2)

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices

        self.register_buffer('pe', pe)  # makes sure it's not a model parameter

    def forward(self, seq_len):
        return self.pe[:seq_len]

In [10]:
class SpatialBlock(nn.Module):
    def __init__(self, hidden_size, nhead):
        super().__init__()
        self.attn = nn.MultiheadAttention(hidden_size, nhead, batch_first=True)

    def forward(self, ego, others):
        # ego: [B, 1, H], others: [B, N, H]
        attn_out, _ = self.attn(ego, others, others)
        return ego + attn_out  # [B, 1, H]

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.query = nn.Parameter(torch.randn(hidden_dim))  # learnable [H]
        self.scale = hidden_dim ** 0.5

    def forward(self, x, mask=None):
        # x: [B, N, T, H]
        B, N, T, H = x.shape
        q = self.query.view(1, 1, 1, H)  # [1, 1, 1, H]
        attn_scores = (x * q).sum(dim=-1) / self.scale  # [B, N, T]

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask.squeeze(-1) == 0, -1e9)

        attn_weights = F.softmax(attn_scores, dim=-1)  # [B, N, T]
        pooled = (attn_weights.unsqueeze(-1) * x).sum(dim=2)  # [B, N, H]
        return pooled

class AgentAwarePredictor(nn.Module):
    def __init__(self, input_stats=None, target_stats=None):
        super().__init__()
        self.hidden_size = 128
        self.nhead = 4
        self.num_layers = 2
        self.num_agent_types = 10
        self.dropout_rate = 0.0

        self.time_pool = AttentionPooling(self.hidden_size * 2)

        agent_embed_dim = self.hidden_size // 4
        self.agent_embed = nn.Embedding(self.num_agent_types, agent_embed_dim)

        self.input_proj = nn.Linear(19 + agent_embed_dim, self.hidden_size)
        self.input_norm = nn.LayerNorm(self.hidden_size)
        self.pos_encoder = PositionalEncoding(self.hidden_size)

        encoder_layer = TransformerEncoderLayer(
            d_model=self.hidden_size * 2,
            nhead=self.nhead,
            dim_feedforward=2 * self.hidden_size,
            dropout=self.dropout_rate,
            batch_first=True
        )
        self.temporal_encoder = TransformerEncoder(encoder_layer, self.num_layers)

        self.cross_attn = SpatialBlock(self.hidden_size * 2, self.nhead)

        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_size * 2, 256),
            nn.ReLU(),
            nn.Dropout(self.dropout_rate),
            nn.Linear(256, 60 * 2)
        )

    def forward(self, data):
        x = data.x  # shape: [B*50, 50, 19]
        x = x.view(-1, 50, 50, 19)  # [B, N=50, T=50, F=19]
        B, N, T, F = x.shape
    
        # Mask dead agents
        dead_agents = (x.abs().sum(dim=(2, 3)) == 0)  # [B, N]
        alive_mask = ~dead_agents  # [B, N]
        max_agents = alive_mask.sum(dim=1).max().item()
        new_x = torch.zeros(B, max_agents, T, F, device=x.device)
    
        for b in range(B):
            new_x[b, :alive_mask[b].sum()] = x[b, alive_mask[b]]
    
        x = new_x
        B, N, T, F = x.shape  # updated N
    
        # Agent types from feature index 5
        agent_types = x[:, :, 0, 5].long()
        type_embed = self.agent_embed(agent_types).unsqueeze(2).expand(-1, -1, T, -1)
    
        # Combine features + agent embedding
        x_combined = torch.cat([x, type_embed], dim=-1)  # [B, N, T, F+E]
        x_proj = self.input_proj(x_combined)
        x_proj = self.input_norm(x_proj)
    
        # Flatten for temporal encoding
        x_flat = x_proj.view(B * N, T, self.hidden_size)
    
        # Positional encoding
        pos_emb = self.pos_encoder(T).unsqueeze(0).expand(B * N, -1, -1)
        x_flat = torch.cat([x_flat, pos_emb], dim=-1)  # [B*N, T, H*2]
    
        # Temporal Transformer
        x_encoded = self.temporal_encoder(x_flat).view(B, N, T, -1)
    
        # Time attention pooling
        mask = (x.abs().sum(dim=-1) > 0).float().unsqueeze(-1)  # [B, N, T, 1]
        x_encoded = self.time_pool(x_encoded, mask=mask)  # [B, N, H]
    
        # Ego attends to all
        h_ego = x_encoded[:, 0:1, :]  # [B, 1, H]
        h_attended = self.cross_attn(h_ego, x_encoded)  # [B, 1, H]
        h_attended = h_attended.squeeze(1)  # [B, H]
    
        pred_norm = self.decoder(h_attended).view(B, 60, 2)
        return pred_norm

    def predict_denorm(self, x):
        return self.forward(x)


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 50):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model/2)

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices

        self.register_buffer('pe', pe)  # makes sure it's not a model parameter

    def forward(self, seq_len):
        return self.pe[:seq_len]

In [11]:
model = AgentAwarePredictor().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [12]:
def combined_loss_fn(pred_norm, y_norm, batch, fde_weight=0.2):
    """
    Combined loss function with MSE + FDE in world coordinates
    """
    # 1) Raw MSE (in normalized coordinates)
    raw_loss = F.mse_loss(pred_norm, y_norm)
    
    # 2) Denormalize to world coordinates
    scale = batch.scale.view(-1, 1, 1)   # (B,1,1)
    origin = batch.origin.view(-1, 1, 2)  # (B,1,2)
    pred_world = pred_norm * scale + origin  # (B,60,2)
    y_world = y_norm * scale + origin
    
    # 3) World MSE
    world_loss = F.mse_loss(pred_world, y_world)
    
    # 4) Final Displacement Error (FDE) at t=60 in world coords
    fde = F.mse_loss(pred_world[:, -1, :], y_world[:, -1, :])
    
    # 5) Combined loss: world_loss + fde_weight * fde
    total_loss = world_loss + fde_weight * fde
    
    return total_loss, raw_loss, world_loss, fde

In [13]:
#torch.cuda.empty_cache()
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss, raw_loss, world_loss, fde = combined_loss_fn(pred, y, batch)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()
    
    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            combined_val_loss, _, _, _ = combined_loss_fn(pred, y, batch)
            val_loss += combined_val_loss.item()

            # show MAE and MSE with unnormalized data
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, y).item()
            val_mse += nn.MSELoss()(pred, y).item()
    
    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    scheduler.step()
    # scheduler.step(val_loss)
    print("epoch 1 almost done")
    
    tqdm.tqdm.write(f"Epoch {epoch:03d} | Learning rate {optimizer.param_groups[0]['lr']:.6f} | train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f}, | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    if val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            no_improvement = 0
            optimizer.param_groups[0]['lr'] /= 2
            # print("Early stop!")
            # break

Epoch:   1%|          | 1/100 [00:53<1:28:43, 53.77s/epoch]

epoch 1 almost done
Epoch 000 | Learning rate 0.001000 | train normalized MSE  73.4450 | val normalized MSE  40.9533, | val MAE   2.9722 | val MSE  24.2212


Epoch:   2%|▏         | 2/100 [01:45<1:26:12, 52.78s/epoch]

epoch 1 almost done
Epoch 001 | Learning rate 0.001000 | train normalized MSE  33.8285 | val normalized MSE  30.0519, | val MAE   2.6652 | val MSE  17.5594


Epoch:   3%|▎         | 3/100 [02:37<1:24:34, 52.31s/epoch]

epoch 1 almost done
Epoch 002 | Learning rate 0.001000 | train normalized MSE  29.9577 | val normalized MSE  33.3440, | val MAE   2.7291 | val MSE  19.6319


Epoch:   4%|▍         | 4/100 [03:29<1:23:17, 52.06s/epoch]

epoch 1 almost done
Epoch 003 | Learning rate 0.001000 | train normalized MSE  27.9871 | val normalized MSE  34.1481, | val MAE   2.8471 | val MSE  20.2417


Epoch:   5%|▌         | 5/100 [04:21<1:22:31, 52.12s/epoch]

epoch 1 almost done
Epoch 004 | Learning rate 0.001000 | train normalized MSE  25.9533 | val normalized MSE  23.8382, | val MAE   1.9627 | val MSE  13.3663


Epoch:   6%|▌         | 6/100 [05:14<1:22:12, 52.47s/epoch]

epoch 1 almost done
Epoch 005 | Learning rate 0.001000 | train normalized MSE  24.3838 | val normalized MSE  21.4336, | val MAE   1.9341 | val MSE  12.0876


Epoch:   7%|▋         | 7/100 [06:07<1:21:27, 52.55s/epoch]

epoch 1 almost done
Epoch 006 | Learning rate 0.001000 | train normalized MSE  24.6724 | val normalized MSE  26.0570, | val MAE   2.2099 | val MSE  14.9005


Epoch:   8%|▊         | 8/100 [07:00<1:20:52, 52.74s/epoch]

epoch 1 almost done
Epoch 007 | Learning rate 0.001000 | train normalized MSE  23.8342 | val normalized MSE  19.4877, | val MAE   1.7464 | val MSE  10.7091


Epoch:   9%|▉         | 9/100 [07:53<1:20:15, 52.92s/epoch]

epoch 1 almost done
Epoch 008 | Learning rate 0.001000 | train normalized MSE  22.8632 | val normalized MSE  23.1485, | val MAE   2.0731 | val MSE  13.2691


Epoch:  10%|█         | 10/100 [08:46<1:19:22, 52.91s/epoch]

epoch 1 almost done
Epoch 009 | Learning rate 0.000500 | train normalized MSE  22.1604 | val normalized MSE  21.9678, | val MAE   1.9548 | val MSE  12.3882


Epoch:  11%|█         | 11/100 [09:39<1:18:18, 52.79s/epoch]

epoch 1 almost done
Epoch 010 | Learning rate 0.000500 | train normalized MSE  19.5466 | val normalized MSE  17.5406, | val MAE   1.5802 | val MSE   9.5730


Epoch:  12%|█▏        | 12/100 [10:31<1:17:10, 52.61s/epoch]

epoch 1 almost done
Epoch 011 | Learning rate 0.000500 | train normalized MSE  18.6283 | val normalized MSE  18.9957, | val MAE   1.7708 | val MSE  10.5108


Epoch:  13%|█▎        | 13/100 [11:24<1:16:16, 52.60s/epoch]

epoch 1 almost done
Epoch 012 | Learning rate 0.000500 | train normalized MSE  18.6025 | val normalized MSE  18.9398, | val MAE   1.7862 | val MSE  10.4451


Epoch:  14%|█▍        | 14/100 [12:16<1:15:15, 52.51s/epoch]

epoch 1 almost done
Epoch 013 | Learning rate 0.000500 | train normalized MSE  18.5768 | val normalized MSE  17.6524, | val MAE   1.5805 | val MSE   9.7208


Epoch:  15%|█▌        | 15/100 [13:08<1:14:15, 52.42s/epoch]

epoch 1 almost done
Epoch 014 | Learning rate 0.000500 | train normalized MSE  18.2544 | val normalized MSE  18.7454, | val MAE   1.7198 | val MSE  10.3695


Epoch:  16%|█▌        | 16/100 [14:00<1:13:07, 52.23s/epoch]

epoch 1 almost done
Epoch 015 | Learning rate 0.000500 | train normalized MSE  18.0028 | val normalized MSE  17.8971, | val MAE   1.7122 | val MSE   9.9027


Epoch:  17%|█▋        | 17/100 [14:52<1:12:14, 52.22s/epoch]

epoch 1 almost done
Epoch 016 | Learning rate 0.000500 | train normalized MSE  18.0980 | val normalized MSE  17.8458, | val MAE   1.6534 | val MSE   9.7788


Epoch:  18%|█▊        | 18/100 [15:44<1:11:13, 52.12s/epoch]

epoch 1 almost done
Epoch 017 | Learning rate 0.000500 | train normalized MSE  17.9182 | val normalized MSE  17.2733, | val MAE   1.5693 | val MSE   9.3328


Epoch:  19%|█▉        | 19/100 [16:36<1:10:14, 52.04s/epoch]

epoch 1 almost done
Epoch 018 | Learning rate 0.000500 | train normalized MSE  17.8653 | val normalized MSE  17.3232, | val MAE   1.5596 | val MSE   9.4653


Epoch:  20%|██        | 20/100 [17:27<1:09:10, 51.88s/epoch]

epoch 1 almost done
Epoch 019 | Learning rate 0.000250 | train normalized MSE  17.8237 | val normalized MSE  17.8381, | val MAE   1.6212 | val MSE   9.8918


Epoch:  21%|██        | 21/100 [18:19<1:08:09, 51.77s/epoch]

epoch 1 almost done
Epoch 020 | Learning rate 0.000250 | train normalized MSE  16.3849 | val normalized MSE  16.1515, | val MAE   1.4745 | val MSE   8.7464


Epoch:  22%|██▏       | 22/100 [19:10<1:07:15, 51.73s/epoch]

epoch 1 almost done
Epoch 021 | Learning rate 0.000250 | train normalized MSE  16.3690 | val normalized MSE  15.6926, | val MAE   1.4247 | val MSE   8.5049


Epoch:  23%|██▎       | 23/100 [20:02<1:06:19, 51.68s/epoch]

epoch 1 almost done
Epoch 022 | Learning rate 0.000250 | train normalized MSE  16.3905 | val normalized MSE  15.5112, | val MAE   1.4129 | val MSE   8.3502


Epoch:  24%|██▍       | 24/100 [20:54<1:05:29, 51.70s/epoch]

epoch 1 almost done
Epoch 023 | Learning rate 0.000250 | train normalized MSE  16.1062 | val normalized MSE  16.3612, | val MAE   1.4869 | val MSE   8.8285


Epoch:  25%|██▌       | 25/100 [21:46<1:04:57, 51.96s/epoch]

epoch 1 almost done
Epoch 024 | Learning rate 0.000250 | train normalized MSE  15.9480 | val normalized MSE  15.6493, | val MAE   1.4093 | val MSE   8.4008


Epoch:  26%|██▌       | 26/100 [22:39<1:04:25, 52.23s/epoch]

epoch 1 almost done
Epoch 025 | Learning rate 0.000250 | train normalized MSE  16.0472 | val normalized MSE  17.6170, | val MAE   1.6621 | val MSE   9.6877


Epoch:  27%|██▋       | 27/100 [23:32<1:03:36, 52.28s/epoch]

epoch 1 almost done
Epoch 026 | Learning rate 0.000250 | train normalized MSE  15.8317 | val normalized MSE  15.8244, | val MAE   1.4423 | val MSE   8.5368


Epoch:  28%|██▊       | 28/100 [24:24<1:02:46, 52.32s/epoch]

epoch 1 almost done
Epoch 027 | Learning rate 0.000250 | train normalized MSE  15.7414 | val normalized MSE  15.5212, | val MAE   1.3978 | val MSE   8.3679


Epoch:  29%|██▉       | 29/100 [25:17<1:02:04, 52.45s/epoch]

epoch 1 almost done
Epoch 028 | Learning rate 0.000250 | train normalized MSE  15.7179 | val normalized MSE  15.7234, | val MAE   1.4261 | val MSE   8.4746


Epoch:  30%|███       | 30/100 [26:09<1:01:13, 52.48s/epoch]

epoch 1 almost done
Epoch 029 | Learning rate 0.000125 | train normalized MSE  15.6910 | val normalized MSE  16.2521, | val MAE   1.4745 | val MSE   8.7851


Epoch:  31%|███       | 31/100 [27:02<1:00:23, 52.51s/epoch]

epoch 1 almost done
Epoch 030 | Learning rate 0.000125 | train normalized MSE  14.9263 | val normalized MSE  15.2095, | val MAE   1.3616 | val MSE   8.1609


Epoch:  32%|███▏      | 32/100 [27:54<59:25, 52.43s/epoch]  

epoch 1 almost done
Epoch 031 | Learning rate 0.000125 | train normalized MSE  14.6290 | val normalized MSE  14.6598, | val MAE   1.3403 | val MSE   7.8547


Epoch:  33%|███▎      | 33/100 [28:46<58:27, 52.35s/epoch]

epoch 1 almost done
Epoch 032 | Learning rate 0.000125 | train normalized MSE  14.5400 | val normalized MSE  14.9081, | val MAE   1.3512 | val MSE   7.9979


Epoch:  34%|███▍      | 34/100 [29:38<57:30, 52.28s/epoch]

epoch 1 almost done
Epoch 033 | Learning rate 0.000125 | train normalized MSE  14.5972 | val normalized MSE  15.3689, | val MAE   1.4141 | val MSE   8.2737


Epoch:  35%|███▌      | 35/100 [30:31<56:41, 52.34s/epoch]

epoch 1 almost done
Epoch 034 | Learning rate 0.000125 | train normalized MSE  14.3734 | val normalized MSE  15.4346, | val MAE   1.4410 | val MSE   8.3591


Epoch:  36%|███▌      | 36/100 [31:23<55:52, 52.39s/epoch]

epoch 1 almost done
Epoch 035 | Learning rate 0.000125 | train normalized MSE  14.3239 | val normalized MSE  14.9857, | val MAE   1.3772 | val MSE   8.0214


Epoch:  37%|███▋      | 37/100 [32:16<55:01, 52.41s/epoch]

epoch 1 almost done
Epoch 036 | Learning rate 0.000125 | train normalized MSE  14.3924 | val normalized MSE  14.6499, | val MAE   1.3038 | val MSE   7.8476


Epoch:  38%|███▊      | 38/100 [33:08<54:08, 52.40s/epoch]

epoch 1 almost done
Epoch 037 | Learning rate 0.000125 | train normalized MSE  14.2355 | val normalized MSE  14.9610, | val MAE   1.3611 | val MSE   8.0468


Epoch:  39%|███▉      | 39/100 [34:01<53:22, 52.50s/epoch]

epoch 1 almost done
Epoch 038 | Learning rate 0.000125 | train normalized MSE  14.1398 | val normalized MSE  14.6629, | val MAE   1.2824 | val MSE   7.8367


Epoch:  40%|████      | 40/100 [34:53<52:21, 52.35s/epoch]

epoch 1 almost done
Epoch 039 | Learning rate 0.000063 | train normalized MSE  14.1164 | val normalized MSE  14.5867, | val MAE   1.3746 | val MSE   7.8395


Epoch:  41%|████      | 41/100 [35:45<51:25, 52.30s/epoch]

epoch 1 almost done
Epoch 040 | Learning rate 0.000063 | train normalized MSE  13.7569 | val normalized MSE  13.9591, | val MAE   1.2464 | val MSE   7.4467


Epoch:  42%|████▏     | 42/100 [36:38<50:39, 52.41s/epoch]

epoch 1 almost done
Epoch 041 | Learning rate 0.000063 | train normalized MSE  13.6034 | val normalized MSE  14.2764, | val MAE   1.2835 | val MSE   7.6420


Epoch:  43%|████▎     | 43/100 [37:31<49:55, 52.54s/epoch]

epoch 1 almost done
Epoch 042 | Learning rate 0.000063 | train normalized MSE  13.5383 | val normalized MSE  13.7746, | val MAE   1.2379 | val MSE   7.3296


Epoch:  44%|████▍     | 44/100 [38:23<49:05, 52.59s/epoch]

epoch 1 almost done
Epoch 043 | Learning rate 0.000063 | train normalized MSE  13.4186 | val normalized MSE  14.2476, | val MAE   1.2535 | val MSE   7.5958


Epoch:  45%|████▌     | 45/100 [39:16<48:13, 52.61s/epoch]

epoch 1 almost done
Epoch 044 | Learning rate 0.000063 | train normalized MSE  13.6256 | val normalized MSE  13.9623, | val MAE   1.2658 | val MSE   7.4538


Epoch:  46%|████▌     | 46/100 [40:10<47:36, 52.89s/epoch]

epoch 1 almost done
Epoch 045 | Learning rate 0.000063 | train normalized MSE  13.2401 | val normalized MSE  14.2129, | val MAE   1.2494 | val MSE   7.5794


Epoch:  47%|████▋     | 47/100 [41:03<46:53, 53.09s/epoch]

epoch 1 almost done
Epoch 046 | Learning rate 0.000063 | train normalized MSE  13.4325 | val normalized MSE  14.2774, | val MAE   1.2736 | val MSE   7.6399


Epoch:  48%|████▊     | 48/100 [41:55<45:47, 52.84s/epoch]

epoch 1 almost done
Epoch 047 | Learning rate 0.000063 | train normalized MSE  13.3235 | val normalized MSE  13.7728, | val MAE   1.2597 | val MSE   7.3371


Epoch:  49%|████▉     | 49/100 [42:48<44:56, 52.87s/epoch]

epoch 1 almost done
Epoch 048 | Learning rate 0.000063 | train normalized MSE  13.3773 | val normalized MSE  14.0002, | val MAE   1.2758 | val MSE   7.4653


Epoch:  50%|█████     | 50/100 [43:41<43:56, 52.74s/epoch]

epoch 1 almost done
Epoch 049 | Learning rate 0.000031 | train normalized MSE  13.3384 | val normalized MSE  13.8825, | val MAE   1.2722 | val MSE   7.4027


Epoch:  51%|█████     | 51/100 [44:33<43:01, 52.68s/epoch]

epoch 1 almost done
Epoch 050 | Learning rate 0.000031 | train normalized MSE  13.1134 | val normalized MSE  13.8464, | val MAE   1.2446 | val MSE   7.3958


Epoch:  52%|█████▏    | 52/100 [45:25<42:01, 52.53s/epoch]

epoch 1 almost done
Epoch 051 | Learning rate 0.000031 | train normalized MSE  12.9791 | val normalized MSE  13.6305, | val MAE   1.2269 | val MSE   7.2721


Epoch:  53%|█████▎    | 53/100 [46:18<41:14, 52.64s/epoch]

epoch 1 almost done
Epoch 052 | Learning rate 0.000031 | train normalized MSE  13.0537 | val normalized MSE  13.7702, | val MAE   1.2348 | val MSE   7.3334


Epoch:  54%|█████▍    | 54/100 [47:11<40:22, 52.66s/epoch]

epoch 1 almost done
Epoch 053 | Learning rate 0.000031 | train normalized MSE  12.9941 | val normalized MSE  13.8190, | val MAE   1.2287 | val MSE   7.3601


Epoch:  55%|█████▌    | 55/100 [48:03<39:26, 52.59s/epoch]

epoch 1 almost done
Epoch 054 | Learning rate 0.000031 | train normalized MSE  12.9974 | val normalized MSE  13.8121, | val MAE   1.2290 | val MSE   7.3610


Epoch:  56%|█████▌    | 56/100 [48:55<38:26, 52.42s/epoch]

epoch 1 almost done
Epoch 055 | Learning rate 0.000031 | train normalized MSE  13.0792 | val normalized MSE  13.5948, | val MAE   1.2184 | val MSE   7.2392


Epoch:  57%|█████▋    | 57/100 [49:48<37:34, 52.44s/epoch]

epoch 1 almost done
Epoch 056 | Learning rate 0.000031 | train normalized MSE  12.9112 | val normalized MSE  13.5376, | val MAE   1.2131 | val MSE   7.2042


Epoch:  58%|█████▊    | 58/100 [50:41<36:50, 52.63s/epoch]

epoch 1 almost done
Epoch 057 | Learning rate 0.000031 | train normalized MSE  12.7723 | val normalized MSE  13.6052, | val MAE   1.2271 | val MSE   7.2509


Epoch:  59%|█████▉    | 59/100 [51:34<35:59, 52.66s/epoch]

epoch 1 almost done
Epoch 058 | Learning rate 0.000031 | train normalized MSE  12.9449 | val normalized MSE  13.5944, | val MAE   1.2231 | val MSE   7.2394


Epoch:  60%|██████    | 60/100 [52:26<35:05, 52.65s/epoch]

epoch 1 almost done
Epoch 059 | Learning rate 0.000016 | train normalized MSE  12.8134 | val normalized MSE  13.6474, | val MAE   1.2255 | val MSE   7.2706


Epoch:  61%|██████    | 61/100 [53:19<34:16, 52.72s/epoch]

epoch 1 almost done
Epoch 060 | Learning rate 0.000016 | train normalized MSE  12.8725 | val normalized MSE  13.6336, | val MAE   1.2164 | val MSE   7.2570


Epoch:  62%|██████▏   | 62/100 [54:13<33:33, 52.98s/epoch]

epoch 1 almost done
Epoch 061 | Learning rate 0.000016 | train normalized MSE  12.7137 | val normalized MSE  13.6120, | val MAE   1.2126 | val MSE   7.2490


Epoch:  63%|██████▎   | 63/100 [55:06<32:41, 53.02s/epoch]

epoch 1 almost done
Epoch 062 | Learning rate 0.000016 | train normalized MSE  12.6998 | val normalized MSE  13.4984, | val MAE   1.2146 | val MSE   7.1833


Epoch:  64%|██████▍   | 64/100 [56:00<31:55, 53.20s/epoch]

epoch 1 almost done
Epoch 063 | Learning rate 0.000016 | train normalized MSE  12.5954 | val normalized MSE  13.5302, | val MAE   1.2276 | val MSE   7.1990


Epoch:  65%|██████▌   | 65/100 [56:53<31:01, 53.19s/epoch]

epoch 1 almost done
Epoch 064 | Learning rate 0.000016 | train normalized MSE  12.6043 | val normalized MSE  13.6644, | val MAE   1.2202 | val MSE   7.2749


Epoch:  66%|██████▌   | 66/100 [57:45<30:03, 53.04s/epoch]

epoch 1 almost done
Epoch 065 | Learning rate 0.000016 | train normalized MSE  12.5992 | val normalized MSE  13.6582, | val MAE   1.2130 | val MSE   7.2674


Epoch:  67%|██████▋   | 67/100 [58:38<29:07, 52.94s/epoch]

epoch 1 almost done
Epoch 066 | Learning rate 0.000016 | train normalized MSE  12.6538 | val normalized MSE  13.7019, | val MAE   1.2235 | val MSE   7.2952


Epoch:  68%|██████▊   | 68/100 [59:31<28:14, 52.96s/epoch]

epoch 1 almost done
Epoch 067 | Learning rate 0.000016 | train normalized MSE  12.5909 | val normalized MSE  13.5452, | val MAE   1.2074 | val MSE   7.2057


Epoch:  69%|██████▉   | 69/100 [1:00:24<27:18, 52.84s/epoch]

epoch 1 almost done
Epoch 068 | Learning rate 0.000016 | train normalized MSE  12.5571 | val normalized MSE  13.4475, | val MAE   1.2114 | val MSE   7.1529


Epoch:  70%|███████   | 70/100 [1:01:17<26:31, 53.05s/epoch]

epoch 1 almost done
Epoch 069 | Learning rate 0.000008 | train normalized MSE  12.6733 | val normalized MSE  13.4753, | val MAE   1.2045 | val MSE   7.1706


Epoch:  71%|███████   | 71/100 [1:02:10<25:38, 53.04s/epoch]

epoch 1 almost done
Epoch 070 | Learning rate 0.000008 | train normalized MSE  12.5097 | val normalized MSE  13.3972, | val MAE   1.2035 | val MSE   7.1282


Epoch:  72%|███████▏  | 72/100 [1:03:03<24:43, 52.99s/epoch]

epoch 1 almost done
Epoch 071 | Learning rate 0.000008 | train normalized MSE  12.5842 | val normalized MSE  13.4929, | val MAE   1.2114 | val MSE   7.1766


Epoch:  73%|███████▎  | 73/100 [1:03:57<23:55, 53.15s/epoch]

epoch 1 almost done
Epoch 072 | Learning rate 0.000008 | train normalized MSE  12.4635 | val normalized MSE  13.5475, | val MAE   1.2124 | val MSE   7.2083


Epoch:  74%|███████▍  | 74/100 [1:04:50<23:03, 53.22s/epoch]

epoch 1 almost done
Epoch 073 | Learning rate 0.000008 | train normalized MSE  12.5523 | val normalized MSE  13.3739, | val MAE   1.2026 | val MSE   7.1146


Epoch:  75%|███████▌  | 75/100 [1:05:43<22:11, 53.28s/epoch]

epoch 1 almost done
Epoch 074 | Learning rate 0.000008 | train normalized MSE  12.5564 | val normalized MSE  13.5129, | val MAE   1.2145 | val MSE   7.1882


Epoch:  76%|███████▌  | 76/100 [1:06:37<21:20, 53.34s/epoch]

epoch 1 almost done
Epoch 075 | Learning rate 0.000008 | train normalized MSE  12.5017 | val normalized MSE  13.5000, | val MAE   1.2041 | val MSE   7.1831


Epoch:  77%|███████▋  | 77/100 [1:07:30<20:25, 53.27s/epoch]

epoch 1 almost done
Epoch 076 | Learning rate 0.000008 | train normalized MSE  12.4407 | val normalized MSE  13.4788, | val MAE   1.2022 | val MSE   7.1696


Epoch:  78%|███████▊  | 78/100 [1:08:22<19:21, 52.79s/epoch]

epoch 1 almost done
Epoch 077 | Learning rate 0.000008 | train normalized MSE  12.5094 | val normalized MSE  13.4541, | val MAE   1.2129 | val MSE   7.1538


Epoch:  79%|███████▉  | 79/100 [1:09:14<18:23, 52.54s/epoch]

epoch 1 almost done
Epoch 078 | Learning rate 0.000008 | train normalized MSE  12.5160 | val normalized MSE  13.4376, | val MAE   1.2101 | val MSE   7.1480


Epoch:  80%|████████  | 80/100 [1:10:05<17:24, 52.23s/epoch]

epoch 1 almost done
Epoch 079 | Learning rate 0.000004 | train normalized MSE  12.4756 | val normalized MSE  13.5334, | val MAE   1.2087 | val MSE   7.1998


Epoch:  81%|████████  | 81/100 [1:10:58<16:34, 52.36s/epoch]

epoch 1 almost done
Epoch 080 | Learning rate 0.000004 | train normalized MSE  12.3296 | val normalized MSE  13.5298, | val MAE   1.2084 | val MSE   7.1970


Epoch:  82%|████████▏ | 82/100 [1:11:50<15:42, 52.34s/epoch]

epoch 1 almost done
Epoch 081 | Learning rate 0.000004 | train normalized MSE  12.3410 | val normalized MSE  13.4692, | val MAE   1.2019 | val MSE   7.1619


Epoch:  83%|████████▎ | 83/100 [1:12:42<14:49, 52.32s/epoch]

epoch 1 almost done
Epoch 082 | Learning rate 0.000004 | train normalized MSE  12.3747 | val normalized MSE  13.5397, | val MAE   1.2069 | val MSE   7.2045


Epoch:  84%|████████▍ | 84/100 [1:13:35<13:57, 52.36s/epoch]

epoch 1 almost done
Epoch 083 | Learning rate 0.000004 | train normalized MSE  12.4580 | val normalized MSE  13.5220, | val MAE   1.2063 | val MSE   7.1967


Epoch:  85%|████████▌ | 85/100 [1:14:28<13:07, 52.49s/epoch]

epoch 1 almost done
Epoch 084 | Learning rate 0.000002 | train normalized MSE  12.5080 | val normalized MSE  13.4886, | val MAE   1.2037 | val MSE   7.1778


Epoch:  86%|████████▌ | 86/100 [1:15:21<12:17, 52.69s/epoch]

epoch 1 almost done
Epoch 085 | Learning rate 0.000002 | train normalized MSE  12.4003 | val normalized MSE  13.5173, | val MAE   1.2066 | val MSE   7.1935


Epoch:  87%|████████▋ | 87/100 [1:16:14<11:27, 52.87s/epoch]

epoch 1 almost done
Epoch 086 | Learning rate 0.000002 | train normalized MSE  12.4944 | val normalized MSE  13.4839, | val MAE   1.2026 | val MSE   7.1738


Epoch:  88%|████████▊ | 88/100 [1:17:07<10:32, 52.73s/epoch]

epoch 1 almost done
Epoch 087 | Learning rate 0.000002 | train normalized MSE  12.5519 | val normalized MSE  13.4839, | val MAE   1.2049 | val MSE   7.1756


Epoch:  89%|████████▉ | 89/100 [1:17:59<09:38, 52.56s/epoch]

epoch 1 almost done
Epoch 088 | Learning rate 0.000002 | train normalized MSE  12.3706 | val normalized MSE  13.4918, | val MAE   1.2068 | val MSE   7.1807


Epoch:  90%|█████████ | 90/100 [1:18:52<08:46, 52.65s/epoch]

epoch 1 almost done
Epoch 089 | Learning rate 0.000001 | train normalized MSE  12.4852 | val normalized MSE  13.4685, | val MAE   1.2037 | val MSE   7.1646


Epoch:  91%|█████████ | 91/100 [1:19:43<07:51, 52.43s/epoch]

epoch 1 almost done
Epoch 090 | Learning rate 0.000001 | train normalized MSE  12.3814 | val normalized MSE  13.4711, | val MAE   1.2032 | val MSE   7.1681


Epoch:  92%|█████████▏| 92/100 [1:20:35<06:58, 52.27s/epoch]

epoch 1 almost done
Epoch 091 | Learning rate 0.000001 | train normalized MSE  12.4991 | val normalized MSE  13.4598, | val MAE   1.2030 | val MSE   7.1628


Epoch:  93%|█████████▎| 93/100 [1:21:28<06:06, 52.31s/epoch]

epoch 1 almost done
Epoch 092 | Learning rate 0.000001 | train normalized MSE  12.4043 | val normalized MSE  13.4616, | val MAE   1.2033 | val MSE   7.1632


Epoch:  94%|█████████▍| 94/100 [1:22:21<05:15, 52.50s/epoch]

epoch 1 almost done
Epoch 093 | Learning rate 0.000001 | train normalized MSE  12.3183 | val normalized MSE  13.4566, | val MAE   1.2019 | val MSE   7.1612


Epoch:  95%|█████████▌| 95/100 [1:23:14<04:23, 52.69s/epoch]

epoch 1 almost done
Epoch 094 | Learning rate 0.000000 | train normalized MSE  12.3894 | val normalized MSE  13.4524, | val MAE   1.2023 | val MSE   7.1583


Epoch:  96%|█████████▌| 96/100 [1:24:07<03:31, 52.84s/epoch]

epoch 1 almost done
Epoch 095 | Learning rate 0.000000 | train normalized MSE  12.4390 | val normalized MSE  13.4543, | val MAE   1.2032 | val MSE   7.1582


Epoch:  97%|█████████▋| 97/100 [1:25:00<02:38, 52.80s/epoch]

epoch 1 almost done
Epoch 096 | Learning rate 0.000000 | train normalized MSE  12.3619 | val normalized MSE  13.4489, | val MAE   1.2029 | val MSE   7.1558


Epoch:  98%|█████████▊| 98/100 [1:25:53<01:45, 52.90s/epoch]

epoch 1 almost done
Epoch 097 | Learning rate 0.000000 | train normalized MSE  12.4232 | val normalized MSE  13.4529, | val MAE   1.2029 | val MSE   7.1586


Epoch:  99%|█████████▉| 99/100 [1:26:46<00:52, 52.82s/epoch]

epoch 1 almost done
Epoch 098 | Learning rate 0.000000 | train normalized MSE  12.3710 | val normalized MSE  13.4496, | val MAE   1.2027 | val MSE   7.1560


Epoch: 100%|██████████| 100/100 [1:27:39<00:00, 52.59s/epoch]

epoch 1 almost done
Epoch 099 | Learning rate 0.000000 | train normalized MSE  12.3581 | val normalized MSE  13.4495, | val MAE   1.2025 | val MSE   7.1560





In [14]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
#model = LinearRegressionModel().to(device)
# model = MLP(50 * 50 * 6, 60 * 2).to(device)
model = AgentAwarePredictor().to(device)

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)
        
        # Reshape the prediction to (N, 60, 2)
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission.csv', index=True, float_format='%.8f')