## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import pickle 
from collections import Counter
import ast

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

print("Using seed:", cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")


## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# drop taelliboden and plainemorte if in there
if 'taelliboden' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'taelliboden']
if 'plainemorte' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'plainemorte']

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))


## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_NN.csv')

months_head_pad, months_tail_pad = mbm.data_processing.utils.build_head_tail_pads_from_monthly_df(
    data_monthly)
_, month_pos = mbm.data_processing.utils._rebuild_month_index(
    months_head_pad, months_tail_pad)

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)


In [None]:
data_glamos["FROM_DATE"] = pd.to_datetime(data_glamos["FROM_DATE"].astype(str),
                                          format="%Y%m%d")
data_glamos["TO_DATE"] = pd.to_datetime(data_glamos["TO_DATE"].astype(str),
                                        format="%Y%m%d")

# Extract first and last months as numbers (1–12)
data_glamos["FIRST_MONTH_NUM"] = data_glamos[["FROM_DATE"
                                              ]].min(axis=1).dt.month
data_glamos["LAST_MONTH_NUM"] = data_glamos[["TO_DATE"]].max(axis=1).dt.month

# Compute min of all FIRST and max of all LAST
global_first = data_glamos["FIRST_MONTH_NUM"].min()
global_last = data_glamos["LAST_MONTH_NUM"].max()

# Convert back to abbreviations if needed
month_abbr = {
    i: pd.to_datetime(str(i), format="%m").strftime("%b").lower()
    for i in range(1, 13)
}
global_first_abbr = month_abbr[global_first]
global_last_abbr = month_abbr[global_last]

# print("Global earliest first month:", global_first_abbr)
# print("Global latest last month:",global_last_abbr)

In [None]:
data_annual = dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']

# print mean and std of N_MONTHS
print('Mean number of months:', data_annual.N_MONTHS.mean())
print('Std number of months:', data_annual.N_MONTHS.std())

# same for winter
data_winter = dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']
print('Mean number of months (winter):', data_winter.N_MONTHS.mean())
print('Std number of months (winter):', data_winter.N_MONTHS.std())

## Blocking on glaciers:

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(TEST_GLACIERS)]))


In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

## Neural Network:

### Parameter grid search results:

In [None]:
# Open grid_search results
gs_results = pd.read_csv(
    'logs/nn_param_search_progress_2025-09-19.csv').sort_values(
        by='valid_loss', ascending=True)
best_params = gs_results.iloc[0].to_dict()

print('Best parameters from grid search:')
for key, value in best_params.items():
    if key not in ['valid_loss', 'train_loss']:
        print(f"{key}: {value}")

gs_results.head(10)

### Initialise network:

In [None]:
features_topo = [
    'ELEVATION_DIFFERENCE',
    'pcsr',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate)

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

# Because CH has some extra columns, we need to cut those
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

# create the same but for winter only:
df_X_train_subset_winter = df_X_train_subset[df_X_train_subset.PERIOD ==
                                             'winter']
df_X_val_subset_winter = df_X_val_subset[df_X_val_subset.PERIOD == 'winter']
y_train_w = df_X_train_subset_winter['POINT_BALANCE'].values
y_val_w = df_X_val_subset_winter['POINT_BALANCE'].values
print('Shape of training dataset only winter:', df_X_train_subset_winter.shape)
print('Shape of validation dataset only winter:', df_X_val_subset_winter.shape)

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# param_init = {'device': 'cuda:0'}
param_init = {'device': 'cpu'}  # Use CPU for training
nInp = len(feature_columns)
print('Number of input features:', nInp)

In [None]:
# params = {
#     'lr': 0.001,
#     'batch_size': 128,
#     'optimizer': torch.optim.Adam,
#     'optimizer__weight_decay': 1e-05,
#     'module__hidden_layers': [128, 128, 64, 32],
#     'module__dropout': 0.2,
#     'module__use_batchnorm': True,
# }

params = {
    'lr': 0.0005,
    'batch_size': 128,
    'optimizer': torch.optim.Adam,
    'optimizer__weight_decay': 0.0,
    'module__hidden_layers': [128, 128, 64, 64, 32],
    'module__dropout': 0.2,
    'module__use_batchnorm': False,
}

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 200,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

custom_nn = mbm.models.CustomNeuralNetRegressor(cfg, **args, **param_init)

### Create datasets:

In [None]:
features, metadata = mbm.data_processing.utils.create_features_metadata(
    cfg, df_X_train_subset)

features_val, metadata_val = mbm.data_processing.utils.create_features_metadata(
    cfg, df_X_val_subset)

# Define the dataset for the NN
dataset = mbm.data_processing.AggregatedDataset(
    cfg,
    features=features,
    metadata=metadata,
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    targets=y_train)
dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                  SliceDataset(dataset, idx=1))
print("train:", dataset.X.shape, dataset.y.shape)

dataset_val = mbm.data_processing.AggregatedDataset(
    cfg,
    features=features_val,
    metadata=metadata_val,
    months_head_pad=months_head_pad,
    months_tail_pad=months_tail_pad,
    targets=y_val)
dataset_val = mbm.data_processing.SliceDatasetBinding(
    SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
print("validation:", dataset_val.X.shape, dataset_val.y.shape)

### Train custom model:

In [None]:
TRAIN = True
if TRAIN:
    custom_nn.seed_all()

    print("Training the model...")
    print('Model parameters:')
    for key, value in args.items():
        print(f"{key}: {value}")
    custom_nn.fit(dataset.X, dataset.y)
    # The dataset provided in fit is not used as the datasets are overwritten in the provided train_split function

    # Generate filename with current date
    current_date = datetime.now().strftime("%Y-%m-%d")
    model_filename = f"nn_model_{current_date}"

    plot_training_history(custom_nn.history, skip_first_n=5)

    # After Training: Best weights are already loaded
    # Save the model
    custom_nn.save_model(model_filename)

    # save params dic
    params_filename = f"nn_params_{current_date}.pkl"

    with open(f"models/{params_filename}", "wb") as f:
        pickle.dump(args, f)

### Load model and make predictions:

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-09-22.pt"  # Replace with actual date if needed

# read pickle with params
params_filename = "nn_params_2025-09-22.pkl"  # Replace with actual date if needed
with open(f"models/{params_filename}", "rb") as f:
    custom_params = pickle.load(f)

params = custom_params

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 300,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **{
        **args,
        **param_init
    },
)
loaded_model = loaded_model.set_params(device='cpu')
loaded_model = loaded_model.to('cpu')

#### On test:

In [None]:
grouped_ids, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    loaded_model, df_X_test_subset, test_set['y'], cfg, months_head_pad,
    months_tail_pad)

baseline_score = scores_NN['rmse']
print('Baseline RMSE:', baseline_score)

scores_annual, scores_winter = compute_seasonal_scores(grouped_ids,
                                                       target_col='target',
                                                       pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

# calculate RMSE
root_mean_squared_error(grouped_ids['target'], grouped_ids['pred'])

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index

fig, axs = plt.subplots(3, 3, figsize=(20, 15), sharex=True)

PlotIndividualGlacierPredVsTruth(grouped_ids,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=test_gl_per_el)


#### On train:

In [None]:
grouped_ids_NN_train, scores_NN_train, ids_train, y_pred_train = evaluate_model_and_group_predictions(
    loaded_model, data_train[all_columns], data_train['POINT_BALANCE'].values,
    cfg, months_head_pad, months_tail_pad)
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN_train, target_col='target', pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids_NN_train,
                               scores_annual=scores_annual_NN,
                               scores_winter=scores_winter_NN,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-14, 8),
                               ax_ylim=(-14, 8))


In [None]:
train_gl_per_el = gl_per_el[train_glaciers].sort_values().index

fig, axs = plt.subplots(7, 3, figsize=(20, 30), sharex=False)

PlotIndividualGlacierPredVsTruth(grouped_ids_NN_train,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=train_gl_per_el,
                                 ax_xlim=None)


## Extrapolate in space:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

glacier_list = list(data_glamos.GLACIER.unique())
print('Number of glaciers with pcsr:', len(glacier_list))

geodetic_glaciers = periods_per_glacier.keys()
print('Number of glaciers with geodetic MB:', len(geodetic_glaciers))

# Intersection of both
common_glaciers = list(set(geodetic_glaciers) & set(glacier_list))
print('Number of common glaciers:', len(common_glaciers))

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = sort_by_area(common_glaciers, gl_area)
glacier_list

#### GLAMOS grids:

In [None]:
# Define paths
path_save_glw = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_combis/glamos_dems_NN_SEB_full_OGGM')
os.makedirs(path_save_glw, exist_ok=True)
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS', 'topo', 'GLAMOS_DEM',
                             'xr_masked_grids')

RUN = True
if RUN:
    emptyfolder(path_save_glw)
    for glacier_name in glacier_list:
        glacier_path = os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                                    glacier_name)

        if not os.path.exists(glacier_path):
            print(f"Folder not found for {glacier_name}, skipping...")
            continue

        glacier_files = sorted(
            [f for f in os.listdir(glacier_path) if glacier_name in f])

        geodetic_range = range(np.min(periods_per_glacier[glacier_name]),
                               np.max(periods_per_glacier[glacier_name]) + 1)

        years = [
            int(file_name.split('_')[2].split('.')[0])
            for file_name in glacier_files
        ]
        years = [y for y in years if y in geodetic_range]

        print(
            f"Processing {glacier_name} ({len(years)} files): {geodetic_range}"
        )

        for year in tqdm(years, desc=f"Processing {glacier_name}",
                         leave=False):
            file_name = f"{glacier_name}_grid_{year}.parquet"

            # Load parquet input glacier grid file in monthly format (pre-processed)
            df_grid_monthly = pd.read_parquet(
                os.path.join(cfg.dataPath + path_glacier_grid_glamos,
                             glacier_name, file_name))

            df_grid_monthly.drop_duplicates(inplace=True)

            # Keep only necessary columns, avoiding missing columns issues
            df_grid_monthly = df_grid_monthly[[
                col for col in all_columns if col in df_grid_monthly.columns
            ]]
            df_grid_monthly = df_grid_monthly.dropna()

            # Create geodata object
            geoData = mbm.geodata.GeoData(df_grid_monthly,
                                          months_head_pad=months_head_pad,
                                          months_tail_pad=months_tail_pad)

            # Computes and saves gridded MB for a year and glacier
            path_glacier_dem = os.path.join(cfg.dataPath, path_xr_grids,
                                            f"{glacier_name}_{year}.zarr")
            geoData.gridded_MB_pred(
                df_grid_monthly,
                loaded_model,
                glacier_name,
                year,
                all_columns,
                path_glacier_dem,
                path_save_glw,
                save_monthly_pred=True,
                type_model='NN',
            )

glacier_name = 'aletsch'
year = 2008
# open xarray
xr.open_dataset(
    os.path.join(
        path_save_glw,
        f'{glacier_name}/{glacier_name}_{year}_annual.zarr')).pred_masked.plot(
            cmap='RdBu')

In [None]:
glaciers_in_glamos = os.listdir(path_save_glw)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)


## Permutation importance:

In [None]:
rng = np.random.default_rng(cfg.seed)
importances = {col: [] for col in feature_columns}

# Compute baseline
_, scores_baseline, _, _ = evaluate_model_and_group_predictions(
    loaded_model, df_X_test_subset, test_set['y'], cfg, months_head_pad,
    months_tail_pad)

baseline_score = scores_baseline['rmse']
print(f"Baseline RMSE: {baseline_score:.4f}")

n_repeats = 10
for col in tqdm(feature_columns):
    for _ in range(n_repeats):
        df_permuted = df_X_test_subset.copy()
        df_permuted[col] = rng.permutation(df_permuted[col].values)

        # Evaluate model on permuted data
        _, scores_perm, _, _ = evaluate_model_and_group_predictions(
            loaded_model, df_permuted, test_set['y'], cfg, months_head_pad,
            months_tail_pad)
        perm_score = scores_perm['rmse']
        importance = perm_score - baseline_score  # Positive = worse performance
        importances[col].append(importance)

# Aggregate results
df_importances = pd.DataFrame({
    "feature":
    feature_columns,
    "mean_importance": [np.mean(importances[col]) for col in feature_columns],
    "std_importance": [np.std(importances[col]) for col in feature_columns],
}).sort_values(by="mean_importance", ascending=False)
plot_permutation_importance(df_importances, top_n=20)

## Maps:

In [None]:
df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=path_save_glw,  # or another path if needed
    cfg=cfg)

In [None]:
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

# Example usage
GLACIER_NAME = 'silvretta'
df_nn = df_all_nn[df_all_nn.GLACIER == GLACIER_NAME]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

plot_scatter_comparison(axs[0],
                        df_nn,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(NN)")

# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_nn = mbm_glwd_pred(path_save_glw, GLACIER_NAME)
MBM_glwmb_nn.rename(columns={"MBM Balance": "MBM Balance NN"}, inplace=True)

# Merge with GLAMOS data
MBM_glwmb_nn = MBM_glwmb_nn.join(GLAMOS_glwmb)

# Drop NaN values to avoid plotting errors
MBM_glwmb_nn = MBM_glwmb_nn.dropna()

MBM_glwmb_nn.plot(ax=axs[1],
                  y=['MBM Balance NN', 'GLAMOS Balance'],
                  marker="o",
                  color=[color_annual, color_winter])

for ax in axs:
    ax.set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
    ax.set_ylabel("Mass Balance [m w.e.]", fontsize=18)
    ax.set_xlabel("Year", fontsize=18)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend(fontsize=14)

plt.tight_layout()
plt.show()

In [None]:
for year in MBM_glwmb_nn.index:
    plot_mass_balance_comparison_annual_glamos_nn(
        glacier_name=GLACIER_NAME,
        year=year,
        cfg=cfg,
        df_stakes=df_stakes,
        path_distributed_mb=path_distributed_MB_glamos,
        path_pred_nn=path_save_glw,
        get_glamos_func=get_GLAMOS_glwmb,
        get_pred_func=get_predicted_mb,
        get_glamos_pred_func=get_predicted_mb_glamos,
        load_grid_func=load_grid_file,
        to_wgs84_func=transform_xarray_coords_lv95_to_wgs84,
        apply_filter_func=apply_gaussian_filter,
        get_colormaps_func=get_color_maps)
