## Setting Up:

In [None]:
import pandas as pd
import os
import warnings
from tqdm.notebook import tqdm
import re
from calendar import month_abbr
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import cartopy.io.img_tiles as cimgt
from cartopy import crs as ccrs, feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.enums import Resampling as RResampling
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import geopandas as gpd
from matplotlib.patches import Wedge, Patch
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import pickle 

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.nn_helpers import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    # "aspect", # OGGM
    # "slope", # OGGM
    "aspect_sgi",  # SGI
    "slope_sgi",  # SGI
    "hugonnet_dhdt",  # OGGM
    "consensus_ice_thickness",  # OGGM
    "millan_v",  # OGGM
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

data_glamos.head(2)

In [None]:
# Glacier outlines:
glacier_outline_sgi = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'SGI_2016_glaciers_copy.shp'))  # Load the shapefile
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

In [None]:
# get number of measurements per glacier:
glacier_info = data_glamos.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = data_glamos.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = data_glamos.groupby(['GLACIER', 'PERIOD'
                                      ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in test_glaciers else 'Train', axis=1)
glacier_info.head(2)

### Assign glaciers to river basin names:

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# === Load SGI region geometries ===
SGI_regions = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'sgi_regions.geojson'))

# Clean object columns
SGI_regions[SGI_regions.select_dtypes(include='object').columns] = \
    SGI_regions.select_dtypes(include='object').apply(lambda col: col.str.strip())

SGI_regions = SGI_regions.drop_duplicates().dropna()
SGI_regions = SGI_regions.set_index('pk_sgi_region')

# === Map to Level 0 river basins ===
catchment_lv0 = {
    'A': 'Rhine',
    'B': 'Rhone',
    'C': 'Po',
    'D': 'Adige',
    'E': 'Danube'
}
rgi_df['rvr_lv0'] = rgi_df['sgi-id'].str[0].map(catchment_lv0)


# === Map to Level 1 river basins using SGI regions ===
def get_river_basin(sgi_id):
    key = sgi_id.split('-')[0]
    if key not in SGI_regions.index:
        return None
    basin = SGI_regions.loc[key, 'river_basin_name']
    if isinstance(basin, pd.Series):
        return basin.dropna().unique()[0] if not basin.dropna().empty else None
    return basin if pd.notna(basin) else None


rgi_df['rvr_lv1'] = rgi_df['sgi-id'].apply(get_river_basin)

# Final formatting
rgi_df = rgi_df.reset_index().rename(columns={
    'short_name': 'GLACIER'
}).set_index('GLACIER')
rgi_df.head()

In [None]:
glacier_info = glacier_info.merge(rgi_df[['rvr_lv0', 'rvr_lv1']],
                                  on='GLACIER',
                                  how='left')
glacier_info.head()

## Intro & methods:

### Geoplots:


#### sqrt scaling:

In [None]:
# Open the original raster
tif_name = "landesforstinventar-vegetationshoehenmodell_relief_sentinel_2024_2056.tif"
tif_path = os.path.join(cfg.dataPath, 'GLAMOS/RGI/', tif_name)

# Desired output resolution (in degrees)
# Approx. 100 m in degrees: ~0.0009 deg
target_res = 0.0009
output_crs = "EPSG:4326"  # WGS84

with rasterio.open(tif_path) as src:
    # Calculate transform and shape with coarser resolution
    transform, width, height = calculate_default_transform(
        src.crs,
        output_crs,
        src.width,
        src.height,
        *src.bounds,
        resolution=target_res)

    # Set up destination array and metadata
    kwargs = src.meta.copy()
    kwargs.update({
        'crs': output_crs,
        'transform': transform,
        'width': width,
        'height': height
    })

    # Prepare empty destination array
    destination = np.empty((height, width), dtype=src.dtypes[0])

    # Reproject with coarsening
    reproject(
        source=rasterio.band(src, 1),
        destination=destination,
        src_transform=src.transform,
        src_crs=src.crs,
        dst_transform=transform,
        dst_crs=output_crs,
        resampling=Resampling.
        average  # average to reduce noise when downsampling
    )

    extent = [
        transform[2], transform[2] + transform[0] * width,
        transform[5] + transform[4] * height, transform[5]
    ]

In [None]:
# ---- 1. Preprocessing ----
# Square-root scaling of number of measurements
glacier_info['sqrt_size'] = np.sqrt(glacier_info['Nb. measurements'])

# Cache dataset-wide min and max
sqrt_min = glacier_info['sqrt_size'].min()
sqrt_max = glacier_info['sqrt_size'].max()

# Define the desired marker size range in points^2
sizes = (30, 1500)  # min and max scatter size


# Function to scale individual values consistently
def scaled_size(val, min_out=sizes[0], max_out=sizes[1]):
    sqrt_val = np.sqrt(val)
    if sqrt_max == sqrt_min:
        return (min_out + max_out) / 2
    return min_out + (max_out - min_out) * ((sqrt_val - sqrt_min) /
                                            (sqrt_max - sqrt_min))


# Apply scaling to full dataset for the actual plot
glacier_info['scaled_size'] = glacier_info['Nb. measurements'].apply(
    scaled_size)

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

#latN, latS = 48, 45.8
latN, latS = 47.1, 45.8
lonW, lonE = 5.8, 10.5
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
# ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

# Add the image to the cartopy map

masked_destination = np.ma.masked_where(destination == 0, destination)
cmap = plt.cm.gray
cmap.set_bad(color='white')  # Set masked (bad) values to white
ax2.imshow(
    masked_destination,
    origin='upper',
    extent=extent,
    transform=ccrs.PlateCarree(),  # Assuming raster is in WGS84
    cmap=cmap,  # or any other colormap
    alpha=0.6,  # transparency
    zorder=0)

# Glacier outlines
glacier_outline_sgi.plot(ax=ax2, transform=projPC, color='black')

# ---- 3. Scatterplot ----
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

g = sns.scatterplot(
    data=glacier_info,
    x='POINT_LON',
    y='POINT_LAT',
    size='scaled_size',
    hue='Train/Test glacier',
    sizes=sizes,
    alpha=0.6,
    palette=custom_palette,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

# ---- 5. Custom Combined Legend ----

# Hue legend handles
handles, labels = g.get_legend_handles_labels()
expected_labels = list(custom_palette.keys())
hue_entries = [(h, l) for h, l in zip(handles, labels) if l in expected_labels]

# Size legend values and handles
size_values = [30, 100, 1000, 6000]
size_handles = [
    Line2D(
        [],
        [],
        marker='o',
        linestyle='None',
        markersize=np.sqrt(scaled_size(val)),  # matplotlib uses radius
        markerfacecolor='gray',
        alpha=0.6,
        label=f'{val}') for val in size_values
]

# Separator label
separator_handle = Patch(facecolor='none',
                         edgecolor='none',
                         label='Nb. measurements')

# Combine all legend entries
# combined_handles = [h for h, _ in hue_entries] + [separator_handle] + size_handles
# combined_labels = [l for _, l in hue_entries] + ['Nb. measurements'] + [str(v) for v in size_values]

# same but without separator
combined_handles = [h for h, _ in hue_entries] + size_handles
combined_labels = [l for _, l in hue_entries] + [str(v) for v in size_values]

# Final legend
ax2.legend(combined_handles,
           combined_labels,
           title='Nb. measurements',
           loc='lower right',
           frameon=True,
           fontsize=18,
           title_fontsize=18,
           borderpad=1.2,
           labelspacing=1.2,
           ncol=3)
# ax2.set_title('Glacier measurement locations', fontsize = 25)
plt.tight_layout()
plt.show()


In [None]:
# Number of measurements per year:
data_glamos.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack().plot(
    kind='bar',
    stacked=True,
    figsize=(20, 5),
    color=[color_dark_blue, '#abd9e9'])
# plt.title('Number of measurements per year for all glaciers', fontsize = 25)
# get legend
plt.legend(title='Period', fontsize=18, title_fontsize=20, ncol=2)

### Input data:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(run_flag=RUN,
                                     data_glamos=data_glamos,
                                     paths=paths,
                                     cfg=cfg,
                                     vois_climate=vois_climate,
                                     vois_topographical=vois_topographical)
data_monthly = dataloader_gl.data

In [None]:
test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in test_glaciers if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in test_glaciers]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(test_glaciers)]
print('Size of test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('Train:')
print('Number of winter and annual samples:', len(data_train))
print('Number of annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of winter and annual samples:', len(data_test))
print('Number of annual samples:', len(data_test_annual))
print('Number of winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

#### Heatmap annual:

In [None]:
plotHeatmap(test_glaciers, data_glamos, glacierCap, period='annual')

In [None]:
# get elevation of glaciers:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

# Plot elevation:
fig = plt.figure(figsize=(10, 2))
ax = plt.subplot(1, 1, 1)
sns.lineplot(gl_per_el.sort_values(ascending=True),
             ax=ax,
             color='gray',
             marker='v')
ax.set_xticklabels('', rotation=90)
ax.set_ylabel('')
ax.set_xlabel('')

In [None]:
len(gl_per_el)

## Results:

### Load NN model:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_NN.csv')
data_monthly = dataloader_gl.data

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

features_topo = [
    'ELEVATION_DIFFERENCE',
    'pcsr',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate)

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

# Because CH has some extra columns, we need to cut those
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# param_init = {'device': 'cuda:0'}
param_init = {'device': 'cpu'}  # Use CPU for training
nInp = len(feature_columns)

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-07-08.pt"  # Replace with actual date if needed

# read pickle with params
params_filename = "nn_params_2025-07-08.pkl"  # Replace with actual date if needed
with open(f"models/{params_filename}", "rb") as f:
    custom_params = pickle.load(f)

params = custom_params

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 300,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

custom_NN_model = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **{
        **args,
        **param_init
    },
)
custom_NN_model = custom_NN_model.set_params(device='cpu')
custom_NN_model = custom_NN_model.to('cpu')

### Scatter plot test:

In [None]:
grouped_ids_NN, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    custom_NN_model, df_X_test_subset, test_set['y'], cfg, mbm)

months_per_id = test_set['df_X'][all_columns].groupby('ID')['MONTHS'].unique()
grouped_ids_NN = grouped_ids_NN.merge(months_per_id, on='ID')

scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN, target_col='target', pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids_NN,
                               scores_annual=scores_annual_NN,
                               scores_winter=scores_winter_NN,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred)

In [None]:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

test_gl_per_el = gl_per_el[test_glaciers].sort_values().index
elvs = gl_per_el[test_glaciers].sort_values()

fig, axs = plt.subplots(3, 3, figsize=(20, 15))

PlotIndividualGlacierPredVsTruth(grouped_ids_NN,
                                 color_annual=color_dark_blue,
                                 color_winter=color_pink,
                                 axs=axs,
                                 custom_order=test_gl_per_el)

for i, ax in enumerate(axs.flatten()):
    test_gl = test_gl_per_el[i]
    el = elvs[test_gl]
    ax.set_title(f'{test_gl.capitalize()}: {round(el, 2)} m', fontsize=24)

### Geodetic mass balance:

In [None]:
PATH_PREDICTIONS_NN = cfg.dataPath + path_distributed_MB_glamos + 'MBM/glamos_dems_NN_full/'

glaciers_in_glamos = os.listdir(PATH_PREDICTIONS_NN)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=test_glaciers,
    path_predictions=PATH_PREDICTIONS_NN,  # or another path if needed
    cfg=cfg)

# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])
df_all_nn = df_all_nn.sort_values(by='Area')
df_all_nn['GLACIER'] = df_all_nn['GLACIER'].apply(lambda x: x.capitalize())

# Compute RMSE and Pearson correlation
rmse_nn = mean_squared_error(df_all_nn["Geodetic MB"],
                             df_all_nn["MBM MB"],
                             squared=False)
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]


In [None]:
fig = plt.figure(figsize=(6, 10))
ax = plt.subplot(1, 1, 1)

sns.scatterplot(
    data=df_all_nn,
    x="Geodetic MB",
    y='MBM MB',
    hue='GLACIER',
    # size="Area",
    # sizes=(10, 1000),
    alpha=0.7,
    ax=ax,
    palette=sns.color_palette("hls", 17))

# Identity line (diagonal y=x)
# diagonal line
pt = (0, 0)
ax.axline(pt, slope=1, color="grey", linestyle="--", linewidth=1)

# Grid and axis labels
ax.axvline(0, color="grey", linestyle="--", linewidth=1)
ax.axhline(0, color="grey", linestyle="--", linewidth=1)
ax.grid(True, linestyle="--", linewidth=0.5)
ax.set_xlabel("Geodetic MB [m w.e.]")

# RMSE and correlation annotation
legend_text = "\n".join(
    (r"$\mathrm{RMSE}=%.3f$" % rmse_nn, r"$\mathrm{\rho}=%.3f$" % corr_nn))
props = dict(boxstyle="round", facecolor="white", alpha=0.5)
ax.text(0.03,
        0.94,
        legend_text,
        transform=ax.transAxes,
        verticalalignment="top",
        fontsize=18,
        bbox=props)
ax.legend([], [], frameon=False)

# Adjust legend outside of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles,
          labels,
          bbox_to_anchor=(1.05, 1),
          loc="upper left",
          borderaxespad=0.,
          ncol=2,
          fontsize=14)

In [None]:
# Example bins in km²
bins = [0, 5, 10, 100, np.inf]
labels = ['<5', '5–10', '10–100', '>100']

df_all_nn['Area_bin'] = pd.cut(df_all_nn['Area'],
                               bins=bins,
                               labels=labels,
                               right=False)

# Plot setup
fig, axs = plt.subplots(1, 3, figsize=(15, 6), sharex=True, sharey=True)
axs = axs.flatten()

# Unique bin labels (ensure consistent order and drop NaN)
unique_bins = df_all_nn['Area_bin'].dropna().unique().tolist()

for i, area_bin in enumerate(unique_bins[:3]):  # plot only up to 3 bins
    ax = axs[i]
    df_all_nn_bin = df_all_nn[df_all_nn['Area_bin'] == area_bin]
    
    sns.scatterplot(
        data=df_all_nn_bin,
        x="Geodetic MB",
        y="MBM MB",
        hue="GLACIER",
        alpha=0.7,
        ax=ax,
        palette = sns.color_palette("Paired", len(df_all_nn_bin.GLACIER.unique())),
    )

    # Grid and axis labels
    ax.axvline(0, color="grey", linestyle="--", linewidth=1)
    ax.axhline(0, color="grey", linestyle="--", linewidth=1)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.set_xlabel("Geodetic MB [m w.e.]")
    ax.set_ylabel("MBM MB [m w.e.]")
    ax.set_title(f"Area: {area_bin} km²")

    # ax.legend(
    #     loc="lower center",
    #     borderaxespad=0.5,
    #     fontsize=12,
    #     bbox_to_anchor=(0.5, -0.3),
    #     ncol=2
    # )
    ax.legend([], [], frameon=False)  # Hide legend for individual plots

# After plotting, before plt.tight_layout()
for ax in axs:
    # Get combined x and y limits across both axes
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    min_limit = min(xmin, ymin)
    max_limit = max(xmax, ymax)

    # Set symmetric limits
    ax.set_xlim(min_limit, max_limit)
    ax.set_ylim(min_limit, max_limit)

    # Now add the identity line
    ax.axline((0, 0), slope=1, color="grey", linestyle="--", linewidth=1)

plt.tight_layout()
plt.show()