## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
import csv

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler

import pickle 
from collections import Counter
import ast
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset


from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *
from scripts.lstm import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# drop taelliboden and plainemorte if in there
if 'taelliboden' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'taelliboden']
if 'plainemorte' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'plainemorte']

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))


## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')
data_monthly = dataloader_gl.data

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

data_annual = dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']

# print mean and std of N_MONTHS
print('Mean number of months:', data_annual.N_MONTHS.mean())
print('Std number of months:', data_annual.N_MONTHS.std())

# same for winter
data_winter = dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']
print('Mean number of months (winter):', data_winter.N_MONTHS.mean())
print('Std number of months (winter):', data_winter.N_MONTHS.std())

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(TEST_GLACIERS)]))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## LSTM:

In [None]:
MONTHLY_COLS = ['t2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'pcsr', 'ELEVATION_DIFFERENCE', ]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt',
    'consensus_ice_thickness', 'millan_v'
]

HYDRO_MONTHS = [
    'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul',
    'aug', 'sep'
]
HYDRO_POS = {m: i for i, m in enumerate(HYDRO_MONTHS)}

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()


In [None]:
# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train, MONTHLY_COLS, STATIC_COLS, HYDRO_POS, expect_target=True)

# --- make train/val loaders and get split indices (scalers fit on TRAIN, applied to whole ds) ---
train_dl, val_dl, train_idx, val_idx = ds_train.make_loaders(
    val_ratio=0.2,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# sanity checks
print("y mean/std (train):", float(ds_train.y_mean), float(ds_train.y_std))
for name, arr in [("Xm", ds_train.Xm), ("Xs", ds_train.Xs), ("y", ds_train.y)]:
    print(name, "has NaN?", bool(torch.isnan(arr).any().item()))
print("Train counts:", int(ds_train.iw[train_idx].sum()), "winter |",
      int(ds_train.ia[train_idx].sum()), "annual")
print("Val   counts:", int(ds_train.iw[val_idx].sum()), "winter |",
      int(ds_train.ia[val_idx].sum()), "annual")

# --- build test dataset and loader (reuses train scalers) ---
ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    HYDRO_POS,
    expect_target=True  # you said test has real targets
)
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test, ds_train, batch_size=128)

In [None]:
N_w, N_a = int(ds_train.iw.sum()), int(ds_train.ia.sum())
w_annual = N_w / max(N_a, 1)
w_annual

### Define & train model:

In [None]:
# load grid search results:
gs_results = pd.read_csv(
    'logs/lstm_param_search_progress_2025-09-06_SEB_all_OGGM.csv').sort_values(
        by='valid_loss', ascending=True)

gs_results['avg_test_loss'] = (gs_results['test_rmse_a'] +
                               gs_results['test_rmse_w']) / 2
gs_results = gs_results.sort_values(by='avg_test_loss', ascending=True)
best_params = gs_results.iloc[0].to_dict()

print('Best parameters from grid search:')
for key, value in best_params.items():
    if key not in ['valid_loss', 'train_loss']:
        print(f"{key}: {value}")

gs_results.head(10)

In [None]:
Fm = ds_train.Xm.shape[-1]
Fs = ds_train.Xs.shape[-1]

params = {
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': True,
    'dropout': 0.3,
    'static_hidden': 64,
    'lr': 0.001,
    'weight_decay': 0.0001,
    'batch_size': 32,
}

# Build model
model = mbm.models.LSTM_MB(Fm=ds_train.Xm.shape[-1],
                           Fs=ds_train.Xs.shape[-1],
                           hidden_size=params['hidden_size'],
                           num_layers=params['num_layers'],
                           bidirectional=params['bidirectional'],
                           dropout=params['dropout'],
                           static_hidden=params['static_hidden'],
                           two_heads=True,
                           head_dropout=0.1).to(device)

TRAIN = True
if TRAIN:
    current_date = datetime.now().strftime("%Y-%m-%d")
    model_filename = f"models/lstm_model_{current_date}_two_heads.pt"

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=params['lr'],
        weight_decay=params['weight_decay'],
        clip_val=1.0,
        loss_fn=lambda outs, b: mbm.models.LSTM_MB.seasonal_mse_weighted(
            outs, b, w_winter=1.0, w_annual=w_annual),
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # optional checkpoint
        save_best_path=model_filename,
    )

### Make predictions:

#### Evaluate on test:

In [None]:
# Load the best weights
best_state = torch.load(f"models/lstm_model_2025-09-09_two_heads.pt",
                        map_location=device)
model.load_state_dict(best_state)

# Evaluate on test for this config
test_metrics, test_df_preds = mbm.models.LSTM_MB.evaluate_with_preds(
    model, device, test_dl, ds_test)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(20, 15), sharex=True)

PlotIndividualGlacierPredVsTruth(test_df_preds,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=test_gl_per_el)


## Grid search:

In [None]:
# param_space = {
#     "hidden_size": [64, 128, 256],
#     "num_layers": [1, 2],
#     "bidirectional": [True, False],
#     "dropout": [0.0, 0.1, 0.3],
#     "static_hidden": [32, 64, 128],
#     "lr": [1e-3, 5e-4, 1e-4],
#     "weight_decay": [1e-4, 1e-5],
#     "batch_size": [32, 64, 128],
# }

# def sample_param_combinations(space, n_samples):
#     all_keys = list(space.keys())
#     all_vals = list(space.values())
#     all_combinations = list(itertools.product(*all_vals))
#     random.shuffle(all_combinations)
#     sampled = all_combinations[:n_samples]
#     return [dict(zip(all_keys, vals)) for vals in sampled]

# sampled_params = sample_param_combinations(param_space, n_samples=100)
# sampled_params


In [None]:
# RUN = False
# if RUN:
#     # make sure logs dir exists
#     os.makedirs("logs", exist_ok=True)
#     log_filename = f'logs/lstm_param_search_progress_{datetime.now().strftime("%Y-%m-%d")}_SEB_all_OGGM.csv'

#     # create log with header
#     with open(log_filename, mode='w', newline='') as log_file:
#         writer = csv.DictWriter(log_file,
#                                 fieldnames=list(sampled_params[0].keys()) +
#                                 ['valid_loss', 'test_rmse_a', 'test_rmse_w'])
#         writer.writeheader()

#     results = []
#     best_overall = {"val": float('inf'), "row": None, "params": None}

#     for i, params in enumerate(sampled_params):
#         seed_all(cfg.seed)
#         print(f"\n--- Running config {i+1}/{len(sampled_params)} ---")
#         print(params)

#         # Build model
#         model = mbm.models.LSTM_MB(
#             Fm=ds_train.Xm.shape[-1],
#             Fs=ds_train.Xs.shape[-1],
#             hidden_size=params['hidden_size'],
#             num_layers=params['num_layers'],
#             bidirectional=params['bidirectional'],
#             dropout=params['dropout'],
#             static_hidden=params['static_hidden'],
#             two_heads=True,
#             head_dropout=params['head_dropout']).to(device)

#         history, best_val, best_state = model.train_loop(
#             device=device,
#             train_dl=train_dl,
#             val_dl=val_dl,
#             epochs=150,
#             lr=params['lr'],
#             weight_decay=params['weight_decay'],
#             clip_val=1.0,
#             # scheduler
#             sched_factor=0.5,
#             sched_patience=6,
#             sched_threshold=0.01,
#             sched_threshold_mode="rel",
#             sched_cooldown=1,
#             sched_min_lr=1e-6,
#             # early stopping
#             es_patience=15,
#             es_min_delta=1e-4,
#             # logging
#             log_every=5,
#             verbose=False,
#             # optional checkpoint
#             save_best_path="models/best_lstm_mb_gs.pt",
#         )

#         # Load the best weights
#         best_state = torch.load("models/best_lstm_mb_gs.pt",
#                                 map_location=device)
#         model.load_state_dict(best_state)

#         # Evaluate on test for this config
#         test_metrics, test_df_preds = mbm.models.LSTM_MB.evaluate_with_preds(
#             model, device, test_dl, ds_test)
#         test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
#             'RMSE_winter']

#         # Log a row for this config
#         row = {
#             **params, 'valid_loss': float(best_val),
#             'test_rmse_a': float(test_rmse_a),
#             'test_rmse_w': float(test_rmse_w)
#         }
#         with open(log_filename, mode='a', newline='') as log_file:
#             writer = csv.DictWriter(log_file, fieldnames=list(row.keys()))
#             writer.writerow(row)

#         results.append(row)

#         # Track overall best by validation loss (or swap to test RMSE if you prefer)
#         if best_val < best_overall['val']:
#             best_overall = {"val": best_val, "row": row, "params": params}

#     print("\n=== Best config by validation loss ===")
#     print(best_overall['params'])
#     print(best_overall['row'])

In [None]:
import itertools
from functools import partial


def make_capped_grid(Fm, Fs):
    """
    Returns exactly 150 configs:
      - 144 with two_heads=True (broad coverage)
      - 6 with two_heads=False (light baseline)
    Design:
      Trunks (12): hs in {96,128,192} x nl in {1,2} x bi in {True, False}
        - if nl==1 -> dropout=0.0
        - if nl==2 -> dropout=0.1
      Static options (2): identity; 2-layer [128,64] with dropout 0.1
      Head dropout (2): {0.0, 0.2}
      Train combos (3): (lr, wd, loss) fixed tuples (see below)
      Single-head slice: 3 representative trunks, static identity only, head_dropout=0.0,
                         2 train combos.
    """

    # ---- helper to package a config ----
    def pack(trunk, static, two_heads, head_dropout, train_combo, loss_name):
        hs, nl, bi, do = trunk
        sl, sh, sd = static  # layers, hidden, static_dropout
        lr, wd = train_combo
        loss_spec = None if loss_name == "neutral" else ("weighted",
                                                         dict(w_winter=1.0,
                                                              w_annual=3.33))
        return dict(
            Fm=Fm,
            Fs=Fs,
            hidden_size=hs,
            num_layers=nl,
            bidirectional=bi,
            dropout=do,
            static_layers=sl,
            static_hidden=sh,
            static_dropout=sd,
            two_heads=two_heads,
            head_dropout=head_dropout,
            lr=lr,
            weight_decay=wd,
            clip_val=1.0,  # fixed to keep the cap
            sched_factor=0.5,
            sched_patience=6,  # fixed to keep the cap
            loss_name=loss_name,
            loss_spec=loss_spec)

    # ---- trunk set (12 total) ----
    # hs x nl x bi with pruned dropout rule
    trunks = []
    for hs in [96, 128, 192]:
        for nl in [1, 2]:
            for bi in [True, False]:
                do = 0.0 if nl == 1 else 0.1
                trunks.append((hs, nl, bi, do))
    assert len(trunks) == 12

    # ---- static options (2 total) ----
    static_id = (0, None, None)  # identity
    static_2L = (2, [128, 64], 0.1)  # two-layer MLP with modest dropout
    statics = [static_id, static_2L]

    # ---- head dropout (2) ----
    head_dropouts = [0.0, 0.2]

    # ---- training combos & loss variants ----
    # Three combos for the two-head grid:
    train_combos_th = [
        (3e-4, 0.0),  # lr, wd
        (1e-3, 1e-4),
        (3e-4, 1e-4),
    ]
    loss_names_th = ["weighted", "neutral", "weighted"]  # align length=3

    # Two combos for the single-head slice:
    train_combos_sh = [
        (3e-4, 0.0),
        (1e-3, 1e-4),
    ]
    loss_names_sh = ["weighted", "neutral"]

    grid = []

    # ---- 144 two-head configs ----
    for trunk in trunks:
        for static in statics:
            for hd in head_dropouts:
                for (lr, wd), loss_name in zip(train_combos_th, loss_names_th):
                    grid.append(
                        pack(trunk, static, True, hd, (lr, wd), loss_name))

    # ---- 6 single-head configs (light baseline) ----
    # Choose 3 representative trunks (covering nl/bidirectional variety)
    rep_trunks = [
        (128, 1, True, 0.0),
        (128, 2, True, 0.1),
        (192, 2, False, 0.1),
    ]
    for trunk, (lr, wd), loss_name in zip(
            itertools.chain.from_iterable([[t] * 2 for t in rep_trunks
                                           ]),  # repeat each trunk 2 times
            train_combos_sh * 3,
            loss_names_sh * 3):
        grid.append(pack(trunk, static_id, False, 0.0, (lr, wd), loss_name))

    assert len(grid) == 150
    return grid


def build_model_from_params(params, mbm, device):
    return mbm.models.LSTM_MB(
        Fm=params['Fm'],
        Fs=params['Fs'],
        hidden_size=params['hidden_size'],
        num_layers=params['num_layers'],
        bidirectional=params['bidirectional'],
        dropout=params['dropout'],
        static_hidden=params['static_hidden'],
        static_layers=params['static_layers'],
        static_dropout=params['static_dropout'],
        two_heads=params['two_heads'],
        head_dropout=params['head_dropout'],
    ).to(device)


def resolve_loss_fn(params, mbm):
    if params['loss_spec'] is None:
        return mbm.models.LSTM_MB.custom_loss
    kind, kw = params['loss_spec']
    if kind == "weighted":
        return partial(mbm.models.LSTM_MB.seasonal_mse_weighted, **kw)
    return mbm.models.LSTM_MB.custom_loss

In [None]:
sampled_params = make_capped_grid(ds_train.Xm.shape[-1], ds_train.Xs.shape[-1])
print("Total configs:", len(sampled_params))  # should print 150
sampled_params[:20]

In [None]:
RUN = True
if RUN:
    os.makedirs("logs", exist_ok=True)
    os.makedirs("models", exist_ok=True)

    log_filename = f'logs/lstm_param_search_progress_{datetime.now().strftime("%Y-%m-%d")}_SEB_all_OGGM.csv'

    # create log with header
    with open(log_filename, mode='w', newline='') as log_file:
        writer = csv.DictWriter(log_file,
                                fieldnames=list(sampled_params[0].keys()) +
                                ['valid_loss', 'test_rmse_a', 'test_rmse_w'])
        writer.writeheader()

    results = []
    best_overall = {"val": float('inf'), "row": None, "params": None}

    for i, params in enumerate(sampled_params):
        seed_all(cfg.seed)
        print(f"\n--- Running config {i+1}/{len(sampled_params)} ---")
        print(params)

        # Build model
        model = build_model_from_params(params, mbm, device)

        # Choose loss
        loss_fn = resolve_loss_fn(params, mbm)

        # Train
        history, best_val, best_state = model.train_loop(
            device=device,
            train_dl=train_dl,
            val_dl=val_dl,
            epochs=150,
            lr=params['lr'],
            weight_decay=params['weight_decay'],
            clip_val=params['clip_val'],
            # scheduler
            sched_factor=params['sched_factor'],
            sched_patience=params['sched_patience'],
            sched_threshold=0.01,
            sched_threshold_mode="rel",
            sched_cooldown=1,
            sched_min_lr=1e-6,
            # early stopping
            es_patience=15,
            es_min_delta=1e-4,
            # logging
            log_every=5,
            verbose=False,
            # checkpoint
            save_best_path="models/best_lstm_mb_gs.pt",
            loss_fn=loss_fn,
        )

        # Load the best weights
        best_state = torch.load("models/best_lstm_mb_gs.pt",
                                map_location=device)
        model.load_state_dict(best_state)

        # Evaluate on test
        test_metrics, test_df_preds = model.evaluate_with_preds(
            device, test_dl, ds_test)
        test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
            'RMSE_winter']

        # Log row
        row = {
            **params, 'valid_loss': float(best_val),
            'test_rmse_a': float(test_rmse_a),
            'test_rmse_w': float(test_rmse_w)
        }
        with open(log_filename, mode='a', newline='') as log_file:
            writer = csv.DictWriter(log_file, fieldnames=list(row.keys()))
            writer.writerow(row)

        results.append(row)

        # Track best by validation loss
        if best_val < best_overall['val']:
            best_overall = {"val": best_val, "row": row, "params": params}

    print("\n=== Best config by validation loss ===")
    print(best_overall['params'])
    print(best_overall['row'])


## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
all_columns = feature_columns + cfg.fieldsNotFeatures

# Required by the dataset builder regardless of your feature list
REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']

# Paths
path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/LSTM_two_heads')
os.makedirs(path_save_glw, exist_ok=True)
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM',
                             'xr_masked_grids')

# Load model once

ckpt_path = f"models/lstm_model_2025-09-08_two_heads.pt"
best_state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(best_state)
model.eval()

RUN = True
if RUN:
    emptyfolder(path_save_glw)

    for glacier_name in glacier_list:
        glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                                    glacier_name)
        if not os.path.exists(glacier_path):
            print(f"Folder not found for {glacier_name}, skipping...")
            continue

        glacier_files = sorted(
            [f for f in os.listdir(glacier_path) if glacier_name in f])

        geodetic_range = range(np.min(periods_per_glacier[glacier_name]),
                               np.max(periods_per_glacier[glacier_name]) + 1)

        years = [int(f.split('_')[2].split('.')[0]) for f in glacier_files]
        years = [y for y in years if y in geodetic_range]

        print(
            f"Processing {glacier_name} ({len(years)} files): {geodetic_range}"
        )

        for year in tqdm(years, desc=f"Processing {glacier_name}",
                         leave=False):
            file_name = f"{glacier_name}_grid_{year}.parquet"
            df_grid_monthly = pd.read_parquet(
                os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                             glacier_name, file_name)).copy()

            df_grid_monthly.drop_duplicates(inplace=True)

            # Keep required + feature columns; DON'T drop PERIOD/MONTHS/YEAR/ID/GLACIER
            keep = [
                c for c in (set(all_columns) | set(REQUIRED))
                if c in df_grid_monthly.columns
            ]
            df_grid_monthly = df_grid_monthly[keep]

            # Ensure PERIOD is set to 'annual' BEFORE sequence building
            if 'PERIOD' not in df_grid_monthly.columns:
                df_grid_monthly['PERIOD'] = 'annual'
            else:
                df_grid_monthly['PERIOD'] = df_grid_monthly['PERIOD'].fillna(
                    'annual')
            df_grid_monthly['PERIOD'] = df_grid_monthly['PERIOD'].str.strip(
            ).str.lower()

            # (Optional) minimal NaN clean-up: only drop if ID or MONTHS missing
            df_grid_monthly = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])

            # (Optional) hydrological coverage check
            have = set(df_grid_monthly['MONTHS'].str.lower().unique())
            need = {
                'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun',
                'jul', 'aug', 'sep'
            }
            if not need.issubset(have):
                missing = sorted(list(need - have))
                print(
                    f"WARNING [{glacier_name} {year}]: missing hydro months: {missing}"
                )

            # --- Build ds_gl WITHOUT targets ---
            ds_gl = mbm.data_processing.MBSequenceDataset.from_dataframe(
                df_grid_monthly,
                MONTHLY_COLS,
                STATIC_COLS,
                HYDRO_POS,
                expect_target=False,
                show_progress=False)
            gl_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
                ds_gl, ds_train,
                batch_size=128)  # copies train scalers & transforms ds_gl

            # Predict (no metrics)
            df_preds = model.predict_with_keys(device, gl_dl, ds_gl)

            # Join preds back to unique cell IDs for saving
            data = df_preds[['ID', 'pred']].set_index('ID')
            grouped_ids = df_grid_monthly.groupby('ID')[[
                'YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'
            ]].first()
            grouped_ids = grouped_ids.merge(data,
                                            left_index=True,
                                            right_index=True,
                                            how='left')

            months_per_id = df_grid_monthly.groupby('ID')['MONTHS'].unique()
            grouped_ids = grouped_ids.merge(months_per_id,
                                            left_index=True,
                                            right_index=True)

            grouped_ids.reset_index(inplace=True)
            grouped_ids.sort_values(by='ID', inplace=True)
            grouped_ids['PERIOD'] = 'annual'

            pred_y_annual = grouped_ids.drop(columns=['YEAR'], errors='ignore')

            # Save
            path_glacier_dem = os.path.join(cfg.dataPath, path_xr_grids,
                                            f"{glacier_name}_{year}.zarr")
            ds = xr.open_dataset(path_glacier_dem)
            geoData = mbm.geodata.GeoData(df_grid_monthly)
            geoData._save_prediction(ds, pred_y_annual, glacier_name, year,
                                     path_save_glw, "annual")

# quick viz
glacier_name = 'aletsch'
year = 2008
xr.open_dataset(os.path.join(path_save_glw, f'{glacier_name}/{glacier_name}_{year}_annual.zarr'))\
  .pred_masked.plot(cmap='RdBu')

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"],
                             squared=False)
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)
