## Setting Up:

In [None]:
import sys, os

sys.path.append(os.path.join(os.getcwd(),
                             '../../'))  # Add root of repo to import MBM
import csv
from functools import partial

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
import itertools
import random
import pickle
from collections import Counter
import ast

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

## Read GL data:

In [None]:
data_glamos = getStakesData(cfg)

months_head_pad, months_tail_pad = mbm.data_processing.utils._compute_head_tail_pads_from_df(
    data_glamos)

### Read input data:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(data_monthly.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = data_monthly[data_monthly.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = data_monthly[data_monthly.GLACIER.isin(train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']

data_test = test_set['df_X']
data_test['y'] = test_set['y']

## Models:

In [None]:
MONTHLY_COLS = [
    't2m',
    'tp',
    'slhf',
    'sshf',
    'ssrd',
    'fal',
    'str',
    'pcsr',
    'ELEVATION_DIFFERENCE',
]
STATIC_COLS = [
    'aspect_sgi', 'slope_sgi', 'hugonnet_dhdt', 'consensus_ice_thickness',
    'millan_v'
]

feature_columns = MONTHLY_COLS + STATIC_COLS

### LSTM:

In [None]:
seed_all(cfg.seed)

df_train = data_train.copy()
df_train['PERIOD'] = df_train['PERIOD'].str.strip().str.lower()

df_test = data_test.copy()
df_test['PERIOD'] = df_test['PERIOD'].str.strip().str.lower()

# --- build train dataset from dataframe ---
ds_train = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_train,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

ds_test = mbm.data_processing.MBSequenceDataset.from_dataframe(
    df_test,
    MONTHLY_COLS,
    STATIC_COLS,
    months_tail_pad=months_tail_pad,
    months_head_pad=months_head_pad,
    expect_target=True)

train_idx, val_idx = mbm.data_processing.MBSequenceDataset.split_indices(
    len(ds_train), val_ratio=0.2, seed=cfg.seed)

Fm = ds_train.Xm.shape[-1]
Fs = ds_train.Xs.shape[-1]

#### Simple LSTM:

In [None]:
custom_params = {
    'Fm': 9,
    'Fs': 5,
    'hidden_size': 128,
    'num_layers': 1,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 0,
    'static_hidden': None,
    'static_dropout': None,
    'lr': 0.0001,
    'weight_decay': 0.0,
    'loss_name': 'neutral',
    'loss_spec': None,
    'two_heads': False,
    'head_dropout': 0
}

# Build dataloaders:
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

# --- make train/val loaders and get split indices (scalers fit on TRAIN, applied to whole ds) ---
train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build test dataset and loader (reuses train scalers) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
#model_filename = f"models/lstm_model_{current_date}_simple.pt"
model_filename = f"models/lstm_model_2025-09-19_simple.pt"
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Load the best weights
best_state = torch.load(model_filename, map_location=device)
model.load_state_dict(best_state)

# Evaluate on test
test_metrics_LSTM_simple, grouped_ids_LSTM_simple = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics_LSTM_simple[
    'RMSE_annual'], test_metrics_LSTM_simple['RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

#### Two heads LSTM:

In [None]:
# --- pick the most recent grid-search log ---
custom_params = {
    'Fm': 9,
    'Fs': 5,
    'hidden_size': 128,
    'num_layers': 2,
    'bidirectional': False,
    'dropout': 0.0,
    'static_layers': 2,
    'static_hidden': [128, 64],
    'static_dropout': 0.1,
    'lr': 0.0005,
    'weight_decay': 0.0001,
    'loss_name': 'neutral',
    'loss_spec': None,
    'two_heads': True,
    'head_dropout': 0.0
}

# Build dataloaders:
ds_train_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_train)

ds_test_copy = mbm.data_processing.MBSequenceDataset._clone_untransformed_dataset(
    ds_test)

# --- make train/val loaders and get split indices (scalers fit on TRAIN, applied to whole ds) ---
train_dl, val_dl = ds_train_copy.make_loaders(
    train_idx=train_idx,
    val_idx=val_idx,
    batch_size_train=64,
    batch_size_val=128,
    seed=cfg.seed,
    fit_and_transform=
    True,  # fit scalers on TRAIN and transform Xm/Xs/y in-place
    shuffle_train=True,
    use_weighted_sampler=True  # use weighted sampler for training
)

# --- build test dataset and loader (reuses train scalers) ---
test_dl = mbm.data_processing.MBSequenceDataset.make_test_loader(
    ds_test_copy, ds_train_copy, batch_size=128, seed=cfg.seed)

# --- build model, resolve loss, train, reload best ---
current_date = datetime.now().strftime("%Y-%m-%d")
model_filename = f"models/lstm_model_{current_date}_two_heads.pt"
model_filename = f"models/lstm_model_2025-09-19_two_heads.pt"
model = mbm.models.LSTM_MB.build_model_from_params(cfg, custom_params, device)
loss_fn = mbm.models.LSTM_MB.resolve_loss_fn(custom_params)

# Load the best weights
best_state = torch.load(model_filename, map_location=device)
model.load_state_dict(best_state)

# Evaluate on test
test_metrics_LSTM_2heads, grouped_ids_LSTM_2heads = model.evaluate_with_preds(
    device, test_dl, ds_test_copy)
test_rmse_a, test_rmse_w = test_metrics_LSTM_2heads[
    'RMSE_annual'], test_metrics_LSTM_2heads['RMSE_winter']

print('Test RMSE annual: {:.3f} | winter: {:.3f}'.format(
    test_rmse_a, test_rmse_w))

### MLP:

In [None]:
dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# --- features ---
features_topo = ['ELEVATION_DIFFERENCE', 'pcsr'] + list(vois_topographical)
feature_columns = features_topo + list(vois_climate)
cfg.setFeatures(feature_columns)

# keep only features + any required non-feature fields your code needs later
all_columns = feature_columns + cfg.fieldsNotFeatures

# --- test subset (no blanket dropna here; keep what your model expects) ---
df_X_test_subset_MLP = test_set['df_X'][all_columns].copy()
y_test = test_set['y']

# --- load saved model (no callbacks/scheduler/early stop needed) ---
from pathlib import Path
import pickle

model_filename = "nn_model_2025-09-22.pt"  # adjust if needed
params_filename = "nn_params_2025-09-22.pkl"  # adjust if needed

with open(f"models/{params_filename}", "rb") as f:
    params = pickle.load(f)

nInp = len(feature_columns)

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    # the rest is irrelevant for inference but harmless if present:
    'batch_size': params.get('batch_size', 128),
    'verbose': 0,
}

loaded_MLP = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **args,
    device='cpu',
).to('cpu')

# --- evaluate on test ---
grouped_ids_NN, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    loaded_MLP, df_X_test_subset_MLP, y_test, cfg, months_head_pad,
    months_tail_pad)

print("Test scores:", scores_NN)

## Compare models:

### Scatter on test glaciers:

In [None]:
fig = plt.figure(figsize=(18, 8))

ax1 = plt.subplot(1, 3, 1)
ax1.set_title('LSTM simple', fontsize=20)
predVSTruth(ax1,
            grouped_ids_LSTM_simple,
            hue='PERIOD',
            palette=[color_dark_blue, color_pink])

legend = "\n".join((
    (r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
     (test_metrics_LSTM_simple["RMSE_annual"],
      test_metrics_LSTM_simple["RMSE_winter"])),
    (r"$\mathrm{R^2_a}=%.3f$, $\mathrm{R^2_w}=%.3f$" %
     (test_metrics_LSTM_simple["R2_annual"],
      test_metrics_LSTM_simple["R2_winter"])),
    r"$\mathrm{B_a}=%.3f$, $\mathrm{B_w}=%.3f$" %
    (test_metrics_LSTM_simple["Bias_annual"],
     test_metrics_LSTM_simple["Bias_winter"]),
))
ax1.text(0.03,
         0.96,
         legend,
         transform=ax1.transAxes,
         verticalalignment="top",
         fontsize=18,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0))

ax2 = plt.subplot(1, 3, 2)
ax2.set_title('LSTM two heads', fontsize=20)
predVSTruth(ax2,
            grouped_ids_LSTM_2heads,
            hue='PERIOD',
            palette=[color_dark_blue, color_pink])

legend = "\n".join((
    (r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
     (test_metrics_LSTM_2heads["RMSE_annual"],
      test_metrics_LSTM_2heads["RMSE_winter"])),
    (r"$\mathrm{R^2_a}=%.3f$, $\mathrm{R^2_w}=%.3f$" %
     (test_metrics_LSTM_2heads["R2_annual"],
      test_metrics_LSTM_2heads["R2_winter"])),
    r"$\mathrm{B_a}=%.3f$, $\mathrm{B_w}=%.3f$" %
    (test_metrics_LSTM_2heads["Bias_annual"],
     test_metrics_LSTM_2heads["Bias_winter"]),
))
ax2.text(0.03,
         0.96,
         legend,
         transform=ax2.transAxes,
         verticalalignment="top",
         fontsize=18,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0))

ax3 = plt.subplot(1, 3, 3)
scores_annual_NN, scores_winter_NN = compute_seasonal_scores(
    grouped_ids_NN, target_col='target', pred_col='pred')
ax3.set_title('MLP', fontsize=20)
predVSTruth(ax3,
            grouped_ids_NN,
            hue='PERIOD',
            palette=[color_dark_blue, color_pink])

legend_NN = "\n".join((
    (r"$\mathrm{RMSE_a}=%.3f$, $\mathrm{RMSE_w}=%.3f$," %
     (scores_annual_NN["rmse"], scores_winter_NN["rmse"])),
    (r"$\mathrm{R^2_a}=%.3f$, $\mathrm{R^2_w}=%.3f$" %
     (scores_annual_NN["R2"], scores_winter_NN["R2"])),
    r"$\mathrm{B_a}=%.3f$, $\mathrm{B_w}=%.3f$" %
    (scores_annual_NN["Bias"], scores_winter_NN["Bias"]),
))
ax3.text(0.03,
         0.96,
         legend_NN,
         transform=ax3.transAxes,
         verticalalignment="top",
         fontsize=18,
         bbox=dict(boxstyle='round', facecolor='white', alpha=0))

plt.tight_layout()

#### Geodetic MB:

In [None]:
PATH_PREDICTIONS_NN = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_combis/glamos_dems_NN_SEB_full_OGGM')
PATH_PREDICTIONS_LSTM_simple = os.path.join(cfg.dataPath, 'GLAMOS',
                                            'distributed_MB_grids',
                                            'MBM/testing_LSTM/LSTM_simple')
PATH_PREDICTIONS_LSTM_two_heads = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_LSTM/LSTM_two_heads_best')

PATH_PREDICTIONS_LSTM_no_oggm = os.path.join(
    cfg.dataPath, 'GLAMOS', 'distributed_MB_grids',
    'MBM/testing_LSTM/LSTM_no_oggm')

In [None]:
glaciers_in_glamos = os.listdir(PATH_PREDICTIONS_NN)

geodetic_mb = get_geodetic_MB(cfg)

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']


# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)


glacier_list = [
    f for f in list(periods_per_glacier.keys()) if f in glaciers_in_glamos
]
glacier_list = sort_by_area(glacier_list, gl_area)
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

In [None]:
df_all_nn = process_geodetic_mass_balance_comparison(
    glacier_list=os.listdir(PATH_PREDICTIONS_NN),
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_NN,  # or another path if needed
    cfg=cfg)

df_lstm_simple = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=PATH_PREDICTIONS_LSTM_simple,  # or another path if needed
    cfg=cfg)

df_lstm_two_heads = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=
    PATH_PREDICTIONS_LSTM_two_heads,  # or another path if needed
    cfg=cfg)

df_lstm_no_oggm = process_geodetic_mass_balance_comparison(
    glacier_list=glacier_list,
    path_SMB_GLAMOS_csv=cfg.dataPath + path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=TEST_GLACIERS,
    path_predictions=
    PATH_PREDICTIONS_LSTM_two_heads,  # or another path if needed
    cfg=cfg)


In [None]:
# Drop rows where any required columns are NaN
df_all_nn = df_all_nn.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_nn = root_mean_squared_error(df_all_nn["Geodetic MB"],
                                  df_all_nn["MBM MB"])
corr_nn = np.corrcoef(df_all_nn["Geodetic MB"], df_all_nn["MBM MB"])[0, 1]

# Drop rows where any required columns are NaN
df_lstm_simple = df_lstm_simple.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_lstm_simple = root_mean_squared_error(df_lstm_simple["Geodetic MB"],
                                           df_lstm_simple["MBM MB"])
corr_lstm_simple = np.corrcoef(df_lstm_simple["Geodetic MB"],
                               df_lstm_simple["MBM MB"])[0, 1]

# Drop rows where any required columns are NaN
df_lstm_two_heads = df_lstm_two_heads.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_lstm_two_heads = root_mean_squared_error(df_lstm_two_heads["Geodetic MB"],
                                              df_lstm_two_heads["MBM MB"])
corr_lstm_two_heads = np.corrcoef(df_lstm_two_heads["Geodetic MB"],
                                  df_lstm_two_heads["MBM MB"])[0, 1]

# Drop rows where any required columns are NaN
df_lstm_no_oggm = df_lstm_no_oggm.dropna(subset=['Geodetic MB', 'MBM MB'])

# Compute RMSE and Pearson correlation
rmse_lstm_no_oggm = root_mean_squared_error(df_lstm_no_oggm["Geodetic MB"],
                                              df_lstm_no_oggm["MBM MB"])
corr_lstm_no_oggm = np.corrcoef(df_lstm_no_oggm["Geodetic MB"],
                                  df_lstm_no_oggm["MBM MB"])[0, 1]

print('NN MLP: RMSE = {:.3f}, Corr = {:.3f}'.format(rmse_nn, corr_nn))
plot_mbm_vs_geodetic_by_area_bin(df_all_nn,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)
print('LSTM simple: RMSE = {:.3f}, Corr = {:.3f}'.format(
    rmse_lstm_simple, corr_lstm_simple))
plot_mbm_vs_geodetic_by_area_bin(df_lstm_simple,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)
print('LSTM two heads: RMSE = {:.3f}, Corr = {:.3f}'.format(
    rmse_lstm_two_heads, corr_lstm_two_heads))
plot_mbm_vs_geodetic_by_area_bin(df_lstm_two_heads,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)

print('LSTM no_oggm: RMSE = {:.3f}, Corr = {:.3f}'.format(
    rmse_lstm_no_oggm, corr_lstm_no_oggm))
plot_mbm_vs_geodetic_by_area_bin(df_lstm_no_oggm,
                                 bins=[0, 1, 5, 10, 100, np.inf],
                                 labels=['<1', '1-5', '5–10', '>10', '>100'],
                                 max_bins=4)