In [1]:
import os
import random
import json
from datetime import datetime
from pathlib import Path
import numpy as np
from hydra.utils import instantiate, to_absolute_path
import torch
from importlib import import_module
import pytorch_lightning as pl

teacher_dict = {
    "_target_": "src.model.trainer_forecast.Trainer",
    "historical_steps": 50,
    "future_steps": 60,
    "dim": 128,
    "encoder_depth": 4,
    "num_heads": 8,
    "mlp_ratio": 4.0,
    "qkv_bias": False,
    "drop_path": 0.2,
    "attention_type": "standard",  # "standard", "linear", "performer"
    "decoder": "mlp",  # Decoder type: "mlp" or "detr"
    "decoder_embed_dim": None,
    "decoder_num_modes": None,
    "decoder_hidden_dim": None,
    "pretrained_weights": None,
    "lr": 0.001,
    "weight_decay": 1e-4,
    "epochs": 20,
    "warmup_epochs": 10,
}

if teacher_dict["decoder"] == "mlp":
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empm.ckpt"
else:
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empd.ckpt"
assert os.path.exists(checkpoint), f"Checkpoint {checkpoint} does not exist"

model_path = teacher_dict["_target_"]
module = import_module(model_path[: model_path.rfind(".")])
Model: pl.LightningModule = getattr(module, model_path[model_path.rfind(".") + 1 :])
teacher = Model.load_from_checkpoint(checkpoint, **teacher_dict)

# Cut the layer that gives the logits and keep the "penultimate" layer
teacher.net.dense_predictor.pop(-1)  
teacher.net.decoder.loc.pop(-1)
teacher.net.decoder.pi.pop(-1)


  from .autonotebook import tqdm as notebook_tqdm


Using decoder type: mlp
RUN EMP-M


Linear(in_features=128, out_features=1, bias=True)

In [2]:
# Print the model layers
print("Teacher model instantiated with the following configuration:")
for name, layer in teacher.named_children():
    print(f"{name}: {layer}")

Teacher model instantiated with the following configuration:
net: EMP(
  (h_proj): Linear(in_features=5, out_features=128, bias=True)
  (h_embed): ModuleList(
    (0): Block(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (drop_path1): Identity()
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (drop_path2): Identity()
    )
    (1): Block(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=

In [3]:
student_dict = {
    "_target_": "src.model.trainer_forecast.Trainer",
    "dim": 64,
    "historical_steps": 50,
    "future_steps": 60,
    "encoder_depth": 3,
    "num_heads": 4,
    "mlp_ratio": 2.0,
    "qkv_bias": False,
    "drop_path": 0.2,
    "attention_type": "performer",  # "standard", "linear", "performer"
    "decoder": "mlp",
    "decoder_embed_dim": 64,
    "decoder_num_modes": 9,
    "decoder_hidden_dim": 256,
    "pretrained_weights": None,
    "lr": 0.001,
    "weight_decay": 1e-4,
    "epochs": 60,
    "warmup_epochs": 10,
}

student = instantiate(student_dict)
# Remove the last layer from the encoder or decoder
student.net.dense_predictor = student.net.dense_predictor[:-1]  
student.net.decoder.loc = student.net.decoder.loc[:-1]
student.net.decoder.pi = student.net.decoder.pi[:-1]

# Print the model layers
print("Student model instantiated with the following configuration:")
for name, layer in student.named_children():
    print(f"{name}: {layer}")

Using decoder type: mlp
RUN EMP-M
Student model instantiated with the following configuration:
net: EMP(
  (h_proj): Linear(in_features=5, out_features=64, bias=True)
  (h_embed): ModuleList(
    (0): Block(
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): PerformerAttention(
        (q_proj): Linear(in_features=64, out_features=64, bias=False)
        (k_proj): Linear(in_features=64, out_features=64, bias=False)
        (v_proj): Linear(in_features=64, out_features=64, bias=False)
        (out_proj): Linear(in_features=64, out_features=64, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (drop_path1): Identity()
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=64, out_features=128, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=128, out_features=64, bias=True)
        (drop

In [4]:
# Create datamodule
datamodule_dict = {
    "_target_": "src.datamodule.av2_datamodule.Av2DataModule",
    "data_root": "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data",
    "data_folder": "emp",
    "train_batch_size": 64,
    "val_batch_size": 64,
    "test_batch_size": 64,
    "shuffle": True,
    "num_workers": 18,
    "pin_memory": True,
}

datamodule = instantiate(datamodule_dict)

FRACTION = 0.5  

datamodule.setup(stage="fit", train_fraction=FRACTION, val_fraction=FRACTION, test_fraction=FRACTION)

train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

print(f"Number of training samples: {len(train_dataloader.dataset)} out of {len(datamodule.full_train_dataset)}")
print(f"Number of validation samples: {len(val_dataloader.dataset)} out of {len(datamodule.full_val_dataset)}")

data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 50.0% of the training and validation datasets.
Number of training samples: 99954 out of 199908
Number of validation samples: 12494 out of 24988


In [5]:
# Check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:

# from tqdm import tqdm
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# teacher = teacher.to(device)
# teacher.eval()
# features_list_train = []
# features_list_val = []
# points_train = []
# points_val = []
# with torch.no_grad():
#     for data in tqdm(train_dataloader):
#         for k in data.keys():
#             if torch.is_tensor(data[k]): data[k] = data[k].to(device)
        
#         points_train.append(data)
        
#         features = teacher.net.get_features(data)
#         features_list_train.append(features)
        
        
#     for data in tqdm(val_dataloader):
#         for k in data.keys():
#             if torch.is_tensor(data[k]): data[k] = data[k].to(device)
        
#         features = teacher.net.get_features(data)
#         features_list_val.append(features)
#         points_val.append(data)
        
# # Save the features to a file
# import pickle as pkl
# output_file_train = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_train.pkl"
# with open(output_file_train, "wb") as f:
#     pkl.dump(features_list_train, f)
    
# output_file_val = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_val.pkl"
# with open(output_file_val, "wb") as f:
#     pkl.dump(features_list_val, f)


# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_train.pkl"
# with open(output_file_data, "wb") as f:
#     pkl.dump(points_train, f)

# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_val.pkl"
# with open(output_file_data, "wb") as f:
#     pkl.dump(points_val, f)


In [7]:
import torch
import torch.nn.functional as F

def h_score(t_feat, s_feat, tau=0.1, N_over_M=1.0):
    """
    Compute h(T, S) as defined in Eq (19) in the paper.
    """
    # Normalize features 
    t_feat = F.normalize(t_feat, dim=1)
    s_feat = F.normalize(s_feat, dim=1)

    # Similarity scores
    sim = torch.sum(t_feat * s_feat, dim=1) / tau  # Exponent of e for numerator
    numerator = torch.exp(sim)
    denominator = numerator + N_over_M
    h = numerator / denominator  # shape: (B,)
    return h


# Test function
def test_h_score():
    # Create dummy features
    t_feat = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    s_feat = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

    # Compute h-score
    h = h_score(t_feat, s_feat)

    # Print result
    print("h-score:", h)
    # Expected output: h-score: tensor([1.0000, 1.0000, 1.0000])
    assert torch.allclose(h, torch.tensor([1.0, 1.0, 1.0]), atol=1e-4), "h-score test failed!"
    
test_h_score()

h-score: tensor([1.0000, 1.0000, 1.0000])


In [8]:
def critic_loss(t_pos, s_pos, t_neg, s_neg, tau=0.1, N=1.0, M = 1.0):
    """
    Compute the full critic loss:
    - t_pos, s_pos: matched teacher/student features (C=1)
    - t_neg, s_neg: mismatched features (C=0)
    """
    N_over_M = N / M if M != 0 else 1.0  # Avoid division by zero
    
    # Positive pair scores
    h_pos = h_score(t_pos, s_pos, tau, N_over_M)
    loss_pos = torch.log(h_pos + 1e-8).mean() # We use mean as Monte Carlo estimate

    # Negative pair scores
    h_neg = h_score(t_neg, s_neg, tau, N_over_M)
    loss_neg = torch.log(1 - h_neg + 1e-8).mean()

    # Combine as in Eq (18)
    loss = -(loss_pos + N * loss_neg)
    return loss

# Test function
def test_critic_loss():
    # Here I tried to find two tensors that would yield a loss of almost 0
    t_pos = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
    s_pos = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])

    t_neg = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
    s_neg = torch.tensor([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0]])

    # Compute critic loss
    loss = critic_loss(t_pos, s_pos, t_neg, s_neg)

    # Print result
    print("Critic loss:", loss)
    # Expected output: Critic loss: tensor(1.)
    assert torch.isclose(loss, torch.tensor(0.), atol=1e-4), "Critic loss test failed!"
    
test_critic_loss()

Critic loss: tensor(    0.0001)


In [9]:
# # Load features from the saved file
# from tqdm import tqdm
# import pickle as pkl


# output_file_train = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_train.pkl"
# with open(output_file_train, "rb") as f:
#     features_list_train = pkl.load(f)
    
# output_file_val = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_val.pkl"
# with open(output_file_val, "rb") as f:
#     features_list_val = pkl.load(f)
    
# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_train.pkl"
# with open(output_file_data, "rb") as f:
#     points_train = pkl.load(f)
    
# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_val.pkl"
# with open(output_file_data, "rb") as f:
#     points_val = pkl.load(f)
    
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:

# student.train()
# student = student.to(device)
# teacher.to(device)
# teacher.eval()
# batch_size = 64

# # Already organized in batches


# # If the embedding dim of the student is different from the teacher, add a projection layer
# # Get a batch of data
# batch = points_train[0]
# for k in batch.keys():
#     if torch.is_tensor(batch[k]): 
#         batch[k] = batch[k].to(device)

# with torch.no_grad():
#     dim_teacher = teacher.net.get_features(batch).shape[1]
#     dim_student = student.net.get_features(batch).shape[1]

# print(f"Teacher feature dimension: {dim_teacher}, Student feature dimension: {dim_student}")
# projection_layer = None
# if dim_teacher != dim_student:
#     projection_layer = torch.nn.Linear(dim_student, dim_teacher)
#     projection_layer.to(device)
#     print(f"Projection layer added: {projection_layer}")
    
#     # Freeze the projection layer
#     for param in projection_layer.parameters():
#         param.requires_grad = False
    
# optimizer, scheduler = student.configure_optimizers()
# optimizer = optimizer[0] if isinstance(optimizer, list) else optimizer
# scheduler = scheduler[0] if isinstance(scheduler, list) else scheduler
    
# # M = nr of positive pairs, N = nr of negative pairs
# M = len(features_list_train[0])  # Number of positive pairs in a batch
# N = len(features_list_train[0])  # Number of negative pairs in a batch
# for i in range(10):
#     losses = []
#     print(f"Epoch {i+1}/{10}")
#     for idx in tqdm(range(len(points_train))):
#         batch = points_train[idx]
#         for k in batch.keys():
#             if torch.is_tensor(batch[k]): batch[k] = batch[k].to(device)
#         # Eliminate "y" key if it exists
#         if 'y' in batch:
#             del batch['y']
#         feature_contrast = student.net.get_features(batch)

#         if projection_layer is not None:
#             feature_contrast = projection_layer(feature_contrast)


#         batch_feats = features_list_train[idx]

#         # Get a random feature from the batch
#         random_idx = random.randint(0, len(batch_feats) - 1)
#         t_pos = batch_feats[random_idx].unsqueeze(0)  # [1, feature_dim]
#         s_pos = feature_contrast[random_idx].unsqueeze(0)  # [1, feature_dim]
        
#         # Get a random negative feature from the batch
#         neg_idx = random.randint(0, len(batch_feats) - 1)
#         while neg_idx == random_idx:  # Ensure it's different from the positive index
#             neg_idx = random.randint(0, len(batch_feats) - 1)
#         t_neg = batch_feats[neg_idx].unsqueeze(0)  # [1, feature_dim]
#         s_neg = feature_contrast[neg_idx].unsqueeze(0)  # [1, feature_dim]
#         # Compute the critic loss
#         loss = critic_loss(t_pos, s_pos, t_neg, s_neg, tau=0.1, N=N, M=M)
#         # print(f"Critic loss: {loss.item()}")
        
#         # Backpropagate the loss
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         losses.append(loss.item())
#     losses = np.mean(losses)
#     print(f"Epoch {i+1} losses: {losses}")
#     if scheduler is not None:
#         if isinstance(scheduler, list):
#             for sch in scheduler:
#                 sch.step()
#         else:
#             scheduler.step()        

In [11]:
teacher.to(device)
student.to(device)

batch = next(iter(train_dataloader))
for k in batch.keys():
    if torch.is_tensor(batch[k]): batch[k] = batch[k].to(device)
# Find the expected output size
with torch.no_grad():
    dim_teacher = teacher.net.get_features(batch).shape[1]
    dim_student = student.net.get_features(batch).shape[1]

print(f"Teacher feature dimension: {dim_teacher}, Student feature dimension: {dim_student}")

# Create learnable projection layer
projection_layer = None
if dim_teacher != dim_student:
    projection_layer = torch.nn.Linear(dim_student, dim_teacher).to(device)
    print(f"Learnable projection layer added: {projection_layer}")
    
    # Freeze the projection layer
    for param in projection_layer.parameters():
        param.requires_grad = False



Teacher feature dimension: 640, Student feature dimension: 448
Learnable projection layer added: Linear(in_features=448, out_features=640, bias=True)


In [12]:
from tqdm import tqdm
import matplotlib.pyplot as plt

teacher.to(device)
student.to(device)
student.train()
teacher.eval()

memory_bank_size = 2048  # Reduced memory bank size for better management
memory_buffer = []

# Create optimizer
if projection_layer is not None:
    all_params = list(student.parameters()) + list(projection_layer.parameters())
    optimizer = torch.optim.AdamW(all_params, lr=1e-4, weight_decay=1e-4)
else:
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

def contrastive_loss_batch_online(student_feats, pos_features, neg_features=None, tau=0.1, NEG_RATIO=1.0):
    """
    Online contrastive loss function using your custom critic loss
    """
    batch_size = student_feats.shape[0]
    
    if batch_size == 1:
        return F.mse_loss(student_feats, pos_features)
    
    # For negative pairs
    if neg_features is not None and neg_features.shape[0] > 0:
        # Sample negative features randomly based on ratio
        num_negatives = min(int(batch_size * NEG_RATIO), neg_features.shape[0])
        if num_negatives > 0:
            neg_indices = torch.randperm(neg_features.shape[0])[:num_negatives]
            
            t_pos = pos_features  # [B, feature_dim]
            s_pos = student_feats  # [B, feature_dim]
            
            t_neg = neg_features[neg_indices]  # [neg_samples, feature_dim]
            
            # Create corresponding student negatives by cycling through student features
            s_neg_indices = torch.arange(num_negatives) % batch_size
            s_neg = student_feats[s_neg_indices]  # [neg_samples, feature_dim]
            
            # print("s_pos shape:", s_pos.shape)
            # print("t_pos shape:", t_pos.shape)
            # print("t_neg shape:", t_neg.shape)
            # print("s_neg shape:", s_neg.shape)
            
            M = t_pos.shape[0]  # Number of positive pairs
            N = t_neg.shape[0]  # Number of negative pairs
            
            loss = critic_loss(t_pos, s_pos, t_neg, s_neg, tau=tau, N=N, M=M)
        else:
            loss = F.mse_loss(student_feats, pos_features)
    else:
        # Fallback to MSE if no negatives available
        loss = F.mse_loss(student_feats, pos_features)
    
    return loss

epochs = 5
print("Starting online contrastive training with custom loss and memory bank...")

NEG_RATIO = 0.2

losses = []
LOSS_BUFFER_SIZE = 3000

for epoch in range(epochs):
    epoch_loss = 0
    num_batches = 0
    
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    # The training loop iterates directly over the train_dataloader
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        for k in batch.keys():
            if torch.is_tensor(batch[k]): 
                batch[k] = batch[k].to(device)
        
        # Eliminate "y" key if it exists
        if 'y' in batch:
            del batch['y']
        
        # Get student features
        student_features = student.net.get_features(batch)
        
        # Apply projection
        if projection_layer is not None:
            student_features = projection_layer(student_features)
        
        # Generate positive features on-the-fly from the teacher
        with torch.no_grad():
            pos_features = teacher.net.get_features(batch)
        
            # Update memory buffer with new features (FIFO queue)
            pos_features_cpu = pos_features.cpu()
            
            # Concatenate new features to memory buffer
            for feature in pos_features_cpu:
                memory_buffer.append(feature)
            
            # Limit memory buffer size
            if len(memory_buffer) > memory_bank_size:
                memory_buffer.pop(0)
        
        # Prepare negative features for loss computation
        neg_features_for_loss = None
        if len(memory_buffer) > 0:
            # Convert memory buffer back to GPU tensors for loss computation
            neg_features_for_loss = torch.stack(memory_buffer).to(device)
            
        loss = contrastive_loss_batch_online(
            student_features, 
            pos_features, 
            neg_features=neg_features_for_loss, 
            tau=0.1, 
            NEG_RATIO=NEG_RATIO
        )

        losses.append(loss.item())
        if len(losses) > LOSS_BUFFER_SIZE:
            losses.pop(0)

        optimizer.zero_grad()
        loss.backward()        
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1
        
        if num_batches % 10 == 0:
            plt.figure(figsize=(10, 5))
            if len(losses) > 50:
                window_size = min(50, len(losses) // 10)
                smoothed_losses = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
                smooth_x = np.arange(window_size-1, len(losses))
                plt.plot(smooth_x, smoothed_losses, color='blue', linewidth=2, label=f'Moving Average (window={window_size})')
            else:
                plt.plot(np.arange(len(losses)), losses, color='lightblue', linewidth=1, label='Raw Loss')

            
            plt.xlabel('Batch')
            plt.ylabel('Loss Value')
            plt.title(f'Epoch {epoch + 1} Loss Progress (Moving Average)')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f"loss_distillation_smoothed_{epoch + 1}.png", dpi=150, bbox_inches='tight')
            plt.close()

        if num_batches % 100 == 0:
            print(f"Batch {num_batches}: Loss = {loss.item():.4f}")
    
    scheduler.step()
    
    if num_batches > 0:
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch + 1} completed. No batches were processed.")

print("Online contrastive training completed!")

Starting online contrastive training with custom loss and memory bank...

Epoch 1/5


Epoch 1:   6%|▋         | 100/1562 [00:42<10:06,  2.41it/s]

Batch 100: Loss = 3.9611


Epoch 1:  13%|█▎        | 200/1562 [01:23<09:28,  2.39it/s]

Batch 200: Loss = 3.8717


Epoch 1:  19%|█▉        | 300/1562 [02:04<09:52,  2.13it/s]

Batch 300: Loss = 3.4674


Epoch 1:  26%|██▌       | 400/1562 [02:45<09:46,  1.98it/s]

Batch 400: Loss = 3.6989


Epoch 1:  32%|███▏      | 500/1562 [03:28<07:30,  2.36it/s]

Batch 500: Loss = 3.2551


Epoch 1:  38%|███▊      | 600/1562 [04:11<07:07,  2.25it/s]

Batch 600: Loss = 3.5387


Epoch 1:  45%|████▍     | 700/1562 [04:54<06:42,  2.14it/s]

Batch 700: Loss = 3.5656


Epoch 1:  51%|█████     | 800/1562 [05:38<05:59,  2.12it/s]

Batch 800: Loss = 4.2850


Epoch 1:  58%|█████▊    | 900/1562 [06:23<05:11,  2.13it/s]

Batch 900: Loss = 3.8064


Epoch 1:  64%|██████▍   | 1000/1562 [07:09<04:23,  2.13it/s]

Batch 1000: Loss = 3.5742


Epoch 1:  70%|███████   | 1100/1562 [07:56<03:48,  2.02it/s]

Batch 1100: Loss = 3.4285


Epoch 1:  77%|███████▋  | 1200/1562 [08:44<03:04,  1.96it/s]

Batch 1200: Loss = 3.4903


Epoch 1:  83%|████████▎ | 1300/1562 [09:33<02:14,  1.94it/s]

Batch 1300: Loss = 3.3120


Epoch 1:  90%|████████▉ | 1400/1562 [10:22<01:23,  1.94it/s]

Batch 1400: Loss = 3.7938


Epoch 1:  96%|█████████▌| 1500/1562 [11:13<00:34,  1.82it/s]

Batch 1500: Loss = 3.5449


Epoch 1: 100%|██████████| 1562/1562 [11:45<00:00,  2.21it/s]


Epoch 1 completed. Average Loss: 3.9268

Epoch 2/5


Epoch 2:   6%|▋         | 100/1562 [00:54<13:10,  1.85it/s]

Batch 100: Loss = 3.6565


Epoch 2:  13%|█▎        | 200/1562 [01:49<12:52,  1.76it/s]

Batch 200: Loss = 3.4739


Epoch 2:  19%|█▉        | 300/1562 [02:45<12:00,  1.75it/s]

Batch 300: Loss = 3.0539


Epoch 2:  26%|██▌       | 400/1562 [03:41<11:03,  1.75it/s]

Batch 400: Loss = 3.1822


Epoch 2:  32%|███▏      | 500/1562 [04:38<10:40,  1.66it/s]

Batch 500: Loss = 3.2209


Epoch 2:  38%|███▊      | 600/1562 [05:37<09:36,  1.67it/s]

Batch 600: Loss = 3.4616


Epoch 2:  45%|████▍     | 700/1562 [06:36<08:40,  1.66it/s]

Batch 700: Loss = 3.3599


Epoch 2:  51%|█████     | 800/1562 [07:36<08:02,  1.58it/s]

Batch 800: Loss = 3.0171


Epoch 2:  58%|█████▊    | 900/1562 [08:38<06:59,  1.58it/s]

Batch 900: Loss = 3.4033


Epoch 2:  64%|██████▍   | 1000/1562 [09:40<05:59,  1.56it/s]

Batch 1000: Loss = 3.0143


Epoch 2:  70%|███████   | 1100/1562 [10:43<05:07,  1.50it/s]

Batch 1100: Loss = 3.1490


Epoch 2:  77%|███████▋  | 1200/1562 [11:48<04:02,  1.49it/s]

Batch 1200: Loss = 3.0140


Epoch 2:  83%|████████▎ | 1300/1562 [12:53<02:50,  1.53it/s]

Batch 1300: Loss = 3.4125


Epoch 2:  90%|████████▉ | 1400/1562 [14:00<01:49,  1.48it/s]

Batch 1400: Loss = 3.3966


Epoch 2:  96%|█████████▌| 1500/1562 [15:08<00:43,  1.43it/s]

Batch 1500: Loss = 3.0139


Epoch 2: 100%|██████████| 1562/1562 [15:50<00:00,  1.64it/s]


Epoch 2 completed. Average Loss: 3.3757

Epoch 3/5


Epoch 3:   6%|▋         | 100/1562 [01:11<17:22,  1.40it/s]

Batch 100: Loss = 3.0966


Epoch 3:  13%|█▎        | 200/1562 [02:23<16:26,  1.38it/s]

Batch 200: Loss = 3.6980


Epoch 3:  19%|█▉        | 300/1562 [03:36<16:03,  1.31it/s]

Batch 300: Loss = 3.5765


Epoch 3:  26%|██▌       | 400/1562 [04:49<14:33,  1.33it/s]

Batch 400: Loss = 3.4281


Epoch 3:  32%|███▏      | 500/1562 [06:04<13:34,  1.30it/s]

Batch 500: Loss = 4.0264


Epoch 3:  38%|███▊      | 600/1562 [07:20<12:07,  1.32it/s]

Batch 600: Loss = 3.4876


Epoch 3:  45%|████▍     | 700/1562 [08:34<10:48,  1.33it/s]

Batch 700: Loss = 3.1494


Epoch 3:  51%|█████     | 800/1562 [09:49<10:11,  1.25it/s]

Batch 800: Loss = 3.3929


Epoch 3:  58%|█████▊    | 900/1562 [11:08<08:59,  1.23it/s]

Batch 900: Loss = 3.4166


Epoch 3:  64%|██████▍   | 1000/1562 [12:24<07:27,  1.26it/s]

Batch 1000: Loss = 3.2519


Epoch 3:  70%|███████   | 1100/1562 [13:41<06:00,  1.28it/s]

Batch 1100: Loss = 3.9392


Epoch 3:  71%|███████   | 1108/1562 [13:47<05:39,  1.34it/s]


KeyboardInterrupt: 

In [13]:
# Save the model
torch.save({
    'student_state_dict': student.state_dict(),
    'projection_layer_state_dict': projection_layer.state_dict() if projection_layer else None,
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epochs,
}, 'student_model_custom_contrastive.pth')

print("Model saved as 'student_model_custom_contrastive.pth'")

Model saved as 'student_model_custom_contrastive.pth'


In [21]:
# Load the model for testing
checkpoint = torch.load('student_model_custom_contrastive.pth')
student = instantiate(student_dict)
# Remove the last layer from the encoder or decoder
student.net.dense_predictor = student.net.dense_predictor[:-1]  
student.net.decoder.loc = student.net.decoder.loc[:-1]
student.net.decoder.pi = student.net.decoder.pi[:-1]
student.load_state_dict(checkpoint['student_state_dict'])  # <-- FIXED LINE

teacher.eval()
student.eval()
teacher.to(device)
student.to(device)

batch_pos = next(iter(train_dataloader))
for k in batch_pos.keys():
    if torch.is_tensor(batch_pos[k]): 
        batch_pos[k] = batch_pos[k].to(device)
# Eliminate "y" key if it exists
if 'y' in batch_pos:
    del batch_pos['y']
    
batch_neg = next(iter(train_dataloader))
for k in batch_neg.keys():
    if torch.is_tensor(batch_neg[k]): 
        batch_neg[k] = batch_neg[k].to(device)
# Eliminate "y" key if it exists
if 'y' in batch_neg:
    del batch_neg['y']

# After loading student and teacher and before loss calculation
with torch.no_grad():
    pos_teacher_features = teacher.net.get_features(batch_pos)
    neg_teacher_features = teacher.net.get_features(batch_neg)
    pos_student_features = student.net.get_features(batch_pos)

    # Project student features if needed
    if pos_teacher_features.shape[1] != pos_student_features.shape[1]:
        projection_layer = torch.nn.Linear(pos_student_features.shape[1], pos_teacher_features.shape[1]).to(device)
        # Optionally load projection weights if you saved them
        if 'projection_layer_state_dict' in checkpoint and checkpoint['projection_layer_state_dict'] is not None:
            projection_layer.load_state_dict(checkpoint['projection_layer_state_dict'])
        pos_student_features = projection_layer(pos_student_features)

loss = contrastive_loss_batch_online(
    pos_student_features, 
    pos_teacher_features, 
    neg_features=neg_teacher_features, 
    tau=0.1, 
    NEG_RATIO=NEG_RATIO
)
print(f"Test loss after loading model: {loss.item():.4f}")

Using decoder type: mlp
RUN EMP-M
Test loss after loading model: 3.2877


In [28]:
student = instantiate(student_dict)
checkpoint = torch.load('student_model_custom_contrastive.pth')
student.load_state_dict(checkpoint['student_state_dict'], strict=False)

student.to(device)
student.train()

optimizer, scheduler = student.configure_optimizers()
optimizer = optimizer[0] if isinstance(optimizer, list) else optimizer
scheduler = scheduler[0] if isinstance(scheduler, list) else scheduler

dataloader = datamodule.train_dataloader()

for epoch in range(5):
    print(f"Epoch {epoch+1}/5")
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)
        preds = student(batch)
        # preds['y_hat']: (batch, 9, 60, 2), batch['y'][:, 0, :, :] : (batch, 60, 2)
        gt_traj_ego = batch['y'][:, 0, :, :]  # (batch, 60, 2)
        pred_modes = preds['y_hat']           # (batch, 9, 60, 2)
        # Compute ADE for each mode
        gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, pred_modes.shape[1], -1, -1)  # (batch, 9, 60, 2)
        displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, 9, 60)
        ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, 9)
        best_mode_idx = ade_per_mode.argmin(dim=1)       # (batch,)
        batch_indices = torch.arange(pred_modes.shape[0])
        best_pred = pred_modes[batch_indices, best_mode_idx]  # (batch, 60, 2)
        # Now compute loss
        loss = torch.nn.functional.mse_loss(best_pred, gt_traj_ego)
        print(f"Loss for epoch {epoch+1}, batch: {loss.item():.4f}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if scheduler is not None:
        scheduler.step()
    print(f"Epoch {epoch+1} done.")

print("Fine-tuning complete.")

Using decoder type: mlp
RUN EMP-M
Epoch 1/5


Epoch 1:   0%|          | 1/313 [00:01<08:57,  1.72s/it]

Loss for epoch 1, batch: 285.2229


Epoch 1:   1%|          | 2/313 [00:01<04:29,  1.16it/s]

Loss for epoch 1, batch: 333.0374


Epoch 1:   1%|          | 3/313 [00:02<02:58,  1.74it/s]

Loss for epoch 1, batch: 362.2126


Epoch 1:   1%|▏         | 4/313 [00:02<02:22,  2.17it/s]

Loss for epoch 1, batch: 480.9289


Epoch 1:   2%|▏         | 5/313 [00:02<01:59,  2.59it/s]

Loss for epoch 1, batch: 336.8015


Epoch 1:   2%|▏         | 6/313 [00:03<01:43,  2.96it/s]

Loss for epoch 1, batch: 350.4828


Epoch 1:   2%|▏         | 7/313 [00:03<01:34,  3.24it/s]

Loss for epoch 1, batch: 245.2197


Epoch 1:   3%|▎         | 8/313 [00:03<01:26,  3.55it/s]

Loss for epoch 1, batch: 364.3583


Epoch 1:   3%|▎         | 9/313 [00:03<01:19,  3.83it/s]

Loss for epoch 1, batch: 370.1012


Epoch 1:   3%|▎         | 10/313 [00:03<01:15,  4.00it/s]

Loss for epoch 1, batch: 307.4930


Epoch 1:   4%|▎         | 11/313 [00:04<01:10,  4.28it/s]

Loss for epoch 1, batch: 336.2839


Epoch 1:   4%|▍         | 12/313 [00:04<01:10,  4.27it/s]

Loss for epoch 1, batch: 371.9744


Epoch 1:   4%|▍         | 13/313 [00:04<01:09,  4.31it/s]

Loss for epoch 1, batch: 407.6709


Epoch 1:   4%|▍         | 14/313 [00:04<01:08,  4.40it/s]

Loss for epoch 1, batch: 370.7653


Epoch 1:   5%|▍         | 15/313 [00:05<01:07,  4.39it/s]

Loss for epoch 1, batch: 428.9459
Loss for epoch 1, batch: 362.8737


Epoch 1:   5%|▌         | 17/313 [00:05<01:02,  4.72it/s]

Loss for epoch 1, batch: 356.6921


Epoch 1:   6%|▌         | 18/313 [00:05<01:02,  4.73it/s]

Loss for epoch 1, batch: 397.5059


Epoch 1:   6%|▌         | 19/313 [00:05<01:02,  4.72it/s]

Loss for epoch 1, batch: 393.9336


Epoch 1:   6%|▋         | 20/313 [00:06<01:01,  4.79it/s]

Loss for epoch 1, batch: 242.1806


Epoch 1:   7%|▋         | 21/313 [00:06<01:01,  4.75it/s]

Loss for epoch 1, batch: 449.3475


Epoch 1:   7%|▋         | 22/313 [00:06<01:01,  4.74it/s]

Loss for epoch 1, batch: 313.8423
Loss for epoch 1, batch: 353.8736


Epoch 1:   8%|▊         | 24/313 [00:06<00:59,  4.86it/s]

Loss for epoch 1, batch: 284.5814


Epoch 1:   8%|▊         | 25/313 [00:07<01:00,  4.74it/s]

Loss for epoch 1, batch: 371.0218


Epoch 1:   8%|▊         | 26/313 [00:07<01:01,  4.64it/s]

Loss for epoch 1, batch: 289.1300


Epoch 1:   9%|▊         | 27/313 [00:07<01:03,  4.53it/s]

Loss for epoch 1, batch: 366.1032


Epoch 1:   9%|▉         | 28/313 [00:07<01:02,  4.53it/s]

Loss for epoch 1, batch: 440.2894


Epoch 1:   9%|▉         | 29/313 [00:07<01:03,  4.47it/s]

Loss for epoch 1, batch: 332.6043


Epoch 1:  10%|▉         | 30/313 [00:08<01:01,  4.60it/s]

Loss for epoch 1, batch: 364.6490


Epoch 1:  10%|▉         | 31/313 [00:08<01:03,  4.47it/s]

Loss for epoch 1, batch: 314.9058


Epoch 1:  10%|█         | 32/313 [00:08<01:05,  4.31it/s]

Loss for epoch 1, batch: 433.1504


Epoch 1:  11%|█         | 33/313 [00:08<01:05,  4.28it/s]

Loss for epoch 1, batch: 382.0585


Epoch 1:  11%|█         | 34/313 [00:09<01:04,  4.32it/s]

Loss for epoch 1, batch: 314.3686


Epoch 1:  11%|█         | 35/313 [00:09<01:03,  4.41it/s]

Loss for epoch 1, batch: 318.2587


Epoch 1:  12%|█▏        | 36/313 [00:09<01:02,  4.42it/s]

Loss for epoch 1, batch: 335.9106


Epoch 1:  12%|█▏        | 37/313 [00:09<01:03,  4.32it/s]

Loss for epoch 1, batch: 385.7719


Epoch 1:  12%|█▏        | 38/313 [00:10<01:04,  4.24it/s]

Loss for epoch 1, batch: 330.5747


Epoch 1:  12%|█▏        | 39/313 [00:10<01:06,  4.11it/s]

Loss for epoch 1, batch: 489.9847


Epoch 1:  13%|█▎        | 40/313 [00:10<01:05,  4.14it/s]

Loss for epoch 1, batch: 499.9015


Epoch 1:  13%|█▎        | 41/313 [00:10<01:05,  4.18it/s]

Loss for epoch 1, batch: 418.8035


Epoch 1:  13%|█▎        | 42/313 [00:11<01:04,  4.21it/s]

Loss for epoch 1, batch: 388.6758


Epoch 1:  14%|█▎        | 43/313 [00:11<01:04,  4.21it/s]

Loss for epoch 1, batch: 414.4447


Epoch 1:  14%|█▍        | 44/313 [00:11<01:04,  4.15it/s]

Loss for epoch 1, batch: 282.5040


Epoch 1:  14%|█▍        | 45/313 [00:11<01:06,  4.05it/s]

Loss for epoch 1, batch: 399.8519


Epoch 1:  15%|█▍        | 46/313 [00:12<01:04,  4.17it/s]

Loss for epoch 1, batch: 280.3168
Loss for epoch 1, batch: 255.6490


Epoch 1:  15%|█▌        | 48/313 [00:12<00:59,  4.45it/s]

Loss for epoch 1, batch: 364.4626


Epoch 1:  16%|█▌        | 49/313 [00:12<00:59,  4.45it/s]

Loss for epoch 1, batch: 400.9965


Epoch 1:  16%|█▌        | 50/313 [00:12<00:57,  4.58it/s]

Loss for epoch 1, batch: 395.7108


Epoch 1:  16%|█▋        | 51/313 [00:13<00:55,  4.68it/s]

Loss for epoch 1, batch: 291.4653


Epoch 1:  17%|█▋        | 52/313 [00:13<00:55,  4.73it/s]

Loss for epoch 1, batch: 338.6885


Epoch 1:  17%|█▋        | 53/313 [00:13<00:54,  4.81it/s]

Loss for epoch 1, batch: 368.1588


Epoch 1:  17%|█▋        | 54/313 [00:13<00:54,  4.75it/s]

Loss for epoch 1, batch: 468.1343


Epoch 1:  18%|█▊        | 55/313 [00:13<00:55,  4.62it/s]

Loss for epoch 1, batch: 285.1781


Epoch 1:  18%|█▊        | 56/313 [00:14<00:57,  4.51it/s]

Loss for epoch 1, batch: 320.2094


Epoch 1:  18%|█▊        | 57/313 [00:14<00:58,  4.38it/s]

Loss for epoch 1, batch: 412.5793


Epoch 1:  19%|█▊        | 58/313 [00:14<00:57,  4.40it/s]

Loss for epoch 1, batch: 347.4369


Epoch 1:  19%|█▉        | 59/313 [00:14<00:58,  4.38it/s]

Loss for epoch 1, batch: 328.6171


Epoch 1:  19%|█▉        | 60/313 [00:15<00:56,  4.46it/s]

Loss for epoch 1, batch: 482.7646


Epoch 1:  19%|█▉        | 61/313 [00:15<00:55,  4.58it/s]

Loss for epoch 1, batch: 296.8341


Epoch 1:  20%|█▉        | 62/313 [00:15<00:57,  4.38it/s]

Loss for epoch 1, batch: 294.6202


Epoch 1:  20%|██        | 63/313 [00:15<00:58,  4.25it/s]

Loss for epoch 1, batch: 360.9269


Epoch 1:  20%|██        | 64/313 [00:16<00:59,  4.17it/s]

Loss for epoch 1, batch: 372.7393


Epoch 1:  21%|██        | 65/313 [00:16<00:59,  4.16it/s]

Loss for epoch 1, batch: 385.9076


Epoch 1:  21%|██        | 66/313 [00:16<01:01,  4.04it/s]

Loss for epoch 1, batch: 278.7875


Epoch 1:  21%|██▏       | 67/313 [00:16<01:01,  3.97it/s]

Loss for epoch 1, batch: 437.1541


Epoch 1:  22%|██▏       | 68/313 [00:17<01:00,  4.02it/s]

Loss for epoch 1, batch: 380.9137


Epoch 1:  22%|██▏       | 69/313 [00:17<01:00,  4.06it/s]

Loss for epoch 1, batch: 533.2612


Epoch 1:  22%|██▏       | 70/313 [00:17<00:58,  4.14it/s]

Loss for epoch 1, batch: 378.6333


Epoch 1:  23%|██▎       | 71/313 [00:17<00:58,  4.12it/s]

Loss for epoch 1, batch: 446.9014


Epoch 1:  23%|██▎       | 72/313 [00:17<00:57,  4.21it/s]

Loss for epoch 1, batch: 377.2294


Epoch 1:  23%|██▎       | 73/313 [00:18<00:55,  4.33it/s]

Loss for epoch 1, batch: 315.6651


Epoch 1:  24%|██▎       | 74/313 [00:18<00:54,  4.40it/s]

Loss for epoch 1, batch: 359.1827


Epoch 1:  24%|██▍       | 75/313 [00:18<00:52,  4.51it/s]

Loss for epoch 1, batch: 364.1428


Epoch 1:  24%|██▍       | 76/313 [00:18<00:54,  4.38it/s]

Loss for epoch 1, batch: 463.2545


Epoch 1:  25%|██▍       | 77/313 [00:19<00:55,  4.27it/s]

Loss for epoch 1, batch: 259.9407


Epoch 1:  25%|██▍       | 78/313 [00:19<00:55,  4.23it/s]

Loss for epoch 1, batch: 409.2771


Epoch 1:  25%|██▌       | 79/313 [00:19<00:54,  4.26it/s]

Loss for epoch 1, batch: 415.5723


Epoch 1:  26%|██▌       | 80/313 [00:19<00:53,  4.35it/s]

Loss for epoch 1, batch: 407.3905


Epoch 1:  26%|██▌       | 81/313 [00:20<00:53,  4.37it/s]

Loss for epoch 1, batch: 331.9213


Epoch 1:  26%|██▌       | 82/313 [00:20<00:51,  4.45it/s]

Loss for epoch 1, batch: 258.4328


Epoch 1:  27%|██▋       | 83/313 [00:20<00:51,  4.43it/s]

Loss for epoch 1, batch: 395.5394


Epoch 1:  27%|██▋       | 84/313 [00:20<00:52,  4.36it/s]

Loss for epoch 1, batch: 403.2826


Epoch 1:  27%|██▋       | 85/313 [00:20<00:52,  4.37it/s]

Loss for epoch 1, batch: 410.1336


Epoch 1:  27%|██▋       | 86/313 [00:21<00:51,  4.45it/s]

Loss for epoch 1, batch: 323.2997


Epoch 1:  28%|██▊       | 87/313 [00:21<00:50,  4.45it/s]

Loss for epoch 1, batch: 323.6117


Epoch 1:  28%|██▊       | 88/313 [00:21<00:50,  4.45it/s]

Loss for epoch 1, batch: 378.9841


Epoch 1:  28%|██▊       | 89/313 [00:21<00:51,  4.37it/s]

Loss for epoch 1, batch: 385.6454


Epoch 1:  29%|██▉       | 90/313 [00:22<00:51,  4.31it/s]

Loss for epoch 1, batch: 237.3598


Epoch 1:  29%|██▉       | 91/313 [00:22<00:51,  4.28it/s]

Loss for epoch 1, batch: 323.6263


Epoch 1:  29%|██▉       | 92/313 [00:22<00:51,  4.32it/s]

Loss for epoch 1, batch: 339.5918


Epoch 1:  30%|██▉       | 93/313 [00:22<00:51,  4.31it/s]

Loss for epoch 1, batch: 264.4967


Epoch 1:  30%|███       | 94/313 [00:23<00:50,  4.36it/s]

Loss for epoch 1, batch: 447.7757


Epoch 1:  30%|███       | 95/313 [00:23<00:49,  4.38it/s]

Loss for epoch 1, batch: 332.0335


Epoch 1:  31%|███       | 96/313 [00:23<00:49,  4.35it/s]

Loss for epoch 1, batch: 413.1617


Epoch 1:  31%|███       | 97/313 [00:23<00:49,  4.36it/s]

Loss for epoch 1, batch: 334.3734


Epoch 1:  31%|███▏      | 98/313 [00:23<00:48,  4.44it/s]

Loss for epoch 1, batch: 294.8307


Epoch 1:  32%|███▏      | 99/313 [00:24<00:47,  4.46it/s]

Loss for epoch 1, batch: 295.7199


Epoch 1:  32%|███▏      | 100/313 [00:24<00:48,  4.35it/s]

Loss for epoch 1, batch: 386.0481


Epoch 1:  32%|███▏      | 101/313 [00:24<00:47,  4.43it/s]

Loss for epoch 1, batch: 289.5653


Epoch 1:  33%|███▎      | 102/313 [00:24<00:47,  4.48it/s]

Loss for epoch 1, batch: 331.3228


Epoch 1:  33%|███▎      | 103/313 [00:25<00:46,  4.55it/s]

Loss for epoch 1, batch: 312.6013


Epoch 1:  33%|███▎      | 104/313 [00:25<00:47,  4.44it/s]

Loss for epoch 1, batch: 346.6962


Epoch 1:  34%|███▎      | 105/313 [00:25<00:48,  4.31it/s]

Loss for epoch 1, batch: 397.0311


Epoch 1:  34%|███▍      | 106/313 [00:25<00:48,  4.27it/s]

Loss for epoch 1, batch: 304.3552


Epoch 1:  34%|███▍      | 107/313 [00:26<00:49,  4.13it/s]

Loss for epoch 1, batch: 452.3322


Epoch 1:  35%|███▍      | 108/313 [00:26<00:47,  4.27it/s]

Loss for epoch 1, batch: 370.4235


Epoch 1:  35%|███▍      | 109/313 [00:26<00:45,  4.46it/s]

Loss for epoch 1, batch: 360.1095


Epoch 1:  35%|███▌      | 110/313 [00:26<00:44,  4.52it/s]

Loss for epoch 1, batch: 343.8559


Epoch 1:  35%|███▌      | 111/313 [00:26<00:45,  4.41it/s]

Loss for epoch 1, batch: 408.3548


Epoch 1:  36%|███▌      | 112/313 [00:27<00:46,  4.33it/s]

Loss for epoch 1, batch: 322.8294


Epoch 1:  36%|███▌      | 113/313 [00:27<00:46,  4.27it/s]

Loss for epoch 1, batch: 289.5149


Epoch 1:  36%|███▋      | 114/313 [00:27<00:45,  4.39it/s]

Loss for epoch 1, batch: 327.5196


Epoch 1:  37%|███▋      | 115/313 [00:27<00:45,  4.34it/s]

Loss for epoch 1, batch: 360.7697


Epoch 1:  37%|███▋      | 116/313 [00:28<00:46,  4.22it/s]

Loss for epoch 1, batch: 305.5026


Epoch 1:  37%|███▋      | 117/313 [00:28<00:46,  4.17it/s]

Loss for epoch 1, batch: 249.6579


Epoch 1:  38%|███▊      | 118/313 [00:28<00:45,  4.25it/s]

Loss for epoch 1, batch: 350.3159


Epoch 1:  38%|███▊      | 119/313 [00:28<00:45,  4.26it/s]

Loss for epoch 1, batch: 268.0307


Epoch 1:  38%|███▊      | 120/313 [00:29<00:46,  4.19it/s]

Loss for epoch 1, batch: 250.5069


Epoch 1:  39%|███▊      | 121/313 [00:29<00:47,  4.03it/s]

Loss for epoch 1, batch: 329.1532


Epoch 1:  39%|███▉      | 122/313 [00:29<00:44,  4.28it/s]

Loss for epoch 1, batch: 325.7759


Epoch 1:  39%|███▉      | 123/313 [00:29<00:44,  4.31it/s]

Loss for epoch 1, batch: 280.1406


Epoch 1:  40%|███▉      | 124/313 [00:29<00:44,  4.24it/s]

Loss for epoch 1, batch: 394.5261


Epoch 1:  40%|███▉      | 125/313 [00:30<00:44,  4.24it/s]

Loss for epoch 1, batch: 367.1279


Epoch 1:  40%|████      | 126/313 [00:30<00:42,  4.37it/s]

Loss for epoch 1, batch: 302.1962


Epoch 1:  41%|████      | 127/313 [00:30<00:41,  4.44it/s]

Loss for epoch 1, batch: 276.5529


Epoch 1:  41%|████      | 128/313 [00:30<00:42,  4.40it/s]

Loss for epoch 1, batch: 304.6687


Epoch 1:  41%|████      | 129/313 [00:31<00:41,  4.41it/s]

Loss for epoch 1, batch: 342.6182


Epoch 1:  42%|████▏     | 130/313 [00:31<00:40,  4.47it/s]

Loss for epoch 1, batch: 293.7610


Epoch 1:  42%|████▏     | 131/313 [00:31<00:40,  4.48it/s]

Loss for epoch 1, batch: 377.8810


Epoch 1:  42%|████▏     | 132/313 [00:31<00:40,  4.51it/s]

Loss for epoch 1, batch: 349.5352


Epoch 1:  42%|████▏     | 133/313 [00:31<00:40,  4.46it/s]

Loss for epoch 1, batch: 433.0593


Epoch 1:  43%|████▎     | 134/313 [00:32<00:39,  4.48it/s]

Loss for epoch 1, batch: 360.4385


Epoch 1:  43%|████▎     | 135/313 [00:32<00:40,  4.42it/s]

Loss for epoch 1, batch: 340.4433


Epoch 1:  43%|████▎     | 136/313 [00:32<00:40,  4.42it/s]

Loss for epoch 1, batch: 314.9773


Epoch 1:  44%|████▍     | 137/313 [00:32<00:39,  4.41it/s]

Loss for epoch 1, batch: 266.6794


Epoch 1:  44%|████▍     | 138/313 [00:33<00:40,  4.30it/s]

Loss for epoch 1, batch: 257.1228


Epoch 1:  44%|████▍     | 139/313 [00:33<00:41,  4.22it/s]

Loss for epoch 1, batch: 291.4588


Epoch 1:  45%|████▍     | 140/313 [00:33<00:41,  4.14it/s]

Loss for epoch 1, batch: 287.6133


Epoch 1:  45%|████▌     | 141/313 [00:33<00:40,  4.25it/s]

Loss for epoch 1, batch: 369.4890


Epoch 1:  45%|████▌     | 142/313 [00:34<00:40,  4.26it/s]

Loss for epoch 1, batch: 250.5443


Epoch 1:  46%|████▌     | 143/313 [00:34<00:40,  4.20it/s]

Loss for epoch 1, batch: 318.0269


Epoch 1:  46%|████▌     | 144/313 [00:34<00:39,  4.28it/s]

Loss for epoch 1, batch: 359.5394


Epoch 1:  46%|████▋     | 145/313 [00:34<00:39,  4.24it/s]

Loss for epoch 1, batch: 278.4197


Epoch 1:  47%|████▋     | 146/313 [00:35<00:39,  4.22it/s]

Loss for epoch 1, batch: 409.6145


Epoch 1:  47%|████▋     | 147/313 [00:35<00:39,  4.23it/s]

Loss for epoch 1, batch: 229.5742


Epoch 1:  47%|████▋     | 148/313 [00:35<00:38,  4.26it/s]

Loss for epoch 1, batch: 303.3301
Loss for epoch 1, batch: 181.6818


Epoch 1:  48%|████▊     | 150/313 [00:35<00:36,  4.48it/s]

Loss for epoch 1, batch: 300.0626


Epoch 1:  48%|████▊     | 151/313 [00:36<00:39,  4.13it/s]

Loss for epoch 1, batch: 284.3603


Epoch 1:  49%|████▊     | 152/313 [00:36<00:38,  4.18it/s]

Loss for epoch 1, batch: 349.4675


Epoch 1:  49%|████▉     | 153/313 [00:36<00:37,  4.27it/s]

Loss for epoch 1, batch: 237.0809


Epoch 1:  49%|████▉     | 154/313 [00:36<00:37,  4.28it/s]

Loss for epoch 1, batch: 194.1778


Epoch 1:  50%|████▉     | 155/313 [00:37<00:37,  4.26it/s]

Loss for epoch 1, batch: 285.3707


Epoch 1:  50%|████▉     | 156/313 [00:37<00:36,  4.31it/s]

Loss for epoch 1, batch: 249.5206


Epoch 1:  50%|█████     | 157/313 [00:37<00:36,  4.32it/s]

Loss for epoch 1, batch: 249.7773


Epoch 1:  50%|█████     | 158/313 [00:37<00:36,  4.30it/s]

Loss for epoch 1, batch: 272.3960


Epoch 1:  51%|█████     | 159/313 [00:38<00:35,  4.37it/s]

Loss for epoch 1, batch: 255.5233


Epoch 1:  51%|█████     | 160/313 [00:38<00:34,  4.48it/s]

Loss for epoch 1, batch: 291.2833


Epoch 1:  51%|█████▏    | 161/313 [00:38<00:33,  4.59it/s]

Loss for epoch 1, batch: 374.6580


Epoch 1:  52%|█████▏    | 162/313 [00:38<00:33,  4.53it/s]

Loss for epoch 1, batch: 287.0134


Epoch 1:  52%|█████▏    | 163/313 [00:38<00:33,  4.44it/s]

Loss for epoch 1, batch: 227.4325


Epoch 1:  52%|█████▏    | 164/313 [00:39<00:33,  4.44it/s]

Loss for epoch 1, batch: 321.2265


Epoch 1:  53%|█████▎    | 165/313 [00:39<00:34,  4.30it/s]

Loss for epoch 1, batch: 211.6369


Epoch 1:  53%|█████▎    | 166/313 [00:39<00:35,  4.11it/s]

Loss for epoch 1, batch: 249.5514


Epoch 1:  53%|█████▎    | 167/313 [00:39<00:35,  4.07it/s]

Loss for epoch 1, batch: 300.5354


Epoch 1:  54%|█████▎    | 168/313 [00:40<00:35,  4.11it/s]

Loss for epoch 1, batch: 392.0760


Epoch 1:  54%|█████▍    | 169/313 [00:40<00:34,  4.14it/s]

Loss for epoch 1, batch: 219.8867


Epoch 1:  54%|█████▍    | 170/313 [00:40<00:34,  4.16it/s]

Loss for epoch 1, batch: 340.9252


Epoch 1:  55%|█████▍    | 171/313 [00:40<00:35,  4.04it/s]

Loss for epoch 1, batch: 251.6275


Epoch 1:  55%|█████▍    | 172/313 [00:41<00:35,  3.94it/s]

Loss for epoch 1, batch: 299.3339


Epoch 1:  55%|█████▌    | 173/313 [00:41<00:35,  3.92it/s]

Loss for epoch 1, batch: 194.5157


Epoch 1:  56%|█████▌    | 174/313 [00:41<00:35,  3.94it/s]

Loss for epoch 1, batch: 270.9695


Epoch 1:  56%|█████▌    | 175/313 [00:41<00:35,  3.92it/s]

Loss for epoch 1, batch: 248.2954


Epoch 1:  56%|█████▌    | 176/313 [00:42<00:34,  3.96it/s]

Loss for epoch 1, batch: 213.4404


Epoch 1:  57%|█████▋    | 177/313 [00:42<00:33,  4.07it/s]

Loss for epoch 1, batch: 221.4819


Epoch 1:  57%|█████▋    | 178/313 [00:42<00:33,  4.09it/s]

Loss for epoch 1, batch: 250.6827


Epoch 1:  57%|█████▋    | 179/313 [00:42<00:32,  4.14it/s]

Loss for epoch 1, batch: 249.5582


Epoch 1:  58%|█████▊    | 180/313 [00:43<00:32,  4.13it/s]

Loss for epoch 1, batch: 216.4513


Epoch 1:  58%|█████▊    | 181/313 [00:43<00:31,  4.21it/s]

Loss for epoch 1, batch: 238.9790


Epoch 1:  58%|█████▊    | 182/313 [00:43<00:30,  4.30it/s]

Loss for epoch 1, batch: 230.4812


Epoch 1:  58%|█████▊    | 183/313 [00:43<00:29,  4.34it/s]

Loss for epoch 1, batch: 229.6236


Epoch 1:  59%|█████▉    | 184/313 [00:44<00:29,  4.33it/s]

Loss for epoch 1, batch: 173.5539


Epoch 1:  59%|█████▉    | 184/313 [00:44<00:31,  4.15it/s]


KeyboardInterrupt: 

In [26]:
# Load the checkpoint without the logit layers
checkpoint_path = 'student_model_custom_contrastive.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)

# Create a fresh student model with full layers
student_full = instantiate(student_dict)

# Load the trained weights
try:
    student_full.load_state_dict(checkpoint['student_state_dict'], strict=False)
    print("Loaded trained weights (non-strict mode)")
except Exception as e:
    print(f"Error loading weights: {e}")
    model_dict = student_full.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint['student_state_dict'].items() if k in model_dict}
    model_dict.update(pretrained_dict)
    student_full.load_state_dict(model_dict)
    print("Loaded compatible weights manually")

student_full = student_full.to(device)
student_full.eval()

torch.save({
    'model_state_dict': student_full.state_dict(),
    'model_config': student_dict,
    'training_info': {
        'epochs_trained': checkpoint.get('epoch', 'unknown'),
        'architecture': 'student_with_contrastive_learning'
    }
}, 'student_model_full_with_predictions.pth')

datamodule.setup(stage="validate", train_fraction=0.1, val_fraction=0.1, test_fraction=0.1)
val_dataloader = datamodule.val_dataloader()

print("Evaluating trajectory prediction model...")

# First, let's inspect the prediction structure
sample_batch = next(iter(val_dataloader))
for k in sample_batch.keys():
    if torch.is_tensor(sample_batch[k]):
        sample_batch[k] = sample_batch[k].to(device)

with torch.no_grad():
    sample_preds = student_full(sample_batch)

print("Inspection of predictions and ground truth:")
print(f"Prediction type: {type(sample_preds)}")
if isinstance(sample_preds, dict):
    for key, value in sample_preds.items():
        if torch.is_tensor(value):
            print(f"  {key}: {value.shape}")

print(f"Ground truth 'y' shape: {sample_batch['y'].shape}")

# Multi-modal trajectory prediction evaluation
mse_errors = []
mae_errors = []
ade_errors = []
fde_errors = []

def evaluate_multimodal_prediction(pred_modes, gt_traj_ego):
    """
    Evaluate multi-modal prediction by selecting the best mode
    pred_modes: (batch, num_modes, time, 2)
    gt_traj_ego: (batch, time, 2)
    """
    batch_size, num_modes, time_steps, _ = pred_modes.shape
    
    # Compute ADE for each mode
    # Expand gt_traj_ego to match pred_modes shape
    gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, num_modes, -1, -1)  # (batch, num_modes, time, 2)
    
    # Compute displacement errors for all modes
    displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, num_modes, time)
    
    # ADE for each mode: average over time
    ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, num_modes)
    
    # FDE for each mode: final time step
    fde_per_mode = displacement_errors[:, :, -1]  # (batch, num_modes)
    
    # Select best mode (lowest ADE)
    best_mode_idx = ade_per_mode.argmin(dim=1)  # (batch,)
    
    # Extract best ADE and FDE
    batch_indices = torch.arange(batch_size)
    best_ade = ade_per_mode[batch_indices, best_mode_idx]  # (batch,)
    best_fde = fde_per_mode[batch_indices, best_mode_idx]  # (batch,)
    
    # Get best trajectory for MSE/MAE computation
    best_traj = pred_modes[batch_indices, best_mode_idx]  # (batch, time, 2)
    
    return best_traj, best_ade, best_fde

for batch in val_dataloader:
    for k in batch.keys():
        if torch.is_tensor(batch[k]):
            batch[k] = batch[k].to(device)

    with torch.no_grad():
        preds = student_full(batch)
        
        # Extract multi-modal predictions
        if isinstance(preds, dict) and 'y_hat' in preds:
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
        else:
            print("Could not find 'y_hat' in predictions")
            break
        
        # Extract ego agent ground truth (first agent, index 0)
        gt_traj_all = batch['y']  # (batch, 42, 60, 2)
        gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2) - ego agent only
        
        # Evaluate multi-modal prediction
        best_traj, batch_ade, batch_fde = evaluate_multimodal_prediction(pred_modes, gt_traj_ego)
        
        # Compute MSE and MAE using best trajectory
        mse = torch.nn.functional.mse_loss(best_traj, gt_traj_ego)
        mae = torch.nn.functional.l1_loss(best_traj, gt_traj_ego)
        
        mse_errors.append(mse.item())
        mae_errors.append(mae.item())
        ade_errors.extend(batch_ade.cpu().numpy())
        fde_errors.extend(batch_fde.cpu().numpy())

print(f"Validation MSE (best mode): {np.mean(mse_errors):.4f} ± {np.std(mse_errors):.4f}")
print(f"Validation MAE (best mode): {np.mean(mae_errors):.4f} ± {np.std(mae_errors):.4f}")
print(f"Validation ADE (best mode): {np.mean(ade_errors):.4f} ± {np.std(ade_errors):.4f}")
print(f"Validation FDE (best mode): {np.mean(fde_errors):.4f} ± {np.std(fde_errors):.4f}")

# Additional analysis: evaluate using prediction confidence (pi weights)
if 'pi' in sample_preds:
    print("\nEvaluating using prediction confidence weights...")
    conf_mse_errors = []
    conf_mae_errors = []
    conf_ade_errors = []
    conf_fde_errors = []
    
    for batch in val_dataloader:
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)

        with torch.no_grad():
            preds = student_full(batch)
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
            pi_weights = preds['pi']  # (batch, 9) - confidence scores
            
            gt_traj_all = batch['y']
            gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2)
            
            # Select mode with highest confidence
            best_conf_idx = pi_weights.argmax(dim=1)  # (batch,)
            batch_indices = torch.arange(pred_modes.shape[0])
            conf_best_traj = pred_modes[batch_indices, best_conf_idx]  # (batch, 60, 2)
            
            # Compute metrics
            conf_mse = torch.nn.functional.mse_loss(conf_best_traj, gt_traj_ego)
            conf_mae = torch.nn.functional.l1_loss(conf_best_traj, gt_traj_ego)
            
            # ADE/FDE for confidence-based selection
            displacement_errors = torch.norm(conf_best_traj - gt_traj_ego, dim=-1)
            conf_ade = displacement_errors.mean(dim=-1)
            conf_fde = displacement_errors[:, -1]
            
            conf_mse_errors.append(conf_mse.item())
            conf_mae_errors.append(conf_mae.item())
            conf_ade_errors.extend(conf_ade.cpu().numpy())
            conf_fde_errors.extend(conf_fde.cpu().numpy())
    
    print(f"Confidence-based MSE: {np.mean(conf_mse_errors):.4f} ± {np.std(conf_mse_errors):.4f}")
    print(f"Confidence-based MAE: {np.mean(conf_mae_errors):.4f} ± {np.std(conf_mae_errors):.4f}")
    print(f"Confidence-based ADE: {np.mean(conf_ade_errors):.4f} ± {np.std(conf_ade_errors):.4f}")
    print(f"Confidence-based FDE: {np.mean(conf_fde_errors):.4f} ± {np.std(conf_fde_errors):.4f}")

Using decoder type: mlp
RUN EMP-M
Loaded trained weights (non-strict mode)
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 10.0% of the training and validation datasets.
Evaluating trajectory prediction model...
Inspection of predictions and ground truth:
Prediction type: <class 'dict'>
  y_hat: torch.Size([64, 9, 60, 2])
  pi: torch.Size([64, 9])
  y_hat_others: torch.Size([64, 50, 60, 2])
  y_hat_eps: torch.Size([64, 9, 2])
  x_agent: torch.Size([64, 64])
Ground truth 'y' shape: torch.Size([64, 51, 60, 2])
Validation MSE (best mode): 388.4232 ± 73.3702
Validation MAE (best mode): 10.5574 ± 1.1664
Validation ADE (best mode): 19.5709 ± 13.8144
Validation FDE (best mode): 38.2336 ± 26.4699

Evaluating using prediction confidence weights...
Confidence-based MSE: 388.4663 ± 73.3734
Confidence-based MAE: 10.5602 ± 1.1661
Conf

In [27]:
# Load the checkpoint without the logit layers
checkpoint_path = 'student_model_custom_contrastive.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)

# # Create a fresh student model with full layers
# student_full = instantiate(student_dict)

# # Load the trained weights
# try:
#     student_full.load_state_dict(checkpoint['student_state_dict'], strict=False)
#     print("Loaded trained weights (non-strict mode)")
# except Exception as e:
#     print(f"Error loading weights: {e}")
#     model_dict = student_full.state_dict()
#     pretrained_dict = {k: v for k, v in checkpoint['student_state_dict'].items() if k in model_dict}
#     model_dict.update(pretrained_dict)
#     student_full.load_state_dict(model_dict)
#     print("Loaded compatible weights manually")

# student_full = student_full.to(device)
# student_full.eval()

# torch.save({
#     'model_state_dict': student_full.state_dict(),
#     'model_config': student_dict,
#     'training_info': {
#         'epochs_trained': checkpoint.get('epoch', 'unknown'),
#         'architecture': 'student_with_contrastive_learning'
#     }
# }, 'student_model_full_with_predictions.pth')

if teacher_dict["decoder"] == "mlp":
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empm.ckpt"
else:
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empd.ckpt"
assert os.path.exists(checkpoint), f"Checkpoint {checkpoint} does not exist"

model_path = teacher_dict["_target_"]
module = import_module(model_path[: model_path.rfind(".")])
Model: pl.LightningModule = getattr(module, model_path[model_path.rfind(".") + 1 :])
teacher = Model.load_from_checkpoint(checkpoint, **teacher_dict)
teacher.to(device)
teacher.eval()

datamodule.setup(stage="validate", train_fraction=0.1, val_fraction=0.1, test_fraction=0.1)
val_dataloader = datamodule.val_dataloader()

print("Evaluating trajectory prediction model...")

# First, let's inspect the prediction structure
sample_batch = next(iter(val_dataloader))
for k in sample_batch.keys():
    if torch.is_tensor(sample_batch[k]):
        sample_batch[k] = sample_batch[k].to(device)

with torch.no_grad():
    sample_preds = teacher(sample_batch)

print("Inspection of predictions and ground truth:")
print(f"Prediction type: {type(sample_preds)}")
if isinstance(sample_preds, dict):
    for key, value in sample_preds.items():
        if torch.is_tensor(value):
            print(f"  {key}: {value.shape}")

print(f"Ground truth 'y' shape: {sample_batch['y'].shape}")

# Multi-modal trajectory prediction evaluation
mse_errors = []
mae_errors = []
ade_errors = []
fde_errors = []

def evaluate_multimodal_prediction(pred_modes, gt_traj_ego):
    """
    Evaluate multi-modal prediction by selecting the best mode
    pred_modes: (batch, num_modes, time, 2)
    gt_traj_ego: (batch, time, 2)
    """
    batch_size, num_modes, time_steps, _ = pred_modes.shape
    
    # Compute ADE for each mode
    # Expand gt_traj_ego to match pred_modes shape
    gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, num_modes, -1, -1)  # (batch, num_modes, time, 2)
    
    # Compute displacement errors for all modes
    displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, num_modes, time)
    
    # ADE for each mode: average over time
    ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, num_modes)
    
    # FDE for each mode: final time step
    fde_per_mode = displacement_errors[:, :, -1]  # (batch, num_modes)
    
    # Select best mode (lowest ADE)
    best_mode_idx = ade_per_mode.argmin(dim=1)  # (batch,)
    
    # Extract best ADE and FDE
    batch_indices = torch.arange(batch_size)
    best_ade = ade_per_mode[batch_indices, best_mode_idx]  # (batch,)
    best_fde = fde_per_mode[batch_indices, best_mode_idx]  # (batch,)
    
    # Get best trajectory for MSE/MAE computation
    best_traj = pred_modes[batch_indices, best_mode_idx]  # (batch, time, 2)
    
    return best_traj, best_ade, best_fde

for batch in val_dataloader:
    for k in batch.keys():
        if torch.is_tensor(batch[k]):
            batch[k] = batch[k].to(device)

    with torch.no_grad():
        preds = teacher(batch)
        
        # Extract multi-modal predictions
        if isinstance(preds, dict) and 'y_hat' in preds:
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
        else:
            print("Could not find 'y_hat' in predictions")
            break
        
        # Extract ego agent ground truth (first agent, index 0)
        gt_traj_all = batch['y']  # (batch, 42, 60, 2)
        gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2) - ego agent only
        
        # Evaluate multi-modal prediction
        best_traj, batch_ade, batch_fde = evaluate_multimodal_prediction(pred_modes, gt_traj_ego)
        
        # Compute MSE and MAE using best trajectory
        mse = torch.nn.functional.mse_loss(best_traj, gt_traj_ego)
        mae = torch.nn.functional.l1_loss(best_traj, gt_traj_ego)
        
        mse_errors.append(mse.item())
        mae_errors.append(mae.item())
        ade_errors.extend(batch_ade.cpu().numpy())
        fde_errors.extend(batch_fde.cpu().numpy())

print(f"Validation MSE (best mode): {np.mean(mse_errors):.4f} ± {np.std(mse_errors):.4f}")
print(f"Validation MAE (best mode): {np.mean(mae_errors):.4f} ± {np.std(mae_errors):.4f}")
print(f"Validation ADE (best mode): {np.mean(ade_errors):.4f} ± {np.std(ade_errors):.4f}")
print(f"Validation FDE (best mode): {np.mean(fde_errors):.4f} ± {np.std(fde_errors):.4f}")

# Additional analysis: evaluate using prediction confidence (pi weights)
if 'pi' in sample_preds:
    print("\nEvaluating using prediction confidence weights...")
    conf_mse_errors = []
    conf_mae_errors = []
    conf_ade_errors = []
    conf_fde_errors = []
    
    for batch in val_dataloader:
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)

        with torch.no_grad():
            preds = teacher(batch)
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
            pi_weights = preds['pi']  # (batch, 9) - confidence scores
            
            gt_traj_all = batch['y']
            gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2)
            
            # Select mode with highest confidence
            best_conf_idx = pi_weights.argmax(dim=1)  # (batch,)
            batch_indices = torch.arange(pred_modes.shape[0])
            conf_best_traj = pred_modes[batch_indices, best_conf_idx]  # (batch, 60, 2)
            
            # Compute metrics
            conf_mse = torch.nn.functional.mse_loss(conf_best_traj, gt_traj_ego)
            conf_mae = torch.nn.functional.l1_loss(conf_best_traj, gt_traj_ego)
            
            # ADE/FDE for confidence-based selection
            displacement_errors = torch.norm(conf_best_traj - gt_traj_ego, dim=-1)
            conf_ade = displacement_errors.mean(dim=-1)
            conf_fde = displacement_errors[:, -1]
            
            conf_mse_errors.append(conf_mse.item())
            conf_mae_errors.append(conf_mae.item())
            conf_ade_errors.extend(conf_ade.cpu().numpy())
            conf_fde_errors.extend(conf_fde.cpu().numpy())
    
    print(f"Confidence-based MSE: {np.mean(conf_mse_errors):.4f} ± {np.std(conf_mse_errors):.4f}")
    print(f"Confidence-based MAE: {np.mean(conf_mae_errors):.4f} ± {np.std(conf_mae_errors):.4f}")
    print(f"Confidence-based ADE: {np.mean(conf_ade_errors):.4f} ± {np.std(conf_ade_errors):.4f}")
    print(f"Confidence-based FDE: {np.mean(conf_fde_errors):.4f} ± {np.std(conf_fde_errors):.4f}")

Using decoder type: mlp
RUN EMP-M
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 10.0% of the training and validation datasets.
Evaluating trajectory prediction model...
Inspection of predictions and ground truth:
Prediction type: <class 'dict'>
  y_hat: torch.Size([64, 6, 60, 2])
  pi: torch.Size([64, 6])
  y_hat_others: torch.Size([64, 44, 60, 2])
  y_hat_eps: torch.Size([64, 6, 2])
  x_agent: torch.Size([64, 128])
Ground truth 'y' shape: torch.Size([64, 45, 60, 2])
Validation MSE (best mode): 0.8327 ± 0.6760
Validation MAE (best mode): 0.4323 ± 0.0488
Validation ADE (best mode): 0.7080 ± 0.7275
Validation FDE (best mode): 1.6843 ± 2.0672

Evaluating using prediction confidence weights...
Confidence-based MSE: 5.3941 ± 2.0615
Confidence-based MAE: 0.9974 ± 0.1312
Confidence-based ADE: 1.7182 ± 1.8246
Confidence-based F