In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from scaled_dataloader import load_train, load_val
from sklearn.preprocessing import StandardScaler
import numpy as np
import matplotlib.pyplot as plt
from plot_trajectory import plot_paths
#from metrics import rmse, mse, mae
import joblib
import math
import os
import json
import copy


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
#Transformer Model


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float()
                        * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        pe = pe.unsqueeze(0)     # â†’ [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x: [batch, seq_len, d_model]
        """
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class TrajectoryTransformer30to10(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        d_model: int = 128,
        nhead: int = 4,
        num_layers: int = 3,
        dim_feedforward: int = 256,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.output_dim = output_dim
        self.future_steps = 10

        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_enc = PositionalEncoding(d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,              # <-- CHANGED
        )

        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        self.out = nn.Linear(d_model, output_dim * self.future_steps)

    def forward(self, src):
        """
        src: [batch, 30, input_dim]
        return: [batch, 10, output_dim]
        """
        x = self.input_proj(src)       # [B, 30, d_model]
        x = self.pos_enc(x)            # [B, 30, d_model]
        x = self.encoder(x)            # [B, 30, d_model]

        last_state = x[:, -1, :]       # [B, d_model]

        out = self.out(last_state)     # [B, 10 * output_dim]

        return out.reshape(-1, 10, self.output_dim)


In [None]:
train_ds, scaler = load_train()
val_ds = load_val(scaler)

val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=64, shuffle=False
)
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=64, shuffle=True
)

In [None]:
#Configurations:

#Test different Models
'''{   
        "name": "Mini_Model",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 128,
            "nhead": 4,
            "num_layers": 3,
            "dim_feedforward": 10,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 10,
            "learning_rate": 1e-3,
            "weight_decay": 1e-4,
        },
    },
    {
        "name": "A_small",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 128,
            "nhead": 4,
            "num_layers": 3,
            "dim_feedforward": 512,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 40,
            "learning_rate": 1e-3,
            "weight_decay": 1e-4,
        },
    },
       {
        "name": "Mini_Test",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 128,
            "nhead": 8,
            "num_layers": 2,
            "dim_feedforward": 100,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 5,
            "learning_rate": 5e-4,
            "weight_decay": 1e-4,
            "patience": 5,
        },
    },
    {
        "name": "C_deeper_2",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 256,
            "nhead": 8,
            "num_layers": 4,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 90,
            "learning_rate": 5e-4,
            "weight_decay": 1e-4,
            "patience": 5,
        },
    },
    {
        "name": "C_deeper",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 256,
            "nhead": 8,
            "num_layers": 4,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 50,
            "learning_rate": 5e-4,
            "weight_decay": 1e-4,
            "patience": 5,
        },
    },
'''

#Right now it only trains for B_medium
configs = [
     {
        "name": "B_medium",
        "model_kwargs": {
            "input_dim": 2,
            "output_dim": 2,
            "d_model": 256,
            "nhead": 8,
            "num_layers": 3,
            "dim_feedforward": 1024,
            "dropout": 0.1,
        },
        "train_kwargs": {
            "num_epochs": 40,
            "learning_rate": 5e-4,
            "weight_decay": 1e-4,
        },
     },
  
]


In [None]:
#Training loop I used for the models



def train_one_config(config, train_loader, val_loader, device):
    name = config["name"]
    model_kwargs = config["model_kwargs"]
    train_kwargs = config["train_kwargs"]

    num_epochs    = train_kwargs["num_epochs"]
    learning_rate = train_kwargs["learning_rate"]
    weight_decay  = train_kwargs["weight_decay"]
    patience      = train_kwargs.get("patience", None)  # z.B. 5 oder None

    print(f"\n=== Training config: {name} ===")
    print("Model args:", model_kwargs)
    print("Train args:", train_kwargs)

    model = TrajectoryTransformer30to10(**model_kwargs).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    history = {
        "epoch": [],
        "train_mse": [],
        "train_rmse": [],
        "train_mae": [],
        "val_mse": [],
        "val_rmse": [],
        "val_mae": [],
    }

    best_val_rmse = float("inf")
    best_state = None
    best_epoch = None
    epochs_no_improve = 0

    for epoch in range(1, num_epochs + 1):
        # ---- TRAIN ----
        model.train()
        train_mse = 0.0
        train_mae = 0.0
        n_train   = 0

        for batch in train_loader:
            X, Y = batch if len(batch) == 2 else batch[:2]
            X, Y = X.to(device), Y.to(device)

            optimizer.zero_grad()
            preds = model(X)
            loss = criterion(preds, Y)
            loss.backward()
            optimizer.step()

            mse_batch = loss.item()
            mae_batch = torch.mean(torch.abs(preds - Y)).item()
            bs = X.size(0)

            train_mse += mse_batch * bs
            train_mae += mae_batch * bs
            n_train   += bs

        train_mse /= n_train
        train_rmse = math.sqrt(train_mse)
        train_mae /= n_train

        # ---- VALIDATION ----
        model.eval()
        val_mse = 0.0
        val_mae = 0.0
        n_val   = 0

        with torch.no_grad():
            for batch in val_loader:
                X_val, Y_val = batch if len(batch) == 2 else batch[:2]
                X_val, Y_val = X_val.to(device), Y_val.to(device)

                preds_val = model(X_val)
                loss_val  = criterion(preds_val, Y_val)

                mse_batch = loss_val.item()
                mae_batch = torch.mean(torch.abs(preds_val - Y_val)).item()
                bs = X_val.size(0)

                val_mse += mse_batch * bs
                val_mae += mae_batch * bs
                n_val   += bs

        val_mse /= n_val
        val_rmse = math.sqrt(val_mse)
        val_mae /= n_val

        history["epoch"].append(epoch)
        history["train_mse"].append(train_mse)
        history["train_rmse"].append(train_rmse)
        history["train_mae"].append(train_mae)
        history["val_mse"].append(val_mse)
        history["val_rmse"].append(val_rmse)
        history["val_mae"].append(val_mae)

        print(
            f"Epoch {epoch:03d} | "
            f"Train RMSE={train_rmse:.4f}, Val RMSE={val_rmse:.4f}"
        )

        # ---- Early Stopping / remember best model ----
        if val_rmse < best_val_rmse:
            best_val_rmse = val_rmse
            best_epoch = epoch
            best_state = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if patience is not None and epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch} (no improvement for {patience} epochs).")
            break

    # reopen best model
    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"Loaded best model from epoch {best_epoch} with Val RMSE={best_val_rmse:.4f}")

    final_val_rmse = best_val_rmse
    final_val_mse  = None  

    return {
        "name": name,
        "config": config,
        "history": history,
        "final_val_rmse": final_val_rmse,
        "final_val_mse": final_val_mse,
        "model": model,
    }


In [None]:
#This part should be flexible -> the things you want to see, here I'm currently not saving the checkpoint
results = []
best_result = None

for cfg in configs:
    result = train_one_config(cfg, train_loader, val_loader, device)
    results.append(result)

    if best_result is None or result["final_val_rmse"] < best_result["final_val_rmse"]:
        best_result = result

print("\n======================")
print("Best config:", best_result["name"])
print("Best final Val RMSE:", best_result["final_val_rmse"])
print("Best model args:", best_result["config"]["model_kwargs"])
print("Best train args:", best_result["config"]["train_kwargs"])