## Setting Up:

In [None]:
import sys, os

sys.path.append(os.path.join(os.getcwd(),
                             '../../'))  # Add root of repo to import MBM
import csv
from functools import partial

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
import pickle
from collections import Counter
import ast

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
    'millan_v'
]

feature_columns = MONTHLY_COLS + STATIC_COLS

#### Build sequences:

In [None]:
df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True  # you said test has real targets
)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

## Grid search:

In [None]:
from itertools import product

param_grid = {
    "lr": [1e-3, 5e-4, 1e-4],
    "weight_decay": [0.0, 1e-5, 1e-4],
    "hidden_size": [64, 128],
    "num_layers": [1, 2],
    "dropout": [0.0, 0.2],
    "head_dropout": [0.0, 0.1],
}

static = [
    (0, 0, None),  # identity (use 0 here for robustness)
    (2, [128, 64], 0.1),  # small two-layer MLP
]

def pack(static_triplet):
    sl, sh, sd = static_triplet
    return dict(
        static_layers=sl,
        static_hidden=sh,
        static_dropout=sd,
    )

# ---- constants that should be the same for every sample ----
const_params = {
    "Fm": ds_train.Xm.shape[-1],  # monthly features
    "Fs": ds_train.Xs.shape[-1],  # static features
    "bidirectional": False,
    "loss_name": "neutral",
    "loss_spec": None,
}

def grid_iter_with_static_and_const(grid, static_list, const):
    keys = list(grid.keys())
    for values in product(*(grid[k] for k in keys), static_list):
        params = dict(zip(keys, values[:-1]))  # non-static hyperparams
        params.update(pack(values[-1]))  # add static config
        params.update(const)  # add constants
        yield params


# ---- generate all sampled param sets ----
sampled_params = list(
    grid_iter_with_static_and_const(param_grid, static, const_params))
print(len(sampled_params))
print(sampled_params[0])  # preview one combo

In [None]:
RUN = True
if RUN:
    os.makedirs("logs", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    log_filename = f'logs/lstm_simple_param_search_progress_{datetime.now().strftime("%Y-%m-%d")}.csv'

    # create log with header
    with open(log_filename, mode='w', newline='') as log_file:
        writer = csv.DictWriter(log_file,
                                fieldnames=list(sampled_params[0].keys()) +
                                ['valid_loss', 'test_rmse_a', 'test_rmse_w'])
        writer.writeheader()

    results = []
    best_overall = {"val": float('inf'), "row": None, "params": None}

    for i, params in enumerate(sampled_params):
        seed_all(cfg.seed)
        model_filename = 'models/best_lstm_mb_gs_simple.pt'

        # delete existing model file:
        if os.path.exists(model_filename):
            os.remove(model_filename)
            print(f"Deleted existing model file: {model_filename}")

        # Build dataloaders:
        ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_train)

        ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
            ds_test)

        # --- make train/val loaders and get split indices (scalers fit on TRAIN, applied to whole ds) ---
        train_dl, val_dl = ds_train_copy.make_loaders(
            train_idx=train_idx,
            val_idx=val_idx,
            batch_size_train=64,
            batch_size_val=128,
            seed=cfg.seed,
            fit_and_transform=
            True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
            shuffle_train=True,
            use_weighted_sampler=True  # use weighted sampler for training
        )

        # --- build test dataset and loader (reuses train scalers) ---
        test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
            ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

        print(f"\n--- Running config {i+1}/{len(sampled_params)} ---")
        print(params)

        # Build model
        model = mbm.models.LSTM_MB.build_model_from_params(cfg, params, device)

        # Choose loss
        loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(params)

        # Train
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=train_dl,
            val_dl=val_dl,
            epochs=150,
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            clip_val=1,
            # scheduler
            sched_factor=0.5,
            sched_patience=6,
            sched_threshold=0.01,
            sched_threshold_mode="rel",
            sched_cooldown=1,
            sched_min_lr=1e-6,
            # early stopping
            es_patience=15,
            es_min_delta=1e-4,
            # logging
            log_every=5,
            verbose=False,
            # checkpoint
            save_best_path=model_filename,
            loss_fn=loss_fn,
        )

        # Load the best weights
        best_state = torch.load(model_filename, map_location=device)
        model.load_state_dict(best_state)

        # Evaluate on test
        test_metrics, test_df_preds = model.evaluate_with_preds(
            device, test_dl, ds_test_copy)
        test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
            'RMSE_winter']

        # Log row
        row = {
            **params, 'valid_loss': float(best_val),
            'test_rmse_a': float(test_rmse_a),
            'test_rmse_w': float(test_rmse_w)
        }

        print(test_rmse_a, test_rmse_w)

        with open(log_filename, mode='a', newline='') as log_file:
            writer = csv.DictWriter(log_file, fieldnames=list(row.keys()))
            writer.writerow(row)

        results.append(row)

        # Track best by validation loss
        if best_val < best_overall['val']:
            best_overall = {"val": best_val, "row": row, "params": params}

    print("\n=== Best config by validation loss ===")
    print(best_overall['params'])
    print(best_overall['row'])

## Test best combination:

In [None]:
# --- pick the most recent grid-search log ---
log_path = 'logs/lstm_simple_param_search_progress_2025-09-16.csv'

df_log = pd.read_csv(log_path)
best_row = df_log.loc[df_log["valid_loss"].idxmin()].to_dict()


# --- normalize params from CSV ---
def _as_bool(x):
    if isinstance(x, bool): return x
    if isinstance(x, (int, float)): return bool(int(x))
    return str(x).strip().lower() in {"1", "true", "t", "yes", "y"}


def _as_opt_list(x):
    if pd.isna(x): return None
    s = str(x).strip()
    if s.lower() in {"", "none"}: return None
    try:
        return ast.literal_eval(s)
    except Exception:
        return None


best_params = {
    "Fm":
    int(best_row["Fm"]),
    "Fs":
    int(best_row["Fs"]),
    "hidden_size":
    int(best_row["hidden_size"]),
    "num_layers":
    int(best_row["num_layers"]),
    "bidirectional":
    _as_bool(best_row["bidirectional"]),
    "dropout":
    float(best_row["dropout"]),
    "static_layers":
    int(best_row["static_layers"]),
    "static_hidden":
    _as_opt_list(best_row.get("static_hidden")),
    "static_dropout":
    None if pd.isna(best_row.get("static_dropout")) else float(
        best_row["static_dropout"]),
    "simple":
    True,  # grid is two-heads only
    "head_dropout":
    float(best_row["head_dropout"]),
    "lr":
    float(best_row["lr"]),
    "weight_decay":
    float(best_row["weight_decay"]),
    "clip_val":
    float(best_row.get("clip_val", 1.0)),
    "sched_factor":
    float(best_row.get("sched_factor", 0.5)),
    "sched_patience":
    int(best_row.get("sched_patience", 6)),
    "loss_name":
    str(best_row.get("loss_name", "neutral")),
}
# Make "weighted" use your LSTM defaults
best_params["loss_spec"] = (
    "weighted", {}) if best_params["loss_name"] == "weighted" else None

# --- deterministic seeding (process-wide) ---
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
import random

random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Optional strict mode:
# torch.use_deterministic_algorithms(True)

# --- build datasets ONCE (fresh/pristine each run of this cell) ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)
ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

# --- reuse fixed split if present, else compute with cfg.seed ---
try:
    train_idx, val_idx  # already defined earlier in your session?
except NameError:
    train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
        len(ds_train), val_ratio=0.2, seed=cfg.seed)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
train_dl, val_dl = ds_train.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    seed=cfg.seed,
    batch_size_train=64,
    batch_size_val=128,
    fit_and_transform=True,
    shuffle_train=True,
    use_weighted_sampler=True,
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test, ds_train, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model_path = "models/lstm_best_simple_retrain.pt"
if os.path.exists(model_path): os.remove(model_path)

model = mbm.models.LSTM_MB.build_model_from_params(cfg, best_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

history, best_val, _ = model.train_loop(
    device=device,
    train_dl=train_dl,
    val_dl=val_dl,
    epochs=150,
    lr=best_params["lr"],
    weight_decay=best_params["weight_decay"],
    clip_val=best_params["clip_val"],
    # scheduler
    sched_factor=best_params["sched_factor"],
    sched_patience=best_params["sched_patience"],
    sched_threshold=0.01,
    sched_threshold_mode="rel",
    sched_cooldown=1,
    sched_min_lr=1e-6,
    # early stopping
    es_patience=15,
    es_min_delta=1e-4,
    # logging
    log_every=5,
    verbose=True,
    # checkpoint
    save_best_path=model_path,
    loss_fn=loss_fn,
)

state = torch.load(model_path, map_location=device)
model.load_state_dict(state)

# --- evaluate (IMPORTANT: pass ds_test, which now has scalers) ---
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test)
print("\nTest metrics:", test_metrics)
