## Setting Up:

In [None]:
import pandas as pd
import os
import warnings
from tqdm.notebook import tqdm
import re
from calendar import month_abbr
import matplotlib.pyplot as plt
import seaborn as sns
from cmcrameri import cm
import xarray as xr
import massbalancemachine as mbm
from collections import defaultdict
import logging
import cartopy.io.img_tiles as cimgt
from cartopy import crs as ccrs, feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.enums import Resampling as RResampling
import numpy as np
from skorch.callbacks import EarlyStopping, LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import joypy
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import geopandas as gpd
from matplotlib.patches import Wedge, Patch
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import pickle 

from scripts.helpers import *
from scripts.glamos_preprocess import *
from scripts.plots import *
from scripts.config_CH import *
from scripts.xgb_helpers import *
from scripts.geodata import *
from scripts.NN_networks import *
from scripts.nn_helpers import *
from scripts.geodata_plots import *

warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

cfg = mbm.SwitzerlandConfig()

In [None]:
seed_all(cfg.seed)
print("Using seed:", cfg.seed)

from torch.utils.data import Subset
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, SubsetRandomSampler
import torch.nn as nn

if torch.cuda.is_available():
    print("CUDA is available")
    free_up_cuda()
else:
    print("CUDA is NOT available")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
seed_all(cfg.seed)
free_up_cuda()

# Plot styles:
path_style_sheet = 'scripts/example.mplstyle'
plt.style.use(path_style_sheet)
colors = get_cmap_hex(cm.batlow, 10)
color_dark_blue = colors[0]
color_pink = '#c51b7d'

# RGI Ids:
# Read rgi ids:
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids, sep=',')
rgi_df.rename(columns=lambda x: x.strip(), inplace=True)
rgi_df.sort_values(by='short_name', inplace=True)
rgi_df.set_index('short_name', inplace=True)

vois_climate = [
    't2m', 'tp', 'slhf', 'sshf', 'ssrd', 'fal', 'str', 
]

vois_topographical = [
    "aspect_sgi",  # SGI
    "slope_sgi",  # SGI
    "hugonnet_dhdt",  # OGGM
    "consensus_ice_thickness",  # OGGM
    "millan_v",  # OGGM
]

## Read GL data:

In [None]:
data_glamos = pd.read_csv(cfg.dataPath + path_PMB_GLAMOS_csv +
                          'CH_wgms_dataset_all.csv')

# Capitalize glacier names:
glacierCap = {}
for gl in data_glamos['GLACIER'].unique():
    if isinstance(gl, str):  # Ensure the glacier name is a string
        if gl.lower() == 'claridenu':
            glacierCap[gl] = 'Clariden_U'
        elif gl.lower() == 'claridenl':
            glacierCap[gl] = 'Clariden_L'
        else:
            glacierCap[gl] = gl.capitalize()
    else:
        print(f"Warning: Non-string glacier name encountered: {gl}")

# Cut to glaciers with pcsr: 
glacier_list = ['adler',
 'albigna',
 'aletsch',
 'allalin',
 'basodino',
 'clariden',
 'corbassiere',
 'corvatsch',
 'findelen',
 'forno',
 'gietro',
 'gorner',
 'gries',
 'hohlaub',
 'joeri',
 'limmern',
 'morteratsch',
 'murtel',
 'oberaar',
 'otemma',
 'pizol',
 'plattalva',
 'rhone',
 'sanktanna',
 'schwarzbach',
 'schwarzberg',
 'sexrouge',
 'silvretta',
 'tortin',
 'tsanfleuron']

data_glamos = data_glamos[data_glamos['GLACIER'].isin(glacier_list)]

In [None]:
# Glacier outlines:
glacier_outline_sgi = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'SGI_2016_glaciers_copy.shp'))  # Load the shapefile
glacier_outline_rgi = gpd.read_file(cfg.dataPath + path_rgi_outlines)

# get number of measurements per glacier:
glacier_info = data_glamos.groupby('GLACIER').size().sort_values(
    ascending=False).reset_index()
glacier_info.rename(columns={0: 'Nb. measurements'}, inplace=True)
glacier_info.set_index('GLACIER', inplace=True)

glacier_loc = data_glamos.groupby('GLACIER')[['POINT_LAT', 'POINT_LON']].mean()

glacier_info = glacier_loc.merge(glacier_info, on='GLACIER')

glacier_period = data_glamos.groupby(['GLACIER', 'PERIOD'
                                      ]).size().unstack().fillna(0).astype(int)

glacier_info = glacier_info.merge(glacier_period, on='GLACIER')

glacier_info['Train/Test glacier'] = glacier_info.apply(
    lambda x: 'Test' if x.name in TEST_GLACIERS else 'Train', axis=1)
glacier_info.head(2)

### Assign glaciers to river basin names:

In [None]:
# === Load RGI glacier IDs ===
rgi_df = pd.read_csv(cfg.dataPath + path_glacier_ids)
rgi_df.columns = rgi_df.columns.str.strip()
rgi_df = rgi_df.sort_values(by='short_name').set_index('short_name')

# === Load SGI region geometries ===
SGI_regions = gpd.read_file(
    os.path.join(cfg.dataPath, path_SGI_topo, 'inventory_sgi2016_r2020',
                 'sgi_regions.geojson'))

# Clean object columns
SGI_regions[SGI_regions.select_dtypes(include='object').columns] = \
    SGI_regions.select_dtypes(include='object').apply(lambda col: col.str.strip())

SGI_regions = SGI_regions.drop_duplicates().dropna()
SGI_regions = SGI_regions.set_index('pk_sgi_region')

# === Map to Level 0 river basins ===
catchment_lv0 = {
    'A': 'Rhine',
    'B': 'Rhone',
    'C': 'Po',
    'D': 'Adige',
    'E': 'Danube'
}
rgi_df['rvr_lv0'] = rgi_df['sgi-id'].str[0].map(catchment_lv0)


# === Map to Level 1 river basins using SGI regions ===
def get_river_basin(sgi_id):
    key = sgi_id.split('-')[0]
    if key not in SGI_regions.index:
        return None
    basin = SGI_regions.loc[key, 'river_basin_name']
    if isinstance(basin, pd.Series):
        return basin.dropna().unique()[0] if not basin.dropna().empty else None
    return basin if pd.notna(basin) else None


rgi_df['rvr_lv1'] = rgi_df['sgi-id'].apply(get_river_basin)

# Final formatting
rgi_df = rgi_df.reset_index().rename(columns={
    'short_name': 'GLACIER'
}).set_index('GLACIER')
rgi_df.head()

In [None]:
glacier_info = glacier_info.merge(rgi_df[['rvr_lv0', 'rvr_lv1']],
                                  on='GLACIER',
                                  how='left')
glacier_info.head()

## Intro & methods:

### Geoplots:


#### sqrt scaling:

In [None]:
# Open the original raster
tif_name = "landesforstinventar-vegetationshoehenmodell_relief_sentinel_2024_2056.tif"
tif_path = os.path.join(cfg.dataPath, 'GLAMOS/RGI/', tif_name)

# Desired output resolution (in degrees)
# Approx. 100 m in degrees: ~0.0009 deg
target_res = 0.0009
output_crs = "EPSG:4326"  # WGS84

with rasterio.open(tif_path) as src:
    # Calculate transform and shape with coarser resolution
    transform, width, height = calculate_default_transform(
        src.crs,
        output_crs,
        src.width,
        src.height,
        *src.bounds,
        resolution=target_res)

    # Set up destination array and metadata
    kwargs = src.meta.copy()
    kwargs.update({
        'crs': output_crs,
        'transform': transform,
        'width': width,
        'height': height
    })

    # Prepare empty destination array
    destination = np.empty((height, width), dtype=src.dtypes[0])

    # Reproject with coarsening
    reproject(
        source=rasterio.band(src, 1),
        destination=destination,
        src_transform=src.transform,
        src_crs=src.crs,
        dst_transform=transform,
        dst_crs=output_crs,
        resampling=Resampling.
        average  # average to reduce noise when downsampling
    )

    extent = [
        transform[2], transform[2] + transform[0] * width,
        transform[5] + transform[4] * height, transform[5]
    ]

In [None]:
# ---- 1. Preprocessing ----
# Square-root scaling of number of measurements
glacier_info['sqrt_size'] = np.sqrt(glacier_info['Nb. measurements'])

# Cache dataset-wide min and max
sqrt_min = glacier_info['sqrt_size'].min()
sqrt_max = glacier_info['sqrt_size'].max()

# Define the desired marker size range in points^2
sizes = (30, 1500)  # min and max scatter size


# Function to scale individual values consistently
def scaled_size(val, min_out=sizes[0], max_out=sizes[1]):
    sqrt_val = np.sqrt(val)
    if sqrt_max == sqrt_min:
        return (min_out + max_out) / 2
    return min_out + (max_out - min_out) * ((sqrt_val - sqrt_min) /
                                            (sqrt_max - sqrt_min))


# Apply scaling to full dataset for the actual plot
glacier_info['scaled_size'] = glacier_info['Nb. measurements'].apply(
    scaled_size)

# ---- 2. Create figure and base map ----
fig = plt.figure(figsize=(18, 10))

#latN, latS = 48, 45.8
latN, latS = 47.1, 45.8
lonW, lonE = 5.8, 10.5
projPC = ccrs.PlateCarree()
ax2 = plt.axes(projection=projPC)
ax2.set_extent([lonW, lonE, latS, latN], crs=ccrs.Geodetic())

ax2.add_feature(cfeature.COASTLINE)
ax2.add_feature(cfeature.LAKES)
ax2.add_feature(cfeature.RIVERS)
# ax2.add_feature(cfeature.BORDERS, linestyle='-', linewidth=1)

# Add the image to the cartopy map

masked_destination = np.ma.masked_where(destination == 0, destination)
cmap = plt.cm.gray
cmap.set_bad(color='white')  # Set masked (bad) values to white
ax2.imshow(
    masked_destination,
    origin='upper',
    extent=extent,
    transform=ccrs.PlateCarree(),  # Assuming raster is in WGS84
    cmap=cmap,  # or any other colormap
    alpha=0.6,  # transparency
    zorder=0)

# Glacier outlines
glacier_outline_sgi.plot(ax=ax2, transform=projPC, color='black')

# ---- 3. Scatterplot ----
custom_palette = {'Train': color_dark_blue, 'Test': '#b2182b'}

g = sns.scatterplot(
    data=glacier_info,
    x='POINT_LON',
    y='POINT_LAT',
    size='scaled_size',
    hue='Train/Test glacier',
    sizes=sizes,
    alpha=0.6,
    palette=custom_palette,
    transform=projPC,
    ax=ax2,
    zorder=10,
    legend=True  # custom legend added below
)

# ---- 4. Gridlines ----
gl = ax2.gridlines(draw_labels=True,
                   linewidth=1,
                   color='gray',
                   alpha=0.5,
                   linestyle='--')
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': 16, 'color': 'black'}
gl.ylabel_style = {'size': 16, 'color': 'black'}
gl.top_labels = gl.right_labels = False

# ---- 5. Custom Combined Legend ----

# Hue legend handles
handles, labels = g.get_legend_handles_labels()
expected_labels = list(custom_palette.keys())
hue_entries = [(h, l) for h, l in zip(handles, labels) if l in expected_labels]

# Size legend values and handles
size_values = [30, 100, 1000, 6000]
size_handles = [
    Line2D(
        [],
        [],
        marker='o',
        linestyle='None',
        markersize=np.sqrt(scaled_size(val)),  # matplotlib uses radius
        markerfacecolor='gray',
        alpha=0.6,
        label=f'{val}') for val in size_values
]

# Separator label
separator_handle = Patch(facecolor='none',
                         edgecolor='none',
                         label='Nb. measurements')

# Combine all legend entries
# combined_handles = [h for h, _ in hue_entries] + [separator_handle] + size_handles
# combined_labels = [l for _, l in hue_entries] + ['Nb. measurements'] + [str(v) for v in size_values]

# same but without separator
combined_handles = [h for h, _ in hue_entries] + size_handles
combined_labels = [l for _, l in hue_entries] + [str(v) for v in size_values]

# Final legend
ax2.legend(combined_handles,
           combined_labels,
           title='Nb. measurements',
           loc='lower right',
           frameon=True,
           fontsize=18,
           title_fontsize=18,
           borderpad=1.2,
           labelspacing=1.2,
           ncol=3)
# ax2.set_title('Glacier measurement locations', fontsize = 25)
plt.tight_layout()
plt.show()


In [None]:
# Number of measurements per year:
data_glamos.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack().plot(
    kind='bar',
    stacked=True,
    figsize=(20, 5),
    color=[color_dark_blue, '#abd9e9'])
# plt.title('Number of measurements per year for all glaciers', fontsize = 25)
# get legend
plt.legend(title='Period', fontsize=18, title_fontsize=20, ncol=2)

In [None]:
meas_period = data_glamos.groupby(['YEAR', 'PERIOD']).count()['POINT_ID'].unstack()
meas_period.sum()

### Input data:

In [None]:
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# Transform data to monthly format (run or load data):
paths = {
    'csv_path': cfg.dataPath + path_PMB_GLAMOS_csv,
    'era5_climate_data':
    cfg.dataPath + path_ERA5_raw + 'era5_monthly_averaged_data.nc',
    'geopotential_data':
    cfg.dataPath + path_ERA5_raw + 'era5_geopotential_pressure.nc',
    'radiation_save_path': cfg.dataPath + path_pcsr + 'zarr/'
}
RUN = False
data_monthly = process_or_load_data(
    run_flag=RUN,
    data_glamos=data_glamos,
    paths=paths,
    cfg=cfg,
    vois_climate=vois_climate,
    vois_topographical=vois_topographical,
    output_file='CH_wgms_dataset_monthly_LSTM.csv')

# Create DataLoader
dataloader_gl = mbm.dataloader.DataLoader(cfg,
                                          data=data_monthly,
                                          random_seed=cfg.seed,
                                          meta_data_columns=cfg.metaData)
months_head_pad, months_tail_pad = mbm.data_processing.utils.build_head_tail_pads_from_monthly_df(data_monthly)

In [None]:
# Ensure all test glaciers exist in the dataset
existing_glaciers = set(dataloader_gl.data.GLACIER.unique())
missing_glaciers = [g for g in TEST_GLACIERS if g not in existing_glaciers]

if missing_glaciers:
    print(
        f"Warning: The following test glaciers are not in the dataset: {missing_glaciers}"
    )

# Define training glaciers correctly
train_glaciers = [i for i in existing_glaciers if i not in TEST_GLACIERS]

data_test = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(TEST_GLACIERS)]
print('Size of test data:', len(data_test))

data_train = dataloader_gl.data[dataloader_gl.data.GLACIER.isin(
    train_glaciers)]
print('Size of train data:', len(data_train))

if len(data_train) == 0:
    print("Warning: No training data available!")
else:
    test_perc = (len(data_test) / len(data_train)) * 100
    print('Percentage of test size: {:.2f}%'.format(test_perc))

# Number of annual versus winter measurements:
print('Train:')
print('Number of winter and annual samples:', len(data_train))
print('Number of annual samples:',
      len(data_train[data_train.PERIOD == 'annual']))
print('Number of winter samples:',
      len(data_train[data_train.PERIOD == 'winter']))

# Same for test
data_test_annual = data_test[data_test.PERIOD == 'annual']
data_test_winter = data_test[data_test.PERIOD == 'winter']

print('Test:')
print('Number of winter and annual samples:', len(data_test))
print('Number of annual samples:', len(data_test_annual))
print('Number of winter samples:', len(data_test_winter))

print('Total:')
print('Number of monthly rows:', len(dataloader_gl.data))
print('Number of annual rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'annual']))
print('Number of winter rows:',
      len(dataloader_gl.data[dataloader_gl.data.PERIOD == 'winter']))

#### Heatmap annual:

In [None]:
plotHeatmap(TEST_GLACIERS, data_glamos, glacierCap, period='annual')

In [None]:
data_annual = data_glamos[data_glamos.PERIOD == 'annual']
num_years = data_annual.groupby(['GLACIER']).nunique().YEAR.sort_values()
len(num_years[num_years > 30])

In [None]:
# Work on a copy to avoid chained-assignment warnings
data_annual = data_glamos.loc[data_glamos.PERIOD == 'annual'].copy()

# Parse FROM/TO into proper datetimes
data_annual['FROM_date'] = pd.to_datetime(data_annual['FROM_DATE'].astype(str),
                                          format='%Y%m%d',
                                          errors='coerce')

# Work on a copy to avoid chained-assignment warnings
data_winter = data_glamos.loc[data_glamos.PERIOD == 'winter'].copy()

# Parse FROM/TO into proper datetimes
data_winter['TO_date'] = pd.to_datetime(data_winter['TO_DATE'].astype(str),
                                          format='%Y%m%d',
                                          errors='coerce')

# Helper: compute mean date and std (in days) for a date Series, using a dummy year (2000)
def mean_date_and_std(date_series: pd.Series, circular: bool = False):
    # Drop NaT
    s = date_series.dropna()
    if s.empty:
        return pd.NaT, np.nan

    # Map to a dummy, fixed year so we can compute day-of-year
    dummy = pd.to_datetime({
        'year': 2000,
        'month': s.dt.month,
        'day': s.dt.day
    },
                           errors='coerce').dropna()

    doy = dummy.dt.dayofyear.astype(float)

    if not circular:
        mean_doy = doy.mean()
        std_doy = doy.std()
    else:
        # Circular mean/std over the year (useful if dates wrap around New Year)
        theta = 2 * np.pi * (doy - 1) / 365.0
        C = np.mean(np.cos(theta))
        S = np.mean(np.sin(theta))
        mean_ang = np.arctan2(S, C)
        if mean_ang < 0:
            mean_ang += 2 * np.pi
        mean_doy = (mean_ang / (2 * np.pi)) * 365.0 + 1
        R = np.sqrt(C**2 + S**2)
        # Convert circular std (radians) to days
        std_ang = np.sqrt(-2 * np.log(max(R, 1e-12)))
        std_doy = std_ang * 365.0 / (2 * np.pi)

    mean_date = pd.Timestamp('2000-01-01') + pd.to_timedelta(mean_doy - 1,
                                                             unit='D')
    return mean_date, std_doy


# Compute stats for FROM and TO
mean_from_annual, std_from_annual = mean_date_and_std(data_annual['FROM_date'],
                                                      circular=False)
mean_from_winter, std_from_winter = mean_date_and_std(data_winter['TO_date'],
                                                      circular=False)

print(
    f"ANNUAL FROM_DATE -> mean: {mean_from_annual.strftime('%m-%d')} | std: {std_from_annual:.2f} days"
)
print(
    f"WINTER TO_DATE   -> mean: {mean_from_winter.strftime('%m-%d')} | std: {std_from_winter:.2f} days"
)

In [None]:
# get elevation of glaciers:
gl_per_el = data_glamos[data_glamos.PERIOD == 'annual'].groupby(
    ['GLACIER'])['POINT_ELEVATION'].mean()
gl_per_el = gl_per_el.sort_values(ascending=False)

# Plot elevation:
fig = plt.figure(figsize=(10, 2))
ax = plt.subplot(1, 1, 1)
sns.lineplot(gl_per_el.sort_values(ascending=True),
             ax=ax,
             color='gray',
             marker='v')
ax.set_xticklabels('', rotation=90)
ax.set_ylabel('')
ax.set_xlabel('')

In [None]:
data_annual.groupby('GLACIER').nunique().YEAR.sort_values()

In [None]:
sns.boxplot(gl_per_el)

median = np.median(gl_per_el.values)
upper_quartile = np.percentile(gl_per_el.values, 75)
lower_quartile = np.percentile(gl_per_el.values, 25)

iqr = upper_quartile - lower_quartile
upper_whisker = gl_per_el.values[gl_per_el.values <= upper_quartile +
                                 1.5 * iqr].max()
lower_whisker = gl_per_el.values[gl_per_el.values >= lower_quartile -
                                 1.5 * iqr].min()

print(f"Median elevation: {median:.2f} m")
print(f"Upper quartile elevation: {upper_quartile:.2f} m")
print(f"Lower quartile elevation: {lower_quartile:.2f} m")
print(f"Upper whisker elevation: {upper_whisker:.2f} m")
print(f"Lower whisker elevation: {lower_whisker:.2f} m")