In [None]:
"""
Seq1 Training + Recursive inference (single-step repeated) using LSTM encoder regressor.

- Train 7 models using LOOKBACKS = [2,3,4,5,6,7,8]
- For each lookback:
      • Build windows
      • Train model
      • Compute validation R²
- Choose best lookback by highest R² (tie → smallest lookback)
- Save ONLY that model + scaler (under the normal filenames)
- Run recursive inference ONLY AFTER best model is chosen

"""

import os
import numpy as np
import pandas as pd
import joblib
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# ---------------- USER SETTINGS ----------------
INPUT_CSV = "Path to your CSV"
FUTURE_SLR_FILE = "Path to your SLR data csv"

OUTDIR = "Output directory file"
os.makedirs(OUTDIR, exist_ok=True)

TARGET = "nsm"
TRAIN_END_YEAR = 2021 #year where you expect the model to end training
VAL_YEAR = 2025

LOOKBACKS = [2,3,4,5,6,7,8] 

BATCH_SIZE = 64
LR = 1e-2
EPOCHS = 500
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# FEATURES (can change based on what features your CSV contains)
BASE_SEQUENCE_FEATURES = [
    "swh_max", "swh_mean", "swh_std", "swh_storm_ratio",
    "attack_sin", "attack_cos", "slr_mm"
]
INTERACTION_FEATURES = ["slr_x_storm", "slr_x_sin", "slr_x_cos"]
SEQUENCE_FEATURES = BASE_SEQUENCE_FEATURES + INTERACTION_FEATURES

FUTURE_YEARS = [2030, 2035, 2040, 2045, 2050]

SCALER_PATH = os.path.join(OUTDIR, f"scaler_{VAL_YEAR}.joblib")
MODEL_PATH = os.path.join(OUTDIR, f"lstm_encoder_model_{VAL_YEAR}.pt")

# ---------------- 1. LOAD + FEATURE ENGINEER ----------------
print("Loading dataset:", INPUT_CSV)
df = pd.read_csv(INPUT_CSV)
df.columns = [c.strip() for c in df.columns]

df["swh_storm_ratio"] = df["swh_max"] / df["swh_mean"]

if all(c in df.columns for c in ["wd_mean_deg", "bearing_deg"]):
    angle_diff = df["wd_mean_deg"] - df["bearing_deg"]
    df["angle_of_attack"] = np.abs(angle_diff % 360)
    df["angle_of_attack"] = np.minimum(df["angle_of_attack"], 360 - df["angle_of_attack"])
    df["attack_sin"] = np.sin(np.radians(df["angle_of_attack"]))
    df["attack_cos"] = np.cos(np.radians(df["angle_of_attack"]))
else:
    df["attack_sin"] = df.get("attack_sin", 0.0)
    df["attack_cos"] = df.get("attack_cos", 0.0)

df["slr_x_storm"] = df["slr_mm"] * df["swh_storm_ratio"]
df["slr_x_sin"] = df["slr_mm"] * df["attack_sin"]
df["slr_x_cos"] = df["slr_mm"] * df["attack_cos"]

train_df = df[df["end_year"] <= TRAIN_END_YEAR].copy()
val_df   = df[df["end_year"] <= VAL_YEAR].copy()

print("Train rows:", len(train_df), "Val rows:", len(val_df))

# ---------------- HELPER: build windows ----------------
def build_single_step_windows(df_in, lookback, seq_features):
    X_list, y_list, meta_list = [], [], []
    for tid, g in df_in.groupby("transect_id"):
        g = g.sort_values("end_year").reset_index(drop=True)
        seq = g[seq_features].values
        n = len(g)
        if n <= lookback: continue
        for i in range(lookback, n):
            X_list.append(seq[i - lookback:i])
            y_list.append(g.loc[i, TARGET])
            meta_list.append(g.iloc[i].to_dict())
    if len(X_list) == 0:
        return np.empty((0, lookback, len(seq_features))), np.empty((0,)), pd.DataFrame()
    return np.stack(X_list), np.array(y_list), pd.DataFrame(meta_list)

# ---------------- MODEL CLASS ----------------
class LSTMEncoderRegressor(nn.Module):
    def __init__(self, n_features, hidden_size=64, n_layers=1, dropout=0.0):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout
        )
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)
        h_last = out[:, -1]
        return self.fc(h_last).squeeze(-1)

# ---------------- 2A. TRAIN OVER ALL LOOKBACKS ----------------
results = []  # store: (lookback, r2, model_state_dict, scaler_obj, X_val_data)

print("\n==============================")
print("TRAINING OVER LOOKBACKS: 2–8")
print("==============================")

for LOOKBACK in LOOKBACKS:
    print(f"\n---- Training for LOOKBACK = {LOOKBACK} ----")

    # build windows
    X_train, Y_train, _ = build_single_step_windows(train_df, LOOKBACK, SEQUENCE_FEATURES)
    X_val,   Y_val,   _ = build_single_step_windows(val_df, LOOKBACK, SEQUENCE_FEATURES)

    if X_train.shape[0] == 0:
        print(f"Skipping LOOKBACK {LOOKBACK} — insufficient historical rows")
        continue

    # scale
    n_features = X_train.shape[2]
    scaler = StandardScaler()
    flat_train = X_train.reshape(X_train.shape[0], -1)
    flat_val   = X_val.reshape(X_val.shape[0], -1)

    flat_train_scaled = scaler.fit_transform(flat_train)
    flat_val_scaled   = scaler.transform(flat_val)

    X_train_scaled = flat_train_scaled.reshape(X_train.shape)
    X_val_scaled   = flat_val_scaled.reshape(X_val.shape)

    # loaders
    train_ds = torch.utils.data.TensorDataset(
        torch.tensor(X_train_scaled, dtype=torch.float32),
        torch.tensor(Y_train, dtype=torch.float32)
    )
    val_ds = torch.utils.data.TensorDataset(
        torch.tensor(X_val_scaled, dtype=torch.float32),
        torch.tensor(Y_val, dtype=torch.float32)
    )

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # model
    model = LSTMEncoderRegressor(n_features).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.MSELoss()

    # training
    best_state = None
    best_loss = np.inf
    patience = 12
    pat = 0

    for epoch in range(1, EPOCHS+1):
        model.train()
        losses = []
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            optimizer.zero_grad()
            p = model(xb)
            loss = criterion(p, yb)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        # validate
        model.eval()
        vlosses = []
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                vlosses.append(criterion(model(xb), yb).item())
        vloss = np.mean(vlosses)

        # early stop
        if vloss < best_loss:
            best_loss = vloss
            best_state = model.state_dict()
            pat = 0
        else:
            pat += 1
        if pat >= patience:
            break

    # compute final metrics
    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        preds = model(torch.tensor(X_val_scaled, dtype=torch.float32).to(DEVICE)).cpu().numpy()
    r2 = r2_score(Y_val, preds)
    mae = mean_absolute_error(Y_val, preds)
    rmse = np.sqrt(mean_squared_error(Y_val, preds))

    print(f"LOOKBACK {LOOKBACK}: R2 = {r2:.4f}, MAE = {mae:.4f}, RMSE = {rmse:.4f}")

    results.append((LOOKBACK, r2, mae, rmse, best_state, scaler, X_val_scaled, Y_val))
    # ---------------- PART B: select best lookback, save model+scaler, THEN run inference ----------------

if not results:
    raise RuntimeError("No trained results found. Ensure Part A training loop ran and filled `results`.")

# Select best by highest R2, tie-breaker = smallest lookback
sorted_results = sorted(results, key=lambda t: (-t[1], t[0]))  # sort by r2 desc, lookback asc
best_lookback, best_r2, best_mae, best_rmse, best_state, best_scaler, best_Xval_scaled, best_Yval = sorted_results[0]

print("\n==============================")
print(f"BEST LOOKBACK SELECTED: {best_lookback}  (R2 = {best_r2:.4f}, MAE = {best_mae:.4f}, RMSE = {best_rmse:.4f})")
print("==============================\n")


# Save best scaler and model state
joblib.dump(best_scaler, SCALER_PATH)
print("Saved scaler ->", SCALER_PATH)

# instantiate model with the number of features matching chosen lookback windows
n_features = len(SEQUENCE_FEATURES)
best_model = LSTMEncoderRegressor(n_features=n_features).to(DEVICE)
best_model.load_state_dict(best_state)
torch.save(best_model.state_dict(), MODEL_PATH)
print("Saved model ->", MODEL_PATH)

# Rebuild validation windows for the best lookback to get proper validation data
X_train_best, Y_train_best, _ = build_single_step_windows(train_df, best_lookback, SEQUENCE_FEATURES)
X_val_best,   Y_val_best,   meta_val_best = build_single_step_windows(val_df, best_lookback, SEQUENCE_FEATURES)

if len(X_val_best) > 0:
    # Scale using the best scaler
    n_features = X_val_best.shape[2]
    flat_val_best = X_val_best.reshape(X_val_best.shape[0], -1)
    flat_val_best_scaled = best_scaler.transform(flat_val_best)
    X_val_best_scaled = flat_val_best_scaled.reshape(X_val_best.shape)
    
    # Get predictions using the best model
    best_model.eval()
    with torch.no_grad():
        y_pred_val_best = best_model(torch.tensor(X_val_best_scaled, dtype=torch.float32).to(DEVICE)).cpu().numpy()
    
    # Create validation export DataFrame
    val_export = pd.DataFrame({
        "transect_id": [m["transect_id"] for m in meta_val_best.to_dict("records")],
        "observed_nsm": Y_val_best,
        "predicted_nsm_RNN": y_pred_val_best,
        "region": [m["region"] for m in meta_val_best.to_dict("records")],
        "bearing_deg": [m["bearing_deg"] for m in meta_val_best.to_dict("records")]
    })
    
    # Calculate metrics
    val_mae = mean_absolute_error(Y_val_best, y_pred_val_best)
    val_rmse = np.sqrt(mean_squared_error(Y_val_best, y_pred_val_best))
    val_r2 = r2_score(Y_val_best, y_pred_val_best)
    
    print(f"\nValidation MAE: {val_mae:.4f}, RMSE: {val_rmse:.4f}, R2: {val_r2:.4f}")
    
    # Save validation predictions
    val_export.to_csv(os.path.join(OUTDIR, "validation_predictions_2025.csv"), index=False)
    print(f"Saved validation predictions -> {os.path.join(OUTDIR, 'validation_predictions_2025.csv')}")
    print(f"Validation dataset shape: {val_export.shape}")
    print(f"Sample validation \n{val_export.head()}")

# --------------- Quick sanity metrics on best model (val R2/MAE) ---------------
best_model.eval()
with torch.no_grad():
    y_pred_val = best_model(torch.tensor(best_Xval_scaled, dtype=torch.float32).to(DEVICE)).cpu().numpy()
val_mae = mean_absolute_error(best_Yval, y_pred_val)
val_rmse = np.sqrt(mean_squared_error(best_Yval, y_pred_val))
val_r2  = r2_score(best_Yval, y_pred_val)
print(f"\nBest model validation MAE: {val_mae:.4f}, RMSE: {val_rmse:.4f}, R2: {val_r2:.4f}")

# Extract and save 2021 baseline shoreline points
baseline_2021 = df[df["end_year"] == 2021].copy()
available_columns = [col for col in ["region", "transect_id", "bearing_deg", "nsm"] if col in baseline_2021.columns]
baseline_2021_export = baseline_2021[available_columns].copy()
if "nsm" in baseline_2021_export.columns:
    baseline_2021_export = baseline_2021_export.rename(columns={"nsm": "nsm_2021_baseline"})
baseline_2021_export.to_csv(os.path.join(OUTDIR, "baseline_2021_shoreline.csv"), index=False)
print(f"Saved 2021 baseline shoreline -> {os.path.join(OUTDIR, 'baseline_2021_shoreline.csv')}")
print(f"Baseline 2021 dataset shape: {baseline_2021_export.shape}")
print(f"Sample baseline \n{baseline_2021_export.head()}")

# ---------------- 7. RECURSIVE INFERENCE (single-step repeated) ----------------
print("\nLoading transect-specific future SLR:", FUTURE_SLR_FILE)
future_slr = pd.read_csv(FUTURE_SLR_FILE)
if "year" in future_slr.columns and "end_year" not in future_slr.columns:
    future_slr = future_slr.rename(columns={"year": "end_year"})
required = ["transect_id", "end_year", "slr_mm"]
if not all(c in future_slr.columns for c in required):
    raise RuntimeError(f"FUTURE_SLR_FILE must contain columns: {required}")

slr_lookup = future_slr.set_index(["transect_id", "end_year"])["slr_mm"]

# Prepare historical last-known rows (up to 2025)
hist_last = df[df["end_year"] <= 2025].copy()

# Helper: build initial window for selected best_lookback
def build_initial_window_for_lookback(transect_id, lookback):
    g = hist_last[hist_last["transect_id"] == transect_id].sort_values("end_year")
    if len(g) < lookback:
        return None
    window = g.iloc[-lookback:][SEQUENCE_FEATURES].values.copy()
    meta_row = g.iloc[-1]
    return window, meta_row

# scale helper using saved scaler
def scale_window_with_saved_scaler(window):
    flat = window.reshape(1, -1)
    flat_scaled = best_scaler.transform(flat)
    return flat_scaled.reshape(1, best_lookback, n_features)

# Setup outputs container
outputs_per_year = {y: [] for y in FUTURE_YEARS}
transect_ids = sorted(hist_last["transect_id"].unique())
print(f"\nStarting recursive inference for {len(transect_ids)} transects using lookback={best_lookback} ...")

# indices used repeatedly
slr_pos = SEQUENCE_FEATURES.index("slr_mm")
ix_swh_storm = SEQUENCE_FEATURES.index("swh_storm_ratio")
ix_attack_sin = SEQUENCE_FEATURES.index("attack_sin")
ix_attack_cos = SEQUENCE_FEATURES.index("attack_cos")
ix_slr_x_storm = SEQUENCE_FEATURES.index("slr_x_storm")
ix_slr_x_sin = SEQUENCE_FEATURES.index("slr_x_sin")
ix_slr_x_cos = SEQUENCE_FEATURES.index("slr_x_cos")

# Precompute sorted SLR years after 2025
slr_years_sorted = sorted(int(y) for y in future_slr["end_year"].unique() if int(y) > 2025)

best_model.eval()
with torch.no_grad():
    for tid in transect_ids:
        res = build_initial_window_for_lookback(tid, best_lookback)
        if res is None:
            continue
        window, meta_row = res
        current_window = window.copy()
        # keep region and bearing from meta_row (if present)
        region_val = meta_row.get("region", None)
        bearing_val = meta_row.get("bearing_deg", None)

        for step_year in slr_years_sorted:
            # get transect-specific slr
            try:
                slr_val = slr_lookup.loc[(tid, step_year)]
            except KeyError:
                # skip missing
                continue

            # update last row's slr + interactions
            updated_last = current_window[-1].copy()
            updated_last[slr_pos] = slr_val
            swh_ratio_val = updated_last[ix_swh_storm]
            attack_sin_val = updated_last[ix_attack_sin]
            attack_cos_val = updated_last[ix_attack_cos]
            updated_last[ix_slr_x_storm] = slr_val * swh_ratio_val
            updated_last[ix_slr_x_sin] = slr_val * attack_sin_val
            updated_last[ix_slr_x_cos] = slr_val * attack_cos_val

            temp_window = current_window.copy()
            temp_window[-1] = updated_last

            # scale & predict
            X_in = scale_window_with_saved_scaler(temp_window)   # shape (1, lookback, n_features)
            X_tensor = torch.tensor(X_in, dtype=torch.float32).to(DEVICE)
            pred_nsm = float(best_model(X_tensor).cpu().numpy().item())

            # if this step_year is requested, store prediction
            if step_year in FUTURE_YEARS:
                outputs_per_year[step_year].append({
                    "transect_id": tid,
                    "predicted_nsm_RNN": pred_nsm,
                    "scenario_year": step_year,
                    "region": region_val,
                    "bearing_deg": bearing_val
                })

# ---------------- SAVE PER-YEAR CSVs (with region + bearing_deg included) ----------------
for year in FUTURE_YEARS:
    rows = outputs_per_year.get(year, [])
    if len(rows) == 0:
        print(f"No predictions for {year} — check FUTURE_SLR_FILE or historical depth.")
        continue
    out_df = pd.DataFrame(rows)
    cols = ["region", "transect_id", "bearing_deg", "scenario_year", "predicted_nsm_RNN"]
    for c in cols:
        if c not in out_df.columns:
            out_df[c] = None
    out_df = out_df[cols]
    out_path = os.path.join(OUTDIR, f"predictions_rnn_{year}.csv")
    out_df.to_csv(out_path, index=False)
    print("Saved predictions:", out_path)

print("\n✅ Done. Best model saved and recursive predictions exported.")