## Setting Up:

In [None]:
import pandas as pd
import os
import warnings
from tqdm.notebook import tqdm
import re
from calendar import month_abbr
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
from pandas.tseries.offsets import MonthEnd
import hashlib

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.xgb_helpers import *
from scripts.geodata import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig(
    metaData=["RGIId", "POINT_ID", "ID", "GLWD_ID", "N_MONTHS", "MONTHS", "PERIOD", "GLACIER", "YEAR", "POINT_LAT", "POINT_LON"],
    notMetaDataNotFeatures=["POINT_BALANCE"],
)

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# Read glacier ids:
glacier_ids = get_glacier_ids(cfg)

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 'u10', 'v10'
]

vois_topographical = [
    # "aspect", # OGGM
    # "slope", # OGGM
    "aspect_sgi",  # SGI
    "slope_sgi",  # SGI
    "hugonnet_dhdt",  # OGGM
    "consensus_ice_thickness",  # OGGM
    "millan_v",  # OGGM
]

## Read GL data:

In [None]:
data_glamos = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv + 'CH_wgms_dataset_all.csv')

print('Number of glaciers:', len(data_glamos['GLACIER'].unique()))
print('Number of winter and annual samples:', len(data_glamos))
print('Number of annual samples:',
      len(data_glamos[data_glamos.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_glamos[data_glamos.PERIOD == 'winter']))

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

data_glamos.head(2)

### Glaciers with pot. radiadation data:

In [None]:
# Glaciers with data of potential clear sky radiation
# Format to same names as stakes:
glDirect = np.sort([
    re.search(r'xr_direct_(.*?)\.zarr', f).group(1)
    for f in os.listdir(cfg.dataPath + path_pcsr + 'zarr/')
])

restgl = np.sort(Diff(list(glDirect), list(data_glamos.GLACIER.unique())))

print('Glaciers with potential clear sky radiation data:\n', glDirect)
print('Number of glaciers:', len(glDirect))
print('Glaciers without potential clear sky radiation data:\n', restgl)

# Filter out glaciers without data:
data_glamos = data_glamos[data_glamos.GLACIER.isin(glDirect)]

# Look at the data of the ERA5 dataset:
xr.open_dataset(cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc')

## Geodetic MB:

In [None]:
def check_geodetic_grids_present(folder_path, glacier_name,
                                 periods_per_glacier):
    """
    Checks if all years between min_start and max_end are present in the folder.

    Parameters:
    - folder_path: Path to the folder containing the files.
    - glacier_name: Name of the glacier to filter relevant files.

    Returns:
    - A set of missing years (if any) and a boolean indicating if all years are present.
    """
    min_start = [min([p[0] for p in periods_per_glacier[glacier_name]])][0]
    max_end = [max([p[1] for p in periods_per_glacier[glacier_name]])][0]

    # Extract list of files related to the given glacier
    files = [
        f for f in os.listdir(folder_path)
        if f.startswith(f"{glacier_name}_grid_") and f.endswith(".parquet")
    ]

    # Extract available years from filenames
    year_pattern = re.compile(rf"{glacier_name}_grid_(\d{{4}})\.parquet")
    available_years = {
        int(year_pattern.search(f).group(1))
        for f in files if year_pattern.search(f)
    }

    # Expected years
    expected_years = set(range(min_start, max_end + 1))

    # Identify missing years
    missing_years = expected_years - available_years
    missing_years = sorted(list(missing_years))

    all_years_present = len(missing_years) == 0
    return missing_years, all_years_present

### Pre-process geodetic MB:

In [None]:
geodetic_mb = get_geodetic_MB(cfg)

# filter to glaciers with potential clear sky radiation data
geodetic_mb = geodetic_mb[geodetic_mb.glacier_name.isin(glDirect)]

# get years per glacier
years_start_per_gl = geodetic_mb.groupby(
    'glacier_name')['Astart'].unique().apply(list).to_dict()
years_end_per_gl = geodetic_mb.groupby('glacier_name')['Aend'].unique().apply(
    list).to_dict()

periods_per_glacier, _ = build_periods_per_glacier(geodetic_mb)

### Get glacier list:

In [None]:
glacier_list = [f for f in list(periods_per_glacier.keys())]

# Sort glaciers by area
gl_area = get_gl_area(cfg)
gl_area['clariden'] = gl_area['claridenL']

# Sort the lists by area if available in gl_area
def sort_by_area(glacier_list, gl_area):
    return sorted(glacier_list, key=lambda g: gl_area.get(g, 0), reverse=False)

glacier_list = sort_by_area(glacier_list, gl_area)
# print len and list
print('Number of glaciers:', len(glacier_list))
print('Glaciers:', glacier_list)

### Missing data:

In [None]:
path_xr_masked_grids = cfg.dataPath+path_GLAMOS_topo+'/xr_masked_grids/'
existing_files = set(
    os.listdir(path_xr_masked_grids))  # Load file list once for efficiency

for glacier_name in glacier_list:
    print(f'{glacier_name.capitalize()}:')

    min_start = min(p[0] for p in periods_per_glacier[glacier_name])
    max_end = max(p[1] for p in periods_per_glacier[glacier_name])

    print(f'Longest geodetic period: {min_start} - {max_end}')
    print(f'Geodetic periods: {periods_per_glacier[glacier_name]}')

    # Geodetic MB:
    missing_years, all_years_present = check_geodetic_grids_present(
        os.path.join(cfg.dataPath, path_glacier_grid_glamos, glacier_name), glacier_name,
        periods_per_glacier)
    if not all_years_present:
        print(f'Missing DEMS geodetic MB: {missing_years}')

    # Gridded MB:
    print('...')
    GLAMOS_glwmb = get_GLAMOS_glwmb(glacier_name, cfg)
    if GLAMOS_glwmb is None:
        print('-------------------------------')
        continue

    start = max(GLAMOS_glwmb.index.min(), 1951)
    end = GLAMOS_glwmb.index.max()

    print(f'Gridded MB period: {start} - {end}')

    # Check that each year in the range has an xr_masked_grids
    missing_years = [
        year for year in range(start, end + 1)
        if year >= 1951 and f'{glacier_name}_{year}.zarr' not in existing_files
    ]

    if missing_years:
        print(f'Missing DEMS gridded MB: {missing_years}')

    print('-------------------------------')

In [None]:
geodetic_mb[geodetic_mb.glacier_name == 'corvatsch']['SGI-ID']

# One glacier example: Gries

## Stake data:
### Input dataset:

In [None]:
glacier_name = 'gries'

data_gl = data_glamos[data_glamos.GLACIER == glacier_name]

min_start = min(p[0] for p in periods_per_glacier[glacier_name])
max_end = max(p[1] for p in periods_per_glacier[glacier_name])

print(f'Longest geodetic period: {min_start} - {max_end}')
print(f'Geodetic periods: {periods_per_glacier[glacier_name]}')

# Geodetic MB:
missing_years, all_years_present = check_geodetic_grids_present(
    os.path.join(cfg.dataPath, path_glacier_grid_glamos, glacier_name), glacier_name,
    periods_per_glacier)
if not all_years_present:
    print(f'Missing DEMS geodetic MB: {missing_years}')

# Gridded MB:
print('...')
GLAMOS_glwmb = get_GLAMOS_glwmb(glacier_name, cfg)

start = max(GLAMOS_glwmb.index.min(), 1951)
end = GLAMOS_glwmb.index.max()

print(f'Gridded MB period: {start} - {end}')

# Check that each year in the range has an xr_masked_grids
missing_years = [
    year for year in range(start, end + 1)
    if year >= 1951 and f'{glacier_name}_{year}.zarr' not in existing_files
]

if missing_years:
    print(f'Missing DEMS gridded MB: {missing_years}')

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data': cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data': cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
dataloader_gl = process_or_load_data(run_flag=RUN,
                                     data_glamos=data_gl,
                                     paths=paths,
                                     cfg=cfg,
                                     vois_climate=vois_climate,
                                     vois_topographical=vois_topographical,
                                     output_file='CH_wgms_dataset_gries.csv')

data_monthly = dataloader_gl.data

data_monthly['GLWD_ID'] = data_monthly.apply(
    lambda x: mbm.data_processing.utils.get_hash(f"{x.GLACIER}_{x.YEAR}"), axis=1)
data_monthly['GLWD_ID'] = data_monthly['GLWD_ID'].astype(str)

data_seas = transform_df_to_seasonal(data_monthly)
print('Number of seasonal rows', len(data_seas))

dataloader_gl = mbm.dataloader.DataLoader(cfg,
                               data=data_seas,
                               random_seed=cfg.seed,
                               meta_data_columns=cfg.metaData)

data_seas.head(2)

### Blocking on stakes:

In [None]:
# Split on measurements (IDs)
splits, test_set, train_set = get_CV_splits(dataloader_gl,
                                            test_split_on='ID',
                                            random_state=cfg.seed,
                                            test_size=0.1)

# Check that no ID from train set is in test set
assert len(set(train_set['df_X'].ID).intersection(set(
    test_set['df_X'].ID))) == 0

data_train = train_set['df_X']
data_test = test_set['df_X']

# Number of annual versus winter measurements:
print('Train:')
print('Number of winter and annual samples:', len(data_train))
print('Number of annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of winter and annual samples:', len(data_test))
print('Number of annual samples:', len(data_test_annual))
print('Number of winter samples:', len(data_test_winter))

print('Total:')
print('Number of rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

visualiseSplits(test_set['y'], train_set['y'], splits)
visualiseInputs(train_set, test_set, vois_climate)

In [None]:
# Number of measurements per year:
fig, ax = plt.subplots(2, 1, figsize=(15, 10))
data_test.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar', stacked=True, color=[color_dark_blue, color_pink], ax=ax[0])
ax[0].set_title('Number of measurements per year for test set')

# Number of measurements per year:
data_train.groupby(['YEAR', 'PERIOD']).size().unstack().plot(
    kind='bar', stacked=True, color=[color_dark_blue, color_pink], ax=ax[1])
ax[1].set_title('Number of measurements per year for train set')
plt.tight_layout()

## XGBoost:

In [None]:
# Grid search
# For each of the XGBoost parameter, define the grid range
param_grid = {
    'max_depth': [2, 3, 4, 5, 6, 7, 8],
    'n_estimators':
    [50, 100, 200, 300, 400, 500, 600,
     700],  # number of trees (too many = overfitting, too few = underfitting)
    'learning_rate': [0.01, 0.1, 0.15, 0.2, 0.25, 0.3]
}

param_init = {}
param_init['device'] = 'cuda:0'
param_init['tree_method'] = 'hist'
param_init["random_state"] = cfg.seed
param_init["n_jobs"] = cfg.numJobs

### Predictions of custom parameters:

In [None]:
custom_params = {'learning_rate': 0.01, 'max_depth': 6, 'n_estimators': 800}

# Feature columns:
feature_columns = [
    'ELEVATION_DIFFERENCE'
] + list(vois_climate) + list(vois_topographical) + ['pcsr']
all_columns = feature_columns + cfg.fieldsNotFeatures
df_X_train_subset = train_set['df_X'][all_columns]
print('Shape of training dataset:', df_X_train_subset.shape)
print('Shape of testing dataset:', test_set['df_X'][all_columns].shape)
print('Running with features:', feature_columns)

params = {**param_init, **custom_params}
print(params)
custom_model = mbm.models.CustomXGBoostRegressor(cfg, **params)

# Fit on train data:
custom_model.fit(train_set['df_X'][all_columns], train_set['y'])

# Make predictions on test
custom_model = custom_model.set_params(device='cpu')
features_test, metadata_test = custom_model._create_features_metadata(
    test_set['df_X'][all_columns])
y_pred = custom_model.predict(features_test)
print('Shape of the test:', features_test.shape)

# Make predictions aggr to meas ID:
y_pred_agg = custom_model.aggrPredict(metadata_test, features_test)

# Calculate scores
score = custom_model.score(test_set['df_X'][all_columns],
                           test_set['y'])  # negative
print('Overall score:', np.abs(score))

grouped_ids = getDfAggregatePred(test_set, y_pred_agg, all_columns)
PlotPredictions(grouped_ids, y_pred, metadata_test, test_set, custom_model)
plt.suptitle(f'MBM tested on test stakes', fontsize=20)
plt.tight_layout()

In [None]:
FIPlot(custom_model, feature_columns, vois_climate)

In [None]:
dataloader_gl.data.head(3)

## Predictions on geod MB:

#### Create input array:

In [None]:
glacier_name = 'gries'

geodetic_period = periods_per_glacier[glacier_name][0]
geodetic_range = range(geodetic_period[0], geodetic_period[1] + 1)
folder_path = os.path.join(cfg.dataPath, path_glacier_grid_glamos, glacier_name)

# check that parquet files for each year
files = [
    f for f in os.listdir(folder_path)
    if f.startswith(f"{glacier_name}_grid_") and f.endswith(".parquet")
]

# Extract available years from filenames
year_pattern = re.compile(rf"{glacier_name}_grid_(\d{{4}})\.parquet")
available_years = {
    int(year_pattern.search(f).group(1))
    for f in files if year_pattern.search(f)
}

# check that period overlaps with available years
assert (len(set(available_years).intersection(
    set(geodetic_range))) == len(geodetic_range))

# Create geodetic input array for MBM for one glacier:
df_X_geod = create_geodetic_input(cfg,
                                  glacier_name,
                                  periods_per_glacier,
                                  to_seasonal=True)

# Check that each ID has two seasons only
assert (df_X_geod.groupby('ID').count().SEASON.unique() == 2)

print('Shape of the geodetic input array:', df_X_geod.shape)

df_X_geod.head(3)

#### Make predictions:

In [None]:
# Generate annual predictions
pred_annual = custom_model.glacier_wide_pred(df_X_geod[all_columns])

# Calculate mean SMB per year and store in a DataFrame
mean_SMB = pred_annual.groupby('GLWD_ID').agg({
    'pred':
    'mean',
    'YEAR':
    'first',
})
mean_SMB = mean_SMB.sort_values(by='YEAR').reset_index().set_index('YEAR')
mean_SMB

In [None]:
# Calculate the geodetic mb for each range (manually)
geodetic_MB_pred, geodetic_MB_target = [], []
for geodetic_period in periods_per_glacier[glacier_name]:
    geodetic_range = range(geodetic_period[0], geodetic_period[1] + 1)
    geodetic_MB_pred.append(mean_SMB.loc[geodetic_range].pred.mean())

y_target = prepareGeoTargets(geodetic_mb, periods_per_glacier, glacier_name)
score = -(
    (np.array(geodetic_MB_pred) - np.array(y_target))**2).mean()

# Calculate the geodetic mb for each range (with implemented function)
score_2 = custom_model.score_geod(df_X_geod[all_columns],
                                  y_target,
                                  periods=periods_per_glacier[glacier_name])

# Test: should be the same
print('Score:', score)
print('Score 2:', score_2)