## Set up:

In [None]:
import os, sys
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import matplotlib.pyplot as plt
import numpy as np
from os import listdir
from os.path import join
import pandas as pd
import re
import xarray as xr
import geopandas as gpd
from shapely.geometry import box
from tqdm.notebook import tqdm
from sklearn.metrics import mean_squared_error
import massbalancemachine as mbm
from cmcrameri import cm
from collections import defaultdict
import warnings

from scripts.geodata import *
from scripts.geodata_plots import *
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.xgb_helpers import *
from scripts.config_CH import *

#  Suppress warnings issued by Cartopy when downloading data files
warnings.filterwarnings('ignore')

# reload modules
%load_ext autoreload
%autoreload 2

# config file
cfg = mbm.SwitzerlandConfig()

# Change this to the path where the distributed model data is stored
PATH_PREDICTIONS = cfg.dataPath + path_distributed_MB_glamos + 'MBM/glamos_dems_corr/'

GLACIER_LIST = ['aletsch', 'silvretta', 'rhone']

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

cmap = cm.devon
color_palette_glaciers = sns.color_palette(get_cmap_hex(cmap, 15))

colors = get_cmap_hex(cm.batlow, 2)
color_xgb = colors[0]
color_tim = '#c51b7d'

vois_climate = [
    't2m_corr', 'tp_corr', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]
vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]
# Feature columns:
feature_columns = [
    'ELEVATION_DIFFERENCE'
] + list(vois_climate) + list(vois_topographical) + ['pcsr']
all_columns = feature_columns + cfg.fieldsNotFeatures

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# === Load SGI region geometries ===
SGI_regions = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'sgi_regions.geojson'))

# Clean object columns
SGI_regions[SGI_regions.select_dtypes(include='object').columns] = \
    SGI_regions.select_dtypes(include='object').apply(lambda col: col.str.strip())

SGI_regions = SGI_regions.drop_duplicates().dropna()
SGI_regions = SGI_regions.set_index('pk_sgi_region')


# Final formatting
rgi_df = rgi_df.reset_index().rename(columns={
    'short_name': 'GLACIER'
}).set_index('GLACIER')

geodetic_mb = get_geodetic_MB(cfg)
periods_per_glacier, geoMB_per_glacier = build_periods_per_glacier(geodetic_mb)

# Glaciers with geodetic MB data:
# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']

## Geodetic MB:

#### Individual match:

In [None]:
GLACIER_NAME = 'aletsch'
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

# If GLAMOS data is missing, skip processing
if GLAMOS_glwmb is None:
    print(f"Skipping {GLACIER_NAME}: No GLAMOS data available.")
else:
    # Get all geodetic periods
    periods = periods_per_glacier.get(GLACIER_NAME, [])
    geoMBs = geoMB_per_glacier.get(GLACIER_NAME, [])

    path_mbm_pred = os.path.join(PATH_PREDICTIONS, GLACIER_NAME)

    # Storage lists
    mbm_mb_mean, glamos_mb_mean, geodetic_mb, target_period = [], [], [], []

    for period in periods:
        mbm_mb, glamos_mb = [], []

        for year in range(period[0], period[1] + 1):
            # Construct file path
            file_path = os.path.join(cfg.dataPath, path_mbm_pred,
                                     f"{GLACIER_NAME}_{year}_annual.zarr")

            # Check if the NetCDF file exists
            if not os.path.exists(file_path):
                print(
                    f"Warning: Missing MBM file for {GLACIER_NAME} ({year}). Skipping..."
                )
                mbm_mb.append(np.nan)
            else:
                ds = xr.open_dataset(file_path)
                # Compute mean glacier-wide MB
                mbm_mb.append(ds["pred_masked"].mean().values)

            # Get GLAMOS Balance for the year, or return NaN if missing
            glamos_mb.append(GLAMOS_glwmb["GLAMOS Balance"].get(year, np.nan))

        # Store mean values, ignoring NaNs
        mbm_mb_mean.append(np.nanmean(mbm_mb))
        glamos_mb_mean.append(np.nanmean(glamos_mb))
        geodetic_mb.append(geoMBs[periods.index(period)])
        target_period.append(period)

    # store all in a dataframe
    df = pd.DataFrame({
        'mbm_mb_mean': mbm_mb_mean,
        'glamos_mb_mean': glamos_mb_mean,
        'geodetic_mb': geodetic_mb,
        'target_period': target_period,
        'end_year': [period[1] for period in target_period],
        'start_year': [period[0] for period in target_period],
    })

fig = plot_geodetic_MB(df, GLACIER_NAME, color_xgb, color_tim)

#### All glaciers in list:

In [None]:
test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]
df_all = process_geodetic_mass_balance_comparison(
    glacier_list=GLACIER_LIST,
    path_SMB_GLAMOS_csv=cfg.dataPath+path_SMB_GLAMOS_csv,
    periods_per_glacier=periods_per_glacier,
    geoMB_per_glacier=geoMB_per_glacier,
    gl_area=gl_area,
    test_glaciers=test_glaciers,
    path_predictions=PATH_PREDICTIONS,  # or another path if needed
    cfg = cfg
)

# Drop rows where any required columns are NaN
df_all = df_all.dropna(subset=['Geodetic MB', 'MBM MB', 'GLAMOS MB'])

# Compute RMSE and Pearson correlation
rmse_mbm = mean_squared_error(df_all["Geodetic MB"],
                              df_all["MBM MB"],
                              squared=False)
corr_mbm = np.corrcoef(df_all["Geodetic MB"], df_all["MBM MB"])[0, 1]
rmse_glamos = mean_squared_error(df_all["Geodetic MB"],
                                 df_all["GLAMOS MB"],
                                 squared=False)
corr_glamos = np.corrcoef(df_all["Geodetic MB"], df_all["GLAMOS MB"])[0, 1]

# Define figure and axes
fig, axs = plt.subplots(1, 1, figsize=(10, 8), sharex=True)

# Plot MBM MB vs Geodetic MB
plot_scatter(df_all, 'GLACIER', True, axs, "MBM MB",
             rmse_mbm, corr_mbm)

axs.set_title('Geodetic target vs Mass Balance Machine', fontsize=24)
axs.set_ylabel('Predicted mass balance [m w.e.]', fontsize=18)
axs.set_xlabel('Geodetic mass balance [m w.e.]', fontsize=18)

# Adjust legend outside of plot
handles, labels = axs.get_legend_handles_labels()
fig.legend(handles,
           labels,
           bbox_to_anchor=(1.05, 1),
           loc="upper left",
           borderaxespad=0.,
           ncol=2,
           fontsize=20)

plt.tight_layout()
plt.show()

## Glacier wide MB:
Compare glacier-wide mass balance against GLAMOS

In [None]:
# Define glacier name
GLACIER_NAME = "aletsch"

# Load GLAMOS data
GLAMOS_glwmb = get_GLAMOS_glwmb(GLACIER_NAME, cfg)

# Define the path to model predictions
path_results = os.path.join(PATH_PREDICTIONS, GLACIER_NAME)

# Extract available years from NetCDF filenames
years = sorted([
    int(f.split("_")[1]) for f in os.listdir(path_results)
    if f.endswith("_annual.zarr")
])

# Extract model-predicted mass balance
pred_gl = []
for year in years:
    file_path = os.path.join(path_results,
                             f"{GLACIER_NAME}_{year}_annual.zarr")
    if not os.path.exists(file_path):
        print(
            f"Warning: Missing MBM file for {GLACIER_NAME} ({year}). Skipping..."
        )
        pred_gl.append(np.nan)
        continue

    ds = xr.open_dataset(file_path)
    pred_gl.append(ds.pred_masked.mean().item())

# Create DataFrame
MBM_glwmb = pd.DataFrame(pred_gl, index=years, columns=["MBM Balance"])

# Merge with GLAMOS data
MBM_glwmb = MBM_glwmb.join(GLAMOS_glwmb)

# Drop NaN values to avoid plotting errors
MBM_glwmb = MBM_glwmb.dropna()

# Plot the data
fig, ax = plt.subplots(figsize=(12, 6))
MBM_glwmb.plot(ax=ax, marker="o", color=[color_xgb, color_tim])
ax.set_title(f"{GLACIER_NAME.capitalize()} Glacier", fontsize=24)
ax.set_ylabel("Mass Balance [m w.e.]", fontsize=18)
ax.set_xlabel("Year", fontsize=18)
ax.grid(True, linestyle="--", linewidth=0.5)

# increase font size of legends
ax.legend(fontsize=14)
plt.show()


## Maps 2D mass balance:

In [None]:
# Load stake data ONCE instead of for every glacier
stake_file = os.path.join(cfg.dataPath, path_PMB_GLAMOS_csv, "CH_wgms_dataset_all.csv")
df_stakes = pd.read_csv(stake_file)

### Validate against stake data:

In [None]:
GLACIER_NAME = "aletsch"
year = 2022
fig = plot_mass_balance(GLACIER_NAME, year, df_stakes,
                        os.path.join(cfg.dataPath, path_distributed_MB_glamos, 'GLAMOS'),
                        PATH_PREDICTIONS)
plt.show()

In [None]:
# Optional: run for all years for this glacier

# # Define glacier name
# GLACIER_NAME = "aletsch"

# # Get available years from filenames
# years = sorted(
#     map(int, [
#         f.split("_")[1]
#         for f in os.listdir(os.path.join(PATH_PREDICTIONS, GLACIER_NAME))
#         if "_" in f and f.endswith("_annual.zarr")
#     ]))

# # if not exist create path:
# if not os.path.exists(f"figures/dst_mb/{GLACIER_NAME}"):
#     os.makedirs(f"figures/dst_mb/{GLACIER_NAME}")
# # else empty:
# else:
#     emptyfolder(f"figures/dst_mb/{GLACIER_NAME}")

# # Iterate through each year
# for year in tqdm(years):
#     print(f"Processing: {GLACIER_NAME}, Year: {year}")

#     fig = plot_mass_balance(GLACIER_NAME, year, df_stakes,
#                             os.path.join(cfg.dataPath, path_distributed_MB_glamos, 'GLAMOS'),
#                             PATH_PREDICTIONS)

#     # save figure
#     output_path = os.path.join(f"figures/dst_mb/{GLACIER_NAME}",
#                                f"{GLACIER_NAME}_{year}.png")
#     # if fig not none:
#     if fig:
#         fig.savefig(output_path, dpi=300, bbox_inches="tight")
#         plt.close()