# Loss analysis

In this notebook we compute the loss value for each data point in the dataset to gain an insight on how the model is behaveaing and where it fails

In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate
from matplotlib import gridspec
import imageio
from IPython.display import Video
from tqdm import tqdm

from src.encoder_temporal import VICRegJEPAEncoder
from src.data.dataset import PointMazeTransitions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.set_float32_matmul_precision('high')

Using device: cuda


Load the dataset

In [2]:
dataset = PointMazeTransitions(
    "data/train_trajectories_10_100_4_64.npz",
    normalize=True,
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

[Dataset] Frames resized to 64×64
[Dataset] Loaded 10 episodes, 1000 transitions.
[Dataset] Frame shape: (64, 64, 3)


Load the model

In [3]:
model = VICRegJEPAEncoder.load_from_checkpoint("checkpoints/visual_encoder/last.ckpt", strict=False)
model.to(device)
model.eval();

Compute the loss for each data point

In [7]:
losses = list()
states = list()
with torch.inference_mode():
    for batch in tqdm(dataloader):
        (state_curr, frame_curr), action, (state_next, frame_next), _ = batch
        states.append(state_curr.numpy())
        batch = (
            (state_curr.to(device), frame_curr.to(device)),
                action.to(device),
                (state_next.to(device), frame_next.to(device)), None
            )
        loss = model.shared_step(batch, -1)
        loss = { k: v.item() for k, v in loss.items() if "curr" in k }
        losses.append(loss)

states = np.concatenate(states, axis=0)
n_losses = len(losses[0].keys())

  std = x.std(dim=0)
100%|██████████| 1000/1000 [00:03<00:00, 273.92it/s]


In [26]:
from sklearn.ensemble import RandomForestClassifier
import numpy as np

rankings = dict()

for loss_name in losses[0].keys():
    loss_values = np.array([loss[loss_name] for loss in losses], dtype=float)
    if np.all(np.isnan(loss_values)):
        print(f"Skipping {loss_name}: all NaN")
        continue

    # fill NaNs with median so thresholding works
    if np.isnan(loss_values).any():
        loss_values[np.isnan(loss_values)] = np.nanmedian(loss_values)

    threshold = np.percentile(loss_values, 99)
    labels = (loss_values >= threshold).astype(int)
    n_pos = labels.sum()

    print(f"\n=== {loss_name} | threshold={threshold:.6g} | positives={n_pos} ===")

    clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1)
    clf.fit(states, labels)

    # feature importances
    importances = clf.feature_importances_
    indices = np.argsort(importances)[::-1]
    print("Feature importances:")
    for rank in range(states.shape[1]):
        i = indices[rank]
        print(f"{rank+1}. feature_{i} ({importances[i]:.2%})")

    rankings[loss_name] = indices


=== recon_loss_curr | threshold=0.0975807 | positives=10 ===
Feature importances:
1. feature_0 (32.08%)
2. feature_2 (26.49%)
3. feature_3 (25.17%)
4. feature_1 (16.25%)

=== proprio_recon_loss_curr | threshold=0.0938972 | positives=10 ===
Feature importances:
1. feature_0 (32.08%)
2. feature_2 (26.49%)
3. feature_3 (25.17%)
4. feature_1 (16.25%)

=== visual_recon_loss_curr | threshold=0.00425967 | positives=10 ===
Feature importances:
1. feature_2 (31.47%)
2. feature_3 (30.97%)
3. feature_0 (19.46%)
4. feature_1 (18.10%)
Skipping vcreg_cls_curr: all NaN

=== vcreg_patch_curr | threshold=4.27778 | positives=10 ===
Feature importances:
1. feature_0 (26.81%)
2. feature_2 (25.15%)
3. feature_1 (24.97%)
4. feature_3 (23.07%)
Skipping vcreg_state_curr: all NaN

=== vicreg_cross_curr | threshold=1.72793 | positives=10 ===
Feature importances:
1. feature_0 (37.77%)
2. feature_3 (22.31%)
3. feature_2 (21.54%)
4. feature_1 (18.38%)

=== state_from_patches_loss_curr | threshold=0.102001 | posit

In [37]:
pos_cost, vel_cost = 0, 0
for k, v in rankings.items():
    _pos_cost = sum(np.where(v < 2)[0] ** 2)
    _vel_cost = sum(np.where((v >=2) & (v <4))[0] **2)
    pos_cost += _pos_cost
    vel_cost += _vel_cost
print(f"\nTotal position cost: {pos_cost}")
print(f"Total velocity cost: {vel_cost}")


Total position cost: 48
Total velocity cost: 36
