## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from collections import Counter
import ast
from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader

from typing import Dict, List, Tuple, Optional

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

# vois_climate = [
#     't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
# ]

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()

    # # Try to limit CPU usage of random search
    # torch.set_num_threads(2)  # or 1
    # os.environ["OMP_NUM_THREADS"] = "1"
    # os.environ["MKL_NUM_THREADS"] = "1"
else:
    print("CUDA is NOT available")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# drop taelliboden and plainemorte if in there
if 'taelliboden' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'taelliboden']
if 'plainemorte' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'plainemorte']

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))


## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_NN.csv')
data_monthly = dataloader_gl.data

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)


In [None]:
data_annual = dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']

# print mean and std of N_MONTHS
print('Mean number of months:', data_annual.N_MONTHS.mean())
print('Std number of months:', data_annual.N_MONTHS.std())

# same for winter
data_winter = dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']
print('Mean number of months (winter):', data_winter.N_MONTHS.mean())
print('Std number of months (winter):', data_winter.N_MONTHS.std())

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(TEST_GLACIERS)]))


In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

# dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

# train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# # Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
# train_indices, val_indices = list(train_itr), list(val_itr)

# df_X_train = data_train.iloc[train_indices]
# y_train = df_X_train['POINT_BALANCE'].values

# # Get val set
# df_X_val = data_train.iloc[val_indices]
# y_val = df_X_val['POINT_BALANCE'].values

## LSTM format:

In [None]:
MONTHLY_COLS = ['t2m','tp','slhf','sshf','ssrd','fal','str','pcsr']
STATIC_COLS  = ['ELEVATION_DIFFERENCE','aspect_sgi','slope_sgi',
                'hugonnet_dhdt','consensus_ice_thickness','millan_v']

HYDRO_MONTHS = ['oct','nov','dec','jan','feb','mar','apr','may','jun','jul','aug','sep']
HYDRO_POS = {m:i for i,m in enumerate(HYDRO_MONTHS)}

In [None]:
def build_sequences(
    df: pd.DataFrame,
    monthly_cols: List[str],
    static_cols: List[str],
    hydro_pos: Dict[str, int],
    *,
    show_progress: bool = True,
    check_unique: bool = True
) -> Dict[str, np.ndarray]:
    """
    Build 12-step LSTM sequences (Oct..Sep hydrological order) from a monthly table.

    Parameters
    ----------
    df : DataFrame
        Must contain columns: GLACIER, YEAR, ID, PERIOD, MONTHS, and (optionally) POINT_BALANCE.
        Assumes MONTHS are already normalized to {'oct','nov','dec','jan','feb','mar','apr','may','jun','jul','aug','sep'}.
    monthly_cols : list[str]
        Per-month feature columns (Fm).
    static_cols : list[str]
        Static feature columns (Fs), constant within a (GLACIER, YEAR, ID, PERIOD) group.
    hydro_pos : dict[str,int]
        Mapping from month token to 0..11 index (e.g., {'oct':0, ..., 'sep':11}).
    expect_target : bool, default True
        If True, expects POINT_BALANCE and returns 'y' (float32). If False, fills 'y' with NaN.
    show_progress : bool, default True
        Show tqdm progress bar.
    check_unique : bool, default True
        Verify that keys (GLACIER,YEAR,ID,PERIOD) are unique.

    Returns
    -------
    data_dict : dict
        {
          'X_monthly': (B,12,Fm) float32,
          'X_static':  (B,Fs)    float32,
          'mask_valid':(B,12)    float32,
          'mask_w':    (B,12)    float32 (Oct..Apr=1),
          'mask_a':    (B,12)    float32 (all 1),
          'y':         (B,)      float32 (NaN if expect_target=False),
          'is_winter': (B,)      bool,
          'is_annual': (B,)      bool,
          'keys':      list[(GLACIER,YEAR,ID,PERIOD)]
        }
    """
    req = {'GLACIER','YEAR','ID','PERIOD','MONTHS', *monthly_cols, *static_cols}
    missing = req - set(df.columns)
    if missing:
        raise KeyError(f"Missing required columns: {sorted(missing)}")

    # Normalize PERIOD just in case of stray whitespace/case
    df = df.copy()
    df['PERIOD'] = df['PERIOD'].str.strip().str.lower()

    mask_w_template = np.zeros(12, dtype=np.float32); mask_w_template[:7] = 1.0  # Oct..Apr
    mask_a_template = np.ones(12, dtype=np.float32)

    X_monthly, X_static = [], []
    mask_valid, mask_w, mask_a = [], [], []
    y, is_winter, is_annual, keys = [], [], [], []

    groups = list(df.groupby(['GLACIER','YEAR','ID','PERIOD']))
    iterator = tqdm(groups, desc="Building sequences") if show_progress else groups

    for (g, yr, mid, per), sub in iterator:
        # Average duplicates within a month, if any
        agg_cols = monthly_cols + static_cols + (['POINT_BALANCE'])
        subm = (sub.groupby('MONTHS', as_index=False)[agg_cols]
                  .mean(numeric_only=True))

        # Build 12×Fm monthly matrix and valid mask
        mat = np.zeros((12, len(monthly_cols)), dtype=np.float32)
        mv  = np.zeros(12, dtype=np.float32)

        for _, r in subm.iterrows():
            m = r['MONTHS']
            if m not in hydro_pos:
                raise ValueError(f"Unexpected month token '{m}'. Expected one of {list(hydro_pos.keys())}.")
            pos = hydro_pos[m]
            mat[pos, :] = r[monthly_cols].to_numpy(np.float32)
            mv[pos] = 1.0

        # Static features: take from first row (should be identical within group)
        s = subm.iloc[0][static_cols].to_numpy(np.float32)

        # Target 
        target = float(subm['POINT_BALANCE'].mean())
        
        # Append once per group
        X_monthly.append(mat)
        X_static.append(s)
        mask_valid.append(mv)
        mask_w.append(mask_w_template.copy())
        mask_a.append(mask_a_template.copy())
        y.append(target)
        is_winter.append(per == 'winter')
        is_annual.append(per == 'annual')
        keys.append((g, int(yr), int(mid), per))

    def stack(a): return np.stack(a, axis=0) if len(a) else np.empty((0,))

    data_dict = dict(
        X_monthly = stack(X_monthly),
        X_static  = stack(X_static),
        mask_valid= stack(mask_valid),
        mask_w    = stack(mask_w),
        mask_a    = stack(mask_a),
        y         = np.asarray(y, dtype=np.float32),
        is_winter = np.asarray(is_winter, dtype=bool),
        is_annual = np.asarray(is_annual, dtype=bool),
        keys      = keys
    )

    # Uniqueness check
    if check_unique:
        if len(keys) != len(set(keys)):
            dupes = [k for k,c in Counter(keys).items() if c > 1]
            raise ValueError(f"Found {len(dupes)} duplicate keys, e.g. {dupes[:5]}")
        else:
            print(f"All {len(keys)} keys are unique.")

    return data_dict

In [None]:
# TRAIN (with targets)
df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()
train_dict = build_sequences(
    df_train, MONTHLY_COLS, STATIC_COLS, HYDRO_POS,
    show_progress=True, check_unique=True
)

# TEST (no targets)
df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()
test_dict = build_sequences(
    df_test, MONTHLY_COLS, STATIC_COLS, HYDRO_POS,
    show_progress=True, check_unique=True
)

In [None]:
class MBSequenceDataset(Dataset):
    def __init__(self, data_dict):
        # raw (unscaled) numpy -> tensors (float / bool)
        self.Xm = torch.from_numpy(data_dict['X_monthly']).float()   # (B,12,Fm)
        self.Xs = torch.from_numpy(data_dict['X_static']).float()    # (B,Fs)
        self.mv = torch.from_numpy(data_dict['mask_valid']).float()  # (B,12)
        self.mw = torch.from_numpy(data_dict['mask_w']).float()      # (B,12)
        self.ma = torch.from_numpy(data_dict['mask_a']).float()      # (B,12)
        self.y  = torch.from_numpy(data_dict['y']).float()           # (B,)
        self.iw = torch.from_numpy(data_dict['is_winter']).bool()    # (B,)
        self.ia = torch.from_numpy(data_dict['is_annual']).bool()    # (B,)
        self.keys = data_dict.get('keys', None)

        # placeholders for scaling (set later)
        self.month_mean = None
        self.month_std  = None
        self.static_mean = None
        self.static_std  = None
        self.y_mean = None
        self.y_std  = None

    def __len__(self): return self.Xm.shape[0]

    def __getitem__(self, idx):
        return {
            "x_m": self.Xm[idx], "x_s": self.Xs[idx],
            "mv": self.mv[idx], "mw": self.mw[idx], "ma": self.ma[idx],
            "y": self.y[idx], "iw": self.iw[idx], "ia": self.ia[idx],
        }

    # ---- scaling helpers ----
    def fit_scalers(self, idx_train):
        # monthly scaler: use only valid months in TRAIN
        Xm = self.Xm[idx_train].numpy()    # (N,12,Fm)
        Mv = self.mv[idx_train].numpy()    # (N,12)
        N, T, Fm = Xm.shape
        mask3 = Mv[..., None]              # (N,12,1)
        num = (Xm * mask3).sum(axis=(0,1)) # (Fm,)
        den = mask3.sum(axis=(0,1))        # (1,)
        month_mean = num / np.maximum(den, 1e-8)
        # std
        var = (( (Xm - month_mean) * mask3 )**2).sum(axis=(0,1)) / np.maximum(den, 1e-8)
        month_std = np.sqrt(np.maximum(var, 1e-8))

        # static scaler: simple mean/std across TRAIN rows
        Xs = self.Xs[idx_train].numpy()
        static_mean = Xs.mean(axis=0)
        static_std  = np.sqrt(np.maximum(Xs.var(axis=0), 1e-8))

        # target scaler
        y = self.y[idx_train].numpy()
        y_mean = float(np.mean(y))
        y_std  = float(np.sqrt(max(np.var(y), 1e-8)))

        # store
        self.month_mean = torch.from_numpy(month_mean).float()
        self.month_std  = torch.from_numpy(month_std).float()
        self.static_mean = torch.from_numpy(static_mean).float()
        self.static_std  = torch.from_numpy(static_std).float()
        self.y_mean = torch.tensor(y_mean, dtype=torch.float32)
        self.y_std  = torch.tensor(y_std, dtype=torch.float32)

    def transform_inplace(self):
        # apply (x - mean)/std; monthly uses broadcasting over (B,12,Fm)
        self.Xm = (self.Xm - self.month_mean) / self.month_std
        self.Xs = (self.Xs - self.static_mean) / self.static_std
        # scale targets
        self.y = (self.y - self.y_mean) / self.y_std
        
    def set_scalers_from(self, other_ds: "MBSequenceDataset"):
        """Copy fitted scalers from another dataset (usually the TRAIN ds)."""
        self.month_mean = other_ds.month_mean.clone()
        self.month_std  = other_ds.month_std.clone()
        self.static_mean = other_ds.static_mean.clone()
        self.static_std  = other_ds.static_std.clone()
        self.y_mean = other_ds.y_mean.clone()
        self.y_std  = other_ds.y_std.clone()

In [None]:
# Build dataset from your prepared dict
ds_train = MBSequenceDataset(train_dict)

# simple random split (for a quick test)
def split_indices(n, val_ratio=0.2):
    idx = np.arange(n); np.random.shuffle(idx)
    cut = max(1, int(n * (1 - val_ratio)))
    return idx[:cut], idx[cut:]

train_idx, val_idx = split_indices(len(ds_train), val_ratio=0.2)

# fit scalers on TRAIN only, then transform whole dataset
ds_train.fit_scalers(train_idx)
ds_train.transform_inplace()

train_ds = Subset(ds_train, train_idx)
val_ds   = Subset(ds_train, val_idx)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=128, shuffle=False)

# 1) Target scale
print("y mean/std (train):", float(ds_train.y_mean), float(ds_train.y_std))

# 2) Any NaNs?
for name, arr in [("Xm", ds_train.Xm), ("Xs", ds_train.Xs), ("y", ds_train.y)]:
    has_nan = torch.isnan(arr).any().item()
    print(name, "has NaN?", bool(has_nan))

# 3) Do winter/annual rows exist in both splits?
print("Train counts:", int(ds_train.iw[train_idx].sum()), "winter |", int(ds_train.ia[train_idx].sum()), "annual")
print("Val   counts:", int(ds_train.iw[val_idx].sum()),   "winter |", int(ds_train.ia[val_idx].sum()),   "annual")

In [None]:
class LSTM_MB(nn.Module):
    def __init__(self, Fm, Fs, hidden=128, layers=1, bidir=True, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=Fm, hidden_size=hidden, num_layers=layers,
            batch_first=True, bidirectional=bidir,
            dropout=dropout if layers > 1 else 0.0
        )
        H = hidden * (2 if bidir else 1)
        self.static_mlp = nn.Sequential(
            nn.Linear(Fs, 64), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(64, 64), nn.ReLU()
        )
        self.head = nn.Linear(H + 64, 1)  # per-month scalar

    def forward(self, x_m, x_s, mv, mw, ma):
        out, _ = self.lstm(x_m)              # (B,12,H or 2H)
        s = self.static_mlp(x_s)             # (B,64)
        s_rep = s.unsqueeze(1).expand(-1, out.size(1), -1)
        z = torch.cat([out, s_rep], dim=-1)  # (B,12,H+64)
        y_month = self.head(z).squeeze(-1)   # (B,12)
        # zero out invalid months BEFORE summing
        y_month = y_month * mv
        y_w = (y_month * mw).sum(dim=1)      # (B,)
        y_a = (y_month * ma).sum(dim=1)      # (B,)
        return y_month, y_w, y_a

In [None]:
Fm = ds_train.Xm.shape[-1]
Fs = ds_train.Xs.shape[-1]
model = LSTM_MB(Fm=Fm, Fs=Fs, hidden=128, layers=1, bidir=True, dropout=0.1).to(device)

In [None]:
def seasonal_mse(outputs, batch):
    _, y_w_pred, y_a_pred = outputs
    # predictions are in scaled space because inputs & target are scaled.
    # BUT: the seasonal sums are linear, so using scaled y is fine as long as y was scaled.
    y_true = batch['y']  # already scaled

    iw, ia = batch['iw'], batch['ia']
    loss = 0.0; terms = 0

    if iw.any():
        loss = loss + torch.mean((y_w_pred[iw] - y_true[iw])**2); terms += 1
    if ia.any():
        loss = loss + torch.mean((y_a_pred[ia] - y_true[ia])**2); terms += 1

    if terms == 0:
        return torch.tensor(0.0, device=y_true.device)
    return loss / terms

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
clip_val = 1.0
epochs = 50

def to_device(batch):
    return {k: (v.to(device) if torch.is_tensor(v) else v) for k,v in batch.items()}

def run_epoch(dl, train=True):
    model.train(train)
    tot, n = 0.0, 0
    with torch.set_grad_enabled(train):
        for batch in dl:
            batch = to_device(batch)
            y_m, y_w, y_a = model(batch['x_m'], batch['x_s'], batch['mv'], batch['mw'], batch['ma'])
            loss = seasonal_mse((y_m, y_w, y_a), batch)
            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), clip_val)
                optimizer.step()
            bs = batch['x_m'].shape[0]
            tot += loss.item() * bs; n += bs
    return tot / max(n,1)

best_state, best_val = None, float('inf')
for ep in range(1, epochs+1):
    tr = run_epoch(train_dl, True)
    va = run_epoch(val_dl, False)
    if va < best_val:
        best_val = va
        best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
    if ep % 5 == 0 or ep == 1:
        print(f"Epoch {ep:03d} | train {tr:.4f} | val {va:.4f}")

# load best
if best_state is not None:
    model.load_state_dict(best_state)

### Make predictions:

In [None]:
@torch.no_grad()
def evaluate_with_preds(dl, ds):
    model.eval()
    rows = []

    # we assume Dataset has attribute .keys aligned with ds order
    all_keys = ds.keys  

    i = 0  # running index to map back to keys
    for batch in dl:
        batch_size = batch['x_m'].shape[0]
        batch_keys = all_keys[i:i+batch_size]
        i += batch_size

        batch = to_device(batch)
        _, y_w, y_a = model(batch['x_m'], batch['x_s'],
                            batch['mv'], batch['mw'], batch['ma'])

        # invert scaling
        y_true = batch['y'] * ds.y_std.to(device) + ds.y_mean.to(device)
        y_w = y_w * ds.y_std.to(device) + ds.y_mean.to(device)
        y_a = y_a * ds.y_std.to(device) + ds.y_mean.to(device)

        for j in range(batch_size):
            g, yr, mid, per = batch_keys[j]
            target = float(y_true[j].cpu())
            if per == "winter":
                pred = float(y_w[j].cpu())
            elif per == "annual":
                pred = float(y_a[j].cpu())
            else:
                raise ValueError(f"Unexpected PERIOD: {per}")
            rows.append({
                "target": target,
                "ID": mid,
                "pred": pred,
                "PERIOD": per,
                "GLACIER": g,
                "YEAR": yr
            })

    df_preds = pd.DataFrame(rows)

    # Compute RMSEs
    def rmse(df, period):
        d = df[df["PERIOD"] == period]
        if len(d) == 0: return float("nan")
        return np.sqrt(((d["pred"] - d["target"])**2).mean())

    metrics = {
        "RMSE_winter": rmse(df_preds, "winter"),
        "RMSE_annual": rmse(df_preds, "annual"),
    }

    return metrics, df_preds

In [None]:
# --- TEST: use the SAME scalers from TRAIN, then transform TEST ---
ds_test = MBSequenceDataset(test_dict)       # test_dict was built from your test DF (with targets)
ds_test.set_scalers_from(ds_train)           # <-- crucial: copy TRAIN scalers
ds_test.transform_inplace()                  # scales inputs & targets using TRAIN stats

test_dl = DataLoader(ds_test, batch_size=128, shuffle=False)

In [None]:
test_metrics, test_df_preds = evaluate_with_preds(test_dl, ds_test)
print("TEST metrics:", test_metrics)
print(test_df_preds.head())

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(20, 15), sharex=True)

PlotIndividualGlacierPredVsTruth(test_df_preds,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=test_gl_per_el)
