In [None]:
from plot_trajectory import plot_paths
import training
import json
import torch
from torch.utils.data import DataLoader
from dataloader import load_val
import joblib
import os
import matplotlib.pyplot as plt
from model_autoregressive import Seq2SeqLSTM
from transformer_model import TrajectoryTransformer30to10
import numpy as np
from tqdm import tqdm
from paths import RESULTS_FILTERED_DIR, RESULTS_UNFILTERED_DIR
from model_selection import rank_models, plot_model_losses, haversine_np

In [None]:
device = training.determine_device()
print("Using device:", device)

In [None]:
metrics = ["val_mse", "val_rmse", "val_mae"]
best_model_name, best_score, best_model_data = rank_models(RESULTS_FILTERED_DIR, metrics[0])

In [None]:
# plot losses for the best model based on MSE

plot_model_losses(best_model_data)

In [None]:
#load data
batch_size = 512
scaler_filtered = joblib.load("scaler_filtered.save")
val_ds_filtered = load_val(filter_stationary=True, scaler=scaler_filtered)
scaler_unfiltered = joblib.load("scaler_unfiltered.save")
val_ds_unfiltered = load_val(filter_stationary=False, scaler=scaler_unfiltered)
val_loader_filtered = DataLoader(val_ds_filtered, batch_size=batch_size, num_workers=4, shuffle=False)
val_loader_unfiltered = DataLoader(val_ds_unfiltered, batch_size=batch_size, num_workers=4, shuffle=False)

In [None]:
best_model_path = "deeper_autoreg_lstm_2_best.pt" 
full_best_model_path = os.path.join(RESULTS_FILTERED_DIR, best_model_path)

model = Seq2SeqLSTM(**best_model_data["config"]["model_kwargs"]).to(device)
model.load_state_dict(torch.load(full_best_model_path, map_location=device))
model.eval()

In [None]:
# best_model_path = "deeper_transformer_best.pt" 
# full_best_model_path = os.path.join(RESULTS_FILTERED_DIR, best_model_path)

# model = Seq2SeqLSTM(**best_model_data["config"]["model_kwargs"]).to(device)
# model.load_state_dict(torch.load(full_best_model_path, map_location=device))
# model.eval()

In [None]:
# compute per-sample MSE on validation set (notice here we do per-sample MSE, so the total loss is 20 times smaller)
all_mse = []
all_samples = []

for x, y in tqdm(val_loader_filtered, desc="Computing per-sample MSE"):
    x, y = x.to(device), y.to(device)
    with torch.no_grad():
        y_pred = model(x)
    
    mse_per_sample = torch.mean((y_pred - y)**2, dim=[1,2])
    all_mse.append(mse_per_sample.cpu().numpy())
    
    for i in range(x.shape[0]):
        all_samples.append({
            "x": x[i].cpu().numpy(),
            "y": y[i].cpu().numpy(),
            "y_pred": y_pred[i].cpu().numpy()
        })

all_mse = np.concatenate(all_mse)
sorted_indices = np.argsort(all_mse)

n = len(all_mse)

In [None]:
percentiles = [10, 30, 50, 80, 90, 95, 99, 99.99, 100]
print("\nMSE percentile summary:\n")
for p in percentiles:
    threshold = np.percentile(all_mse, p)
    print(f"{p:>3}% of samples have MSE ≤ {threshold:.6f}")

In [None]:
groups = {
    "Best": sorted_indices[:3],
    "Q1": sorted_indices[n//4:n//4+3],
    "Median": sorted_indices[n//2:n//2+3],
    "Q3": sorted_indices[3*n//4:3*n//4+3],
    "Worst": sorted_indices[-3:]
}

def plot_sample(sample, title):
    plot_paths(sample["x"], sample["y"], sample["y_pred"], title, scaler=scaler_filtered)

for group_name, indices in groups.items():
    print(f"\nPlotting 3 samples from {group_name} group:")
    for idx in indices:
        plot_sample(all_samples[idx], f"{group_name} Sample (MSE={all_mse[idx]:.6f})")

In [None]:
print("\n=== Haversine Distance Evaluation by Groups ===")

for group_name, indices in groups.items():
    group_means = []

    print(f"\n### {group_name} group ###")

    for idx in indices:
        sample = all_samples[idx]

        # inverse scale
        y_true_scaled = sample["y"]
        y_pred_scaled = sample["y_pred"]
        y_true_unscaled = scaler_filtered.inverse_transform(y_true_scaled)
        y_pred_unscaled = scaler_filtered.inverse_transform(y_pred_scaled)

        # Compute Haversine
        dists_km, mean_hav_km = haversine_np(y_true_unscaled, y_pred_unscaled)

        # Save group mean
        group_means.append(mean_hav_km)

        # --- Pretty step-wise print for groups ---
        print(f"\nSample {idx} (MSE={all_mse[idx]:.6f}) – Haversine per step:")
        for step, d in enumerate(dists_km, start=1):
            print(f"  Step {step:02d} → {d:10.6f} km")

        print(f"  → Mean Haversine for this sample: {mean_hav_km:.6f} km")