## Setting Up:

In [None]:
import sys, os

sys.path.append(os.path.join(os.getcwd(),
                             '../../'))  # Add root of repo to import MBM
import csv
from functools import partial

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
import pickle
from collections import Counter
import ast

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# check:
df = dataloader_gl.data
df[(df.POINT_ID == 'adler_26') & (df.YEAR == 2006)].MONTHS

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## LSTM:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 
]

feature_columns = MONTHLY_COLS + STATIC_COLS

### Build LSTM dataloaders:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

### Define & train model:

In [None]:
log_path = 'logs/lstm_two_heads_param_search_progress_2025-09-17.csv'
best_params = get_best_params_for_lstm(log_path, select_by='test_rmse_a')
best_params

In [None]:
log_path = Path(log_path)
df = pd.read_csv(log_path)
df["avg_test_loss"] = (df["test_rmse_a"] + df["test_rmse_w"]) / 2
df.sort_values(by="test_rmse_a", inplace=True)
df.head(2)

In [None]:
plot_topk_param_distributions(log_path, k=10, metric="avg_test_loss")

In [None]:
# custom_params = {'Fm': 9,
#  'Fs': 5,
#  'hidden_size': 128,
#  'num_layers': 1,
#  'bidirectional': False,
#  'dropout': 0.1,
#  'static_layers': 0,
#  'static_hidden': None,
#  'static_dropout': None,
#  '': False,
#  'head_dropout': 0.2,
#  'lr': 0.001,
#  'weight_decay': 0.0001,
#  'clip_val': 1.0,
#  'sched_factor': 0.5,
#  'sched_patience': 6,
#  'loss_name': 'neutral',
#  'loss_spec': None}

custom_params = best_params

custom_params['two_heads'] = True
custom_params['head_dropout'] = 0.0
custom_params['Fm'] = 9
custom_params['Fs'] = 2

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_oggm.pt"

if os.path.exists(model_filename): os.remove(model_filename)

# --- loaders (fit scalers on TRAIN, apply to whole ds_train) ---
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- test loader (copies TRAIN scalers into ds_test and transforms it) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(best_params)

TRAIN = True
if TRAIN:
    if os.path.exists(model_filename): os.remove(model_filename)

    history, best_val, best_state = model.train_loop(
        device=device,
        train_dl=train_dl,
        val_dl=val_dl,
        epochs=150,
        lr=best_params['lr'],
        weight_decay=best_params['weight_decay'],
        clip_val=1,
        # scheduler
        sched_factor=0.5,
        sched_patience=6,
        sched_threshold=0.01,
        sched_threshold_mode="rel",
        sched_cooldown=1,
        sched_min_lr=1e-6,
        # early stopping
        es_patience=15,
        es_min_delta=1e-4,
        # logging
        log_every=5,
        verbose=True,
        # checkpoint
        save_best_path=model_filename,
        loss_fn=loss_fn,
    )

state = torch.load(model_filename, map_location=device)
model.load_state_dict(state)

# Evaluate on test
test_metrics, test_df_preds = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics['RMSE_annual'], test_metrics[
    'RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

plot_history_lstm(history)

In [None]:
scores_annual, scores_winter = compute_seasonal_scores(test_df_preds,
                                                       target_col='target',
                                                       pred_col='pred')

print("Annual scores:", scores_annual)
print("Winter scores:", scores_winter)

fig = plot_predictions_summary(grouped_ids=test_df_preds,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(20, 15), sharex=True)

PlotIndividualGlacierPredVsTruth(test_df_preds,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=test_gl_per_el)


## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

In [None]:
all_columns = feature_columns + cfg.fieldsNotFeatures

# Required by the dataset builder regardless of your feature list
REQUIRED = ['GLACIER', 'YEAR', 'ID', 'PERIOD', 'MONTHS']

# Paths
path_save_glw = os.path.join(cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
                             'MBM/testing_LSTM/LSTM_no_oggm')
os.makedirs(path_save_glw, exist_ok=True)
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM',
                             'xr_masked_grids')

# Load model once
best_state = torch.load(model_filename, map_location=device)
model.load_state_dict(best_state)
model.eval()

RUN = True
if RUN:
    emptyfolder(path_save_glw)

    for glacier_name in glacier_list:
        glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                                    glacier_name)
        if not os.path.exists(glacier_path):
            print(f"Folder not found for {glacier_name}, skipping...")
            continue

        glacier_files = sorted(
            [f for f in os.listdir(glacier_path) if glacier_name in f])

        geodetic_range = range(np.min(periods_per_glacier[glacier_name]),
                               np.max(periods_per_glacier[glacier_name]) + 1)

        years = [int(f.split('_')[2].split('.')[0]) for f in glacier_files]
        years = [y for y in years if y in geodetic_range]

        print(
            f"Processing {glacier_name} ({len(years)} files): {geodetic_range}"
        )

        for year in tqdm(years, desc=f"Processing {glacier_name}",
                         leave=False):
            seed_all(cfg.seed)

            file_name = f"{glacier_name}_grid_{year}.parquet"
            df_grid_monthly = pd.read_parquet(
                os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                             glacier_name, file_name)).copy()

            df_grid_monthly.drop_duplicates(inplace=True)

            # Keep required + feature columns; DON'T drop PERIOD/MONTHS/YEAR/ID/GLACIER
            keep = [
                c for c in (set(all_columns) | set(REQUIRED))
                if c in df_grid_monthly.columns
            ]
            df_grid_monthly = df_grid_monthly[keep]

            # Ensure PERIOD is set to 'annual' BEFORE sequence building
            if 'PERIOD' not in df_grid_monthly.columns:
                df_grid_monthly['PERIOD'] = 'annual'
            else:
                df_grid_monthly['PERIOD'] = df_grid_monthly['PERIOD'].fillna(
                    'annual')
            df_grid_monthly['PERIOD'] = df_grid_monthly['PERIOD'].str.strip(
            ).str.lower()

            # (Optional) minimal NaN clean-up: only drop if ID or MONTHS missing
            df_grid_monthly = df_grid_monthly.dropna(subset=['ID', 'MONTHS'])

            # (Optional) hydrological coverage check
            have = set(df_grid_monthly['MONTHS'].str.lower().unique())
            need = {
                'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun',
                'jul', 'aug', 'sep'
            }
            if not need.issubset(have):
                missing = sorted(list(need - have))
                print(
                    f"WARNING [{glacier_name} {year}]: missing hydro months: {missing}"
                )

            # --- Build ds_gl WITHOUT targets ---
            # Build dataloaders:
            ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
                ds_train)
            
            ds_train_copy.fit_scalers(train_idx)   # <--- FIT ONLY; do NOT transform here
            # (No need to call transform_inplace on ds_scalers for glacier-wide inference)

            ds_gl = mbm.data_processing.MBSequenceDataset.from_dataframe(
                df_grid_monthly,
                MONTHLY_COLS,
                STATIC_COLS,
                months_tail_pad=months_tail_pad,
                months_head_pad=months_head_pad,
                expect_target=False,
                show_progress=False)

            test_gl_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
                ds_gl, ds_train_copy, seed=cfg.seed,
                batch_size=128)  # copies train scalers & transforms ds_gl

            # Predict (no metrics)
            df_preds = model.predict_with_keys(device, test_gl_dl, ds_gl)

            # Join preds back to unique cell IDs for saving
            data = df_preds[['ID', 'pred']].set_index('ID')
            grouped_ids = df_grid_monthly.groupby('ID')[[
                'YEAR', 'POINT_LAT', 'POINT_LON', 'GLWD_ID'
            ]].first()
            grouped_ids = grouped_ids.merge(data,
                                            left_index=True,
                                            right_index=True,
                                            how='left')

            months_per_id = df_grid_monthly.groupby('ID')['MONTHS'].unique()
            grouped_ids = grouped_ids.merge(months_per_id,
                                            left_index=True,
                                            right_index=True)

            grouped_ids.reset_index(inplace=True)
            grouped_ids.sort_values(by='ID', inplace=True)
            grouped_ids['PERIOD'] = 'annual'

            pred_y_annual = grouped_ids.drop(columns=['YEAR'], errors='ignore')

            # Save
            path_glacier_dem = os.path.join(cfg.dataPath, path_xr_grids,
                                            f"{glacier_name}_{year}.zarr")
            ds = xr.open_dataset(path_glacier_dem)
            geoData = mbm.geodata.GeoData(df_grid_monthly,
                                          months_head_pad=months_head_pad,
                                          months_tail_pad=months_tail_pad)
            geoData._save_prediction(ds, pred_y_annual, glacier_name, year,
                                     path_save_glw, "annual")

# quick viz
glacier_name = 'aletsch'
year = 2008
xr.open_dataset(os.path.join(path_save_glw, f'{glacier_name}/{glacier_name}_{year}_annual.zarr'))\
  .pred_masked.plot(cmap='RdBu')

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)


## Permutation importance:

In [None]:
import copy
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset, Subset


def permutation_importance_lstm(
    model,
    device,
    ds_test,  # MBSequenceDataset (already scaled)
    MONTHLY_COLS,
    STATIC_COLS,
    *,
    batch_size: int = 128,
    n_repeats: int = 10,
    seed: int = 42,
    metric: str = "RMSE_mean",  # "RMSE_annual" | "RMSE_winter" | "RMSE_mean"
    num_workers: int = 0,
    pin_memory: bool = False,
):
    """
    Compute permutation importance for an LSTM with sequence + static features.

    Importance = (score after permutation) - (baseline score); higher => more important.

    Assumes:
      - ds_test is an MBSequenceDataset with scalers set (same as used at inference).
      - model has .evaluate_with_preds(device, dl, ds) which returns metrics dict.

    Returns
    -------
    df_imp : DataFrame with columns: feature, group, metric, mean_importance, std_importance
    """
    rng = np.random.default_rng(seed)
    model.eval()

    # --- helper to compute a scalar score from metrics dict ---
    def score_from_metrics(metrics: dict) -> float:
        if metric == "RMSE_annual":
            return float(metrics["RMSE_annual"])
        elif metric == "RMSE_winter":
            return float(metrics["RMSE_winter"])
        elif metric == "RMSE_mean":
            #  mean of available RMSEs
            vals = [
                v for k, v in metrics.items()
                if k.startswith("RMSE_") and np.isfinite(v)
            ]
            return float(np.mean(vals)) if len(vals) else float("nan")
        else:
            raise ValueError(f"Unknown metric '{metric}'")

    # Build a baseline DataLoader on the unmodified test dataset
    base_dl = DataLoader(
        ds_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    base_metrics, _ = model.evaluate_with_preds(device, base_dl, ds_test)
    baseline_score = score_from_metrics(base_metrics)

    results = []

    # --- STATIC FEATURES ---
    for feat in STATIC_COLS:
        col_idx = STATIC_COLS.index(feat)
        deltas = []
        for _ in range(n_repeats):
            # shallow copy dataset object; clone just the tensor we mutate
            ds_perm = copy.copy(ds_test)
            Xs_perm = ds_test.Xs.clone()
            col = Xs_perm[:, col_idx].cpu().numpy()
            Xs_perm[:, col_idx] = torch.from_numpy(rng.permutation(col)).to(
                Xs_perm.dtype)
            ds_perm.Xs = Xs_perm

            dl = DataLoader(ds_perm,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=pin_memory)
            m, _ = model.evaluate_with_preds(device, dl, ds_perm)
            deltas.append(score_from_metrics(m) - baseline_score)

        results.append({
            "feature":
            feat,
            "group":
            "static",
            "metric":
            metric,
            "mean_importance":
            float(np.mean(deltas)),
            "std_importance":
            float(np.std(deltas, ddof=1) if len(deltas) > 1 else 0.0),
        })

    # --- MONTHLY FEATURES (permute entire sequences across samples) ---
    for feat in MONTHLY_COLS:
        fidx = MONTHLY_COLS.index(feat)
        deltas = []
        for _ in range(n_repeats):
            ds_perm = copy.copy(ds_test)
            Xm_perm = ds_test.Xm.clone()  # (N, 12, Fm)

            # Extract all sequences for this feature: shape (N, 12)
            seqs = Xm_perm[:, :, fidx].cpu().numpy()
            # Permute across samples (axis 0). Each row is a 12-month sequence.
            seqs_perm = rng.permutation(seqs)
            Xm_perm[:, :, fidx] = torch.from_numpy(seqs_perm).to(Xm_perm.dtype)

            ds_perm.Xm = Xm_perm

            dl = DataLoader(ds_perm,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=num_workers,
                            pin_memory=pin_memory)
            m, _ = model.evaluate_with_preds(device, dl, ds_perm)
            deltas.append(score_from_metrics(m) - baseline_score)

        results.append({
            "feature":
            feat,
            "group":
            "monthly",
            "metric":
            metric,
            "mean_importance":
            float(np.mean(deltas)),
            "std_importance":
            float(np.std(deltas, ddof=1) if len(deltas) > 1 else 0.0),
        })

    df_imp = pd.DataFrame(results).sort_values(
        "mean_importance", ascending=False).reset_index(drop=True)
    return df_imp, baseline_score


In [None]:
# Assume ds_test already has train scalers copied & transform_inplace() applied
df_imp, baseline = permutation_importance_lstm(
    model,
    device,
    ds_test,
    MONTHLY_COLS,
    STATIC_COLS,
    batch_size=128,
    n_repeats=10,
    seed=cfg.seed,
    metric="RMSE_mean",  # or "RMSE_annual" / "RMSE_winter"
)

print(f"Baseline {df_imp.metric.iloc[0]}: {baseline:.4f}")
print(df_imp.head(20))

In [None]:
plot_permutation_importance(df_imp, top_n=20)