## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging

from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
import torch.nn as nn
from skorch.helper import SliceDataset

import matplotlib.pyplot as plt
import geopandas as gpd

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.nn_helpers import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# Glacier outlines:
glacier_outline_sgi = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'SGI_2016_glaciers_copy.shp'))  # Load the shapefile
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Read GL data:

In [None]:
data_glamos = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                          'CH_wgms_dataset_all.csv')
data_glamos = data_glamos.dropna()

# Glaciers with data of potential clear sky radiation
# Format to same names as stakes:
glDirect = np.sort([
    re.search(r'xr_direct_(.*?)\.zarr', f).group(1)
    for f in os.listdir(cfg.dataPath + path_pcsr + 'zarr/')
])

restgl = np.sort(Diff(list(glDirect), list(data_glamos.GLACIER.unique())))

# Filter out glaciers without data:
data_glamos = data_glamos[data_glamos.GLACIER.isin(glDirect)]

print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

data_glamos.head(2)

In [None]:
# get number of measurements per glacier:
glacier_info = data_glamos.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = data_glamos.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = data_glamos.groupby(['GLACIER', 'PERIOD'
                                      ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in test_glaciers else 'Train', axis=1)
glacier_info.head(2)

### Assign glaciers to river basin names:

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# === Load SGI region geometries ===
SGI_regions = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'sgi_regions.geojson'))

# Clean object columns
SGI_regions[SGI_regions.select_dtypes(include='object').columns] = \
    SGI_regions.select_dtypes(include='object').apply(lambda col: col.str.strip())

SGI_regions = SGI_regions.drop_duplicates().dropna()
SGI_regions = SGI_regions.set_index('pk_sgi_region')

# === Map to Level 0 river basins ===
catchment_lv0 = {
    'A': 'Rhine',
    'B': 'Rhone',
    'C': 'Po',
    'D': 'Adige',
    'E': 'Danube'
}
rgi_df['rvr_lv0'] = rgi_df['sgi-id'].str[0].map(catchment_lv0)


# === Map to Level 1 river basins using SGI regions ===
def get_river_basin(sgi_id):
    key = sgi_id.split('-')[0]
    if key not in SGI_regions.index:
        return None
    basin = SGI_regions.loc[key, 'river_basin_name']
    if isinstance(basin, pd.Series):
        return basin.dropna().unique()[0] if not basin.dropna().empty else None
    return basin if pd.notna(basin) else None


rgi_df['rvr_lv1'] = rgi_df['sgi-id'].apply(get_river_basin)

# Final formatting
rgi_df = rgi_df.reset_index().rename(columns={
    'short_name': 'GLACIER'
}).set_index('GLACIER')
rgi_df.head()

In [None]:
glacier_info = glacier_info.merge(rgi_df[['rvr_lv0', 'rvr_lv1']],
                                  on='GLACIER',
                                  how='left')
glacier_info.head()

## Models:
### XGBoost model:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = True
dataloader_gl = process_or_load_data(run_flag=RUN,
                                     data_glamos=data_glamos,
                                     paths=paths,
                                     cfg=cfg,
                                     vois_climate=vois_climate,
                                     vois_topographical=vois_topographical)
data_monthly = dataloader_gl.data

In [None]:
param_init_xgb = {
    'device': 'cuda:0',
    'tree_method': 'hist',
    "random_state": cfg.seed,
    "n_jobs": cfg.numJobs
}

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

In [None]:
custom_params = {'learning_rate': 0.01, 'max_depth': 6, 'n_estimators': 800}

# Feature columns:
feature_columns = [
    'ELEVATION_DIFFERENCE'
] + list(vois_climate) + list(vois_topographical) + ['pcsr']
# feature_columns = ['ELEVATION_DIFFERENCE'
#                    ] + list(vois_climate) + list(vois_topographical)
all_columns = feature_columns + cfg.fieldsNotFeatures
df_X_train_subset = train_set['df_X'][all_columns]
print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of testing dataset:', test_set['df_X'][all_columns].shape)
print('Running with features:', feature_columns)

params = {**param_init_xgb, **custom_params}
print(params)
custom_xgb_model = mbm.models.CustomXGBoostRegressor(cfg, **params)

# Fit on train data:
custom_xgb_model.fit(train_set['df_X'][all_columns], train_set['y'])

In [None]:
# Make predictions on test
custom_xgb_model = custom_xgb_model.set_params(device='cpu')
features_test, metadata_test = custom_xgb_model._create_features_metadata(
    test_set['df_X'][all_columns])
y_pred = custom_xgb_model.predict(features_test)
print('Shape of the test:', features_test.shape)

# Make predictions aggr to meas ID:
y_pred_agg = custom_xgb_model.aggrPredict(metadata_test, features_test)

# Calculate scores
score = custom_xgb_model.score(test_set['df_X'][all_columns],
                               test_set['y'])  # negative
print('Overall score:', np.abs(score))

grouped_ids_xgb = getDfAggregatePred(test_set, y_pred_agg, all_columns)
PlotPredictions(grouped_ids_xgb, y_pred, metadata_test, test_set,
                custom_xgb_model)
plt.suptitle(f'MBM tested on {test_glaciers}', fontsize=20)
plt.tight_layout()

### NN model:

#### Initialize model:

In [None]:
# Remove columns that are metadata or neither used in metadata or features
feature_columns = list(train_set['df_X'].columns.difference(cfg.metaData).drop(
    cfg.notMetaDataNotFeatures))

all_columns = feature_columns + cfg.fieldsNotFeatures

nInp = len(feature_columns)
cfg.setFeatures(feature_columns)

params_NN = {
    'lr': 0.01,
    'batch_size': 256,
    'optimizer': torch.optim.SGD,
    'module__0__out_features': 16,
    'module__2__out_features': 8
}

network = nn.Sequential(
    nn.Linear(nInp, params_NN['module__0__out_features']),
    nn.ReLU(),
    nn.Linear(params_NN['module__0__out_features'],
              params_NN['module__2__out_features']),
    nn.ReLU(),
    nn.Linear(params_NN['module__2__out_features'], 1),
)

early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=10,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

param_init_NN = {'device': 'cuda:0'}

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


args_NN = {
    'module': network,
    'nbFeatures': nInp,
    'train_split': my_train_split,
    'batch_size': params_NN['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params_NN['lr'],
    'max_epochs': 300,
    'optimizer': params_NN['optimizer'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

#### Load model:

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-06-02.pt"  # Replace with actual date if needed

# current_date = datetime.now().strftime("%Y-%m-%d")
# model_filename = f"nn_model_{current_date}.pt"

custom_NN_model = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **{
        **args_NN,
        **param_init_NN
    },
)
custom_NN_model = custom_NN_model.set_params(device='cpu')
custom_NN_model = custom_NN_model.to('cpu')

In [None]:
# Create features and metadata
features_test, metadata_test = custom_NN_model._create_features_metadata(
    test_set['df_X'][all_columns])

# Ensure all tensors are on CPU if they are torch tensors
if hasattr(features_test, 'cpu'):
    features_test = features_test.cpu()

# Ensure targets are also on CPU
targets_test = test_set['y']
if hasattr(targets_test, 'cpu'):
    targets_test = targets_test.cpu()

# Create the dataset
dataset_test = mbm.data_processing.AggregatedDataset(cfg,
                                                     features=features_test,
                                                     metadata=metadata_test,
                                                     targets=targets_test)

dataset_test = [
    SliceDataset(dataset_test, idx=0),
    SliceDataset(dataset_test, idx=1)
]

# Make predictions aggr to meas ID
y_pred = custom_NN_model.predict(dataset_test[0])
y_pred_agg = custom_NN_model.aggrPredict(dataset_test[0])

batchIndex = np.arange(len(y_pred_agg))
y_true = np.array([e for e in dataset_test[1][batchIndex]])

# Calculate scores
score = custom_NN_model.score(dataset_test[0], dataset_test[1])
mse, rmse, mae, pearson = custom_NN_model.evalMetrics(y_pred, y_true)

# Aggregate predictions
id = dataset_test[0].dataset.indexToId(batchIndex)
data = {
    'target': [e[0] for e in dataset_test[1]],
    'ID': id,
    'pred': y_pred_agg
}
grouped_ids_NN = pd.DataFrame(data)

# Add period
periods_per_ids = test_set['df_X'][all_columns].groupby('ID')['PERIOD'].first()
grouped_ids_NN = grouped_ids_NN.merge(periods_per_ids, on='ID')

# Add glacier name
glacier_per_ids = test_set['df_X'][all_columns].groupby(
    'ID')['GLACIER'].first()
grouped_ids_NN = grouped_ids_NN.merge(glacier_per_ids, on='ID')

# Add YEAR
years_per_ids = test_set['df_X'][all_columns].groupby('ID')['YEAR'].first()
grouped_ids_NN = grouped_ids_NN.merge(years_per_ids, on='ID')

In [None]:
PlotPredictions_NN(grouped_ids_NN)

## Compare models:

### Scatter on test glaciers:

In [None]:
fig = plt.figure(figsize=(15, 5))

grouped_ids_annual_NN = grouped_ids_NN[grouped_ids_NN.PERIOD == 'annual']
y_true_mean_NN = grouped_ids_annual_NN['target']
y_pred_agg_NN = grouped_ids_annual_NN['pred']
scores_annual_NN = {
    'mse': mean_squared_error(y_true_mean_NN, y_pred_agg_NN),
    'rmse': root_mean_squared_error(y_true_mean_NN, y_pred_agg_NN),
    'mae': mean_absolute_error(y_true_mean_NN, y_pred_agg_NN),
    'pearson_corr': np.corrcoef(y_true_mean_NN, y_pred_agg_NN)[0, 1]
}

grouped_ids_annual_xgb = grouped_ids_xgb[grouped_ids_xgb.PERIOD == 'annual']
y_true_mean_xgb = grouped_ids_annual_xgb['target']
y_pred_agg_xgb = grouped_ids_annual_xgb['pred']
scores_annual_xgb = {
    'mse': mean_squared_error(y_true_mean_xgb, y_pred_agg_xgb),
    'rmse': root_mean_squared_error(y_true_mean_xgb, y_pred_agg_xgb),
    'mae': mean_absolute_error(y_true_mean_xgb, y_pred_agg_xgb),
    'pearson_corr': np.corrcoef(y_true_mean_xgb, y_pred_agg_xgb)[0, 1]
}

# Winter
grouped_ids_winter_NN = grouped_ids_NN[grouped_ids_NN.PERIOD == 'winter']
y_true_mean_NN_w = grouped_ids_winter_NN['target']
y_pred_agg_NN_w = grouped_ids_winter_NN['pred']
scores_winter_NN = {
    'mse': mean_squared_error(y_true_mean_NN_w, y_pred_agg_NN_w),
    'rmse': root_mean_squared_error(y_true_mean_NN_w, y_pred_agg_NN_w),
    'mae': mean_absolute_error(y_true_mean_NN_w, y_pred_agg_NN_w),
    'pearson_corr': np.corrcoef(y_true_mean_NN_w, y_pred_agg_NN_w)[0, 1]
}

grouped_ids_winter_xgb = grouped_ids_xgb[grouped_ids_xgb.PERIOD == 'winter']
y_true_mean_xgb_w = grouped_ids_winter_xgb['target']
y_pred_agg_xgb_w = grouped_ids_winter_xgb['pred']
scores_winter_xgb = {
    'mse': mean_squared_error(y_true_mean_xgb_w, y_pred_agg_xgb_w),
    'rmse': root_mean_squared_error(y_true_mean_xgb_w, y_pred_agg_xgb_w),
    'mae': mean_absolute_error(y_true_mean_xgb_w, y_pred_agg_xgb_w),
    'pearson_corr': np.corrcoef(y_true_mean_xgb_w, y_pred_agg_xgb_w)[0, 1]
}

ax1 = plt.subplot(1, 2, 1)
ax1.set_title('XGB predictions', fontsize=20)
predVSTruth(ax1,
            grouped_ids_xgb,
            scores_annual_xgb,
            hue='PERIOD',
            add_legend=False,
            palette=[color_dark_blue, color_pink])

legend_xgb = "\n".join(
    ((r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
      (scores_annual_xgb["rmse"], scores_winter_xgb["rmse"])),
     (r"$\mathrm{\rho_a}=%.3f$, $\mathrm{\rho_w}=%.3f$" %
      (scores_annual_xgb["pearson_corr"], scores_winter_xgb["pearson_corr"]))))
ax1.text(0.03,
         0.98,
         legend_xgb,
         transform=ax1.transAxes,
         verticalalignment="top",
         fontsize=20,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

ax2 = plt.subplot(1, 2, 2)
ax2.set_title('NN predictions', fontsize=20)
predVSTruth(ax2,
            grouped_ids_NN,
            scores_annual_NN,
            hue='PERIOD',
            add_legend=False,
            palette=[color_dark_blue, color_pink])

legend_NN = "\n".join(
    ((r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
      (scores_annual_NN["rmse"], scores_winter_NN["rmse"])),
     (r"$\mathrm{\rho_a}=%.3f$, $\mathrm{\rho_w}=%.3f$" %
      (scores_annual_NN["pearson_corr"], scores_winter_NN["pearson_corr"]))))
ax2.text(0.03,
         0.98,
         legend_NN,
         transform=ax2.transAxes,
         verticalalignment="top",
         fontsize=20,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

plt.tight_layout()

### Geodetic MB:

In [None]:
PATH_PREDICTIONS_NN = cfg.dataPath + path_distributed_MB_glamos + 'MBM/glamos_dems_NN/'
PATH_PREDICTIONS_XGB = cfg.dataPath + path_distributed_MB_glamos + 'MBM/glamos_dems/'

In [None]:
glaciers_in_glamos = os.listdir(PATH_PREDICTIONS_NN)

geodetic_mb = get_geodetic_MB(cfg)

# filter to glaciers with potential clear sky radiation data
geodetic_mb = geodetic_mb[geodetic_mb.glacier_name.isin(glDirect)]

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, _ = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)

glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
# print len and list
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

In [None]:
df_all_xgb = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath+path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=test_glaciers,
    path_predictions=PATH_PREDICTIONS_XGB,  # or another path if needed
    cfg = cfg
)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath+path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=test_glaciers,
    path_predictions=PATH_PREDICTIONS_NN,  # or another path if needed
    cfg = cfg
)


In [None]:
# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB', 'GLAMOS MB'])

# Compute RMSE and Pearson correlation
rmse_nn = mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"],
                             squared=False)
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

# Drop rows where any required columns are NaN
df_all_xgb = df_all_xgb.dropna(subset=['Geodetic MB', 'MBM MB', 'GLAMOS MB'])

# Compute RMSE and Pearson correlation
rmse_xgb = mean_squared_error(df_all_xgb["Geodetic MB"],
                              df_all_xgb["MBM MB"],
                              squared=False)
corr_xgb = np.corrcoef(df_all_xgb["Geodetic MB"], df_all_xgb["MBM MB"])[0, 1]

# Define figure and axes
fig, axs = plt.subplots(1, 2, figsize=(10, 8), sharex=True)

# Plot MBM MB vs Geodetic MB
plot_scatter(df_all_xgb, 'GLACIER', True, axs[0], "MBM MB", rmse_xgb, corr_xgb)

axs[0].set_title('XGB predictions', fontsize=20)
axs[0].set_ylabel('Predicted mass balance [m w.e.]', fontsize=18)
axs[0].set_xlabel('Geodetic mass balance [m w.e.]', fontsize=18)

plot_scatter(df_all_nn, 'GLACIER', True, axs[1], "MBM MB", rmse_nn, corr_nn)
axs[1].set_ylabel('Predicted mass balance [m w.e.]', fontsize=18)
axs[1].set_xlabel('Geodetic mass balance [m w.e.]', fontsize=18)
axs[1].set_title('NN predictions', fontsize=20)

# Adjust legend outside of plot
handles, labels = axs[0].get_legend_handles_labels()
fig.legend(handles,
           labels,
           bbox_to_anchor=(1.05, 1),
           loc="upper left",
           borderaxespad=0.,
           ncol=2,
           fontsize=20)

plt.tight_layout()
plt.show()