In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os
import torch
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from tqdm import tqdm
import torch.nn.functional as F


In [2]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Apple GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("CUDA GPU")
else:
    device = torch.device('cpu')

Apple GPU


In [3]:
def getData(path):
    train_file = np.load(path+"/train.npz")
    train_data = train_file['data']
    test_file = np.load(path+"/test_input.npz")
    test_data = test_file['data']
    print(f"Training Data's shape is {train_data.shape} and Test Data's is {test_data.shape}")
    return train_data, test_data
trainData, testData = getData("./data/")

Training Data's shape is (10000, 50, 110, 6) and Test Data's is (2100, 50, 50, 6)


In [None]:
class WindowedNormalizedDataset(Dataset):
    def __init__(self, data, scale=10.0):
        self.data = data  
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx].copy()  

        presence = (scene[..., 0] != 0) | (scene[..., 1] != 0)  

        origin = scene[0, 49].copy()
        tx, ty, _, _, theta, _ = origin

        cos_theta = np.cos(-theta)
        sin_theta = np.sin(-theta)

        normalized_scene = np.zeros((50, 110, 7), dtype=np.float32)

        # --- Feature Normalization + Scaling ---
        x = scene[..., 0] - tx
        y = scene[..., 1] - ty
        x_n = x * cos_theta - y * sin_theta
        y_n = x * sin_theta + y * cos_theta
        normalized_scene[..., 0] = x_n / self.scale
        normalized_scene[..., 1] = y_n / self.scale

        vx = scene[..., 2]
        vy = scene[..., 3]
        vx_n = vx * cos_theta - vy * sin_theta
        vy_n = vx * sin_theta + vy * cos_theta
        normalized_scene[..., 2] = vx_n / self.scale
        normalized_scene[..., 3] = vy_n / self.scale

        heading = scene[..., 4]
        normalized_heading = heading - theta
        normalized_heading = (normalized_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 4] = normalized_heading

        normalized_scene[..., 5] = scene[..., 5]  
        normalized_scene[..., 6] = presence.astype(np.float32)

        missing_mask = np.expand_dims(~presence, -1)
        normalized_scene[..., :5] = np.where(missing_mask, 0, normalized_scene[..., :5])

        X = normalized_scene[:, :50, :]
        ego_future = normalized_scene[0, 50:]
        Y = np.zeros((60, 3), dtype=np.float32)
        Y[:, :2] = ego_future[:, :2]
        Y[:, 2] = ego_future[:, 6]

        return (
            torch.tensor(X, dtype=torch.float32),
            torch.tensor(Y, dtype=torch.float32),
            torch.tensor(origin, dtype=torch.float32)
        )


In [None]:
class WindowedNormalizedDataset(Dataset):
    def __init__(self, data, scale=10.0):
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx].copy()
        presence = (scene[..., 0] != 0) | (scene[..., 1] != 0)

        origin = scene[0, 49].copy()
        tx, ty, _, _, theta, _ = origin
        cos_theta = np.cos(-theta)
        sin_theta = np.sin(-theta)

        normalized_scene = np.zeros((50, 110, 13), dtype=np.float32)

        # Positions
        x = scene[..., 0] - tx
        y = scene[..., 1] - ty
        x_n = x * cos_theta - y * sin_theta
        y_n = x * sin_theta + y * cos_theta
        normalized_scene[..., 0] = x_n / self.scale
        normalized_scene[..., 1] = y_n / self.scale
        
        # Velocities
        vx = scene[..., 2]
        vy = scene[..., 3]
        vx_n = vx * cos_theta - vy * sin_theta
        vy_n = vx * sin_theta + vy * cos_theta
        normalized_scene[..., 2] = vx_n / self.scale
        normalized_scene[..., 3] = vy_n / self.scale
        
        # Heading
        heading = scene[..., 4]
        normalized_heading = heading - theta
        normalized_heading = (normalized_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 4] = normalized_heading
        
        # Agent Type & Presence
        normalized_scene[..., 5] = scene[..., 5]
        normalized_scene[..., 6] = presence.astype(np.float32)

        # --- New Feature 1: Acceleration Magnitude (Index 7) ---
        vx_norm = normalized_scene[..., 2]
        vy_norm = normalized_scene[..., 3]
        ax = np.zeros_like(vx_norm)
        ay = np.zeros_like(vy_norm)
        ax[:, 1:] = (vx_norm[:, 1:] - vx_norm[:, :-1]) * 10.0  # 10Hz data
        ay[:, 1:] = (vy_norm[:, 1:] - vy_norm[:, :-1]) * 10.0
        acc_mag = np.sqrt(ax**2 + ay**2)
        normalized_scene[..., 7] = acc_mag

        # --- New Feature 2: Relative Heading to Ego (Index 8) ---
        ego_heading = normalized_scene[0, :, 4]  
        agent_heading = normalized_scene[..., 4] 
        relative_heading = agent_heading - ego_heading[np.newaxis, :]
        relative_heading = (relative_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 8] = relative_heading

        # --- New Feature 3: Interaction Criticality (Index 9) ---
        dist_to_ego = np.linalg.norm(scene[..., :2] - scene[0, :, :2], axis=-1)
        normalized_scene[..., 10] = dist_to_ego / self.scale
        
        # Compute time-to-collision proxy
        rel_vel = np.linalg.norm(scene[..., 2:4] - scene[0, :, 2:4], axis=-1)
        ttc_proxy = np.zeros_like(dist_to_ego)
        valid_mask = (dist_to_ego > 0.1) & (rel_vel > 0.1)
        ttc_proxy[valid_mask] = dist_to_ego[valid_mask] / rel_vel[valid_mask]
        normalized_scene[..., 9] = np.clip(ttc_proxy, 0, 10)  # Cap at 10s

        # --- Existing Features ---
        normalized_scene[..., 11] = np.sqrt(vx**2 + vy**2) / self.scale  
        normalized_scene[..., 12] = (np.arange(50)[:, None] == 0).astype(np.float32)  
        
        # Mask invalid timesteps
        missing_mask = np.expand_dims(~presence, -1)
        normalized_scene[..., :13] = np.where(missing_mask, 0, normalized_scene[..., :14])

        # Inputs: first 50 timesteps
        X = normalized_scene[:, :50, :]  

        # Target: ego agent's future positions and presence
        ego_future = normalized_scene[0, 50:]
        Y = np.zeros((60, 3), dtype=np.float32)
        Y[:, :2] = ego_future[:, :2]
        Y[:, 2] = ego_future[:, 6]  

        return (
            torch.tensor(X, dtype=torch.float32),
            torch.tensor(Y, dtype=torch.float32),
            torch.tensor(origin, dtype=torch.float32)
        )

In [None]:
class WindowedNormalizedTestDataset(Dataset):
    def __init__(self, data, scale):
        self.data = data  
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx].copy()  
        
        # Create presence flag: 1 if agent present, 0 if missing (both x and y are zero)
        presence = (scene[..., 0] != 0) | (scene[..., 1] != 0)  
        
        # Get reference values from ego vehicle (agent 0) at t=49
        origin = scene[0, 49].copy()  
        tx, ty, _, _, theta, _ = origin  
        
        # Compute rotation matrix for -theta (align ego heading to x-axis)
        cos_theta = np.cos(-theta)
        sin_theta = np.sin(-theta)
        
        # Initialize normalized scene with 7 features
        normalized_scene = np.zeros((50, 50, 7), dtype=np.float32)
        
        # --- Feature Normalization ---
        x = scene[..., 0] - tx
        y = scene[..., 1] - ty
        normalized_scene[..., 0] = x * cos_theta - y * sin_theta  
        normalized_scene[..., 1] = x * sin_theta + y * cos_theta  
        normalized_scene[..., 0] = normalized_scene[..., 0]/ self.scale
        normalized_scene[..., 1] = normalized_scene[..., 1]/ self.scale
        
        # Velocities
        vx = scene[..., 2]
        vy = scene[..., 3]
        normalized_scene[..., 2] = vx * cos_theta - vy * sin_theta  
        normalized_scene[..., 3] = vx * sin_theta + vy * cos_theta  
        normalized_scene[..., 2] = normalized_scene[..., 2] / self.scale
        normalized_scene[..., 3] = normalized_scene[..., 3] / self.scale
        
        # Heading (normalize relative to ego and wrap to [-π, π])
        heading = scene[..., 4]
        normalized_heading = heading - theta
        normalized_heading = (normalized_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 4] = normalized_heading
        
        # Agent type (unchanged) and presence flag
        normalized_scene[..., 5] = scene[..., 5]  
        normalized_scene[..., 6] = presence.astype(np.float32)  
        
        # Reset features for missing agents (where presence=0)
        missing_mask = np.expand_dims(~presence, -1)  
        normalized_scene[..., :5] = np.where(missing_mask, 0, normalized_scene[..., :5])
        
        # Split into input (0-49) and target (50-109)
        X = normalized_scene[:, :50, :]  
        
        return (
            torch.tensor(X, dtype=torch.float32),
            torch.tensor(origin, dtype=torch.float32)
        )

def denormalize_ego(predicted, origin):
    """
    Convert normalized ego predictions back to global coordinates
    predicted: (..., 2) tensor of normalized [x, y] positions
    origin: (6,) tensor of ego's reference state at t=49
    """
    tx, ty, _, _, theta, _ = origin
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    
    x_rot = predicted[..., 0] * cos_theta - predicted[..., 1] * sin_theta
    y_rot = predicted[..., 0] * sin_theta + predicted[..., 1] * cos_theta
    
    x_global = x_rot + tx
    y_global = y_rot + ty
    
    return torch.stack([x_global, y_global], dim=-1)

In [None]:
class WindowedNormalizedTestDataset(Dataset):
    def __init__(self, data, scale=10.0):
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx].copy()
        presence = (scene[..., 0] != 0) | (scene[..., 1] != 0)

        origin = scene[0, 49].copy()
        tx, ty, _, _, theta, _ = origin
        cos_theta = np.cos(-theta)
        sin_theta = np.sin(-theta)

        normalized_scene = np.zeros((50, 50, 13), dtype=np.float32)

        # Positions
        x = scene[..., 0] - tx
        y = scene[..., 1] - ty
        x_n = x * cos_theta - y * sin_theta
        y_n = x * sin_theta + y * cos_theta
        normalized_scene[..., 0] = x_n / self.scale
        normalized_scene[..., 1] = y_n / self.scale
        
        # Velocities
        vx = scene[..., 2]
        vy = scene[..., 3]
        vx_n = vx * cos_theta - vy * sin_theta
        vy_n = vx * sin_theta + vy * cos_theta
        normalized_scene[..., 2] = vx_n / self.scale
        normalized_scene[..., 3] = vy_n / self.scale
        
        # Heading
        heading = scene[..., 4]
        normalized_heading = heading - theta
        normalized_heading = (normalized_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 4] = normalized_heading
        
        # Agent Type & Presence
        normalized_scene[..., 5] = scene[..., 5]
        normalized_scene[..., 6] = presence.astype(np.float32)

        # --- New Feature 1: Acceleration Magnitude (Index 7) ---
        vx_norm = normalized_scene[..., 2]
        vy_norm = normalized_scene[..., 3]
        ax = np.zeros_like(vx_norm)
        ay = np.zeros_like(vy_norm)
        ax[:, 1:] = (vx_norm[:, 1:] - vx_norm[:, :-1]) * 10.0  # 10Hz data
        ay[:, 1:] = (vy_norm[:, 1:] - vy_norm[:, :-1]) * 10.0
        acc_mag = np.sqrt(ax**2 + ay**2)
        normalized_scene[..., 7] = acc_mag

        # --- New Feature 2: Relative Heading to Ego (Index 8) ---
        ego_heading = normalized_scene[0, :, 4]  
        agent_heading = normalized_scene[..., 4]  
        relative_heading = agent_heading - ego_heading[np.newaxis, :]
        relative_heading = (relative_heading + np.pi) % (2 * np.pi) - np.pi
        normalized_scene[..., 8] = relative_heading

        # --- New Feature 3: Interaction Criticality (Index 9) ---
        dist_to_ego = np.linalg.norm(scene[..., :2] - scene[0, :, :2], axis=-1)
        normalized_scene[..., 10] = dist_to_ego / self.scale
        
        # Compute time-to-collision proxy
        rel_vel = np.linalg.norm(scene[..., 2:4] - scene[0, :, 2:4], axis=-1)
        ttc_proxy = np.zeros_like(dist_to_ego)
        valid_mask = (dist_to_ego > 0.1) & (rel_vel > 0.1)
        ttc_proxy[valid_mask] = dist_to_ego[valid_mask] / rel_vel[valid_mask]
        normalized_scene[..., 9] = np.clip(ttc_proxy, 0, 10)

        # --- Existing Features ---
        normalized_scene[..., 11] = np.sqrt(vx**2 + vy**2) / self.scale  
        normalized_scene[..., 12] = (np.arange(50)[:, None] == 0).astype(np.float32)  

        # Mask invalid timesteps
        missing_mask = np.expand_dims(~presence, -1)
        normalized_scene[..., :13] = np.where(missing_mask, 0, normalized_scene[..., :14])

        # Inputs: first 50 timesteps
        X = normalized_scene[:, :50, :]  

        return (
            torch.tensor(X, dtype=torch.float32),
            torch.tensor(origin, dtype=torch.float32)
        )

In [None]:
def denormalize_ego_batch(predicted, origin, scale=10.0):
    """
    Convert batch of normalized (and scaled) ego predictions back to global coordinates.

    predicted: (B, ..., 2) tensor of normalized [x, y] positions
    origin: (B, 6) tensor of ego's reference state at t=49
    Returns:
        (B, ..., 2) tensor of global [x, y] positions
    """
    tx = origin[:, 0]  
    ty = origin[:, 1]  
    theta = origin[:, 4]  

    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)

    # Expand for broadcasting
    while len(cos_theta.shape) < len(predicted.shape) - 1:
        cos_theta = cos_theta.unsqueeze(1)
        sin_theta = sin_theta.unsqueeze(1)
        tx = tx.unsqueeze(1)
        ty = ty.unsqueeze(1)

    x = predicted[..., 0] * scale
    y = predicted[..., 1] * scale

    x_rot = x * cos_theta - y * sin_theta
    y_rot = x * sin_theta + y * cos_theta

    x_global = x_rot + tx
    y_global = y_rot + ty

    return torch.stack([x_global, y_global], dim=-1)


In [22]:
trainData[1, 0, 49, :], trainData[1, 0, 50, :]

(array([ 3.16906469e+03,  1.68248551e+03,  5.46145515e+00, -5.85380650e+00,
        -8.22467566e-01,  0.00000000e+00]),
 array([ 3.16959927e+03,  1.68191109e+03,  5.35655550e+00, -5.75120145e+00,
        -8.22600550e-01,  0.00000000e+00]))

In [23]:
data = WindowedNormalizedDataset(trainData)
X, Y, origin = data.__getitem__(1)
X[0, 49, :], Y[0, :], origin.shape

(tensor([0.0000, 0.0000, 0.8006, 0.0019, 0.0000, 0.0000, 1.0000, 0.0572, 0.0000,
         0.0000, 0.0000, 0.8006, 1.0000]),
 tensor([7.8468e-02, 9.1270e-05, 1.0000e+00]),
 torch.Size([6]))

In [24]:
# x, y = denormalize_ego(Y[0, :2], origin)
# x, y

In [None]:
import torch
import torch.nn as nn

class TrajectoryTransformer(nn.Module):
    def __init__(self, input_dim=650, model_dim=256, num_heads=8, num_layers=6, dropout=0.1, pred_len=60, num_agents=50):
        super().__init__()
        self.model_dim = model_dim
        self.pred_len = pred_len
        self.num_agents = num_agents
        
        self.trajectory_encoder = nn.Sequential(
            nn.Linear(input_dim, model_dim),
            nn.LayerNorm(model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, model_dim),
            nn.LayerNorm(model_dim),
            nn.ReLU(),
            nn.Linear(model_dim, model_dim),
            nn.LayerNorm(model_dim),
            nn.ReLU()
        )
        
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=model_dim, 
                nhead=num_heads, 
                dropout=dropout, 
                batch_first=True
            ),
            num_layers=num_layers
        )
        
        self.output_fcpre = nn.Linear(model_dim, model_dim)  
        self.output_fc = nn.Linear(model_dim, pred_len * 2)  
    
    def forward(self, x):
        B, N, T, Ft = x.shape
        x = x.view(B, N, T * Ft)  
        agent_tokens = self.trajectory_encoder(x)  
        encoded_tokens = self.transformer_encoder(agent_tokens)  
        ego_token = encoded_tokens[:, 0, :]  
        output = F.relu(self.output_fcpre(ego_token))  
        output = self.output_fc(output)  
        output = output.view(B, self.pred_len, 2) 
        
        return output

model = TrajectoryTransformer()
x = torch.randn(1, 50, 50, 13)  
out = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")  

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

Input shape: torch.Size([1, 50, 50, 13])
Output shape: torch.Size([1, 60, 2])

Model parameters: 8,286,840


In [26]:
model = TrajectoryTransformer().to(device=device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 8286840


In [None]:
np.random.seed(42)
num_samples = trainData.shape[0]
indices = np.random.permutation(num_samples)
split_index = int(0.9 * num_samples)
train_idx, val_idx = indices[:split_index], indices[split_index:]

train_data = trainData[train_idx]
val_data = trainData[val_idx]

print("Train shape:", train_data.shape)
print("Validation shape:", val_data.shape)

Train shape: (9000, 50, 110, 6)
Validation shape: (1000, 50, 110, 6)


In [28]:
trainTensor = WindowedNormalizedDataset(train_data)
testTensor = WindowedNormalizedDataset(val_data)
train_dataloader = DataLoader(trainTensor, batch_size=128, shuffle=True)
val_dataloader = DataLoader(testTensor, batch_size=128, shuffle=False)

In [29]:
import torch
import torch.nn as nn

class WeightedMSELoss(nn.Module):
    def __init__(self, weights):
        super(WeightedMSELoss, self).__init__()
        self.weights = weights

    def forward(self, input, target):
        loss = (input - target) ** 2
        weighted_loss = loss * self.weights
        return weighted_loss.mean()

feature_weights = torch.tensor([1.0, 1.0, 0.1, 0.1, 0.01, 0.01], device=device)


In [None]:
# export PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
#  train MSE 0.0130673476 | train val MSE 0.0830958160 | val MAE 4.7582098953 | val MSE 8.3095822744

torch.cuda.empty_cache()

best_model = torch.load("./models/final/best_model.pt")
model.load_state_dict(best_model)

epochs = 1000
lossFn = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25)
best_val_loss = 0.0830958160 # float('inf')
best_train_loss = 0.0130673476 # float('inf')
position_scale = 1.0
velocity_scale = 1.0
all_losses = {
    'training_mse_loss':[],
    'validation_mse_loss':[],
    'true_mse':[],
    'true_mae':[]
}

for each_epoch in range(epochs):
    model.train()
    runningLoss = 0.0
    loop = tqdm(train_dataloader, desc=f"Epoch [{each_epoch+1}/{epochs}]")
    
    for batchX, batchY, origin in loop:
        batchX = batchX.to(device)
        batchY = batchY.to(device)
        origin = origin.to(device)

        
        pred = model(batchX)  
        
        loss = lossFn(pred[..., :2], batchY[..., :2]).to(device)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        runningLoss += loss.item()        
    
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    
    with torch.no_grad():
        for batchX, batchY, origin in loop:
            batchX = batchX.to(device)
            batchY = batchY.to(device)
            origin = origin.to(device)

            
            pred = model(batchX)  
            
            loss = lossFn(pred[..., :2], batchY[..., :2]).to(device)
            unnorm_pred = denormalize_ego_batch(pred, origin)
            unnorm_true = denormalize_ego_batch(batchY, origin)


            
            val_loss += loss.item()
            val_mae += nn.L1Loss()(unnorm_pred[..., :2], unnorm_true[..., :2]).item()
            val_mse += nn.MSELoss()(unnorm_pred[..., :2], unnorm_true[..., :2]).item()
    # break
    train_loss = runningLoss/len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    
    all_losses["training_mse_loss"].append(train_loss)
    all_losses["validation_mse_loss"].append(val_loss)
    all_losses["true_mse"].append(val_mse)
    all_losses["true_mae"].append(val_mae)
    
    loop.write(f" train MSE {train_loss:.10f} | train val MSE {val_loss:.10f} | val MAE {val_mae:.10f} | val MSE {val_mse:.10f}")
    scheduler.step()
    
    if train_loss < best_train_loss and val_loss < best_val_loss: 
        best_val_loss = val_loss
        best_train_loss = train_loss
        no_improvement = 0
        torch.save(model.state_dict(), "./models/final/best_model.pt")
        loop.write(f" model Saved")
    torch.cuda.empty_cache()


Epoch [1/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.73it/s]


 train MSE 0.0499150728 | train val MSE 0.2085464783 | val MAE 7.5795354098 | val MSE 20.8546437919


Epoch [2/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.72it/s]


 train MSE 0.0311552232 | train val MSE 0.2784796467 | val MAE 8.8686763421 | val MSE 27.8479639292


Epoch [3/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.73it/s]


 train MSE 0.0319098578 | train val MSE 0.1992247645 | val MAE 7.2741061896 | val MSE 19.9224709123


Epoch [4/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.84it/s]


 train MSE 0.0321231385 | train val MSE 0.2626639118 | val MAE 8.5264174789 | val MSE 26.2663877606


Epoch [5/1000]: 100%|██████████| 71/71 [00:24<00:00,  2.85it/s]


 train MSE 0.0379377645 | train val MSE 0.2547864101 | val MAE 8.2255081087 | val MSE 25.4786435366


Epoch [6/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.77it/s]


 train MSE 0.0287870184 | train val MSE 0.1966206619 | val MAE 7.3442446142 | val MSE 19.6620670855


Epoch [7/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.67it/s]


 train MSE 0.0321630158 | train val MSE 0.2204909322 | val MAE 7.7978917137 | val MSE 22.0490975827


Epoch [8/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.63it/s]


 train MSE 0.0283659850 | train val MSE 0.1968259843 | val MAE 7.3824988157 | val MSE 19.6826030314


Epoch [9/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.51it/s]


 train MSE 0.0259744003 | train val MSE 0.2080170959 | val MAE 7.4894277677 | val MSE 20.8017103225


Epoch [10/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.54it/s]


 train MSE 0.0264493448 | train val MSE 0.2232943475 | val MAE 8.0467425510 | val MSE 22.3294377923


Epoch [11/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.62it/s]


 train MSE 0.0351753289 | train val MSE 0.1481648892 | val MAE 6.3130003065 | val MSE 14.8164912760


Epoch [12/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.65it/s]


 train MSE 0.0271189914 | train val MSE 0.1685080252 | val MAE 6.9895960391 | val MSE 16.8508049399


Epoch [13/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0237936873 | train val MSE 0.2127864058 | val MAE 7.6560075879 | val MSE 21.2786399424


Epoch [14/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.72it/s]


 train MSE 0.0226731807 | train val MSE 0.2004004067 | val MAE 7.8564288393 | val MSE 20.0400392413


Epoch [15/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.70it/s]


 train MSE 0.0204320759 | train val MSE 0.1530469658 | val MAE 6.2972009182 | val MSE 15.3046978116


Epoch [16/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0228313403 | train val MSE 0.5201866589 | val MAE 12.2006958425 | val MSE 52.0186747909


Epoch [17/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.73it/s]


 train MSE 0.0222364129 | train val MSE 0.1053062193 | val MAE 5.3975016996 | val MSE 10.5306211486


Epoch [18/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.77it/s]


 train MSE 0.0206663434 | train val MSE 0.1774998656 | val MAE 6.8616340607 | val MSE 17.7499851137


Epoch [19/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0206794469 | train val MSE 0.1282569980 | val MAE 5.7689941525 | val MSE 12.8257020190


Epoch [20/1000]: 100%|██████████| 71/71 [00:24<00:00,  2.88it/s]


 train MSE 0.0232882001 | train val MSE 0.1404173720 | val MAE 6.0200119540 | val MSE 14.0417369977


Epoch [21/1000]: 100%|██████████| 71/71 [00:24<00:00,  2.88it/s]


 train MSE 0.0134837750 | train val MSE 0.0715988021 | val MAE 4.3877554908 | val MSE 7.1598815620


Epoch [22/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.79it/s]


 train MSE 0.0107979802 | train val MSE 0.0662451473 | val MAE 4.1897197515 | val MSE 6.6245144457
 model Saved


Epoch [23/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.69it/s]


 train MSE 0.0102114081 | train val MSE 0.0682543400 | val MAE 4.3814989775 | val MSE 6.8254322819


Epoch [24/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0096481100 | train val MSE 0.0640045556 | val MAE 4.2222650312 | val MSE 6.4004565626
 model Saved


Epoch [25/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.48it/s]


 train MSE 0.0091968604 | train val MSE 0.0584511136 | val MAE 4.0444317572 | val MSE 5.8451104797
 model Saved


Epoch [26/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.47it/s]


 train MSE 0.0090340535 | train val MSE 0.0607732341 | val MAE 4.1522447132 | val MSE 6.0773239397


Epoch [27/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.49it/s]


 train MSE 0.0087572588 | train val MSE 0.0623244427 | val MAE 4.2121823132 | val MSE 6.2324455343


Epoch [28/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.47it/s]


 train MSE 0.0086746623 | train val MSE 0.0557554971 | val MAE 3.9709502161 | val MSE 5.5755517855
 model Saved


Epoch [29/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.57it/s]


 train MSE 0.0085592406 | train val MSE 0.0545682913 | val MAE 3.9547514543 | val MSE 5.4568292648
 model Saved


Epoch [30/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.65it/s]


 train MSE 0.0083907081 | train val MSE 0.0535582034 | val MAE 3.8835042194 | val MSE 5.3558222465
 model Saved


Epoch [31/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0081994923 | train val MSE 0.0516128750 | val MAE 3.7734197341 | val MSE 5.1612875573
 model Saved


Epoch [32/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0080672345 | train val MSE 0.0495765212 | val MAE 3.6975575462 | val MSE 4.9576517344
 model Saved


Epoch [33/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0081818507 | train val MSE 0.0559888261 | val MAE 4.0095484108 | val MSE 5.5988819189


Epoch [34/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.69it/s]


 train MSE 0.0080354988 | train val MSE 0.0492692125 | val MAE 3.6695245057 | val MSE 4.9269230627
 model Saved


Epoch [35/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0076277808 | train val MSE 0.0518086400 | val MAE 3.7889718600 | val MSE 5.1808620207


Epoch [36/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.70it/s]


 train MSE 0.0077144813 | train val MSE 0.0529370919 | val MAE 3.8534200527 | val MSE 5.2937107980


Epoch [37/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0076629672 | train val MSE 0.0611232058 | val MAE 4.2531394958 | val MSE 6.1123246439


Epoch [38/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.62it/s]


 train MSE 0.0076639126 | train val MSE 0.0489246879 | val MAE 3.6587978266 | val MSE 4.8924697228
 model Saved


Epoch [39/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.62it/s]


 train MSE 0.0076143392 | train val MSE 0.0445210207 | val MAE 3.5023617223 | val MSE 4.4521030411
 model Saved


Epoch [40/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.60it/s]


 train MSE 0.0072894553 | train val MSE 0.0453951674 | val MAE 3.5845779516 | val MSE 4.5395171642


Epoch [41/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.58it/s]


 train MSE 0.0068741515 | train val MSE 0.0407064095 | val MAE 3.3677727804 | val MSE 4.0706410669
 model Saved


Epoch [42/1000]: 100%|██████████| 71/71 [00:29<00:00,  2.44it/s]


 train MSE 0.0065517337 | train val MSE 0.0408393057 | val MAE 3.3855041675 | val MSE 4.0839314237


Epoch [43/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.54it/s]


 train MSE 0.0065581365 | train val MSE 0.0404655494 | val MAE 3.3483489491 | val MSE 4.0465547070
 model Saved


Epoch [44/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.52it/s]


 train MSE 0.0064902451 | train val MSE 0.0421037277 | val MAE 3.4648433588 | val MSE 4.2103737593


Epoch [45/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0064049064 | train val MSE 0.0406975549 | val MAE 3.4100480638 | val MSE 4.0697568469


Epoch [46/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0064093540 | train val MSE 0.0405009829 | val MAE 3.3771473244 | val MSE 4.0500980616


Epoch [47/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0063795413 | train val MSE 0.0393589267 | val MAE 3.3542716987 | val MSE 3.9358947314
 model Saved


Epoch [48/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.63it/s]


 train MSE 0.0062987720 | train val MSE 0.0404428916 | val MAE 3.4313536808 | val MSE 4.0442901812


Epoch [49/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.69it/s]


 train MSE 0.0063023077 | train val MSE 0.0391924966 | val MAE 3.3356545679 | val MSE 3.9192499183
 model Saved


Epoch [50/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.70it/s]


 train MSE 0.0062884005 | train val MSE 0.0377030847 | val MAE 3.2678051479 | val MSE 3.7703085206
 model Saved


Epoch [51/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0062983623 | train val MSE 0.0403422858 | val MAE 3.4208843596 | val MSE 4.0342295356


Epoch [52/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0061263559 | train val MSE 0.0359261111 | val MAE 3.1717117317 | val MSE 3.5926105734
 model Saved


Epoch [53/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.72it/s]


 train MSE 0.0061221950 | train val MSE 0.0407784468 | val MAE 3.4696651809 | val MSE 4.0778456219


Epoch [54/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.67it/s]


 train MSE 0.0060672105 | train val MSE 0.0384974492 | val MAE 3.3345988132 | val MSE 3.8497450761


Epoch [55/1000]: 100%|██████████| 71/71 [00:29<00:00,  2.43it/s]


 train MSE 0.0060371836 | train val MSE 0.0371077524 | val MAE 3.2765251137 | val MSE 3.7107767537


Epoch [56/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.47it/s]


 train MSE 0.0060697496 | train val MSE 0.0366604480 | val MAE 3.2636108212 | val MSE 3.6660452858


Epoch [57/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.51it/s]


 train MSE 0.0060812640 | train val MSE 0.0382514102 | val MAE 3.3426315486 | val MSE 3.8251423948


Epoch [58/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.55it/s]


 train MSE 0.0060052038 | train val MSE 0.0378805633 | val MAE 3.3204964325 | val MSE 3.7880559377


Epoch [59/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0059590684 | train val MSE 0.0382351881 | val MAE 3.3750612140 | val MSE 3.8235196099


Epoch [60/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0060183033 | train val MSE 0.0366433789 | val MAE 3.2838323228 | val MSE 3.6643368416


Epoch [61/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.73it/s]


 train MSE 0.0058714986 | train val MSE 0.0357607437 | val MAE 3.2280724756 | val MSE 3.5760742743
 model Saved


Epoch [62/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.79it/s]


 train MSE 0.0058202881 | train val MSE 0.0367112177 | val MAE 3.2809953429 | val MSE 3.6711219177


Epoch [63/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.84it/s]


 train MSE 0.0058240575 | train val MSE 0.0363525362 | val MAE 3.2646885552 | val MSE 3.6352534816


Epoch [64/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.82it/s]


 train MSE 0.0057537971 | train val MSE 0.0349365822 | val MAE 3.1797295548 | val MSE 3.4936563559
 model Saved


Epoch [65/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0057422791 | train val MSE 0.0348023170 | val MAE 3.1795534529 | val MSE 3.4802315440
 model Saved


Epoch [66/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.80it/s]


 train MSE 0.0057050726 | train val MSE 0.0371449379 | val MAE 3.3252222985 | val MSE 3.7144935988


Epoch [67/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.77it/s]


 train MSE 0.0056541791 | train val MSE 0.0361118315 | val MAE 3.2527700514 | val MSE 3.6111843269


Epoch [68/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0057705613 | train val MSE 0.0348773858 | val MAE 3.1828266084 | val MSE 3.4877390042


Epoch [69/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.73it/s]


 train MSE 0.0057156650 | train val MSE 0.0359220522 | val MAE 3.2443377152 | val MSE 3.5922064912


Epoch [70/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0056513002 | train val MSE 0.0352546585 | val MAE 3.2187677100 | val MSE 3.5254670233


Epoch [71/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0056923823 | train val MSE 0.0332264274 | val MAE 3.0864981823 | val MSE 3.3226445056
 model Saved


Epoch [72/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0056887320 | train val MSE 0.0350831668 | val MAE 3.2066097520 | val MSE 3.5083154123


Epoch [73/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0056652790 | train val MSE 0.0353731815 | val MAE 3.2255336568 | val MSE 3.5373199135


Epoch [74/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.58it/s]


 train MSE 0.0057131986 | train val MSE 0.0345736145 | val MAE 3.1811678782 | val MSE 3.4573620185


Epoch [75/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0056930810 | train val MSE 0.0350116632 | val MAE 3.2094091251 | val MSE 3.5011668988


Epoch [76/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0056552439 | train val MSE 0.0345969949 | val MAE 3.1842822284 | val MSE 3.4597005416


Epoch [77/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0055927766 | train val MSE 0.0339075167 | val MAE 3.1403625906 | val MSE 3.3907507248


Epoch [78/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0056290269 | train val MSE 0.0345782216 | val MAE 3.1928644218 | val MSE 3.4578226060


Epoch [79/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.76it/s]


 train MSE 0.0056444469 | train val MSE 0.0351215554 | val MAE 3.2149368227 | val MSE 3.5121562909


Epoch [80/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.81it/s]


 train MSE 0.0056505493 | train val MSE 0.0336701788 | val MAE 3.1354993694 | val MSE 3.3670193143


Epoch [81/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.75it/s]


 train MSE 0.0057196610 | train val MSE 0.0341555955 | val MAE 3.1676599123 | val MSE 3.4155595340


Epoch [82/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.73it/s]


 train MSE 0.0056178983 | train val MSE 0.0350721831 | val MAE 3.2234403044 | val MSE 3.5072194524


Epoch [83/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.72it/s]


 train MSE 0.0055173596 | train val MSE 0.0350449097 | val MAE 3.2246289998 | val MSE 3.5044926964


Epoch [84/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.77it/s]


 train MSE 0.0055155661 | train val MSE 0.0343794118 | val MAE 3.1790350266 | val MSE 3.4379426036


Epoch [85/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.67it/s]


 train MSE 0.0055091609 | train val MSE 0.0345587326 | val MAE 3.1923380345 | val MSE 3.4558746405


Epoch [86/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.70it/s]


 train MSE 0.0055547809 | train val MSE 0.0344179103 | val MAE 3.1826935261 | val MSE 3.4417921081


Epoch [87/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0055485645 | train val MSE 0.0347556378 | val MAE 3.2053226344 | val MSE 3.4755651932


Epoch [88/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.63it/s]


 train MSE 0.0056063356 | train val MSE 0.0347349923 | val MAE 3.1990989596 | val MSE 3.4734997544


Epoch [89/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.46it/s]


 train MSE 0.0055644850 | train val MSE 0.0347066151 | val MAE 3.2032412291 | val MSE 3.4706627298


Epoch [90/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.48it/s]


 train MSE 0.0055754782 | train val MSE 0.0338923973 | val MAE 3.1565449871 | val MSE 3.3892416451


Epoch [91/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.52it/s]


 train MSE 0.0055872221 | train val MSE 0.0343432598 | val MAE 3.1786103584 | val MSE 3.4343269933


Epoch [92/1000]: 100%|██████████| 71/71 [00:28<00:00,  2.51it/s]


 train MSE 0.0055488601 | train val MSE 0.0348140547 | val MAE 3.2076223940 | val MSE 3.4814042840


Epoch [93/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0056007847 | train val MSE 0.0344219369 | val MAE 3.1869321242 | val MSE 3.4421947133


Epoch [94/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.65it/s]


 train MSE 0.0055687946 | train val MSE 0.0343721806 | val MAE 3.1885647215 | val MSE 3.4372180067


Epoch [95/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.75it/s]


 train MSE 0.0055766211 | train val MSE 0.0344083933 | val MAE 3.1899268255 | val MSE 3.4408396408


Epoch [96/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0055538434 | train val MSE 0.0333857352 | val MAE 3.1215149984 | val MSE 3.3385745734


Epoch [97/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.77it/s]


 train MSE 0.0055185818 | train val MSE 0.0341104180 | val MAE 3.1650780141 | val MSE 3.4110419322


Epoch [98/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.80it/s]


 train MSE 0.0055013598 | train val MSE 0.0344961378 | val MAE 3.1918270811 | val MSE 3.4496126268


Epoch [99/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0055566135 | train val MSE 0.0342742519 | val MAE 3.1825122423 | val MSE 3.4274275955


Epoch [100/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.72it/s]


 train MSE 0.0055247376 | train val MSE 0.0346040342 | val MAE 3.1982904524 | val MSE 3.4604040626


Epoch [101/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.82it/s]


 train MSE 0.0055407872 | train val MSE 0.0336217399 | val MAE 3.1422905661 | val MSE 3.3621756770


Epoch [102/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.67it/s]


 train MSE 0.0054873598 | train val MSE 0.0341218498 | val MAE 3.1602436006 | val MSE 3.4121856224


Epoch [103/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0055850397 | train val MSE 0.0342311621 | val MAE 3.1803673990 | val MSE 3.4231150597


Epoch [104/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0054906128 | train val MSE 0.0340910267 | val MAE 3.1724495329 | val MSE 3.4091035631


Epoch [105/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.60it/s]


 train MSE 0.0055009535 | train val MSE 0.0337287530 | val MAE 3.1504543647 | val MSE 3.3728759978


Epoch [106/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.63it/s]


 train MSE 0.0055849169 | train val MSE 0.0337669071 | val MAE 3.1461553052 | val MSE 3.3766911402


Epoch [107/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.63it/s]


 train MSE 0.0056056242 | train val MSE 0.0336879973 | val MAE 3.1416700222 | val MSE 3.3688021991


Epoch [108/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.55it/s]


 train MSE 0.0055869456 | train val MSE 0.0339136356 | val MAE 3.1600344926 | val MSE 3.3913653642


Epoch [109/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.55it/s]


 train MSE 0.0055152936 | train val MSE 0.0335157636 | val MAE 3.1378103010 | val MSE 3.3515762743


Epoch [110/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0054965863 | train val MSE 0.0339461717 | val MAE 3.1632740796 | val MSE 3.3946172800


Epoch [111/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0055687800 | train val MSE 0.0338007931 | val MAE 3.1524358504 | val MSE 3.3800800350


Epoch [112/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0055267752 | train val MSE 0.0339765864 | val MAE 3.1644526273 | val MSE 3.3976596557


Epoch [113/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.75it/s]


 train MSE 0.0055526871 | train val MSE 0.0338395276 | val MAE 3.1565872692 | val MSE 3.3839541450


Epoch [114/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.68it/s]


 train MSE 0.0054646755 | train val MSE 0.0337240922 | val MAE 3.1486219466 | val MSE 3.3724094629


Epoch [115/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.75it/s]


 train MSE 0.0055300910 | train val MSE 0.0348503705 | val MAE 3.1906350888 | val MSE 3.4850366004


Epoch [116/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.79it/s]


 train MSE 0.0055345767 | train val MSE 0.0339068895 | val MAE 3.1557714306 | val MSE 3.3906900138


Epoch [117/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0055018857 | train val MSE 0.0339874201 | val MAE 3.1657439098 | val MSE 3.3987418339


Epoch [118/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0055613803 | train val MSE 0.0338442372 | val MAE 3.1604656838 | val MSE 3.3844262976


Epoch [119/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.79it/s]


 train MSE 0.0055247911 | train val MSE 0.0337233819 | val MAE 3.1509701572 | val MSE 3.3723400328


Epoch [120/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.82it/s]


 train MSE 0.0055371978 | train val MSE 0.0333723575 | val MAE 3.1319328286 | val MSE 3.3372361511


Epoch [121/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.80it/s]


 train MSE 0.0055350082 | train val MSE 0.0339026868 | val MAE 3.1604007147 | val MSE 3.3902699426


Epoch [122/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.62it/s]


 train MSE 0.0055373858 | train val MSE 0.0336689116 | val MAE 3.1485008560 | val MSE 3.3668921404


Epoch [123/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.66it/s]


 train MSE 0.0055079571 | train val MSE 0.0338163796 | val MAE 3.1575809456 | val MSE 3.3816382550


Epoch [124/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.65it/s]


 train MSE 0.0054710306 | train val MSE 0.0338666409 | val MAE 3.1591316052 | val MSE 3.3866644520


Epoch [125/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0055630891 | train val MSE 0.0337792249 | val MAE 3.1557580642 | val MSE 3.3779218439


Epoch [126/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.64it/s]


 train MSE 0.0054957325 | train val MSE 0.0338985292 | val MAE 3.1618039832 | val MSE 3.3898538258


Epoch [127/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.62it/s]


 train MSE 0.0055310168 | train val MSE 0.0339800441 | val MAE 3.1637682579 | val MSE 3.3980047442


Epoch [128/1000]: 100%|██████████| 71/71 [00:27<00:00,  2.61it/s]


 train MSE 0.0055896784 | train val MSE 0.0337952654 | val MAE 3.1539724357 | val MSE 3.3795282040


Epoch [129/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.71it/s]


 train MSE 0.0056017904 | train val MSE 0.0338860804 | val MAE 3.1564804614 | val MSE 3.3886080291


Epoch [130/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.74it/s]


 train MSE 0.0055249263 | train val MSE 0.0337566743 | val MAE 3.1519104950 | val MSE 3.3756677713


Epoch [131/1000]: 100%|██████████| 71/71 [00:25<00:00,  2.78it/s]


 train MSE 0.0055931962 | train val MSE 0.0337440327 | val MAE 3.1509329565 | val MSE 3.3744038045


Epoch [132/1000]: 100%|██████████| 71/71 [00:26<00:00,  2.73it/s]


 train MSE 0.0055336726 | train val MSE 0.0336720448 | val MAE 3.1481382884 | val MSE 3.3672049008


Epoch [133/1000]:  39%|███▉      | 28/71 [00:11<00:17,  2.45it/s]


KeyboardInterrupt: 

In [None]:
test_dataset = WindowedNormalizedTestDataset(testData)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


best_model = torch.load("./models/final/best_model.pt")
model = model = TrajectoryTransformer().to(device=device)
# model = model = TrajectoryTransformerPlus().to(device=device)


model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batchX, origin in test_loader:
        batchX = batchX.to(device)
        batchY = batchY.to(device)
        origin = origin.to(device)

        
        pred = model(batchX)  
        
        unnorm_pred = denormalize_ego_batch(pred[..., :2], origin)
        # print(unnorm_pred.shape)
        pred_list.append(unnorm_pred.cpu().numpy())
        # print(len(pred))
        

pred_list = np.concatenate(pred_list, axis=0)  
pred_output = pred_list.reshape(-1, 2)  
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('./models/modelI/testTransFormer.csv', index=True)

In [None]:
#  train MSE 0.0058623942 | train val MSE 0.0470800689 | val MAE 3.5586496554 | val MSE 4.7080054618  --- Test: 8.5
#  train MSE 0.0027955156 | train val MSE 0.0230247062 | val MAE 2.6988429800 | val MSE 2.3024725225 ---- Test: 8.43
