In [12]:
import os
import random
import json
from datetime import datetime
from pathlib import Path
import numpy as np
from hydra.utils import instantiate, to_absolute_path
import torch
from importlib import import_module
import pytorch_lightning as pl

teacher_dict = {
    "_target_": "src.model.trainer_forecast.Trainer",
    "historical_steps": 50,
    "future_steps": 60,
    "dim": 128,
    "encoder_depth": 4,
    "num_heads": 8,
    "mlp_ratio": 4.0,
    "qkv_bias": False,
    "drop_path": 0.2,
    "attention_type": "standard",  # "standard", "linear", "performer"
    "decoder": "mlp",  # Decoder type: "mlp" or "detr"
    "decoder_embed_dim": None,
    "decoder_num_modes": None,
    "decoder_hidden_dim": None,
    "pretrained_weights": None,
    "lr": 0.001,
    "weight_decay": 1e-4,
    "epochs": 20,
    "warmup_epochs": 10,
}

if teacher_dict["decoder"] == "mlp":
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empm.ckpt"
else:
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empd.ckpt"
assert os.path.exists(checkpoint), f"Checkpoint {checkpoint} does not exist"

model_path = teacher_dict["_target_"]
module = import_module(model_path[: model_path.rfind(".")])
Model: pl.LightningModule = getattr(module, model_path[model_path.rfind(".") + 1 :])
teacher = Model.load_from_checkpoint(checkpoint, **teacher_dict)

# Cut the layer that gives the logits and keep the "penultimate" layer
teacher.net.dense_predictor.pop(-1)  
teacher.net.decoder.loc.pop(-1)
teacher.net.decoder.loc.pop(-1)
teacher.net.decoder.pi.pop(-1)
teacher.net.decoder.pi.pop(-1)


Using decoder type: mlp
RUN EMP-M


ReLU()

In [13]:
# Print the model layers
print("Teacher model instantiated with the following configuration:")
for name, layer in teacher.named_children():
    print(f"{name}: {layer}")

Teacher model instantiated with the following configuration:
net: EMP(
  (h_proj): Linear(in_features=5, out_features=128, bias=True)
  (h_embed): ModuleList(
    (0): Block(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (drop_path1): Identity()
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=128, out_features=512, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=512, out_features=128, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (drop_path2): Identity()
    )
    (1): Block(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=

In [14]:
student_dict = {
    "_target_": "src.model.trainer_forecast.Trainer",
    "dim": 64,
    "historical_steps": 50,
    "future_steps": 60,
    "encoder_depth": 3,
    "num_heads": 4,
    "mlp_ratio": 2.0,
    "qkv_bias": False,
    "drop_path": 0.2,
    "attention_type": "performer",  # "standard", "linear", "performer"
    "decoder": "mlp",
    "decoder_embed_dim": 64,
    "decoder_num_modes": 9,
    "decoder_hidden_dim": 256,
    "pretrained_weights": None,
    "lr": 0.001,
    "weight_decay": 1e-4,
    "epochs": 60,
    "warmup_epochs": 10,
}

student = instantiate(student_dict)
# Remove the last layer from the encoder or decoder
student.net.dense_predictor = student.net.dense_predictor[:-1]  
student.net.decoder.loc.pop(-1)
student.net.decoder.loc.pop(-1)
student.net.decoder.pi.pop(-1)
student.net.decoder.pi.pop(-1)


# Print the model layers
print("Student model instantiated with the following configuration:")
for name, layer in student.named_children():
    print(f"{name}: {layer}")

Using decoder type: mlp
RUN EMP-M
Student model instantiated with the following configuration:
net: EMP(
  (h_proj): Linear(in_features=5, out_features=64, bias=True)
  (h_embed): ModuleList(
    (0): Block(
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (attn): PerformerAttention(
        (q_proj): Linear(in_features=64, out_features=64, bias=False)
        (k_proj): Linear(in_features=64, out_features=64, bias=False)
        (v_proj): Linear(in_features=64, out_features=64, bias=False)
        (out_proj): Linear(in_features=64, out_features=64, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (drop_path1): Identity()
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=64, out_features=128, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=128, out_features=64, bias=True)
        (drop

In [15]:
# Create datamodule
datamodule_dict = {
    "_target_": "src.datamodule.av2_datamodule.Av2DataModule",
    "data_root": "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data",
    "data_folder": "emp",
    "train_batch_size": 64,
    "val_batch_size": 64,
    "test_batch_size": 64,
    "shuffle": True,
    "num_workers": 18,
    "pin_memory": True,
}

datamodule = instantiate(datamodule_dict)

FRACTION = 0.5  

datamodule.setup(stage="fit", train_fraction=FRACTION, val_fraction=FRACTION, test_fraction=FRACTION)

train_dataloader = datamodule.train_dataloader()
val_dataloader = datamodule.val_dataloader()

print(f"Number of training samples: {len(train_dataloader.dataset)} out of {len(datamodule.full_train_dataset)}")
print(f"Number of validation samples: {len(val_dataloader.dataset)} out of {len(datamodule.full_val_dataset)}")

data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 50.0% of the training and validation datasets.
Number of training samples: 99954 out of 199908
Number of validation samples: 12494 out of 24988


In [16]:
# Check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [17]:

# from tqdm import tqdm
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# teacher = teacher.to(device)
# teacher.eval()
# features_list_train = []
# features_list_val = []
# points_train = []
# points_val = []
# with torch.no_grad():
#     for data in tqdm(train_dataloader):
#         for k in data.keys():
#             if torch.is_tensor(data[k]): data[k] = data[k].to(device)
        
#         points_train.append(data)
        
#         features = teacher.net.get_features(data)
#         features_list_train.append(features)
        
        
#     for data in tqdm(val_dataloader):
#         for k in data.keys():
#             if torch.is_tensor(data[k]): data[k] = data[k].to(device)
        
#         features = teacher.net.get_features(data)
#         features_list_val.append(features)
#         points_val.append(data)
        
# # Save the features to a file
# import pickle as pkl
# output_file_train = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_train.pkl"
# with open(output_file_train, "wb") as f:
#     pkl.dump(features_list_train, f)
    
# output_file_val = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_val.pkl"
# with open(output_file_val, "wb") as f:
#     pkl.dump(features_list_val, f)


# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_train.pkl"
# with open(output_file_data, "wb") as f:
#     pkl.dump(points_train, f)

# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_val.pkl"
# with open(output_file_data, "wb") as f:
#     pkl.dump(points_val, f)


In [26]:
import torch
import torch.nn.functional as F

def h_score(t_feat, s_feat, tau=0.1, N_over_M=1.0):
    """
    Compute h(T, S) as defined in Eq (19) in the paper.
    """
    # Normalize features 
    t_feat = F.normalize(t_feat, dim=1)
    s_feat = F.normalize(s_feat, dim=1)

    # Similarity scores
    sim = torch.sum(t_feat * s_feat, dim=1) / tau  # Exponent of e for numerator
    numerator = torch.exp(sim)
    denominator = numerator + N_over_M
    h = numerator / denominator  # shape: (B,)
    return h


# Test function
def test_h_score():
    # Create dummy features
    t_feat = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    s_feat = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

    # Compute h-score
    h = h_score(t_feat, s_feat)

    # Print result
    print("h-score:", h)
    # Expected output: h-score: tensor([1.0000, 1.0000, 1.0000])
    assert torch.allclose(h, torch.tensor([1.0, 1.0, 1.0]), atol=1e-4), "h-score test failed!"
    
test_h_score()

h-score: tensor([1.0000, 1.0000, 1.0000])


In [32]:
def critic_loss(t_pos, s_pos, t_neg, s_neg, tau=0.1, N=1.0, M = 1.0):
    """
    Compute the full critic loss:
    - t_pos, s_pos: matched teacher/student features (C=1)
    - t_neg, s_neg: mismatched features (C=0)
    """
    N_over_M = N / M if M != 0 else 1.0  # Avoid division by zero
    
    # Positive pair scores
    h_pos = h_score(t_pos, s_pos, tau, N_over_M)
    loss_pos = torch.median(torch.log(h_pos + 1e-8))  # Explicit Monte Carlo estimate

    # Negative pair scores
    h_neg = h_score(t_neg, s_neg, tau, N_over_M)
    loss_neg = torch.median(torch.log(1 - h_neg + 1e-8))  # Explicit Monte Carlo estimate

    # Combine as in Eq (18)
    loss = -(loss_pos + N * loss_neg)
    return loss

# Test function
def test_critic_loss():
    # Here I tried to find two tensors that would yield a loss of almost 0
    t_pos = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
    s_pos = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])

    t_neg = torch.tensor([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])
    s_neg = torch.tensor([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0]])

    # Compute critic loss
    loss = critic_loss(t_pos, s_pos, t_neg, s_neg, tau=0.1)

    # Print result
    print("Critic loss:", loss)
    # Expected output: Critic loss: tensor(1.)
    assert torch.isclose(loss, torch.tensor(0.), atol=1e-4), "Critic loss test failed!"
    
test_critic_loss()

Critic loss: tensor(    0.0001)


In [9]:
# # Load features from the saved file
# from tqdm import tqdm
# import pickle as pkl


# output_file_train = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_train.pkl"
# with open(output_file_train, "rb") as f:
#     features_list_train = pkl.load(f)
    
# output_file_val = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/features_val.pkl"
# with open(output_file_val, "rb") as f:
#     features_list_val = pkl.load(f)
    
# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_train.pkl"
# with open(output_file_data, "rb") as f:
#     points_train = pkl.load(f)
    
# output_file_data = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/data_points_val.pkl"
# with open(output_file_data, "rb") as f:
#     points_val = pkl.load(f)
    
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:

# student.train()
# student = student.to(device)
# teacher.to(device)
# teacher.eval()
# batch_size = 64

# # Already organized in batches


# # If the embedding dim of the student is different from the teacher, add a projection layer
# # Get a batch of data
# batch = points_train[0]
# for k in batch.keys():
#     if torch.is_tensor(batch[k]): 
#         batch[k] = batch[k].to(device)

# with torch.no_grad():
#     dim_teacher = teacher.net.get_features(batch).shape[1]
#     dim_student = student.net.get_features(batch).shape[1]

# print(f"Teacher feature dimension: {dim_teacher}, Student feature dimension: {dim_student}")
# projection_layer = None
# if dim_teacher != dim_student:
#     projection_layer = torch.nn.Linear(dim_student, dim_teacher)
#     projection_layer.to(device)
#     print(f"Projection layer added: {projection_layer}")
    
#     # Freeze the projection layer
#     for param in projection_layer.parameters():
#         param.requires_grad = False
    
# optimizer, scheduler = student.configure_optimizers()
# optimizer = optimizer[0] if isinstance(optimizer, list) else optimizer
# scheduler = scheduler[0] if isinstance(scheduler, list) else scheduler
    
# # M = nr of positive pairs, N = nr of negative pairs
# M = len(features_list_train[0])  # Number of positive pairs in a batch
# N = len(features_list_train[0])  # Number of negative pairs in a batch
# for i in range(10):
#     losses = []
#     print(f"Epoch {i+1}/{10}")
#     for idx in tqdm(range(len(points_train))):
#         batch = points_train[idx]
#         for k in batch.keys():
#             if torch.is_tensor(batch[k]): batch[k] = batch[k].to(device)
#         # Eliminate "y" key if it exists
#         if 'y' in batch:
#             del batch['y']
#         feature_contrast = student.net.get_features(batch)

#         if projection_layer is not None:
#             feature_contrast = projection_layer(feature_contrast)


#         batch_feats = features_list_train[idx]

#         # Get a random feature from the batch
#         random_idx = random.randint(0, len(batch_feats) - 1)
#         t_pos = batch_feats[random_idx].unsqueeze(0)  # [1, feature_dim]
#         s_pos = feature_contrast[random_idx].unsqueeze(0)  # [1, feature_dim]
        
#         # Get a random negative feature from the batch
#         neg_idx = random.randint(0, len(batch_feats) - 1)
#         while neg_idx == random_idx:  # Ensure it's different from the positive index
#             neg_idx = random.randint(0, len(batch_feats) - 1)
#         t_neg = batch_feats[neg_idx].unsqueeze(0)  # [1, feature_dim]
#         s_neg = feature_contrast[neg_idx].unsqueeze(0)  # [1, feature_dim]
#         # Compute the critic loss
#         loss = critic_loss(t_pos, s_pos, t_neg, s_neg, tau=0.1, N=N, M=M)
#         # print(f"Critic loss: {loss.item()}")
        
#         # Backpropagate the loss
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         losses.append(loss.item())
#     losses = np.mean(losses)
#     print(f"Epoch {i+1} losses: {losses}")
#     if scheduler is not None:
#         if isinstance(scheduler, list):
#             for sch in scheduler:
#                 sch.step()
#         else:
#             scheduler.step()        

In [33]:
teacher.to(device)
student.to(device)

batch = next(iter(train_dataloader))
for k in batch.keys():
    if torch.is_tensor(batch[k]): batch[k] = batch[k].to(device)
# Find the expected output size
with torch.no_grad():
    dim_teacher = teacher.net.get_features(batch).shape[1]
    dim_student = student.net.get_features(batch).shape[1]

print(f"Teacher feature dimension: {dim_teacher}, Student feature dimension: {dim_student}")

# Create learnable projection layer
projection_layer = None
if dim_teacher != dim_student:
    projection_layer = torch.nn.Linear(dim_student, dim_teacher).to(device)
    print(f"Learnable projection layer added: {projection_layer}")
    
    # # Freeze the projection layer
    # for param in projection_layer.parameters():
    #     param.requires_grad = False



Teacher feature dimension: 512, Student feature dimension: 384
Learnable projection layer added: Linear(in_features=384, out_features=512, bias=True)


In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt

teacher.to(device)
student.to(device)
student.train()
teacher.eval()

memory_bank_size = 2048  # Reduced memory bank size for better management
memory_buffer = []

# Create optimizer
if projection_layer is not None:
    all_params = list(student.parameters()) + list(projection_layer.parameters())
    optimizer = torch.optim.AdamW(all_params, lr=1e-4, weight_decay=1e-4)
else:
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

def contrastive_loss_batch_online(student_feats, pos_features, neg_features=None, tau=0.1, NEG_RATIO=1.0):
    """
    Online contrastive loss function using your custom critic loss
    """
    batch_size = student_feats.shape[0]
    
    if batch_size == 1:
        return F.mse_loss(student_feats, pos_features)
    
    # For negative pairs
    if neg_features is not None and neg_features.shape[0] > 0:
        # Sample negative features randomly based on ratio
        num_negatives = min(int(batch_size * NEG_RATIO), neg_features.shape[0])
        if num_negatives > 0:
            neg_indices = torch.randperm(neg_features.shape[0])[:num_negatives]
            
            t_pos = pos_features  # [B, feature_dim]
            s_pos = student_feats  # [B, feature_dim]
            
            t_neg = neg_features[neg_indices]  # [neg_samples, feature_dim]
            
            # Create corresponding student negatives by cycling through student features
            s_neg_indices = torch.arange(num_negatives) % batch_size
            s_neg = student_feats[s_neg_indices]  # [neg_samples, feature_dim]
            
            M = t_pos.shape[0]  # Number of positive pairs
            N = t_neg.shape[0]  # Number of negative pairs
            
            loss = critic_loss(t_pos, s_pos, t_neg, s_neg, tau=tau, N=N, M=M)
        else:
            loss = F.mse_loss(student_feats, pos_features)
    else:
        # Fallback to MSE if no negatives available
        loss = F.mse_loss(student_feats, pos_features)
    
    return loss

epochs = 3
print("Starting online contrastive training with custom loss and memory bank...")

NEG_RATIO = 0.4

losses = []
LOSS_BUFFER_SIZE = 3000

for epoch in range(epochs):
    epoch_loss = 0
    num_batches = 0
    
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    # The training loop iterates directly over the train_dataloader
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        for k in batch.keys():
            if torch.is_tensor(batch[k]): 
                batch[k] = batch[k].to(device)
        
        # Eliminate "y" key if it exists
        if 'y' in batch:
            del batch['y']
        
        # Get student features
        student_features = student.net.get_features(batch)
        
        # Apply projection
        if projection_layer is not None:
            student_features = projection_layer(student_features)
        
        # Generate positive features on-the-fly from the teacher
        with torch.no_grad():
            pos_features = teacher.net.get_features(batch)
        
            # Update memory buffer with new features (FIFO queue)
            pos_features_cpu = pos_features.cpu()
            
            # Concatenate new features to memory buffer
            for feature in pos_features_cpu:
                memory_buffer.append(feature)
            
            # Limit memory buffer size
            if len(memory_buffer) > memory_bank_size:
                memory_buffer.pop(0)
        
        # Prepare negative features for loss computation
        neg_features_for_loss = None
        if len(memory_buffer) > 0:
            # Convert memory buffer back to GPU tensors for loss computation
            neg_features_for_loss = torch.stack(memory_buffer).to(device)
            
        loss = contrastive_loss_batch_online(
            student_features, 
            pos_features, 
            neg_features=neg_features_for_loss, 
            tau=0.05, 
            NEG_RATIO=NEG_RATIO
        )

        losses.append(loss.item())
        if len(losses) > LOSS_BUFFER_SIZE:
            losses.pop(0)

        optimizer.zero_grad()
        loss.backward()        
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1
        
        if num_batches % 10 == 0:
            plt.figure(figsize=(10, 5))
            if len(losses) > 50:
                window_size = min(50, len(losses) // 10)
                smoothed_losses = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
                smooth_x = np.arange(window_size-1, len(losses))
                plt.plot(smooth_x, smoothed_losses, color='blue', linewidth=2, label=f'Moving Average (window={window_size})')
            else:
                plt.plot(np.arange(len(losses)), losses, color='lightblue', linewidth=1, label='Raw Loss')

            
            plt.xlabel('Batch')
            plt.ylabel('Loss Value')
            plt.title(f'Epoch {epoch + 1} Loss Progress (Moving Average)')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f"loss_distillation_smoothed.png", dpi=150, bbox_inches='tight')
            plt.close()

        if num_batches % 100 == 0:
            print(f"Batch {num_batches}: Loss = {loss.item():.4f}")
    
    scheduler.step()
    
    if num_batches > 0:
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch + 1} completed. No batches were processed.")

print("Online contrastive training completed!")

Starting online contrastive training with custom loss and memory bank...

Epoch 1/3


Epoch 1:   6%|▋         | 100/1562 [00:37<08:45,  2.78it/s]

Batch 100: Loss = 4.2903


Epoch 1:  13%|█▎        | 200/1562 [01:14<09:59,  2.27it/s]

Batch 200: Loss = 4.1856


Epoch 1:  19%|█▉        | 300/1562 [01:53<08:06,  2.59it/s]

Batch 300: Loss = 3.8885


Epoch 1:  26%|██▌       | 400/1562 [02:30<08:02,  2.41it/s]

Batch 400: Loss = 3.9247


Epoch 1:  32%|███▏      | 500/1562 [03:09<07:11,  2.46it/s]

Batch 500: Loss = 3.5616


Epoch 1:  38%|███▊      | 600/1562 [03:49<06:50,  2.34it/s]

Batch 600: Loss = 3.9076


Epoch 1:  45%|████▍     | 700/1562 [04:28<05:40,  2.54it/s]

Batch 700: Loss = 3.4219


Epoch 1:  51%|█████     | 800/1562 [05:08<05:20,  2.38it/s]

Batch 800: Loss = 3.3985


Epoch 1:  58%|█████▊    | 900/1562 [05:49<04:36,  2.39it/s]

Batch 900: Loss = 3.3754


Epoch 1:  64%|██████▍   | 1000/1562 [06:30<03:56,  2.38it/s]

Batch 1000: Loss = 2.9685


Epoch 1:  70%|███████   | 1100/1562 [07:12<03:18,  2.33it/s]

Batch 1100: Loss = 2.9214


Epoch 1:  77%|███████▋  | 1200/1562 [07:56<02:47,  2.16it/s]

Batch 1200: Loss = 3.2257


Epoch 1:  83%|████████▎ | 1300/1562 [08:40<02:02,  2.14it/s]

Batch 1300: Loss = 3.5046


Epoch 1:  90%|████████▉ | 1400/1562 [09:29<01:22,  1.96it/s]

Batch 1400: Loss = 3.1751


Epoch 1:  96%|█████████▌| 1500/1562 [10:15<00:30,  2.05it/s]

Batch 1500: Loss = 2.8062


Epoch 1: 100%|██████████| 1562/1562 [10:43<00:00,  2.43it/s]


Epoch 1 completed. Average Loss: 3.8118

Epoch 2/3


Epoch 2:   6%|▋         | 100/1562 [00:47<12:08,  2.01it/s]

Batch 100: Loss = 3.1506


Epoch 2:  13%|█▎        | 200/1562 [01:35<10:57,  2.07it/s]

Batch 200: Loss = 3.0198


Epoch 2:  19%|█▉        | 300/1562 [02:24<10:36,  1.98it/s]

Batch 300: Loss = 3.6756


Epoch 2:  26%|██▌       | 400/1562 [03:12<09:40,  2.00it/s]

Batch 400: Loss = 3.3759


Epoch 2:  32%|███▏      | 500/1562 [04:02<09:08,  1.94it/s]

Batch 500: Loss = 2.4824


Epoch 2:  38%|███▊      | 600/1562 [04:52<08:07,  1.97it/s]

Batch 600: Loss = 2.8384


Epoch 2:  45%|████▍     | 700/1562 [05:43<07:38,  1.88it/s]

Batch 700: Loss = 3.9443


Epoch 2:  51%|█████     | 800/1562 [06:35<06:39,  1.91it/s]

Batch 800: Loss = 2.8144


Epoch 2:  58%|█████▊    | 900/1562 [07:27<06:00,  1.84it/s]

Batch 900: Loss = 3.5903


Epoch 2:  64%|██████▍   | 1000/1562 [08:21<05:26,  1.72it/s]

Batch 1000: Loss = 2.7193


Epoch 2:  70%|███████   | 1100/1562 [09:14<04:17,  1.79it/s]

Batch 1100: Loss = 2.7227


Epoch 2:  77%|███████▋  | 1200/1562 [10:10<03:30,  1.72it/s]

Batch 1200: Loss = 2.9943


Epoch 2:  83%|████████▎ | 1300/1562 [11:05<02:30,  1.75it/s]

Batch 1300: Loss = 3.1031


Epoch 2:  90%|████████▉ | 1400/1562 [12:02<01:35,  1.70it/s]

Batch 1400: Loss = 2.8758


Epoch 2:  96%|█████████▌| 1500/1562 [13:02<00:39,  1.59it/s]

Batch 1500: Loss = 2.9377


Epoch 2: 100%|██████████| 1562/1562 [13:41<00:00,  1.90it/s]


Epoch 2 completed. Average Loss: 3.0069

Epoch 3/3


Epoch 3:   6%|▋         | 100/1562 [01:05<16:00,  1.52it/s]

Batch 100: Loss = 2.3023


Epoch 3:  13%|█▎        | 200/1562 [02:09<14:40,  1.55it/s]

Batch 200: Loss = 3.7220


Epoch 3:  19%|█▉        | 300/1562 [03:12<13:18,  1.58it/s]

Batch 300: Loss = 2.8158


Epoch 3:  26%|██▌       | 400/1562 [04:17<12:50,  1.51it/s]

Batch 400: Loss = 3.2722


Epoch 3:  32%|███▏      | 500/1562 [05:23<11:48,  1.50it/s]

Batch 500: Loss = 3.1270


Epoch 3:  38%|███▊      | 600/1562 [06:30<11:14,  1.43it/s]

Batch 600: Loss = 3.7962


Epoch 3:  45%|████▍     | 700/1562 [07:38<09:55,  1.45it/s]

Batch 700: Loss = 2.4666


Epoch 3:  51%|█████     | 800/1562 [08:47<08:57,  1.42it/s]

Batch 800: Loss = 2.3797


Epoch 3:  58%|█████▊    | 900/1562 [09:57<08:00,  1.38it/s]

Batch 900: Loss = 2.3092


Epoch 3:  64%|██████▍   | 1000/1562 [11:06<06:51,  1.37it/s]

Batch 1000: Loss = 2.3858


Epoch 3:  70%|███████   | 1100/1562 [12:17<05:39,  1.36it/s]

Batch 1100: Loss = 2.7917


Epoch 3:  77%|███████▋  | 1200/1562 [13:28<04:35,  1.32it/s]

Batch 1200: Loss = 3.2554


Epoch 3:  83%|████████▎ | 1300/1562 [14:41<03:18,  1.32it/s]

Batch 1300: Loss = 3.0826


Epoch 3:  90%|████████▉ | 1400/1562 [15:55<01:53,  1.42it/s]

Batch 1400: Loss = 2.1903


Epoch 3:  96%|█████████▌| 1500/1562 [17:07<00:46,  1.32it/s]

Batch 1500: Loss = 2.5514


Epoch 3: 100%|██████████| 1562/1562 [17:52<00:00,  1.46it/s]

Epoch 3 completed. Average Loss: 2.7446
Online contrastive training completed!





In [17]:
# Save the model
torch.save({
    'student_state_dict': student.state_dict(),
    'projection_layer_state_dict': projection_layer.state_dict() if projection_layer else None,
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epochs,
}, 'student_trainableProj_loc_3ep_05tau.pth')

print("Model saved as 'student_trainableProj_loc_3ep_05tau.pth'")

Model saved as 'student_trainableProj_loc_3ep_05tau.pth'


In [13]:
# More training
checkpoint = torch.load('student_trainableProj_locPiOthers.pth')
student.load_state_dict(checkpoint['student_state_dict'])
if projection_layer and checkpoint['projection_layer_state_dict']:
    projection_layer.load_state_dict(checkpoint['projection_layer_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint.get('epoch', 0) + 1

for epoch in range(start_epoch, start_epoch + epochs):
    epoch_loss = 0
    num_batches = 0
    
    print(f"\nEpoch {epoch + 1}/{epochs}")
    
    # The training loop iterates directly over the train_dataloader
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}"):
        for k in batch.keys():
            if torch.is_tensor(batch[k]): 
                batch[k] = batch[k].to(device)
        
        # Eliminate "y" key if it exists
        if 'y' in batch:
            del batch['y']
        
        # Get student features
        student_features = student.net.get_features(batch)
        
        # Apply projection
        if projection_layer is not None:
            student_features = projection_layer(student_features)
        
        # Generate positive features on-the-fly from the teacher
        with torch.no_grad():
            pos_features = teacher.net.get_features(batch)
        
            # Update memory buffer with new features (FIFO queue)
            pos_features_cpu = pos_features.cpu()
            
            # Concatenate new features to memory buffer
            for feature in pos_features_cpu:
                memory_buffer.append(feature)
            
            # Limit memory buffer size
            if len(memory_buffer) > memory_bank_size:
                memory_buffer.pop(0)
        
        # Prepare negative features for loss computation
        neg_features_for_loss = None
        if len(memory_buffer) > 0:
            # Convert memory buffer back to GPU tensors for loss computation
            neg_features_for_loss = torch.stack(memory_buffer).to(device)
            
        loss = contrastive_loss_batch_online(
            student_features, 
            pos_features, 
            neg_features=neg_features_for_loss, 
            tau=0.05, 
            NEG_RATIO=NEG_RATIO
        )

        losses.append(loss.item())
        if len(losses) > LOSS_BUFFER_SIZE:
            losses.pop(0)

        optimizer.zero_grad()
        loss.backward()        
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1
        
        if num_batches % 10 == 0:
            plt.figure(figsize=(10, 5))
            if len(losses) > 50:
                window_size = min(50, len(losses) // 10)
                smoothed_losses = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
                smooth_x = np.arange(window_size-1, len(losses))
                plt.plot(smooth_x, smoothed_losses, color='blue', linewidth=2, label=f'Moving Average (window={window_size})')
            else:
                plt.plot(np.arange(len(losses)), losses, color='lightblue', linewidth=1, label='Raw Loss')

            
            plt.xlabel('Batch')
            plt.ylabel('Loss Value')
            plt.title(f'Epoch {epoch + 1} Loss Progress (Moving Average)')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(f"loss_distillation_smoothed.png", dpi=150, bbox_inches='tight')
            plt.close()

        if num_batches % 100 == 0:
            print(f"Batch {num_batches}: Loss = {loss.item():.4f}")
    
    scheduler.step()
    
    if num_batches > 0:
        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch + 1} completed. No batches were processed.")

print("Online contrastive training completed!")

FileNotFoundError: [Errno 2] No such file or directory: 'student_trainableProj_locPiOthers.pth'

In [19]:
# Load the model for testing
checkpoint = torch.load('student_trainableProj_loc_3ep_05tau.pth')
student = instantiate(student_dict)
# Remove the last layer from the encoder or decoder
student.net.dense_predictor = student.net.dense_predictor[:-1]  
student.net.decoder.loc = student.net.decoder.loc[:-1]
student.net.decoder.pi = student.net.decoder.pi[:-1]
student.load_state_dict(checkpoint['student_state_dict'])  # <-- FIXED LINE

teacher.eval()
student.eval()
teacher.to(device)
student.to(device)

batch_pos = next(iter(train_dataloader))
for k in batch_pos.keys():
    if torch.is_tensor(batch_pos[k]): 
        batch_pos[k] = batch_pos[k].to(device)
# Eliminate "y" key if it exists
if 'y' in batch_pos:
    del batch_pos['y']
    
batch_neg = next(iter(train_dataloader))
for k in batch_neg.keys():
    if torch.is_tensor(batch_neg[k]): 
        batch_neg[k] = batch_neg[k].to(device)
# Eliminate "y" key if it exists
if 'y' in batch_neg:
    del batch_neg['y']

# After loading student and teacher and before loss calculation
with torch.no_grad():
    pos_teacher_features = teacher.net.get_features(batch_pos)
    neg_teacher_features = teacher.net.get_features(batch_neg)
    pos_student_features = student.net.get_features(batch_pos)

    # Project student features if needed
    if pos_teacher_features.shape[1] != pos_student_features.shape[1]:
        projection_layer = torch.nn.Linear(pos_student_features.shape[1], pos_teacher_features.shape[1]).to(device)
        # Optionally load projection weights if you saved them
        if 'projection_layer_state_dict' in checkpoint and checkpoint['projection_layer_state_dict'] is not None:
            projection_layer.load_state_dict(checkpoint['projection_layer_state_dict'])
        pos_student_features = projection_layer(pos_student_features)

loss = contrastive_loss_batch_online(
    pos_student_features, 
    pos_teacher_features, 
    neg_features=neg_teacher_features, 
    tau=0.05, 
    NEG_RATIO=NEG_RATIO
)
print(f"Test loss after loading model: {loss.item():.4f}")

Using decoder type: mlp
RUN EMP-M
Test loss after loading model: 3.1843


In [20]:
student = instantiate(student_dict)
checkpoint = torch.load('student_trainableProj_locPiOthers.pth')
student.load_state_dict(checkpoint['student_state_dict'], strict=False)

student.to(device)
student.train()

optimizer, scheduler = student.configure_optimizers()
optimizer = optimizer[0] if isinstance(optimizer, list) else optimizer
scheduler = scheduler[0] if isinstance(scheduler, list) else scheduler

dataloader = datamodule.train_dataloader()

# Freeze all layers except the last one
for name, param in student.named_parameters():
    if 'dense_predictor' in name or 'decoder' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

for epoch in range(5):
    print(f"Epoch {epoch+1}/5")
    losses = []
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)
        preds = student(batch)
        # preds['y_hat']: (batch, 9, 60, 2), batch['y'][:, 0, :, :] : (batch, 60, 2)
        gt_traj_ego = batch['y'][:, 0, :, :]  # (batch, 60, 2)
        pred_modes = preds['y_hat']           # (batch, 9, 60, 2)
        # Compute ADE for each mode
        gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, pred_modes.shape[1], -1, -1)  # (batch, 9, 60, 2)
        displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, 9, 60)
        ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, 9)
        best_mode_idx = ade_per_mode.argmin(dim=1)       # (batch,)
        batch_indices = torch.arange(pred_modes.shape[0])
        best_pred = pred_modes[batch_indices, best_mode_idx]  # (batch, 60, 2)
        # Now compute loss
        loss = torch.nn.functional.mse_loss(best_pred, gt_traj_ego)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if scheduler is not None:
        scheduler.step()
    print(f"Epoch {epoch+1} done.")
    print(f"Average loss: {np.mean(losses):.4f}")

print("Fine-tuning complete.")

Using decoder type: mlp
RUN EMP-M
Epoch 1/5


Epoch 1: 100%|██████████| 1562/1562 [02:25<00:00, 10.71it/s]


Epoch 1 done.
Average loss: 79.2203
Epoch 2/5


Epoch 2: 100%|██████████| 1562/1562 [02:26<00:00, 10.69it/s]


Epoch 2 done.
Average loss: 11.2786
Epoch 3/5


Epoch 3: 100%|██████████| 1562/1562 [02:23<00:00, 10.92it/s]


Epoch 3 done.
Average loss: 7.9010
Epoch 4/5


Epoch 4: 100%|██████████| 1562/1562 [02:24<00:00, 10.79it/s]


Epoch 4 done.
Average loss: 5.3608
Epoch 5/5


Epoch 5: 100%|██████████| 1562/1562 [02:25<00:00, 10.70it/s]

Epoch 5 done.
Average loss: 4.5518
Fine-tuning complete.





In [21]:
# Save the model
torch.save({
    'student_state_dict': student.state_dict(),
    'projection_layer_state_dict': projection_layer.state_dict() if projection_layer else None,
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epochs,
}, 'student_trainableProj_loc_3ep_05tau_more.pth')

print("Model saved as 'student_trainableProj_loc_3ep_05tau_more.pth'")

Model saved as 'student_trainableProj_loc_3ep_05tau_more.pth'


In [23]:
# Load the checkpoint without the logit layers
checkpoint_path = 'student_trainableProj_loc_3ep_05tau_more.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)

# Create a fresh student model with full layers
student_full = instantiate(student_dict)

# Load the trained weights
try:
    student_full.load_state_dict(checkpoint['student_state_dict'], strict=False)
    print("Loaded trained weights (non-strict mode)")
except Exception as e:
    print(f"Error loading weights: {e}")
    model_dict = student_full.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint['student_state_dict'].items() if k in model_dict}
    model_dict.update(pretrained_dict)
    student_full.load_state_dict(model_dict)
    print("Loaded compatible weights manually")

student_full = student_full.to(device)
student_full.eval()

# torch.save({
#     'model_state_dict': student_full.state_dict(),
#     'model_config': student_dict,
#     'training_info': {
#         'epochs_trained': checkpoint.get('epoch', 'unknown'),
#         'architecture': 'student_with_contrastive_learning'
#     }
# }, 'student_model_full_with_predictions.pth')

datamodule.setup(stage="validate", train_fraction=0.1, val_fraction=0.1, test_fraction=0.1)
val_dataloader = datamodule.val_dataloader()

print("Evaluating trajectory prediction model...")

# First, let's inspect the prediction structure
sample_batch = next(iter(val_dataloader))
for k in sample_batch.keys():
    if torch.is_tensor(sample_batch[k]):
        sample_batch[k] = sample_batch[k].to(device)

with torch.no_grad():
    sample_preds = student_full(sample_batch)

print("Inspection of predictions and ground truth:")
print(f"Prediction type: {type(sample_preds)}")
if isinstance(sample_preds, dict):
    for key, value in sample_preds.items():
        if torch.is_tensor(value):
            print(f"  {key}: {value.shape}")

print(f"Ground truth 'y' shape: {sample_batch['y'].shape}")

# Multi-modal trajectory prediction evaluation
mse_errors = []
mae_errors = []
ade_errors = []
fde_errors = []

def evaluate_multimodal_prediction(pred_modes, gt_traj_ego):
    """
    Evaluate multi-modal prediction by selecting the best mode
    pred_modes: (batch, num_modes, time, 2)
    gt_traj_ego: (batch, time, 2)
    """
    batch_size, num_modes, time_steps, _ = pred_modes.shape
    
    # Compute ADE for each mode
    # Expand gt_traj_ego to match pred_modes shape
    gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, num_modes, -1, -1)  # (batch, num_modes, time, 2)
    
    # Compute displacement errors for all modes
    displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, num_modes, time)
    
    # ADE for each mode: average over time
    ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, num_modes)
    
    # FDE for each mode: final time step
    fde_per_mode = displacement_errors[:, :, -1]  # (batch, num_modes)
    
    # Select best mode (lowest ADE)
    best_mode_idx = ade_per_mode.argmin(dim=1)  # (batch,)
    
    # Extract best ADE and FDE
    batch_indices = torch.arange(batch_size)
    best_ade = ade_per_mode[batch_indices, best_mode_idx]  # (batch,)
    best_fde = fde_per_mode[batch_indices, best_mode_idx]  # (batch,)
    
    # Get best trajectory for MSE/MAE computation
    best_traj = pred_modes[batch_indices, best_mode_idx]  # (batch, time, 2)
    
    return best_traj, best_ade, best_fde

for batch in val_dataloader:
    for k in batch.keys():
        if torch.is_tensor(batch[k]):
            batch[k] = batch[k].to(device)

    with torch.no_grad():
        preds = student_full(batch)
        
        # Extract multi-modal predictions
        if isinstance(preds, dict) and 'y_hat' in preds:
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
        else:
            print("Could not find 'y_hat' in predictions")
            break
        
        # Extract ego agent ground truth (first agent, index 0)
        gt_traj_all = batch['y']  # (batch, 42, 60, 2)
        gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2) - ego agent only
        
        # Evaluate multi-modal prediction
        best_traj, batch_ade, batch_fde = evaluate_multimodal_prediction(pred_modes, gt_traj_ego)
        
        # Compute MSE and MAE using best trajectory
        mse = torch.nn.functional.mse_loss(best_traj, gt_traj_ego)
        mae = torch.nn.functional.l1_loss(best_traj, gt_traj_ego)
        
        mse_errors.append(mse.item())
        mae_errors.append(mae.item())
        ade_errors.extend(batch_ade.cpu().numpy())
        fde_errors.extend(batch_fde.cpu().numpy())

print(f"Validation MSE (best mode): {np.mean(mse_errors):.4f} ± {np.std(mse_errors):.4f}")
print(f"Validation MAE (best mode): {np.mean(mae_errors):.4f} ± {np.std(mae_errors):.4f}")
print(f"Validation ADE (best mode): {np.mean(ade_errors):.4f} ± {np.std(ade_errors):.4f}")
print(f"Validation FDE (best mode): {np.mean(fde_errors):.4f} ± {np.std(fde_errors):.4f}")

# Additional analysis: evaluate using prediction confidence (pi weights)
if 'pi' in sample_preds:
    print("\nEvaluating using prediction confidence weights...")
    conf_mse_errors = []
    conf_mae_errors = []
    conf_ade_errors = []
    conf_fde_errors = []
    
    for batch in val_dataloader:
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)

        with torch.no_grad():
            preds = student_full(batch)
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
            pi_weights = preds['pi']  # (batch, 9) - confidence scores
            
            gt_traj_all = batch['y']
            gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2)
            
            # Select mode with highest confidence
            best_conf_idx = pi_weights.argmax(dim=1)  # (batch,)
            batch_indices = torch.arange(pred_modes.shape[0])
            conf_best_traj = pred_modes[batch_indices, best_conf_idx]  # (batch, 60, 2)
            
            # Compute metrics
            conf_mse = torch.nn.functional.mse_loss(conf_best_traj, gt_traj_ego)
            conf_mae = torch.nn.functional.l1_loss(conf_best_traj, gt_traj_ego)
            
            # ADE/FDE for confidence-based selection
            displacement_errors = torch.norm(conf_best_traj - gt_traj_ego, dim=-1)
            conf_ade = displacement_errors.mean(dim=-1)
            conf_fde = displacement_errors[:, -1]
            
            conf_mse_errors.append(conf_mse.item())
            conf_mae_errors.append(conf_mae.item())
            conf_ade_errors.extend(conf_ade.cpu().numpy())
            conf_fde_errors.extend(conf_fde.cpu().numpy())
    
    print(f"Confidence-based MSE: {np.mean(conf_mse_errors):.4f} ± {np.std(conf_mse_errors):.4f}")
    print(f"Confidence-based MAE: {np.mean(conf_mae_errors):.4f} ± {np.std(conf_mae_errors):.4f}")
    print(f"Confidence-based ADE: {np.mean(conf_ade_errors):.4f} ± {np.std(conf_ade_errors):.4f}")
    print(f"Confidence-based FDE: {np.mean(conf_fde_errors):.4f} ± {np.std(conf_fde_errors):.4f}")

Using decoder type: mlp
RUN EMP-M
Loaded trained weights (non-strict mode)
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 10.0% of the training and validation datasets.
Evaluating trajectory prediction model...
Inspection of predictions and ground truth:
Prediction type: <class 'dict'>
  y_hat: torch.Size([64, 9, 60, 2])
  pi: torch.Size([64, 9])
  y_hat_others: torch.Size([64, 48, 60, 2])
  y_hat_eps: torch.Size([64, 9, 2])
  x_agent: torch.Size([64, 64])
Ground truth 'y' shape: torch.Size([64, 49, 60, 2])
Validation MSE (best mode): 3.9397 ± 1.2164
Validation MAE (best mode): 1.1643 ± 0.1050
Validation ADE (best mode): 1.8989 ± 1.2829
Validation FDE (best mode): 4.3558 ± 3.5862

Evaluating using prediction confidence weights...
Confidence-based MSE: 118.1811 ± 14.9803
Confidence-based MAE: 6.3760 ± 0.3610
Confidence-ba

In [27]:
# Load the checkpoint without the logit layers
checkpoint_path = 'student_model_custom_contrastive.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)

# # Create a fresh student model with full layers
# student_full = instantiate(student_dict)

# # Load the trained weights
# try:
#     student_full.load_state_dict(checkpoint['student_state_dict'], strict=False)
#     print("Loaded trained weights (non-strict mode)")
# except Exception as e:
#     print(f"Error loading weights: {e}")
#     model_dict = student_full.state_dict()
#     pretrained_dict = {k: v for k, v in checkpoint['student_state_dict'].items() if k in model_dict}
#     model_dict.update(pretrained_dict)
#     student_full.load_state_dict(model_dict)
#     print("Loaded compatible weights manually")

# student_full = student_full.to(device)
# student_full.eval()

# torch.save({
#     'model_state_dict': student_full.state_dict(),
#     'model_config': student_dict,
#     'training_info': {
#         'epochs_trained': checkpoint.get('epoch', 'unknown'),
#         'architecture': 'student_with_contrastive_learning'
#     }
# }, 'student_model_full_with_predictions.pth')

if teacher_dict["decoder"] == "mlp":
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empm.ckpt"
else:
    checkpoint = "/home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/emp-main/checkpoints/empd.ckpt"
assert os.path.exists(checkpoint), f"Checkpoint {checkpoint} does not exist"

model_path = teacher_dict["_target_"]
module = import_module(model_path[: model_path.rfind(".")])
Model: pl.LightningModule = getattr(module, model_path[model_path.rfind(".") + 1 :])
teacher = Model.load_from_checkpoint(checkpoint, **teacher_dict)
teacher.to(device)
teacher.eval()

datamodule.setup(stage="validate", train_fraction=0.1, val_fraction=0.1, test_fraction=0.1)
val_dataloader = datamodule.val_dataloader()

print("Evaluating trajectory prediction model...")

# First, let's inspect the prediction structure
sample_batch = next(iter(val_dataloader))
for k in sample_batch.keys():
    if torch.is_tensor(sample_batch[k]):
        sample_batch[k] = sample_batch[k].to(device)

with torch.no_grad():
    sample_preds = teacher(sample_batch)

print("Inspection of predictions and ground truth:")
print(f"Prediction type: {type(sample_preds)}")
if isinstance(sample_preds, dict):
    for key, value in sample_preds.items():
        if torch.is_tensor(value):
            print(f"  {key}: {value.shape}")

print(f"Ground truth 'y' shape: {sample_batch['y'].shape}")

# Multi-modal trajectory prediction evaluation
mse_errors = []
mae_errors = []
ade_errors = []
fde_errors = []

def evaluate_multimodal_prediction(pred_modes, gt_traj_ego):
    """
    Evaluate multi-modal prediction by selecting the best mode
    pred_modes: (batch, num_modes, time, 2)
    gt_traj_ego: (batch, time, 2)
    """
    batch_size, num_modes, time_steps, _ = pred_modes.shape
    
    # Compute ADE for each mode
    # Expand gt_traj_ego to match pred_modes shape
    gt_expanded = gt_traj_ego.unsqueeze(1).expand(-1, num_modes, -1, -1)  # (batch, num_modes, time, 2)
    
    # Compute displacement errors for all modes
    displacement_errors = torch.norm(pred_modes - gt_expanded, dim=-1)  # (batch, num_modes, time)
    
    # ADE for each mode: average over time
    ade_per_mode = displacement_errors.mean(dim=-1)  # (batch, num_modes)
    
    # FDE for each mode: final time step
    fde_per_mode = displacement_errors[:, :, -1]  # (batch, num_modes)
    
    # Select best mode (lowest ADE)
    best_mode_idx = ade_per_mode.argmin(dim=1)  # (batch,)
    
    # Extract best ADE and FDE
    batch_indices = torch.arange(batch_size)
    best_ade = ade_per_mode[batch_indices, best_mode_idx]  # (batch,)
    best_fde = fde_per_mode[batch_indices, best_mode_idx]  # (batch,)
    
    # Get best trajectory for MSE/MAE computation
    best_traj = pred_modes[batch_indices, best_mode_idx]  # (batch, time, 2)
    
    return best_traj, best_ade, best_fde

for batch in val_dataloader:
    for k in batch.keys():
        if torch.is_tensor(batch[k]):
            batch[k] = batch[k].to(device)

    with torch.no_grad():
        preds = teacher(batch)
        
        # Extract multi-modal predictions
        if isinstance(preds, dict) and 'y_hat' in preds:
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
        else:
            print("Could not find 'y_hat' in predictions")
            break
        
        # Extract ego agent ground truth (first agent, index 0)
        gt_traj_all = batch['y']  # (batch, 42, 60, 2)
        gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2) - ego agent only
        
        # Evaluate multi-modal prediction
        best_traj, batch_ade, batch_fde = evaluate_multimodal_prediction(pred_modes, gt_traj_ego)
        
        # Compute MSE and MAE using best trajectory
        mse = torch.nn.functional.mse_loss(best_traj, gt_traj_ego)
        mae = torch.nn.functional.l1_loss(best_traj, gt_traj_ego)
        
        mse_errors.append(mse.item())
        mae_errors.append(mae.item())
        ade_errors.extend(batch_ade.cpu().numpy())
        fde_errors.extend(batch_fde.cpu().numpy())

print(f"Validation MSE (best mode): {np.mean(mse_errors):.4f} ± {np.std(mse_errors):.4f}")
print(f"Validation MAE (best mode): {np.mean(mae_errors):.4f} ± {np.std(mae_errors):.4f}")
print(f"Validation ADE (best mode): {np.mean(ade_errors):.4f} ± {np.std(ade_errors):.4f}")
print(f"Validation FDE (best mode): {np.mean(fde_errors):.4f} ± {np.std(fde_errors):.4f}")

# Additional analysis: evaluate using prediction confidence (pi weights)
if 'pi' in sample_preds:
    print("\nEvaluating using prediction confidence weights...")
    conf_mse_errors = []
    conf_mae_errors = []
    conf_ade_errors = []
    conf_fde_errors = []
    
    for batch in val_dataloader:
        for k in batch.keys():
            if torch.is_tensor(batch[k]):
                batch[k] = batch[k].to(device)

        with torch.no_grad():
            preds = teacher(batch)
            pred_modes = preds['y_hat']  # (batch, 9, 60, 2)
            pi_weights = preds['pi']  # (batch, 9) - confidence scores
            
            gt_traj_all = batch['y']
            gt_traj_ego = gt_traj_all[:, 0, :, :]  # (batch, 60, 2)
            
            # Select mode with highest confidence
            best_conf_idx = pi_weights.argmax(dim=1)  # (batch,)
            batch_indices = torch.arange(pred_modes.shape[0])
            conf_best_traj = pred_modes[batch_indices, best_conf_idx]  # (batch, 60, 2)
            
            # Compute metrics
            conf_mse = torch.nn.functional.mse_loss(conf_best_traj, gt_traj_ego)
            conf_mae = torch.nn.functional.l1_loss(conf_best_traj, gt_traj_ego)
            
            # ADE/FDE for confidence-based selection
            displacement_errors = torch.norm(conf_best_traj - gt_traj_ego, dim=-1)
            conf_ade = displacement_errors.mean(dim=-1)
            conf_fde = displacement_errors[:, -1]
            
            conf_mse_errors.append(conf_mse.item())
            conf_mae_errors.append(conf_mae.item())
            conf_ade_errors.extend(conf_ade.cpu().numpy())
            conf_fde_errors.extend(conf_fde.cpu().numpy())
    
    print(f"Confidence-based MSE: {np.mean(conf_mse_errors):.4f} ± {np.std(conf_mse_errors):.4f}")
    print(f"Confidence-based MAE: {np.mean(conf_mae_errors):.4f} ± {np.std(conf_mae_errors):.4f}")
    print(f"Confidence-based ADE: {np.mean(conf_ade_errors):.4f} ± {np.std(conf_ade_errors):.4f}")
    print(f"Confidence-based FDE: {np.mean(conf_fde_errors):.4f} ± {np.std(conf_fde_errors):.4f}")

Using decoder type: mlp
RUN EMP-M
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/train, total number of files: 199908
data root: /home/alex/UNI EMERGENCY/Project/DeepLeaning-Lab/data/emp/val, total number of files: 24988
Using 10.0% of the training and validation datasets.
Evaluating trajectory prediction model...
Inspection of predictions and ground truth:
Prediction type: <class 'dict'>
  y_hat: torch.Size([64, 6, 60, 2])
  pi: torch.Size([64, 6])
  y_hat_others: torch.Size([64, 44, 60, 2])
  y_hat_eps: torch.Size([64, 6, 2])
  x_agent: torch.Size([64, 128])
Ground truth 'y' shape: torch.Size([64, 45, 60, 2])
Validation MSE (best mode): 0.8327 ± 0.6760
Validation MAE (best mode): 0.4323 ± 0.0488
Validation ADE (best mode): 0.7080 ± 0.7275
Validation FDE (best mode): 1.6843 ± 2.0672

Evaluating using prediction confidence weights...
Confidence-based MSE: 5.3941 ± 2.0615
Confidence-based MAE: 0.9974 ± 0.1312
Confidence-based ADE: 1.7182 ± 1.8246
Confidence-based F