# Glacier grids from RGI:

Creates monthly grid files for the MBM to make PMB predictions over the whole glacier grid. The files come from the RGI grid with OGGM topography. Computing takes a long time because of the conversion to monthly format.
## Setting up:

In [None]:
import pandas as pd
import os
import warnings
from tqdm.notebook import tqdm
import re
import massbalancemachine as mbm
import geopandas as gpd
import matplotlib.pyplot as plt
from cmcrameri import cm
from oggm import utils, workflow
from oggm import cfg as oggmCfg
import geopandas as gpd
import geopandas as gpd
import traceback
import salem
import oggm
import pickle
from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from skorch.helper import SliceDataset
from cartopy import crs as ccrs, feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from collections import Counter

# scripts
from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.geodata import *
from scripts.xgb_helpers import *
from scripts.config_CH import *
from scripts.NN_networks import *
from scripts.nn_helpers import *
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
free_up_cuda()  # in case no memory

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)

# Climate columns
vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]
# Topographical columns
vois_topographical = [
    "aspect",
    "slope",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
    "topo",
]

glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)


In [None]:
gdirs, rgidf = initialize_oggm_glacier_directories(
    cfg,
    rgi_region="11",
    rgi_version="6",
    base_url=
    "https://cluster.klima.uni-bremen.de/~oggm/gdirs/oggm_v1.6/L3-L5_files/2023.1/elev_bands/W5E5_w_data/",
    log_level='WARNING',
    task_list=None,
)
# Save OGGM xr for all needed glaciers in RGI region 11.6:
export_oggm_grids(cfg, gdirs)

In [None]:
# RGI Ids:
# Read glacier ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)
rgi_df.loc['rhone']

## Create RGI grids for all glaciers:

In [None]:
path_RGIs = cfg.dataPath + path_OGGM + 'xr_grids/'
glaciers = os.listdir(path_RGIs)

print(f"Found {len(glaciers)} glaciers in RGI region 11.6")

# Open an example
# rgi_gl = gdirs[0].rgi_id
rgi_gl = 'RGI60-11.01238'

ds = xr.open_dataset(path_RGIs + rgi_gl + '.zarr')
glacier_mask = np.where(ds['glacier_mask'].values == 0, np.nan,
                        ds['glacier_mask'].values)

# Create glacier mask
ds = ds.assign(masked_slope=glacier_mask * ds['slope'])
ds = ds.assign(masked_elev=glacier_mask * ds['topo'])
ds = ds.assign(masked_aspect=glacier_mask * ds['aspect'])
ds = ds.assign(masked_dis=glacier_mask * ds['dis_from_border'])

# Assign other variables only if available
if 'hugonnet_dhdt' in ds:
    ds = ds.assign(masked_hug=glacier_mask * ds['hugonnet_dhdt'])
if 'consensus_ice_thickness' in ds:
    ds = ds.assign(masked_cit=glacier_mask * ds['consensus_ice_thickness'])
if 'millan_v' in ds:
    ds = ds.assign(masked_miv=glacier_mask * ds['millan_v'])

glacier_indices = np.where(ds['glacier_mask'].values == 1)

fig, axs = plt.subplots(1, 4, figsize=(16, 8), sharey=True)

ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=False)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=False)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=False)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect OGGM")
axs[1].set_title("Slope OGGM")
axs[2].set_title("DEM OGGM")
axs[3].set_title("Glacier mask OGGM")

In [None]:
def create_masked_glacier(path_RGIs, rgi_gl):
    # Load dataset
    ds = xr.open_dataset(path_RGIs + rgi_gl + '.zarr')

    # Check if 'glacier_mask' exists
    if 'glacier_mask' not in ds:
        raise ValueError(
            f"'glacier_mask' variable not found in dataset {rgi_gl}")

    # Create glacier mask
    glacier_mask = np.where(ds['glacier_mask'].values == 0, np.nan,
                            ds['glacier_mask'].values)

    # Apply mask to core variables
    ds = ds.assign(masked_slope=glacier_mask * ds['slope'])
    ds = ds.assign(masked_elev=glacier_mask * ds['topo'])
    ds = ds.assign(masked_aspect=glacier_mask * ds['aspect'])
    ds = ds.assign(masked_dis=glacier_mask * ds['dis_from_border'])

    # Apply mask to optional variables if present
    if 'hugonnet_dhdt' in ds:
        ds = ds.assign(masked_hug=glacier_mask * ds['hugonnet_dhdt'])
    if 'consensus_ice_thickness' in ds:
        ds = ds.assign(masked_cit=glacier_mask * ds['consensus_ice_thickness'])
    if 'millan_v' in ds:
        ds = ds.assign(masked_miv=glacier_mask * ds['millan_v'])

    # Indices where glacier_mask == 1
    glacier_indices = np.where(ds['glacier_mask'].values == 1)

    return ds, glacier_indices

### Create masked grids:

In [None]:
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS/topo/RGI_v6_11/',
                             'xr_masked_grids/')
RUN = False
if RUN:
    emptyfolder(path_xr_grids)

    for gdir in tqdm(gdirs):
        rgi_gl = gdir.rgi_id

        try:
            # Create masked glacier dataset
            ds, glacier_indices = create_masked_glacier(path_RGIs, rgi_gl)
        except ValueError as e:
            print(f"Skipping {rgi_gl}: {e}")
            continue  # Skip to next glacier

        dx_m, dy_m = get_res_from_projected(ds)

        # Coarsen to 50 m resolution if needed
        if 20 < dx_m < 50:
            ds = coarsenDS_mercator(ds, target_res_m=50)
            dx_m, dy_m = get_res_from_projected(ds)
        else:
            ds = ds

        # Change coordinates to Lat/Lon projection
        original_proj = ds.pyproj_srs
        ds = ds.rio.write_crs(original_proj)
        ds_latlon = ds.rio.reproject("EPSG:4326")
        ds_latlon = ds_latlon.rename({'x': 'lon', 'y': 'lat'})

        # Save xarray dataset
        save_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
        ds_latlon.to_zarr(save_path)

# open example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

rgi_gl_rhone = gdir_rhone.rgi_id
ds = xr.open_dataset(path_xr_grids + rgi_gl_rhone + '.zarr')
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=True)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=True)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=True)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect")
axs[1].set_title("Slope")
axs[2].set_title("DEM")
axs[3].set_title("Glacier mask")
plt.tight_layout()

In [None]:
# open example
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.00878':
        gdir_rhone = gdir

rgi_gl_rhone = gdir_rhone.rgi_id
ds = xr.open_dataset(path_xr_grids + rgi_gl_rhone + '.zarr')
fig, axs = plt.subplots(1, 4, figsize=(15, 6))
ds.masked_aspect.plot(ax=axs[0], cmap='twilight_shifted', add_colorbar=True)
ds.masked_slope.plot(ax=axs[1], cmap='cividis', add_colorbar=True)
ds.masked_elev.plot(ax=axs[2], cmap='terrain', add_colorbar=True)
ds.glacier_mask.plot(ax=axs[3], cmap='binary', add_colorbar=False)

axs[0].set_title("Aspect")
axs[1].set_title("Slope")
axs[2].set_title("DEM")
axs[3].set_title("Glacier mask")
plt.tight_layout()

### Create monthly dataframes:

In [None]:
RUN = False
path_rgi_alps = os.path.join(cfg.dataPath,
                             'GLAMOS/topo/gridded_topo_inputs/RGI_v6_11/')

if RUN:
    years = range(2000, 2024)

    #os.makedirs(path_rgi_alps, exist_ok=True)
    #emptyfolder(path_rgi_alps)

    valid_rgis = [
        f.replace('.zarr', '') for f in os.listdir(path_xr_grids)
        if f.endswith('.zarr')
    ]
    
    processed_rgis = os.listdir(path_rgi_alps)
    rest_rgis = list(set(valid_rgis) - set(processed_rgis))
    print(f"Number of glaciers to process: {len(rest_rgis)}")

    for gdir in tqdm(gdirs, desc="Processing glaciers"):
    # for gdir in [gdir_rhone]:  # For testing, only process one glacier
        rgi_gl = gdir.rgi_id

        if rgi_gl not in valid_rgis:
            print(f"Skipping {rgi_gl}: not found in valid RGI glaciers")
            continue
        if rgi_gl in processed_rgis:
            continue
        try:
            file_path = os.path.join(path_xr_grids, f"{rgi_gl}.zarr")
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Missing file: {file_path}")

            try:
                ds = xr.open_zarr(file_path, consolidated=True)
            except Exception:
                ds = xr.open_zarr(file_path)

            # Create glacier grid
            try:
                df_grid = create_glacier_grid_RGI(ds, years, rgi_gl)
            except Exception as e:
                print(f"Failed creating glacier grid for {rgi_gl}: {e}")
                continue

            df_grid.reset_index(drop=True, inplace=True)

            # Add GLWD_ID
            df_grid['GLWD_ID'] = [
                mbm.data_processing.utils.get_hash(f"{r}_{y}") for r, y in zip(
                    df_grid['RGIId'].astype(str), df_grid['YEAR'].astype(str))
            ]
            df_grid['GLWD_ID'] = df_grid['GLWD_ID'].astype(str)
            df_grid['GLACIER'] = df_grid['RGIId']

            # Prepare output folder
            folder_path = os.path.join(path_rgi_alps, rgi_gl)
            os.makedirs(folder_path, exist_ok=True)

            # Process each year
            for year in years:
                try:
                    df_grid_y = df_grid[df_grid.YEAR == year].copy()
                    if df_grid_y.empty:
                        continue

                    # Wrap Dataset creation & climate feature extraction
                    try:
                        dataset_grid_yearly = mbm.data_processing.Dataset(
                            cfg=cfg,
                            data=df_grid_y,
                            region_name='CH',
                            data_path=os.path.join(cfg.dataPath,
                                                   path_PMB_GLAMOS_csv))

                        era5_climate_data = os.path.join(
                            cfg.dataPath, path_ERA5_raw,
                            'era5_monthly_averaged_data_Alps.nc')
                        geopotential_data = os.path.join(
                            cfg.dataPath, path_ERA5_raw,
                            'era5_geopotential_pressure_Alps.nc')

                        dataset_grid_yearly.get_climate_features(
                            climate_data=era5_climate_data,
                            geopotential_data=geopotential_data,
                            change_units=True,
                            smoothing_vois={
                                'vois_climate': vois_climate,
                                'vois_other': ['ALTITUDE_CLIMATE']
                            })
                    except Exception as e:
                        print(
                            f"Failed adding climate features for {rgi_gl}: {e}"
                        )
                        continue

                    vois_topographical_sub = [
                        voi for voi in vois_topographical
                        if voi in df_grid_y.columns
                    ]

                    dataset_grid_yearly.convert_to_monthly(
                        meta_data_columns=cfg.metaData,
                        vois_climate=vois_climate,
                        vois_topographical=vois_topographical_sub)

                    save_path = os.path.join(folder_path,
                                             f"{rgi_gl}_grid_{year}.parquet")
                    dataset_grid_yearly.data.to_parquet(save_path,
                                                        engine="pyarrow",
                                                        compression="snappy")
                    #print(f"Saved: {save_path}")

                except Exception as e:
                    print(f"Failed processing {rgi_gl} for year {year}: {e}")
                    continue

        except Exception as e:
            print(f"Error with glacier {rgi_gl}: {e}")
            continue

In [None]:
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

# Look at one example
# load the dataset
rgi_gl = gdir_rhone.rgi_id

year = 2000
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

year = 2008
df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
print(df['t2m'].unique())

In [None]:
for gdir in gdirs:
    if gdir.rgi_id == 'RGI60-11.01238':
        gdir_rhone = gdir

# Look at one example
# load the dataset
year = 2008
rgi_gl = gdir_rhone.rgi_id

df = pd.read_parquet(
    os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
df = df[df.MONTHS == 'sep']
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
voi = [
    't2m', 'tp', 'ALTITUDE_CLIMATE', 'ELEVATION_DIFFERENCE', 'hugonnet_dhdt',
    'consensus_ice_thickness'
]
axs = axs.flatten()
for i, var in enumerate(voi):
    sns.scatterplot(df,
                    x='POINT_LON',
                    y='POINT_LAT',
                    hue=var,
                    s=5,
                    alpha=0.5,
                    palette='twilight_shifted',
                    ax=axs[i])

### Location of all glaciers:

In [None]:
rgi_ids = os.listdir(path_rgi_alps)
pos_gl = []
for rgi_gl in tqdm(rgi_ids):
    df = pd.read_parquet(
        os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))
    pos_gl.append((df.POINT_LAT.mean(), df.POINT_LON.mean()))
df_pos_all = pd.DataFrame(pos_gl, columns=['lat', 'lon'])
df_pos_all['rgi_id'] = rgi_ids

In [None]:
print('Number of glaciers in RGI region 11.6:', len(df_pos_all))

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

latN, latS = 48, 44
lonW, lonE = 4, 14
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

g = sns.scatterplot(
    data=df_pos_all,
    x='lon',
    y='lat',
    alpha=0.6,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

glacier_outline_rgi.plot(ax=ax2, transform=projPC, color='black')

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

## Train NN model:

### Input data:

In [None]:
# Climate columns
vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str',
]
# Topographical columns
vois_topographical = [
    "aspect",
    "slope",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
data_glamos = getStakesData(cfg)

# drop taelliboden if in there
if 'taelliboden' in data_glamos['GLACIER'].unique():
    data_glamos = data_glamos[data_glamos['GLACIER'] != 'taelliboden']

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_central_alps.csv')
data_monthly = dataloader_gl.data
months_head_pad, months_tail_pad = mbm.data_processing.utils.build_head_tail_pads_from_monthly_df(data_monthly)

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"),
    axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)


In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(TEST_GLACIERS)]))

print('---------------\n CV splits:')
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=TEST_GLACIERS,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))


In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

### Initialise network:

In [None]:
features_topo = [
    'ELEVATION_DIFFERENCE',
    # 'pcsr',
] + list(vois_topographical)

feature_columns = features_topo + list(vois_climate)

cfg.setFeatures(feature_columns)

all_columns = feature_columns + cfg.fieldsNotFeatures

# Because CH has some extra columns, we need to cut those
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
df_X_test_subset = test_set['df_X'][all_columns]

print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', df_X_test_subset.shape)
print('Running with features:', feature_columns)

assert all(train_set['df_X'].POINT_BALANCE == train_set['y'])

# create the same but for winter only:
df_X_train_subset_winter = df_X_train_subset[df_X_train_subset.PERIOD ==
                                             'winter']
df_X_val_subset_winter = df_X_val_subset[df_X_val_subset.PERIOD == 'winter']
y_train_w = df_X_train_subset_winter['POINT_BALANCE'].values
y_val_w = df_X_val_subset_winter['POINT_BALANCE'].values
print('Shape of training dataset only winter:', df_X_train_subset_winter.shape)
print('Shape of validation dataset only winter:', df_X_val_subset_winter.shape)

In [None]:
early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=15,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

lr_scheduler_cb = LRScheduler(policy=ReduceLROnPlateau,
                              monitor='valid_loss',
                              mode='min',
                              factor=0.5,
                              patience=5,
                              threshold=0.01,
                              threshold_mode='rel',
                              verbose=True)

dataset = dataset_val = None  # Initialized hereafter


def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val


# param_init = {'device': 'cuda:0'}
param_init = {'device': 'cpu'}  # Use CPU for training
nInp = len(feature_columns)
print('Number of input features:', nInp)

params = {
    'lr': 0.001,
    'batch_size': 256,
    'optimizer': torch.optim.AdamW,
    'optimizer__weight_decay': 1e-05,
    'module__hidden_layers': [256, 128, 64, 32],
    'module__dropout': 0.2,
    'module__use_batchnorm': True,
}

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 200,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

custom_nn = mbm.models.CustomNeuralNetRegressor(cfg, **args, **param_init)

### Create datasets:

In [None]:
features, metadata = mbm.data_processing.utils.create_features_metadata(cfg, df_X_train_subset)

features_val, metadata_val = mbm.data_processing.utils.create_features_metadata(
    cfg, df_X_val_subset)

# Define the dataset for the NN
dataset = mbm.data_processing.AggregatedDataset(cfg,
                                                features=features,
                                                metadata=metadata,
                                                months_head_pad=months_head_pad,
                                                months_tail_pad=months_tail_pad,
                                                targets=y_train)
dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0),
                                                  SliceDataset(dataset, idx=1))
print("train:", dataset.X.shape, dataset.y.shape)

dataset_val = mbm.data_processing.AggregatedDataset(cfg,
                                                    features=features_val,
                                                    metadata=metadata_val,
                                                    months_head_pad=months_head_pad,
                                                    months_tail_pad=months_tail_pad,
                                                    targets=y_val)
dataset_val = mbm.data_processing.SliceDatasetBinding(
    SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
print("validation:", dataset_val.X.shape, dataset_val.y.shape)

### Train custom NN model:

In [None]:
TRAIN = False
if TRAIN:
    custom_nn.seed_all()

    print("Training the model...")
    print('Model parameters:')
    for key, value in args.items():
        print(f"{key}: {value}")
    custom_nn.fit(dataset.X, dataset.y)
    # The dataset provided in fit is not used as the datasets are overwritten in the provided train_split function

    # Generate filename with current date
    current_date = datetime.now().strftime("%Y-%m-%d")
    model_filename = f"nn_model_{current_date}_CA"

    plot_training_history(custom_nn.history, skip_first_n=5)

    # After Training: Best weights are already loaded
    # Save the model
    custom_nn.save_model(model_filename)

    # save params dic
    params_filename = f"nn_params_{current_date}_CA.pkl"

    with open(f"models/{params_filename}", "wb") as f:
        pickle.dump(args, f)

### Load trained model:

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-08-26_CA.pt"  # Replace with actual date if needed
# read pickle with params
params_filename = "nn_params_2025-08-26_CA.pkl"  # Replace with actual date if needed
with open(f"models/{params_filename}", "rb") as f:
    custom_params = pickle.load(f)

params = custom_params

args = {
    'module': FlexibleNetwork,
    'nbFeatures': nInp,
    'module__input_dim': nInp,
    'module__dropout': params['module__dropout'],
    'module__hidden_layers': params['module__hidden_layers'],
    'train_split': my_train_split,
    'batch_size': params['batch_size'],
    'verbose': 1,
    'iterator_train__shuffle': True,
    'lr': params['lr'],
    'max_epochs': 300,
    'optimizer': params['optimizer'],
    'optimizer__weight_decay': params['optimizer__weight_decay'],
    'module__use_batchnorm': params['module__use_batchnorm'],
    'callbacks': [
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ]
}

loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
    cfg,
    model_filename,
    **{
        **args,
        **param_init
    },
)
loaded_model = loaded_model.set_params(device='cpu')
loaded_model = loaded_model.to('cpu')

In [None]:
grouped_ids, scores_NN, ids_NN, y_pred_NN = evaluate_model_and_group_predictions(
    loaded_model, df_X_test_subset, test_set['y'], cfg, months_head_pad, months_tail_pad)
scores_annual, scores_winter = compute_seasonal_scores(grouped_ids,
                                                       target_col='target',
                                                       pred_col='pred')
fig = plot_predictions_summary(grouped_ids=grouped_ids,
                               scores_annual=scores_annual,
                               scores_winter=scores_winter,
                               predVSTruth=predVSTruth,
                               plotMeanPred=plotMeanPred,
                               ax_xlim=(-8, 6),
                               ax_ylim=(-8, 6))

### Extrapolate in space to RGI glaciers:

#### Check glaciers with missing topo:

In [None]:
rgi_ids = os.listdir(path_rgi_alps)
year = 2000  # Example year, change as needed

incomplete_rgis = {}
total_area_rgi, area_incomplete_rgi = 0, 0
for rgi_gl in tqdm(rgi_ids):
    df = pd.read_parquet(
        os.path.join(path_rgi_alps, rgi_gl, f"{rgi_gl}_grid_{year}.parquet"))

    total_area_rgi += glacier_outline_rgi[glacier_outline_rgi.RGIId ==
                                          rgi_gl].Area.values[0]

    # check if all vois_topographical in df.columns
    missing_vois = [voi for voi in vois_topographical if voi not in df.columns]
    if len(missing_vois) > 0:
        incomplete_rgis[rgi_gl] = missing_vois
        area_incomplete_rgi += glacier_outline_rgi[glacier_outline_rgi.RGIId ==
                                                   rgi_gl].Area.values[0]

df_pos_all['incomplete_topo'] = df_pos_all['rgi_id'].apply(
    lambda x: x in incomplete_rgis)

print('Number of incomplete RGI glaciers:', len(incomplete_rgis))

In [None]:
print(
    f"Total area of glaciers with missing topo: {area_incomplete_rgi:.2f} km²")
print(
    f"Total area of all glaciers in RGI region 11.6: {total_area_rgi:.2f} km²")
# percentage
perc_incomplete = (area_incomplete_rgi / total_area_rgi) * 100
print(
    f"Percentage of glaciers with missing topo: {perc_incomplete:.2f}% ({len(incomplete_rgis)} glaciers)"
)

In [None]:
# Step 1: Flatten the lists into a single list
all_vars = [var for sublist in incomplete_rgis.values() for var in sublist]

# Step 2: Count occurrences
var_counts = Counter(all_vars)

# Step 3: Convert to DataFrame for easier plotting (optional)
var_counts_df = pd.DataFrame.from_dict(var_counts,
                                       orient='index',
                                       columns=['count'
                                                ]).sort_values(by='count',
                                                               ascending=False)

# Step 4: Plot
plt.figure(figsize=(5, 5))
var_counts_df.plot(kind='bar', legend=False)
plt.title('Frequency of topo variables missing in RGI glaciers')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

fig = plt.figure(figsize=(18, 10))

latN, latS = 48, 44
lonW, lonE = 4, 14
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

g = sns.scatterplot(
    data=df_pos_all,
    x='lon',
    y='lat',
    alpha=0.6,
    hue='incomplete_topo',
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

# glacier_outline_rgi.plot(ax=ax2, transform=projPC, color='black')

gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

In [None]:
area_gl = []
for rgi_gl in incomplete_rgis:
    area_gl.append(glacier_outline_rgi[glacier_outline_rgi.RGIId ==
                                       rgi_gl].Area.values[0])

plt.figure(figsize=(10, 5))
plt.hist(area_gl, bins=30,
         edgecolor='black')  # You can adjust 'bins' as needed
plt.xlabel('Glacier Area (km²)')
plt.ylabel('Frequency')
plt.title('Distribution of glacier areas with missing topo')
plt.grid(True)
plt.tight_layout()
plt.show()

### Extrapolate (only on full glaciers):

In [None]:
full_rgis = [rgi_gl for rgi_gl in rgi_ids if rgi_gl not in incomplete_rgis]
print('Number of glaciers with complete topo data:', len(full_rgis))

In [None]:
# Define paths
path_save_glw = cfg.dataPath + '/GLAMOS/distributed_MB_grids/MBM/central_europe/'
path_xr_grids = os.path.join(cfg.dataPath, 'GLAMOS/topo/RGI_v6_11/',
                             'xr_masked_grids/')
RUN = False
if RUN:
    emptyfolder(path_save_glw)

    # Define output CSV path after clearing
    output_file = os.path.join(path_save_glw, "glacier_mean_MB.csv")

    # Start with header
    with open(output_file, 'w') as f:
        f.write("Index,RGIId,Year,Mean_MB\n")

    index_counter = 0  # Initialize index

    output_df = pd.read_csv(output_file)
    missing_rgis = [
        rgi_gl for rgi_gl in full_rgis
        if rgi_gl not in output_df.RGIId.unique()
    ]

    for rgi_gl in tqdm(missing_rgis):
        glacier_path = os.path.join(path_rgi_alps, rgi_gl)

        if not os.path.exists(glacier_path):
            print(f"Folder not found for {rgi_gl}, skipping...")
            continue

        glacier_files = sorted(
            [f for f in os.listdir(glacier_path) if rgi_gl in f])
        years = [
            int(file_name.split('_')[2].split('.')[0])
            for file_name in glacier_files
        ]

        for year in years:
            file_name = f"{rgi_gl}_grid_{year}.parquet"
            file_path = os.path.join(glacier_path, file_name)

            try:
                # Load parquet file
                df_grid_monthly = pd.read_parquet(file_path)
                df_grid_monthly.drop_duplicates(inplace=True)

                df_grid_monthly = df_grid_monthly[[
                    col for col in all_columns
                    if col in df_grid_monthly.columns
                ]]
                df_grid_monthly = df_grid_monthly.dropna()

                path_glacier_dem = os.path.join(cfg.dataPath, path_xr_grids,
                                                f"{rgi_gl}.zarr")

                # Predict annual MB
                pred_annual, df_pred_months_annual = loaded_model.glacier_wide_pred(
                    df_grid_monthly[all_columns], type_pred='annual')
                pred_y_annual = pred_annual.drop(columns=['YEAR'],
                                                 errors='ignore')

                # Compute mean MB
                mean_MB = pred_y_annual.pred.mean()

                # Write row to CSV
                with open(output_file, 'a') as f:
                    f.write(f"{index_counter},{rgi_gl},{year},{mean_MB:.4f}\n")

                index_counter += 1

            except Exception as e:
                print(f"Error processing {rgi_gl} {year}: {e}")

In [None]:
# open output file
output_file = os.path.join(path_save_glw, "glacier_mean_MB.csv")
output_file = pd.read_csv(output_file)
output_file

In [None]:
missing_rgis = [
    rgi_gl for rgi_gl in full_rgis if rgi_gl not in output_file.RGIId.unique()
]
missing_rgis

In [None]:
len(full_rgis) * len(range(2000, 2024))  # Total number of rows expected

### Mean predicted MB:

In [None]:
# open output file
output_df = os.path.join(path_save_glw, "glacier_mean_MB.csv")
output_df = pd.read_csv(output_df)

output_df.groupby('Year').agg({'Mean_MB': 'sum'})