## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM
from sklearn.metrics import r2_score

import pickle
import pandas as pd
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import joypy
from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import torch.nn as nn
from skorch.helper import SliceDataset
from pandas.api.types import CategoricalDtype
from matplotlib.patches import Patch
import matplotlib.gridspec as gridspec

import matplotlib.pyplot as plt
import geopandas as gpd

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.nn_helpers import *
from scripts.geodata_plots import *
from scripts.NN_networks import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# Glacier outlines:
glacier_outline_sgi = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'SGI_2016_glaciers_copy.shp'))  # Load the shapefile
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_glamos.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

In [None]:
# get number of measurements per glacier:
glacier_info = data_glamos.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = data_glamos.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = data_glamos.groupby(['GLACIER', 'PERIOD'
                                      ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in TEST_GLACIERS else 'Train', axis=1)
glacier_info.head(2)

### Assign glaciers to river basin names:

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# === Load SGI region geometries ===
SGI_regions = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'sgi_regions.geojson'))

# Clean object columns
SGI_regions[SGI_regions.select_dtypes(include='object').columns] = \
    SGI_regions.select_dtypes(include='object').apply(lambda col: col.str.strip())

SGI_regions = SGI_regions.drop_duplicates().dropna()
SGI_regions = SGI_regions.set_index('pk_sgi_region')

# === Map to Level 0 river basins ===
catchment_lv0 = {
    'A': 'Rhine',
    'B': 'Rhone',
    'C': 'Po',
    'D': 'Adige',
    'E': 'Danube'
}
rgi_df['rvr_lv0'] = rgi_df['sgi-id'].str[0].map(catchment_lv0)


# === Map to Level 1 river basins using SGI regions ===
def get_river_basin(sgi_id):
    key = sgi_id.split('-')[0]
    if key not in SGI_regions.index:
        return None
    basin = SGI_regions.loc[key, 'river_basin_name']
    if isinstance(basin, pd.Series):
        return basin.dropna().unique()[0] if not basin.dropna().empty else None
    return basin if pd.notna(basin) else None


rgi_df['rvr_lv1'] = rgi_df['sgi-id'].apply(get_river_basin)

# Final formatting
rgi_df = rgi_df.reset_index().rename(columns={
    'short_name': 'GLACIER'
}).set_index('GLACIER')

glacier_info = glacier_info.merge(rgi_df[['rvr_lv0', 'rvr_lv1']],
                                  on='GLACIER',
                                  how='left')
glacier_info.head()

### Read input data:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_NN.csv')
months_head_pad, months_tail_pad = mbm.data_processing.utils.build_head_tail_pads_from_monthly_df(
    data_monthly)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

In [None]:
features_topo = [
    'ELEVATION_DIFFERENCE',
    'pcsr',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate)

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

# Because CH has some extra columns, we need to cut those
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

## Models:
### XGBoost model:

In [None]:
param_init_xgb = {
    # 'device': 'cuda:0',
    'device': 'cpu',
    'tree_method': 'hist',
    "random_state": cfg.seed,
    "n_jobs": cfg.numJobs
}

custom_params = {'learning_rate': 0.01, 'max_depth': 6, 'n_estimators': 800}

params = {**param_init_xgb, **custom_params}
print(params)
custom_xgb_model = mbm.models.CustomXGBoostRegressor(cfg, **params)

# Fit on train data:
custom_xgb_model.fit(df_X_train_subset, y_train)

In [None]:
# Make predictions on test
custom_xgb_model = custom_xgb_model.set_params(device='cpu')
features_test, metadata_test = mbm.data_processing.utils.create_features_metadata(
    cfg, test_set['df_X'][all_columns])
y_pred = custom_xgb_model.predict(features_test)
print('Shape of the test:', features_test.shape)

# Make predictions aggr to meas ID:
y_pred_agg = custom_xgb_model.aggrPredict(metadata_test, features_test)

# Calculate scores
score = custom_xgb_model.score(test_set['df_X'][all_columns],
                               test_set['y'])  # negative
print('Overall score:', np.abs(score))
grouped_ids_xgb = getDfAggregatePred(test_set, y_pred_agg, all_columns)

scores_annual_xgb, scores_winter_xgb = compute_seasonal_scores(
    grouped_ids_xgb, target_col='target', pred_col='pred')

In [None]:
fig = plot_predictions_summary(grouped_ids=grouped_ids_xgb,
                               scores_annual=scores_annual_xgb,
                               scores_winter=scores_winter_xgb,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred)

In [None]:
df_pred_xgb = pd.DataFrame(metadata_test, columns=cfg.metaData)
df_pred_xgb['pred'] = y_pred

pred_per_id = pd.DataFrame(df_pred_xgb.groupby('ID').pred.unique())
pred_per_id['MONTHS'] = df_pred_xgb.groupby('ID').MONTHS.unique()
pred_per_id.reset_index(inplace=True)

# df_pred_months_annual = df_pred_months[df_pred_months['PERIOD'] == 'annual']
months_extended = [
    'sep', 'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun',
    'jul', 'aug', 'sep_', 'oct_'
]

df_months_xgb = pd.DataFrame(columns=months_extended)

for i, row in pred_per_id.iterrows():
    dic = {}
    for i, month in enumerate(row.MONTHS):
        if month in dic.keys():
            month = month + '_'
        dic[month] = row.pred[i]

    # add missing months from months extended
    for month in months_extended:
        if month not in dic.keys():
            dic[month] = np.nan

    df_months_xgb = pd.concat(
        [df_months_xgb, pd.DataFrame([dic])], ignore_index=True)
df_months_xgb = df_months_xgb.dropna(axis=1, how='all')

In [None]:
features_train, metadata_train = mbm.data_processing.utils.create_features_metadata(
    cfg, data_train[all_columns])
y_pred = custom_xgb_model.predict(features_train)
y_pred_agg = custom_xgb_model.aggrPredict(metadata_train, features_train)

grouped_ids_xgb_train = getDfAggregatePred(train_set, y_pred_agg, all_columns)

scores_annual_xgb, scores_winter_xgb = compute_seasonal_scores(
    grouped_ids_xgb_train, target_col='target', pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids_xgb_train,
                               scores_annual=scores_annual_xgb,
                               scores_winter=scores_winter_xgb,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-14, 8),
                               ax_ylim=(-14, 8))

In [None]:
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)
train_gl_per_el = gl_per_el[train_glaciers].sort_values().index

fig, axs = plt.subplots(8, 3, figsize=(20, 30), sharex=False)

PlotIndividualGlacierPredVsTruth(grouped_ids_xgb_train,
                                 axs=axs,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 custom_order=train_gl_per_el,
                                 ax_xlim=None)


### NN model:

#### Initialize model:

In [None]:
dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# --- features ---
features_topo = ['ELEVATION_DIFFERENCE', 'pcsr'] + list(vois_topographical)
feature_columns = features_topo + list(vois_climate)
cfg.setFeatures(feature_columns)

# keep only features + any required non-feature fields your code needs later
all_columns = feature_columns + cfg.fieldsNotFeatures

# --- test subset (no blanket dropna here; keep what your model expects) ---
df_X_test_subset_MLP = test_set['df_X'][all_columns].copy()
y_test = test_set['y']

# --- load saved model (no callbacks/scheduler/early stop needed) ---
from pathlib import Path
import pickle

model_filename = "nn_model_2025-09-22.pt"  # adjust if needed
params_filename = "nn_params_2025-09-22.pkl"  # adjust if needed

with open(f"models/{params_filename}", "rb") as f:
    params = pickle.load(f)

nInp = len(feature_columns)

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    # the rest is irrelevant for inference but harmless if present:
    'batch_size': params.get('batch_size', 128),
    'verbose': 0,
}

loaded_MLP = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **args,
    device='cpu',
).to('cpu')


#### Scatter plots:

##### Test:

In [None]:
# --- evaluate on test ---
grouped_ids_NN, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    loaded_MLP, df_X_test_subset_MLP, y_test, cfg, months_head_pad,
    months_tail_pad)

print("Test scores:", scores_NN)

In [None]:
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN, target_col='target', pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids_NN,
                               scores_annual=scores_annual_NN,
                               scores_winter=scores_winter_NN,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred)

##### Train set:

In [None]:
grouped_ids_NN_train, scores_NN_train, ids_NN_train, y_pred_NN_train = evaluate_model_and_group_predictions(
    loaded_MLP, data_train[all_columns], data_train['POINT_BALANCE'].values,
    cfg, months_head_pad, months_tail_pad)
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN_train, target_col='target', pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids_NN_train,
                               scores_annual=scores_annual_NN,
                               scores_winter=scores_winter_NN,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-14, 8),
                               ax_ylim=(-14, 8))


In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[TEST_GLACIERS].sort_values().index
elvs = gl_per_el[TEST_GLACIERS].sort_values()

fig, axs = plt.subplots(3, 3, figsize=(20, 15))

PlotIndividualGlacierPredVsTruth(grouped_ids_NN,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 axs=axs,
                                 custom_order=test_gl_per_el)

for i, ax in enumerate(axs.flatten()):
    test_gl = test_gl_per_el[i]
    el = elvs[test_gl]
    ax.set_title(f'{test_gl.capitalize()}: {round(el, 2)} m', fontsize=24)

#### Look for diff in test gl:

In [None]:
f, axs = plt.subplots(4, 4, figsize=(15, 15))

vois = ['POINT_BALANCE', 'ELEVATION_DIFFERENCE', 'YEAR'
        ] + vois_climate + vois_topographical + ['pcsr']

axs = axs.flatten()

# Draw the plots and remove legends initially
for i, voi in enumerate(vois):
    axs[i].set_title(voi, fontsize=12)
    g = sns.histplot(df_X_test_subset,
                     x=voi,
                     ax=axs[i],
                     alpha=0.6,
                     hue='GLACIER',
                     stat='density',
                     kde=True,
                     fill=False)
    axs[i].set_xlabel('')

for ax in axs[0:17]:
    ax.legend_.remove()  # Remove individual legends

plt.tight_layout()
plt.show()


In [None]:
df_X_train_subset_gl = df_X_train_subset[df_X_train_subset.GLACIER ==
                                         'aletsch']

f, axs = plt.subplots(5, 4, figsize=(15, 15))

vois = ['POINT_BALANCE', 'ELEVATION_DIFFERENCE', 'YEAR'
        ] + vois_climate + vois_topographical + ['pcsr']

axs = axs.flatten()

# Draw the plots and remove legends initially
for i, voi in enumerate(vois):
    axs[i].set_title(voi, fontsize=12)
    g = sns.histplot(
        df_X_train_subset_gl,
        x=voi,
        ax=axs[i],
        alpha=0.6,
        #  hue='GLACIER',
        stat='density',
        kde=True,
        fill=False)
    axs[i].set_xlabel('')

# for ax in axs[0:17]:
#     ax.legend_.remove()  # Remove individual legends

plt.tight_layout()
plt.show()


## Compare models:

### Scatter on test glaciers:

In [None]:
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN, target_col='target', pred_col='pred')
scores_annual_xgb, scores_winter_xgb = compute_seasonal_scores(
    grouped_ids_xgb, target_col='target', pred_col='pred')

fig = plt.figure(figsize=(15, 5))

ax1 = plt.subplot(1, 2, 2)
ax1.set_title('XGB predictions', fontsize=20)
predVSTruth(ax1,
            grouped_ids_xgb,
            scores_annual_xgb,
            hue='PERIOD',
            palette=[color_dark_blue, color_pink])

legend_xgb = "\n".join((
    (r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
     (scores_annual_xgb["rmse"], scores_winter_xgb["rmse"])),
    (r"$\mathrm{R^2_a}=%.3f$, $\mathrm{R^2_w}=%.3f$" %
     (scores_annual_xgb["R2"], scores_winter_xgb["R2"])),
    r"$\mathrm{B_a}=%.3f$, $\mathrm{B_w}=%.3f$" %
    (scores_annual_xgb["Bias"], scores_winter_xgb["Bias"]),
))
ax1.text(0.03,
         0.96,
         legend_xgb,
         transform=ax1.transAxes,
         verticalalignment="top",
         fontsize=18,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0))

ax2 = plt.subplot(1, 2, 1)
ax2.set_title('NN predictions', fontsize=20)
predVSTruth(ax2,
            grouped_ids_NN,
            scores_annual_NN,
            hue='PERIOD',
            palette=[color_dark_blue, color_pink])

legend_NN = "\n".join((
    (r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
     (scores_annual_NN["rmse"], scores_winter_NN["rmse"])),
    (r"$\mathrm{R^2_a}=%.3f$, $\mathrm{R^2_w}=%.3f$" %
     (scores_annual_NN["R2"], scores_winter_NN["R2"])),
    r"$\mathrm{B_a}=%.3f$, $\mathrm{B_w}=%.3f$" %
    (scores_annual_NN["Bias"], scores_winter_NN["Bias"]),
))
ax2.text(0.03,
         0.96,
         legend_NN,
         transform=ax2.transAxes,
         verticalalignment="top",
         fontsize=18,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0))

plt.tight_layout()

### Spatial reconstruction:

In [None]:
path_save_glw = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_combis/glamos_dems_NN_SEB_full_OGGM')

glaciers = os.listdir(path_save_glw)
hydro_months = [
    'sep', 'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun',
    'jul', 'aug'
]
# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(path_save_glw, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_NN = pd.concat(all_glacier_data, axis=0, ignore_index=True)
df_months_NN

In [None]:
path_save_glw = cfg.dataPath + '/GLAMOS/distributed_MB_grids/MBM/glamos_dems/'

hydro_months = [
    'sep', 'oct', 'nov', 'dec', 'jan', 'feb', 'mar', 'apr', 'may', 'jun',
    'jul', 'aug'
]
# Initialize final storage for all glacier data
all_glacier_data = []

# Loop over glaciers
for glacier_name in tqdm(glaciers):
    glacier_path = os.path.join(path_save_glw, glacier_name)
    if not os.path.isdir(glacier_path):
        continue  # skip non-directories

    # Regex pattern adapted for current glacier name
    pattern = re.compile(rf'{glacier_name}_(\d{{4}})_[a-z]{{3}}\.zarr')

    # Extract available years
    years = set()
    for fname in os.listdir(glacier_path):
        match = pattern.match(fname)
        if match:
            years.add(int(match.group(1)))
    years = sorted(years)

    # Collect all year-month data
    all_years_data = []
    for year in years:
        monthly_data = {}
        for month in hydro_months:
            zarr_path = os.path.join(glacier_path,
                                     f'{glacier_name}_{year}_{month}.zarr')
            if not os.path.exists(zarr_path):
                continue

            ds = xr.open_dataset(zarr_path)
            df = ds.pred_masked.to_dataframe().drop(['x', 'y'],
                                                    axis=1).reset_index()
            df_pred_months = df[df.pred_masked.notna()]

            df_el = ds.masked_elev.to_dataframe().drop(['x', 'y'],
                                                       axis=1).reset_index()
            df_elv_months = df_el[df.pred_masked.notna()]

            df_pred_months['elevation'] = df_elv_months.masked_elev.values

            monthly_data[month] = df_pred_months.pred_masked.values

        if monthly_data:
            df_months = pd.DataFrame(monthly_data)
            df_months['year'] = year
            df_months['glacier'] = glacier_name  # add glacier name
            df_months['elevation'] = df_pred_months.elevation.values
            all_years_data.append(df_months)

    # Concatenate this glacier's data
    if all_years_data:
        df_glacier = pd.concat(all_years_data, axis=0, ignore_index=True)
        all_glacier_data.append(df_glacier)

# Final full DataFrame for all glaciers
df_months_XGB = pd.concat(all_glacier_data, axis=0, ignore_index=True)
df_months_XGB

In [None]:
# Step 1: Count entries per (glacier, year)
counts_XGB = df_months_XGB.groupby(['glacier',
                                    'year']).size().rename('count_XGB')
counts_NN = df_months_NN.groupby(['glacier', 'year']).size().rename('count_NN')

# Step 2: Merge counts
comparison = pd.merge(counts_XGB,
                      counts_NN,
                      how='outer',
                      left_index=True,
                      right_index=True)

# Step 3: Fill missing with 0 (if a pair is present in one but not the other)
comparison = comparison.fillna(0).astype(int)

# Step 4: Find where counts differ
discrepancies = comparison[comparison['count_XGB'] != comparison['count_NN']]

# Show the result
discrepancies.reset_index()

#### Glacier-wide:

In [None]:
# get glacier-wide MB for every year
glwd_months_NN = df_months_NN.groupby(['glacier', 'year']).mean().reset_index()
glwd_months_XGB = df_months_XGB.groupby(['glacier',
                                         'year']).mean().reset_index()
glwd_months_XGB

In [None]:
array_nn, array_xgb, months = [], [], []
month_order = [
    'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct',
    'nov', 'dec'
]
cat_month = CategoricalDtype(month_order, ordered=True)

df_months_xgb = glwd_months_XGB[month_order]
df_months_nn = glwd_months_NN[month_order]

for col in df_months_nn.columns:
    array_nn.append(df_months_nn[col].values)
    array_xgb.append(df_months_xgb[col].values)
    months.append(np.tile(col, len(df_months_nn[col])))

df_months_nn_long = pd.DataFrame(
    data={
        'mb_nn': np.concatenate(np.array(array_nn)),
        'mb_xgb': np.concatenate(np.array(array_xgb)),
        'Month': np.concatenate(np.array(months))
    })

# order df_months_nn_long
df_months_nn_long['Month'] = df_months_nn_long['Month'].astype(cat_month)
df_months_nn_long

In [None]:
model_colors = [color_annual, color_winter]
alpha = 1

cm = 1 / 2.54
ax, fig = joypy.joyplot(df_months_nn_long,
                        by='Month',
                        column=['mb_xgb', 'mb_nn'],
                        alpha=0.8,
                        overlap=0,
                        fill=False,
                        linewidth=1.5,
                        xlabelsize=8.5,
                        ylabelsize=8.5,
                        x_range=[-2.2, 2.2],
                        grid=False,
                        color=model_colors,
                        figsize=(12 * cm, 14 * cm),
                        ylim='own')

vline_alpha = 0.5
plt.axvline(x=0, color='grey', alpha=vline_alpha, linewidth=1)

plt.xlabel('Mass balance (m w.e.)', fontsize=8.5)
plt.yticks(ticks=range(1, 13), labels=month_order, fontsize=8.5)
plt.gca().set_yticklabels(month_order)

legend_patches = [
    Patch(facecolor=color, label=model, alpha=alpha, edgecolor='k')
    for model, color in zip(['XGB', 'NN'], model_colors)
]
plt.legend(handles=legend_patches,
           loc='upper center',
           bbox_to_anchor=(0.48, -0.1),
           ncol=4,
           fontsize=8.5,
           handletextpad=0.5,
           columnspacing=1)

#### Geodetic MB:

In [None]:
PATH_PREDICTIONS_NN = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_combis/glamos_dems_NN_SEB_full_OGGM')
PATH_PREDICTIONS_XGB = cfg.dataPath + path_distributed_MB_glamos + 'MBM/glamos_dems/'

In [None]:
glaciers_in_glamos = os.listdir(PATH_PREDICTIONS_NN)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

In [None]:
df_all_xgb = process_geodetic_mass_balance_comparison(
    glacier_list=os.listdir(PATH_PREDICTIONS_XGB),
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_XGB,  # or another path if needed
    cfg=cfg)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_NN,  # or another path if needed
    cfg=cfg)


In [None]:
# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_nn = mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"],
                             squared=False)
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

# Drop rows where any required columns are NaN
df_all_xgb = df_all_xgb.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_xgb = mean_squared_error(df_all_xgb["Geodetic MB"],
                              df_all_xgb["MBM MB"],
                              squared=False)
corr_xgb = np.corrcoef(df_all_xgb["Geodetic MB"], df_all_xgb["MBM MB"])[0, 1]

# Define figure and axes
fig, axs = plt.subplots(1, 2, figsize=(10, 8), sharex=True, sharey=True)

# Plot MBM MB vs Geodetic MB
plot_scatter_geodetic_MB(df_all_xgb, 'GLACIER', True, axs[1], "MBM MB",
                         rmse_xgb, corr_xgb)

axs[1].set_title('XGB predictions', fontsize=20)
axs[1].set_ylabel('Predicted mass balance [m w.e.]', fontsize=18)
axs[1].set_xlabel('Geodetic mass balance [m w.e.]', fontsize=18)

plot_scatter_geodetic_MB(df_all_nn, 'GLACIER', True, axs[0], "MBM MB", rmse_nn,
                         corr_nn)
axs[0].set_ylabel('Predicted mass balance [m w.e.]', fontsize=18)
axs[0].set_xlabel('Geodetic mass balance [m w.e.]', fontsize=18)
axs[0].set_title('NN predictions', fontsize=20)

# Adjust legend outside of plot
handles, labels = axs[1].get_legend_handles_labels()
fig.legend(handles,
           labels,
           bbox_to_anchor=(1.05, 1),
           loc="upper left",
           borderaxespad=0.,
           ncol=2,
           fontsize=20)

plt.tight_layout()
plt.show()

### Look at maps:

In [None]:
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv,
                          "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

In [None]:
# Example usage
GLACIER_NAME = 'aletsch'
df_xgb = df_all_xgb[df_all_xgb.GLACIER == GLACIER_NAME]
df_nn = df_all_nn[df_all_nn.GLACIER == GLACIER_NAME]

fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)

plot_scatter_comparison(axs[0],
                        df_xgb,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(XGB)")
plot_scatter_comparison(axs[1],
                        df_nn,
                        GLACIER_NAME,
                        color_mbm=color_annual,
                        color_glamos=color_winter,
                        title_suffix="(NN)")

plt.tight_layout()
plt.show()

In [None]:
# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

MBM_glwmb_nn = mbm_glwd_pred(PATH_PREDICTIONS_NN, GLACIER_NAME)
MBM_glwmb_nn.rename(columns={"MBM Balance": "MBM Balance NN"}, inplace=True)
MBM_glwmb_xgb = mbm_glwd_pred(PATH_PREDICTIONS_XGB, GLACIER_NAME)
MBM_glwmb_xgb.rename(columns={"MBM Balance": "MBM Balance XGB"}, inplace=True)

# Merge with GLAMOS data
MBM_glwmb_nn = MBM_glwmb_nn.join(GLAMOS_glwmb)

# Drop NaN values to avoid plotting errors
MBM_glwmb_nn = MBM_glwmb_nn.dropna()

MBM_glwmb = MBM_glwmb_nn.join(MBM_glwmb_xgb)
# Plot the data
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
MBM_glwmb.plot(ax=axs[0],
               y=['MBM Balance XGB', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])
MBM_glwmb.plot(ax=axs[1],
               y=['MBM Balance NN', 'GLAMOS Balance'],
               marker="o",
               color=[color_annual, color_winter])

for ax in axs:
    ax.set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
    ax.set_ylabel("Mass Balance [m w.e.]", fontsize=18)
    ax.set_xlabel("Year", fontsize=18)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend(fontsize=14)

axs[0].set_title(f"{GLACIER_NAME.capitalize()} Glacier (XGB)", fontsize=16)
axs[1].set_title(f"{GLACIER_NAME.capitalize()} Glacier (NN)", fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
for year in MBM_glwmb_nn.index:
    plot_mass_balance_comparison_annual(
        glacier_name=GLACIER_NAME,
        year=year,
        cfg=cfg,
        df_stakes=df_stakes,
        path_distributed_mb=path_distributed_MB_glamos,
        path_pred_xgb=PATH_PREDICTIONS_XGB,
        path_pred_nn=PATH_PREDICTIONS_NN,
        get_glamos_func=get_GLAMOS_glwmb,
        get_pred_func=get_predicted_mb,
        get_glamos_pred_func=get_predicted_mb_glamos,
        load_grid_func=load_grid_file,
        to_wgs84_func=transform_xarray_coords_lv95_to_wgs84,
        apply_filter_func=apply_gaussian_filter,
        get_colormaps_func=get_color_maps)


In [None]:
glacier_name = 'gorner'
year = 2008
stake_loc = df_stakes[df_stakes.GLACIER == glacier_name][[
    'POINT_LAT', 'POINT_LON'
]]

df_sgi_grid = pd.read_parquet(
    cfg.dataPath +
    '../data/GLAMOS/topo/gridded_topo_inputs/SGI_grid/gorner/gorner_grid_2023.parquet'
)

sns.scatterplot(df_sgi_grid, x='POINT_LON', y='POINT_LAT', s=1, color='black')
sns.scatterplot(stake_loc,
                x='POINT_LON',
                y='POINT_LAT',
                s=20,
                color='red',
                label='Stake location')