### Setup

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import MinMaxScaler, RobustScaler
import matplotlib.pyplot as plt

from src.lstm import LSTM
from src.trainer import Trainer, EarlyStopping
from data.data_creation import get_trajectories, plot_trajectories, plot_boxplots

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

### Load and prepare data

In [None]:
if not os.path.exists('data/trajectories.npy'):
    x, y, vx, vy, t = get_trajectories()
    np.save('data/trajectories.npy', [x, y, vx, vy, t])
else:
    x, y, vx, vy, t = np.load('data/trajectories.npy', allow_pickle=True)

# Stack into a (num_samples, 12) array
data = np.vstack((x, y, vx, vy)).T
print(f"Data shape: {data.shape}")


In [None]:
train_test_split = 0.85
train_end = int(train_test_split * len(data))
train_data = data[:train_end]
test_data = data[train_end:]
print(f"Train: {train_data.shape}, Test: {test_data.shape}")

# --- Normalize Data ---
scaler = RobustScaler()  # or MinMaxScaler()
scaler.fit(train_data)
train_data_scaled = scaler.transform(train_data)
test_data_scaled = scaler.transform(test_data)

# Save scaler for later use
os.makedirs("data", exist_ok=True)
with open("data/scaler.pkl", "wb") as f:
    pickle.dump(scaler, f)

In [None]:
def generate_xy(data, lag=50):
    X = torch.tensor(data[:-lag], dtype=torch.float32)
    y = torch.tensor(data[lag:], dtype=torch.float32)
    return X, y

lag = 50
X_train, y_train = generate_xy(train_data_scaled, lag)
X_test, y_test = generate_xy(test_data_scaled, lag)

## Boxplots

In [None]:
def remove_outliers(data):
    Q1 = np.percentile(data, 25, axis=0)
    Q3 = np.percentile(data, 75, axis=0)
    IQR = Q3 - Q1
    lower = Q1 - 1.5 * IQR
    upper = Q3 + 1.5 * IQR
    return np.clip(data, lower, upper)

data_clean = remove_outliers(data)
plot_boxplots(data_clean)

### Train stateful models

In [None]:
hidden_sizes = [256, 512, 1028]  # adjust as needed
num_epochs = 100

os.makedirs("stateful_models", exist_ok=True)

for hidden_size in hidden_sizes:
    print(f"\n--- Training LSTM with hidden size {hidden_size} ---")
    model = LSTM(input_size=12, hidden_size=hidden_size, output_size=12, initializer_method='xavier').to(device)
    early_stopping = EarlyStopping(patience=1)
    trainer = Trainer(model, learning_rate=0.001, early_stopping=early_stopping)
    
    trainer.train(X_train.to(device), y_train.to(device), X_test.to(device), y_test.to(device), epochs=num_epochs)
    
    # Save model and training history
    model_path = f"stateful_models/lstm_stateful_h{hidden_size}.pt"
    torch.save(model.state_dict(), model_path)
    np.save(f"stateful_models/train_losses_h{hidden_size}.npy", np.array(trainer.train_losses))
    np.save(f"stateful_models/val_losses_h{hidden_size}.npy", np.array(trainer.val_losses))
    
    print(f"Saved model and losses for hidden size {hidden_size}")

### Training/validation loss plot

In [None]:
hidden_size = 256  # choose which model to plot
train_losses = np.load(f"stateful_models/train_losses_h{hidden_size}.npy", allow_pickle=True)
val_losses = np.load(f"stateful_models/val_losses_h{hidden_size}.npy", allow_pickle=True)

plt.figure(figsize=(8,5))
plt.plot(train_losses, color='black', label='Training Loss')
plt.plot(val_losses, color='blue', label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'LSTM Loss - Hidden size {hidden_size}')
plt.grid(True)
plt.legend()
plt.show()


### Iterative predictions vs ground truth

In [None]:
initialization_steps = 1000
steps = 2000

model.load_state_dict(torch.load(f"stateful_models/lstm_stateful_h{hidden_size}.pt"))
model.eval()

output = model.generate_timeseries(X_test[:initialization_steps].to(device), steps=steps)
true = y_test[initialization_steps:steps]

# Convert back to original scale
true_np = scaler.inverse_transform(true.detach().numpy())
output_np = scaler.inverse_transform(output.detach().numpy())

plot_trajectories(true_np, output_np)


### Distribution of predicted features

In [None]:
plt.figure(figsize=(12,5))
for i in range(12):
    plt.plot(output_np[:, i], alpha=0.6, label=f"Feature {i+1}")
    plt.fill_between(range(len(output_np)), output_np[:, i], alpha=0.1)
plt.title("Predicted Feature Distribution")
plt.xlabel("Timestep")
plt.ylabel("Feature value")
plt.legend()
plt.show()