## Setting Up:

In [None]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), '../../')) # Add root of repo to import MBM

import pandas as pd
import warnings
from tqdm.notebook import tqdm
import re
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import torch.nn as nn
from skorch.helper import SliceDataset
from datetime import datetime
from skorch.callbacks import EarlyStopping, LRScheduler, Checkpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.nn_helpers import *
from scripts.xgb_helpers import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()
cfg

In [None]:
# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath+path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    "aspect_sgi",
    "slope_sgi",
    "hugonnet_dhdt",
    "consensus_ice_thickness",
    "millan_v",
]

In [None]:
seed_all(cfg.seed)

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()

    # # Try to limit CPU usage of random search
    # torch.set_num_threads(2)  # or 1
    # os.environ["OMP_NUM_THREADS"] = "1"
    # os.environ["MKL_NUM_THREADS"] = "1"
else:
    print("CUDA is NOT available")


## Read GL data:

In [None]:
data_glamos = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv + 'CH_wgms_dataset_all.csv')

# Glaciers with data of potential clear sky radiation
# Format to same names as stakes:
glDirect = np.sort([
    re.search(r'xr_direct_(.*?)\.zarr', f).group(1)
    for f in os.listdir(cfg.dataPath + path_pcsr + 'zarr/')
])

restgl = np.sort(Diff(list(glDirect), list(data_glamos.GLACIER.unique())))

print('Glaciers with potential clear sky radiation data:\n', glDirect)
print('Number of glaciers:', len(glDirect))
print('Glaciers without potential clear sky radiation data:\n', restgl)

# Filter out glaciers without data:
data_glamos = data_glamos[data_glamos.GLACIER.isin(glDirect)]

print('-------------------')
print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

## Input data:
### Input dataset:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data': cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data': cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_NN.csv')
data_monthly = dataloader_gl.data

## Blocking on glaciers:

In [None]:
test_glaciers = [
    'tortin', 'plattalva', 'sanktanna', 'schwarzberg', 'hohlaub', 'pizol',
    'corvatsch', 'tsanfleuron', 'forno'
]

# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in test_glaciers if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in test_glaciers]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(test_glaciers)]
print('Size of monthly test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of monthly train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('-------------\nTrain:')
print('Number of monthly winter and annual samples:', len(data_train))
print('Number of monthly annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of monthly winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of monthly winter and annual samples:', len(data_test))
print('Number of monthly annual samples:', len(data_test_annual))
print('Number of monthly winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

# same for original data:
print('-------------\nIn annual format:')
print('Number of annual train rows:',
      len(data_glamos[data_glamos.GLACIER.isin(train_glaciers)]))
print('Number of annual test rows:',
      len(data_glamos[data_glamos.GLACIER.isin(test_glaciers)]))


In [None]:
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='GLACIER',
                                            test_splits=test_glaciers,
                                            random_state=cfg.seed)

print('Test glaciers: ({}) {}'.format(len(test_set['splits_vals']),
                                      test_set['splits_vals']))
test_perc = (len(test_set['df_X']) / len(train_set['df_X'])) * 100
print('Percentage of test size: {:.2f}%'.format(test_perc))
print('Size of test set:', len(test_set['df_X']))
print('Train glaciers: ({}) {}'.format(len(train_set['splits_vals']),
                                       train_set['splits_vals']))
print('Size of train set:', len(train_set['df_X']))

In [None]:
# Validation and train split:
data_train = train_set['df_X']
data_train['y'] = train_set['y']
dataloader = mbm.dataloader.DataLoader(cfg, data=data_train)

train_itr, val_itr = dataloader.set_train_test_split(test_size=0.2)

# Get all indices of the training and valing dataset at once from the iterators. Once called, the iterators are empty.
train_indices, val_indices = list(train_itr), list(val_itr)

df_X_train = data_train.iloc[train_indices]
y_train = df_X_train['POINT_BALANCE'].values

# Get val set
df_X_val = data_train.iloc[val_indices]
y_val = df_X_val['POINT_BALANCE'].values

## Neural Network:

In [None]:
# # Lines below don't give the features in the right order and this results in an inappropriate normalization
# feature_columns = [
#     'ELEVATION_DIFFERENCE'
# ] + list(vois_climate) + list(vois_topographical) + ['pcsr']
# print("feature_columns=",feature_columns)

# Remove columns that are metadata or neither used in metadata or features
feature_columns = list(data_train.columns.difference(cfg.metaData).drop(cfg.notMetaDataNotFeatures).drop('y'))

all_columns = feature_columns + cfg.fieldsNotFeatures
df_X_train_subset = df_X_train[all_columns]
df_X_val_subset = df_X_val[all_columns]
print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of validation dataset:', df_X_val_subset.shape)
print('Shape of testing dataset:', test_set['df_X'][all_columns].shape)
print('Running with features:', feature_columns)

### Initialise network:

In [None]:
param_init = {}
param_init['device'] = 'cuda:0'

nInp = len(feature_columns)
cfg.setFeatures(feature_columns)
network = nn.Sequential(
    nn.Linear(nInp, 12),
    nn.ReLU(),
    nn.Linear(12, 4),
    nn.ReLU(),
    nn.Linear(4, 1),
)

early_stop = EarlyStopping(
    monitor='valid_loss',
    patience=10,
    threshold=1e-4,  # Optional: stop only when improvement is very small
)

checkpoint = Checkpoint(
    f_params='best_model.pt',  # Save model weights
    monitor='train_loss',  # What to monitor
    load_best=True  # Load best weights after training
)

lr_scheduler_cb = LRScheduler(
    policy=ReduceLROnPlateau,
    monitor='valid_loss',  # or 'valid_loss' if you're using validation
    mode='min',
    factor=0.5,
    patience=5,
    threshold=0.01,
    threshold_mode='rel',
    verbose=True)

dataset = dataset_val = None # Initialized hereafter
def my_train_split(ds, y=None, **fit_params):
    return dataset, dataset_val

custom_nn = mbm.models.CustomNeuralNetRegressor(
    cfg,
    module=network,
    nbFeatures=nInp,
    train_split=my_train_split,
    batch_size=256,
    verbose=1,
    iterator_train__shuffle=True,
    lr=0.001,
    max_epochs=50,
    optimizer=torch.optim.Adam,
    callbacks=[
        ('early_stop', early_stop),
        ('lr_scheduler', lr_scheduler_cb),
    ],
    **param_init)

### Create datasets:

In [None]:
features, metadata = custom_nn._create_features_metadata(
    df_X_train_subset[all_columns])
features_val, metadata_val = custom_nn._create_features_metadata(
    df_X_val_subset[all_columns])

# Define the dataset for the NN
dataset = mbm.data_processing.AggregatedDataset(
    cfg,
    features=features,
    metadata=metadata,
    targets=y_train
)
# splits = dataset.mapSplitsToDataset(splits)
dataset = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset, idx=0), SliceDataset(dataset, idx=1))
print("train:", dataset.X.shape, dataset.y.shape)

dataset_val = mbm.data_processing.AggregatedDataset(
    cfg,
    features=features_val,
    metadata=metadata_val,
    targets=y_val
)
dataset_val = mbm.data_processing.SliceDatasetBinding(SliceDataset(dataset_val, idx=0), SliceDataset(dataset_val, idx=1))
print("validation:", dataset_val.X.shape, dataset_val.y.shape)

train_idx, val_idx = splits[0]
print("train_idx, val_idx =", train_idx, val_idx)

### Train custom model:

In [None]:
TRAIN = True

if TRAIN:
    custom_nn.fit(dataset.X, dataset.y) # The dataset provided in fit is not used as the datasets are overwritten in the provided train_split function

    # Generate filename with current date
    current_date = datetime.now().strftime("%Y-%m-%d")
    model_filename = f"nn_model_{current_date}.pkl"

    # After Training: Best weights are already loaded
    # Save the model
    custom_nn.save_model(model_filename)

    plot_training_history(custom_nn, skip_first_n=5)
else:
    # Load model and set to CPU
    model_filename = "nn_model_2025-05-23.pkl"  # Replace with actual date if needed
    loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(
        model_filename)
    loaded_model = loaded_model.set_params(device='cpu')
    loaded_model = loaded_model.to('cpu')

### Load model and make predictions:

In [None]:
# Load model and set to CPU
model_filename = "nn_model_2025-05-23.pkl"  # Replace with actual date if needed
loaded_model = mbm.models.CustomNeuralNetRegressor.load_model(model_filename)
loaded_model = loaded_model.set_params(device='cpu')
loaded_model = loaded_model.to('cpu')

# Create features and metadata
features_test, metadata_test = loaded_model._create_features_metadata(
    test_set['df_X'][all_columns])

# Ensure all tensors are on CPU if they are torch tensors
if hasattr(features_test, 'cpu'):
    features_test = features_test.cpu()
if hasattr(metadata_test, 'cpu'):
    metadata_test = metadata_test.cpu()

# Ensure targets are also on CPU
targets_test = test_set['y']
if hasattr(targets_test, 'cpu'):
    targets_test = targets_test.cpu()

# Create the dataset
dataset_test = mbm.data_processing.AggregatedDataset(
    cfg,
    features=features_test,
    metadata=metadata_test,
    targets=targets_test)

dataset_test = [
    SliceDataset(dataset_test, idx=0),
    SliceDataset(dataset_test, idx=1)
]

# Make predictions aggr to meas ID
y_pred = loaded_model.predict(dataset_test[0])
y_pred_agg = loaded_model.aggrPredict(dataset_test[0])

batchIndex = np.arange(len(y_pred_agg))
y_true = np.array([e for e in dataset_test[1][batchIndex]])

# Calculate scores
score = loaded_model.score(dataset_test[0], dataset_test[1])
mse, rmse, mae, pearson = loaded_model.evalMetrics(y_pred, y_true)

# Aggregate predictions
id = dataset_test[0].dataset.indexToId(batchIndex)
data = {
    'target': [e[0] for e in dataset_test[1]],
    'ID': id,
    'pred': y_pred_agg
}
grouped_ids = pd.DataFrame(data)

# Add period
periods_per_ids = test_set['df_X'][all_columns].groupby(
        'ID')['PERIOD'].first()
grouped_ids = grouped_ids.merge(periods_per_ids, on = 'ID')

# Add glacier name
glacier_per_ids = test_set['df_X'][all_columns].groupby(
        'ID')['GLACIER'].first()
grouped_ids = grouped_ids.merge(glacier_per_ids, on = 'ID')

# Add YEAR
years_per_ids = test_set['df_X'][all_columns].groupby(
        'ID')['YEAR'].first()
grouped_ids = grouped_ids.merge(years_per_ids, on = 'ID')

grouped_ids

In [None]:
PlotPredictions_NN(grouped_ids)

In [None]:
predVSTruth_all(grouped_ids, mae, rmse, title='NN on test')

In [None]:
# Aggregate predictions to annual or winter:
PlotIndividualGlacierPredVsTruth(grouped_ids, figsize=(20, 15))

### Grid search:

In [None]:
RUN_GRIDSEARCH = False
if RUN_GRIDSEARCH:
    # GridSearch
    # custom_nn.gridsearch(parameters=parameters, splits=splits, dataset=dataset, num_jobs=-1)

    # RandomisedSearch, with n_iter the number of parameter settings that are sampled. Trade-off between goodness of the solution
    # versus runtime.
    param_grid = {'lr': [0.001, 0.01], 'max_epochs': [1000, 2000]}
    custom_nn.randomsearch(
        parameters=param_grid,
        n_iter=20,
        splits=splits,
        dataset=dataset,
    )
    best_params = params = custom_nn.param_search.best_params_
    best_estimator = custom_nn.param_search.best_estimator_
    print("Best parameters:\n", best_params)
    print("Best score:\n", custom_nn.param_search.best_score_)

    # Save the model
    custom_nn.save_model('gs_model.pkl')

    # Create a folder to save figures (optional)
    save_dir = "figures"
    os.makedirs(save_dir, exist_ok=True)

    # Plot training and test scores
    plt.figure(figsize=(8, 5))
    plt.plot(custom_nn.param_search.cv_results_['mean_train_score'],
             label='Train score')
    plt.plot(custom_nn.param_search.cv_results_['mean_test_score'],
             label='Test score')
    plt.xlabel('Hyperparameter Set Index')
    plt.ylabel('Score')
    plt.title('RandomizedSearchCV Results')
    plt.legend()
    plt.grid(True)

    # Save the figure
    plt.savefig(os.path.join(save_dir, "param_search_scores.png"),
                dpi=300,
                bbox_inches='tight')
    plt.close()  # closes the plot to avoid display in notebooks/scripts