In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataloader import load_train, load_val
import numpy as np
import matplotlib.pyplot as plt
from plot_trajectory import plot_paths
from transformer_model import TrajectoryTransformer30to10, PositionalEncoding
import training
#from metrics import rmse, mse, mae
import math
import os
import json


In [None]:
device = training.determine_device()

In [None]:
#Test different Models
configs = [
    {   
        "name": "mini_transformer",
        "epochs": 10,
        "model_kwargs": {
            "d_model": 128,
            "nhead": 4,
            "num_layers": 3,
            "dim_feedforward": 10,
            "dropout": 0.1,
        },
    },
    {
        "name": "small_transformer",
        "epochs": 40,
        "model_kwargs": {
            "d_model": 128,
            "nhead": 4,
            "num_layers": 3,
            "dim_feedforward": 512,
            "dropout": 0.1,
        },
    },
    {
        "name": "medium_transformer",
        "epochs": 40,
        "model_kwargs": {
            "d_model": 256,
            "nhead": 8,
            "num_layers": 3,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_transformer_2",
        "epochs": 90,
        "model_kwargs": {
            "d_model": 256,
            "nhead": 8,
            "num_layers": 4,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "optimizer_args": {
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "deeper_transformer",
        "epochs": 50,
        "model_kwargs": {
            "d_model": 256,
            "nhead": 8,
            "num_layers": 4,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "optimizer_args": {
            "lr": 5e-4,
        },
    },
]




In [None]:
#Plotting code

os.makedirs("plots", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)

def plot_val_curve(history, name):
    epochs   = history["epoch"]
    val_rmse = history["val_rmse"]

    plt.figure()
    plt.plot(epochs, val_rmse, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Val RMSE")
    plt.title(f"Val RMSE – {name}")
    plt.tight_layout()

    save_path = os.path.join("plots", f"{name}_val_rmse.png")
    plt.savefig(save_path, dpi=150)
    plt.close()


def plot_paths_saved(x_np, y_np, y_pred_np, idx, name):
    """NUR einen einzelnen Trajektorienplot bauen & speichern."""
    plt.figure(figsize=(5, 5))

    # Entweder deinen bestehenden plot_paths benutzen:
    # plot_paths(x_np, y_np, y_pred_np, idx)

    # ...ODER direkt hier plotten (Beispiel: [lat = col 0, lon = col 1]):
    plt.plot(x_np[:, 1],      x_np[:, 0],      'o-', label='history')
    plt.plot(y_np[:, 1],      y_np[:, 0],      'x--', label='true future')
    plt.plot(y_pred_np[:, 1], y_pred_np[:, 0], 's--', label='pred future')

    plt.title(f"{name} – val sample {idx}")
    plt.legend()
    plt.tight_layout()

    save_path = os.path.join("plots", f"{name}_traj_seg{idx}.png")
    plt.savefig(save_path, dpi=150)
    plt.close()




In [None]:
train_ds, scaler = load_train()
val_ds = load_val(scaler)

val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=64, shuffle=False
)

def make_small_val_subset(val_loader, num_batches=2):
    small_batches = []
    for i, (X, Y) in enumerate(val_loader):
        if i >= num_batches:
            break
        # Clone, damit späteres .to(device) nichts kaputtmacht
        small_batches.append((X.clone(), Y.clone()))
    return small_batches

small_val_batches = make_small_val_subset(val_loader, num_batches=2)


In [None]:
configs = [
    {   
        "name": "micro_transformer",
        "epochs": 1,
        "model_kwargs": {
            "d_model": 8,
            "nhead": 4,
            "num_layers": 1,
            "dim_feedforward": 2,
            "dropout": 0.1,
        },
    }
]

In [None]:
results = training.train_all(TrajectoryTransformer30to10, configs, train=train_ds, val=val_ds)

In [None]:
best_result = None
for result in results.values():
    final_val_rmse = result["history"]["val_rmse"][-1]
    if best_result is None or final_val_rmse < best_result["final_val_rmse"]:
        best_result = result

with open("search_results.json", "w") as f:
    json.dump(results, f, indent=2)

In [None]:
for result in results:
    name   = result['config']["name"]
    model  = result['model']
    history = result['history']
    
    # Validierungs-Kurve speichern
    plot_val_curve(history, name)

    # -----------------------------
    # HIER: 2–3 ROUTEN PLOTTEN
    # -----------------------------
    model.eval()
    with torch.no_grad():
        for seg_idx, (X_val_small, Y_val_small) in enumerate(small_val_batches):
            if seg_idx >= 3:
                break

            Xs = X_val_small.to(device)
            Ys = Y_val_small.to(device)

            Y_pred = model(Xs)

            x_np      = Xs[0].cpu().numpy()
            y_np      = Ys[0].cpu().numpy()
            y_pred_np = Y_pred[0].cpu().numpy()

            # ---------------------------------------
            # INVERSE SCALING
            # ---------------------------------------
            x_np[:, :2]      = scaler.inverse_transform(x_np[:, :2])
            y_np[:, :2]      = scaler.inverse_transform(y_np[:, :2])
            y_pred_np[:, :2] = scaler.inverse_transform(y_pred_np[:, :2])

            # ---------------------------------------
            #  MSE/RMSE nach inverse scaling
            # ---------------------------------------
            mse  = np.mean((y_pred_np - y_np)**2)
            rmse = np.sqrt(mse)
            print(f"[{name}] Segment {seg_idx}: MSE={mse:.4f}, RMSE={rmse:.4f}")

            # ---------------------------------------
            #  Plot speichern
            # ---------------------------------------
            plot_paths_saved(x_np, y_np, y_pred_np, seg_idx, name)




In [None]:

for result in results:
    name   = result['config']["name"]
    model  = result['model']
    history = result['history']
    
    plot_val_curve(history, name)

    # 4) 2–3 Trajektorienplots direkt nach Training erzeugen
    model.eval()
    with torch.no_grad():
        for seg_idx, (X_val_small, Y_val_small) in enumerate(small_val_batches):
            if seg_idx >= 3:   # max 3 Segmente
                break

            X_val_small = X_val_small.to(device)
            Y_val_small = Y_val_small.to(device)

            preds_small = model(X_val_small)

            x_np      = X_val_small[0].cpu().numpy()
            y_np      = Y_val_small[0].cpu().numpy()
            y_pred_np = preds_small[0].cpu().numpy()

            # HIER: ggf. inverse scaling einbauen
            # z.B.:
            # x_np_scaled = x_np.copy()
            # ...
            # und dann plot_paths_saved(x_np_scaled, ...)

            plot_paths_saved(x_np, y_np, y_pred_np, seg_idx, name)


print("\n======================")
print("Best config:", best_result["name"])
print("Best final Val RMSE:", best_result["final_val_rmse"])
print("Best model args:", best_result["config"]["model_kwargs"])
print("Best train args:", best_result["config"]["train_kwargs"])
print("Best checkpoint:", best_result["checkpoint_path"])



### Plot Validation Loss

chosen = results[1]   # 0,1,2 depending on list of results
history = chosen["history"]
name = chosen["A_model"]

epochs = range(1, len(history["train_mse"]) + 1)

plt.figure(figsize=(8, 5))
plt.plot(epochs, history["train_mse"], 'r-', label='Train MSE')
plt.plot(epochs, history["val_mse"], 'b-', label='Val MSE')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.title(f'Training & Validation MSE — Model {name}')
plt.legend()
plt.grid(True)
plt.show()
