In [1]:
import xarray as xr
import arviz as az
import pymc as pm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import matplotlib.cm as cm
import sys
import emcee
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
import pathlib
from scipy import stats
from scipy.stats import norm
import seaborn as sns
from emcee.autocorr import integrated_time

## Script to anaylse outputs of emulator-based MCMC runs

In [None]:
if 'win' in sys.platform:
    path = "E:/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/"
else:
    path = "/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/"

#ds = az.from_netcdf(path+"point_demcz_posterior_combined.nc")
#ds = az.from_netcdf(path+"stage1_demczsyserr_posterior_combined.nc") #snowline + alb summer sys err
#ds = az.from_netcdf(path+"stage1_albmb_demczsyserr_posterior_combined.nc")
ds = az.from_netcdf(path+"stage2_final_demczsyserr_posterior_combined.nc")

#ds = az.from_netcdf(path+"stage1_albmbtimes3_demczsyserr_posterior_combined.nc")
#ds = az.from_netcdf(path+"full_demczsyserr_posterior_combined.nc") #snowline + alb summer sys err
#ds = az.from_netcdf(path+"stage1_syserr_demcz_posterior_combined.nc") #snowline sys err
#ds = az.from_netcdf(path+"demcz_posterior_combined.nc")
#ds = az.from_netcdf(path+"stage1_demcz_posterior_combined.nc")
#ds = az.from_netcdf(path+"stage2_demcz_posterior_combined.nc")
ds

In [None]:
#Normal params
try:
    param_names = ["rrrfactor", "albice", "albsnow", "albfirn", "albaging", "albdepth", "iceroughness","centersnow"]
    az.plot_trace(ds, var_names=param_names)
except:
    param_names = ["rrrfactor", "albice", "albsnow", "albfirn", "albaging", "albdepth", "iceroughness"]
    az.plot_trace(ds, var_names=param_names)    

In [None]:
ds.posterior.loglike_alb[:,:].plot.hist()

In [None]:
ds.posterior.loglike_tsl[:,:].plot.hist()

In [None]:
ds.posterior.loglike_mb[:,:].plot.hist()

In [None]:
ds.posterior.mu_mb.isel(chain=2, mu_mb_dim_0=0).plot.hist()

In [None]:
ds.posterior.total_loglike.isel(chain=2).plot.hist()

In [None]:
summary_stage1 = az.summary(ds, hdi_prob=0.95)
summary_stage1.loc[summary_stage1.index.intersection(param_names + ["sigma_alb_summer", "sigma_tsl_summer"])] #"sigma_tsl_summer"


In [None]:
## Derive init vals for 2nd stage
num_chains_stage2 = 20

# Get the total number of posterior samples from Stage 1
num_draws = ds.posterior.draw.size
num_chains_stage1 = ds.posterior.chain.size

# --- Generate a list of initial value dictionaries ---
initial_values_stage2 = []
for _ in range(num_chains_stage2):
    # Pick a random chain and a random draw from the Stage 1 trace
    rand_chain = np.random.randint(0, num_chains_stage1)
    rand_draw = np.random.randint(0, num_draws)
    
    # Create a dictionary for this chain's initial values
    init_dict = {
        'rrrfactor': ds.posterior['rrrfactor'].isel(chain=rand_chain, draw=rand_draw).values,
        'albsnow': ds.posterior['albsnow'].isel(chain=rand_chain, draw=rand_draw).values,
        'albfirn': ds.posterior['albfirn'].isel(chain=rand_chain, draw=rand_draw).values,
        'albaging': ds.posterior['albaging'].isel(chain=rand_chain, draw=rand_draw).values,
        'albdepth': ds.posterior['albdepth'].isel(chain=rand_chain, draw=rand_draw).values,
        'albice': ds.posterior['albice'].isel(chain=rand_chain, draw=rand_draw).values,
        'iceroughness': ds.posterior['iceroughness'].isel(chain=rand_chain, draw=rand_draw).values,
        #'sigma_tsl_summer': ds.posterior['sigma_tsl_summer'].isel(chain=rand_chain, draw=rand_draw).values,
        'sigma_alb_summer': ds.posterior['sigma_alb_summer'].isel(chain=rand_chain, draw=rand_draw).values
    }
    initial_values_stage2.append(init_dict)
    
import pickle
if 'win' in sys.platform:
    with open('E:/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/MiscTests/LHS/stage2_initial_values.pkl', 'wb') as f:
        pickle.dump(initial_values_stage2, f)
else:
    with open('/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/MiscTests/LHS/stage2_initial_values.pkl', 'wb') as f:
        pickle.dump(initial_values_stage2, f)     
initial_values_stage2

In [None]:
#unc terms
try:
    #param_unc_names = ["sigma_alb_summer"]
    param_unc_names = ["sigma_tsl_summer", "sigma_alb_summer"]
    az.plot_trace(ds, var_names=param_unc_names)
except:
    pass

In [None]:
#### TODO ####
## Integrative autocorrelation time - IAT for all parameters + average, must flatten 25000
## ESS 
## R-hat
## Monte Carlo Standard Error
ds.posterior

In [None]:
# Step 1: Extract total log-likelihood (shape: chains x draws)
# Assumes you have this already
loglikes = ds.posterior.total_loglike.values[:,:]  # shape (n_chains, n_samples)

# Step 2: Compute mean log-likelihood per chain
mean_ll = loglikes.mean(axis=1)

# Step 3: Get indices of best chains (exclude worst 4)
n_cut = 0 #4
n_chains_to_keep = loglikes.shape[0] - n_cut
best_chain_idx = np.argsort(-mean_ll)[:n_chains_to_keep]  # best → worst
print(best_chain_idx)
# Step 4: Subset posterior object
# Assumes `posterior` is an arviz.InferenceData object
posterior_subset = ds.posterior.sel(chain=best_chain_idx)
posterior_subset



In [None]:
# -- CONFIGURATION --
import matplotlib.ticker as ticker
dpi = 300
figsize = (13, 13)  # adjust based on number of parameters
font_size = 22
sample_size = 20000
palette = sns.color_palette("crest", as_cmap=True)

# -- SET STYLE --
sns.set(style="white", font_scale=1.6)
plt.rcParams.update({
    "font.size": font_size,
    "axes.labelsize": font_size,
    "axes.titlesize": font_size,
    "xtick.labelsize": font_size * 0.8,
    "ytick.labelsize": font_size * 0.8,
    "figure.dpi": dpi
})

# -- EXTRACT POSTERIOR AND SAMPLE --
posi = ds.posterior[param_names]
posterior_stacked = posi.stack(sample=("chain", "draw"))  # collapse chains
posi_samples_kde = posterior_stacked.to_dataframe()[param_names].sample(sample_size, random_state=42)
posi_samples_kde.rename(columns={'rrrfactor': r'$p_{f}$', 'albice': r'$\alpha_{ice}$', 'albsnow': r'$\alpha_{fs}$','albfirn': r'$\alpha_{firn}$', 'albaging': r'$\alpha_{aging}$',
                'albdepth': r'$\alpha_{depth}$','iceroughness': r'$z0_{ice}$'}, inplace=True)

# -- PAIRPLOT --
g = sns.pairplot(posi_samples_kde, kind="kde", diag_kind="kde", plot_kws={"fill": True, "cmap": "crest", "levels": 8}, diag_kws={"fill": True, "color": "black"}, corner=True)
g.fig.set_size_inches(figsize)
for ax in g.axes.flat:
    if ax is not None:
        ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=2))
        ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=2))

# -- SAVE FIGURE --
plt.tight_layout()
plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/posterior_param_pairplots.pdf", bbox_inches="tight")

In [None]:
ds.posterior["sigma_alb_summer"].values.shape

In [None]:
## Prepare integrated autocorr times
comb_full = np.stack([ds.posterior["rrrfactor"].values, ds.posterior["albice"].values, ds.posterior["albsnow"].values,
                 ds.posterior["albfirn"].values, ds.posterior["albaging"].values, ds.posterior["albdepth"].values,
                 ds.posterior["iceroughness"].values, ds.posterior["sigma_tsl_summer"].values, ds.posterior["sigma_alb_summer"].values])
print(comb_full.shape)

## For subset without bad chains:
comb = np.stack([posterior_subset["rrrfactor"].values, posterior_subset["albice"].values, posterior_subset["albsnow"].values,
                 posterior_subset["albfirn"].values, posterior_subset["albaging"].values, posterior_subset["albdepth"].values,
                 posterior_subset["iceroughness"].values, posterior_subset["sigma_tsl_summer"].values, posterior_subset["sigma_alb_summer"].values])
print(comb.shape)

params_in_order = ["rrrfactor", "albice", "albsnow", "albfirn", "albaging", "albdepth", "iceroughness", "sigma_tsl_summer", "sigma_alb_summer"]

samples_full = comb_full.transpose(2, 1, 0)
samples = comb.transpose(2, 1, 0)  # Now (n_samples, n_chains, n_param)
  # shape: (n_samples, n_chains, n_param)
print(samples.shape)


In [None]:
samples = samples[5000:]
print(samples.shape)

In [17]:
n_steps, n_walkers, n_param = samples.shape
ns = np.arange(5000, n_steps, 5000) #start 10k not 5k
taus_by_param = [[] for _ in range(n_param)]  # One list per parameter

for n in ns:
    try:
        tau = integrated_time(samples[:n, :, :], quiet=True)  # shape (n_param,)
        for i in range(n_param):
            taus_by_param[i].append(tau[i])
    except emcee.autocorr.AutocorrError:
        for i in range(n_param):
            taus_by_param[i].append(np.nan)  # Pad with nan if not computable
            
## Repeat for full dataset
n_steps_full, n_walkers_full, n_param_full = samples_full.shape
ns = np.arange(5000, n_steps, 5000)
taus_by_param_full = [[] for _ in range(n_param_full)]  # One list per parameter

for n in ns:
    try:
        tau = integrated_time(samples_full[:n, :, :], quiet=True)  # shape (n_param,)
        for i in range(n_param_full):
            taus_by_param_full[i].append(tau[i])
    except emcee.autocorr.AutocorrError:
        for i in range(n_param_full):
            taus_by_param_full[i].append(np.nan)  # Pad with nan if not computable

In [None]:
# Load first stage results to paste into approach
ds_first = az.from_netcdf(path+"stage1_final_demczsyserr_posterior_combined.nc")
ds_first

In [None]:
## Note: Due to 2-stage approach, we need TSLA from this run but MB and ALB should come from first logl. run! 
# 
total_loglikes = ds.posterior.total_loglike.values
print(total_loglikes.shape)
total_loglikes = total_loglikes[:,:]
print(total_loglikes.shape)

mb_loglikes = ds_first.posterior.loglike_mb.values
mb_loglikes = mb_loglikes[:,:,0,0]
print(mb_loglikes.shape)

alb_loglikes = ds_first.posterior.loglike_alb.values
print(alb_loglikes.shape)

tsla_loglikes = ds.posterior.loglike_tsl.values
print(tsla_loglikes.shape)

In [None]:
# total_loglikes shape: (n_chains, n_steps)
n_chains, n_steps = total_loglikes.shape

# Compute mean log-likelihood per chain
mean_ll = total_loglikes.mean(axis=1)
print(mean_ll)

# Sort chains from best (highest mean) to worst
sorted_idx = np.argsort(-mean_ll)  # negative for descending

# Generate colors from colormap (e.g., 'viridis', 'plasma', 'inferno')
try:
    cmap = cm.cividis_r  # Pick your preferred colormap
except:
    cmap = cm.viridis
colors = [cmap(i / (n_chains - 1)) for i in range(n_chains)]  # evenly spaced colors

In [21]:
mb_loglike_mean = mb_loglikes.mean(axis=0)
tsla_loglike_mean = tsla_loglikes.mean(axis=0)
alb_loglike_mean = alb_loglikes.mean(axis=0)
total_loglike_mean = total_loglikes.mean(axis=0)


In [None]:
taus_array = np.array(taus_by_param_full)
taus_array.shape

In [None]:
## New plot - idea: 4 panels, Logls (3) + 1 integr. auto corr.
plt.rcParams.update({'font.size': 22})
fig, axes = plt.subplots(2, 2, figsize=(20, 14), dpi=300, sharex=True)
axes = axes.flatten()

for rank, chain_idx in enumerate(sorted_idx):
    if chain_idx < 15:
        axes[0].plot(
            mb_loglikes[chain_idx, :],
            color=colors[rank], alpha=0.7, 
            label=f"Chain {chain_idx} (Rank {rank+1})"
        )
        #axes[0].axvline(x=5000, linestyle="dashed", color="black", lw=0.5)
    
    axes[1].plot(
        tsla_loglikes[chain_idx, :],
        color=colors[rank], alpha=0.7, 
        label=f"Chain {chain_idx} (Rank {rank+1})"
    )
    #axes[1].axvline(x=5000, linestyle="dashed", color="black", lw=0.5) 
    
    if chain_idx < 15:   
        axes[2].plot(
            alb_loglikes[chain_idx, :],
            color=colors[rank], alpha=0.7, 
            label=f"Chain {chain_idx} (Rank {rank+1})"
        )
        #axes[2].axvline(x=5000, linestyle="dashed", color="black", lw=0.5)

# for each subplot add a red mean line
axes[0].plot(mb_loglike_mean, color="black", alpha=0.4, zorder=7)
axes[1].plot(tsla_loglike_mean, color="black", alpha=0.4, zorder=7)
axes[2].plot(alb_loglike_mean, color="black", alpha=0.4, zorder=7)
#axes[3].plot(mb_loglike_mean, color="red", alpha=0.6, zorder=7)


# Add axes labels
axes[0].set_ylabel(r'$\mathcal{L}(B_{geod}|\theta)$')
axes[1].set_ylabel(r'$\mathcal{L}(SLA|\theta)$')
axes[2].set_ylabel(r'$\mathcal{L}(\bar{\alpha}|\theta)$')

for i in range(n_param_full):
    axes[3].plot(ns, taus_by_param_full[i], label=f"{params_in_order[i]}", marker='o')

#mean
axes[3].plot(ns, taus_array.mean(axis=0), marker='o', color="black", zorder=7)

axes[3].set_xticks(np.arange(0, 100000+10000, 10000))
axes[3].set_xticklabels(np.arange(0, 100000+10000, 10000), rotation=30)

axes[2].set_xticks(np.arange(0, 100000+10000, 10000))
axes[2].set_xticklabels(np.arange(0, 100000+10000, 10000), rotation=30)
axes[2].set_xlabel("Samples in chains")

axes[3].set_yticks(np.arange(20, 40, 5))
axes[3].set_ylabel("Integr. Autocorr. Time")
axes[3].set_xlabel("Samples in chains")

fig.text(0.01, 0.98, 'a)', transform=fig.transFigure, fontsize=24)
fig.text(0.49, 0.98, 'b)', transform=fig.transFigure, fontsize=24)
fig.text(0.01, 0.53, 'c)', transform=fig.transFigure, fontsize=24)
fig.text(0.49, 0.53, 'd)', transform=fig.transFigure, fontsize=24)
#

# Add legend (example: fake classes or categories)
y_label_dict = {'rrrfactor': r'$p_{f}$', 'albice': r'$\alpha_{ice}$', 'albsnow': r'$\alpha_{fs}$','albfirn': r'$\alpha_{firn}$', 'albaging': r'$\alpha_{aging}$',
                'albdepth': r'$\alpha_{depth}$','iceroughness': r'$z0_{ice}$', 'mb_logp': r'$\mathcal{L}(MB|\theta)$', 'sigma_tsl_summer': r'$\sigma_{\eta}^{SLA}$',
                'sigma_alb_summer': r'$\sigma_{\eta}^{\bar{\alpha}}$'}
#plt.subplots_adjust(wspace=0.0, hspace=0.4, right=0.9, bottom=0.1)
fig.tight_layout()
## rename params in legends
handles, labels_old = axes[3].get_legend_handles_labels()
labels = [y_label_dict[x] for x in labels_old]
fig.legend(handles, labels,
           loc='lower center', ncol=9, frameon=True, bbox_to_anchor=(0.5, -0.04))

if 'win' in sys.platform:
    plt.savefig("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/Fig05_mcmc_convergence_plots.png", bbox_inches="tight")
else:
    plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/Fig05_mcmc_convergence_plots.png", bbox_inches="tight")
"""
"""

In [None]:
""" ARCHIVED because of two-stage approach now.

plt.rcParams.update({'font.size': 22})
fig, axes = plt.subplots(3, 2, figsize=(16, 12), dpi=300, sharex=True)
axes = axes.flatten()

for rank, chain_idx in enumerate(sorted_idx):
    axes[0].plot(
        total_loglikes[chain_idx, :],
        color=colors[rank], alpha=0.7, 
        label=f"Chain {chain_idx} (Rank {rank+1})"
    )
    axes[0].axvline(x=5000, linestyle="dashed", color="black", lw=0.5)
    
    axes[1].plot(
        mb_loglikes[chain_idx, :],
        color=colors[rank], alpha=0.7, 
        label=f"Chain {chain_idx} (Rank {rank+1})"
    )
    axes[1].axvline(x=5000, linestyle="dashed", color="black", lw=0.5) 
       
    axes[2].plot(
        tsla_loglikes[chain_idx, :],
        color=colors[rank], alpha=0.7, 
        label=f"Chain {chain_idx} (Rank {rank+1})"
    )
    axes[2].axvline(x=5000, linestyle="dashed", color="black", lw=0.5)
    
    axes[3].plot(
        alb_loglikes[chain_idx, :],
        color=colors[rank], alpha=0.7, 
        label=f"Chain {chain_idx} (Rank {rank+1})"
    )
    axes[3].axvline(x=5000, linestyle="dashed", color="black", lw=0.5)
# Add axes labels
axes[0].set_ylabel(r'$\mathcal{L}(total|\theta)$')
axes[1].set_ylabel(r'$\mathcal{L}(MB|\theta)$')
axes[2].set_ylabel(r'$\mathcal{L}(TSLA|\theta)$')
axes[3].set_ylabel(r'$\mathcal{L}(ALB|\theta)$')
axes[3].set_yticks(np.arange(1, 1.6+0.2, 0.2))

# Add colorbar to show ranking
# Shared vertical colorbar for first 4 plots
# Position: [left, bottom, width, height]
cbar_ax = fig.add_axes([0.92, 0.37, 0.02, 0.52])  # spans top two rows
sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=n_chains - 1))
sm.set_array([])
cbar = plt.colorbar(sm, cax=cbar_ax)
cbar.set_label("Chain Performance Rank")

for i in range(n_param_full):
    axes[4].plot(ns, taus_by_param_full[i], label=f"{params_in_order[i]}", marker='o')

for i in range(n_param):
    axes[5].plot(ns, taus_by_param[i], label=f"{params_in_order[i]}", marker='o')    

axes[4].set_xticks(np.arange(0, 100000+10000, 10000))
axes[5].set_xticks(np.arange(0, 100000+10000, 10000))
axes[4].set_xticklabels(np.arange(0, 100000+10000, 10000), rotation=30)
axes[5].set_xticklabels(np.arange(0, 100000+10000, 10000), rotation=30)

axes[4].set_yticks(np.arange(20, 65+15, 15))
axes[5].set_yticks(np.arange(20, 45+5, 5))
axes[4].set_ylabel("Integr. Autocorr. Time")
axes[4].set_xlabel("Samples in chains")
axes[5].set_xlabel("Samples in chains")
# Add legend (example: fake classes or categories)
y_label_dict = {'rrrfactor': r'$p_{f}$', 'albice': r'$\alpha_{ice}$', 'albsnow': r'$\alpha_{fs}$','albfirn': r'$\alpha_{firn}$', 'albaging': r'$\alpha_{aging}$',
                'albdepth': r'$\alpha_{depth}$','iceroughness': r'$z0_{ice}$', 'mb_logp': r'$\mathcal{L}(MB|\theta)$', 'sigma_tsl_summer': r'$\sigma_{tsla}^{sys}$',
                'sigma_alb_summer': r'$\sigma_{alb}^{sys}$'}

handles, labels_old = axes[4].get_legend_handles_labels()
labels = [y_label_dict[x] for x in labels_old]
fig.legend(handles, labels,
           loc='lower center', ncol=7, frameon=True, bbox_to_anchor=(0.5, -0.07))
#plt.subplots_adjust(wspace=0.0, hspace=0.4, right=0.9, bottom=0.1)
#fig.tight_layout()
## rename params in legends
if 'win' in sys.platform:
    plt.savefig("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_convergence_plots.png", bbox_inches="tight")
else:
    plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_convergence_plots.png", bbox_inches="tight")
"""

In [25]:
## Load priors and add to plot below
"""
with pm.Model() as model:
    #Stage 1 Params: TSLA + ALB only
    rrr = pm.TruncatedNormal('rrrfactor', mu=0.7785, sigma=0.781, lower=0.648, upper=0.946)
    snow = pm.TruncatedNormal("albsnow", mu=0.903, sigma=0.1, lower=0.887, upper=0.928)
    ice = pm.TruncatedNormal("albice", mu=0.17523, sigma=0.1, lower=0.1182, upper=0.2302)
    firn = pm.TruncatedNormal("albfirn", mu=0.6036, sigma=0.1, lower=0.51, upper=0.6747)
    aging = pm.TruncatedNormal("albaging", mu=13.82, sigma=5.372, lower=5, upper=24.77)
    depth = pm.TruncatedNormal("albdepth", mu=1.776, sigma=0.666, lower=1.0, upper=4)
    rough = pm.TruncatedNormal("iceroughness", mu=8.612, sigma=9, lower=1.2, upper=19.65)
    sigma_alb_summer = pm.HalfNormal("sigma_alb_summer", sigma=0.02)
    sigma_tsl_summer = pm.HalfNormal("sigma_tsl_summer", sigma=0.03)
"""
from scipy.stats import truncnorm, halfnorm

# Helper function to convert bounds to a, b for truncnorm
def get_truncnorm_samples(mu, sigma, lower, upper, size=1000):
    a, b = (lower - mu) / sigma, (upper - mu) / sigma
    return truncnorm.rvs(a, b, loc=mu, scale=sigma, size=size)

# Sample size
n_samples = 100000

# Define priors and sample
priors = {
    'rrrfactor': get_truncnorm_samples(0.7785, 0.0781, 0.648, 0.946, n_samples),
    'albsnow': get_truncnorm_samples(0.903, 0.1, 0.887, 0.928, n_samples),
    'albice': get_truncnorm_samples(0.17523, 0.1, 0.1182, 0.2302, n_samples),
    'albfirn': get_truncnorm_samples(0.6036, 0.1, 0.51, 0.6747, n_samples),
    'albaging': get_truncnorm_samples(13.82, 5.372, 5, 24.77, n_samples),
    'albdepth': get_truncnorm_samples(1.776, 0.666, 1.0, 4, n_samples),
    'iceroughness': get_truncnorm_samples(8.612, 9, 1.2, 19.65, n_samples),
    'sigma_alb_summer': halfnorm.rvs(scale=0.02, size=n_samples),
    'sigma_tsl_summer': halfnorm.rvs(scale=0.03, size=n_samples),
}


In [None]:
fig, ax = plt.subplots(1,1)
sns.kdeplot(priors['albsnow'], ax=ax)

In [None]:
## Custom plot for chains based on az.trace
chains = comb_full.transpose(1, 2, 0)
#select from 5000 onwards
chains = chains[:,5000:,:]
print(chains.shape)

n_chains, n_samples, n_params = chains.shape

In [None]:
# Create plot
plt.rcParams.update({'font.size': 22})
fig, axes = plt.subplots(n_params, 2, figsize=(12, 2.5 * n_params), dpi=300)
if n_params == 1:
    axes = np.expand_dims(axes, 0)

for i in range(n_params):
    ax_trace = axes[i, 0]
    ax_kde = axes[i, 1]

    # Plot each chain, color-coded by performance rank
    for rank, chain_idx in enumerate(sorted_idx):
        chain_data = chains[chain_idx, :, i]
        ax_trace.plot(
            chain_data,
            color=colors[rank],
            alpha=0.7,
            label=f"Chain {chain_idx} (Rank {rank+1})" if i == 0 else None
        )
    
    mean_trace = np.mean(chains[:, :, i], axis=0)

    # Plot the mean trace line
    ax_trace.plot(
        mean_trace,
        color='black',
        alpha=0.4,
        #label='Mean Trace' if i == 0 else None
    )
    # Median and 95% CI from all chains combined
    combined_samples = chains[:, :, i].flatten()
    median = np.median(combined_samples)
    lower, upper = np.percentile(combined_samples, [2.5, 97.5])

    # Add horizontal lines to trace plot
    #ax_trace.axhline(median, color='black', linestyle='--', label='Median' if i == 0 else None)
    #ax_trace.axhline(lower, color='gray', linestyle=':', label='95% CI' if i == 0 else None)
    #ax_trace.axhline(upper, color='gray', linestyle=':')
    

    ax_trace.set_ylabel(y_label_dict[params_in_order[i]])
    ax_trace.set_xlabel("")
    if i != 8:
        ax_trace.set_xticklabels("")


    # KDE plot
    sns.kdeplot(priors[params_in_order[i]], fill=True, ax=ax_kde, color="coral", alpha=0.7, bw_adjust=0.7, common_norm=False, label="Prior")
    sns.kdeplot(combined_samples, fill=True, ax=ax_kde, color="slateblue", alpha=0.7, bw_adjust=0.7, common_norm=False, label="Posterior")
    ax_kde.axvline(median, color='black', linestyle='--')
    ax_kde.axvline(lower, color='gray', linestyle=':')
    ax_kde.axvline(upper, color='gray', linestyle=':')
    ax_kde.set_ylabel("")

    #if i == n_params - 1:
    #    ax_kde.set_xlabel(params_in_order[i])
    
y_top = 0.990
y_bottom = 0.120
rows = 9  # number of rows
step = (y_top - y_bottom) / (rows - 1)

labels = [chr(i) + ')' for i in range(ord('a'), ord('r') + 1)]

for i in range(rows):
    y = y_top - i * step
    # Left column
    fig.text(0.02, y, labels[2*i], transform=fig.transFigure, fontsize=22)
    # Right column
    fig.text(0.52, y, labels[2*i+1], transform=fig.transFigure, fontsize=22)
    
axes[0, 1].legend()
axes[-1,0].set_xlabel("Samples in chains")
fig.tight_layout()

if 'win' in sys.platform:
    plt.savefig("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/FigA01_mcmc_chains_traces_plots.png", bbox_inches="tight")
else:
    plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/FigA01_mcmc_chains_traces_plots.png", bbox_inches="tight")
"""
"""
#last ticks necessary, density label off

# Add shared legend
#handles, labels = ax_trace.get_legend_handles_labels()
#fig.legend(handles, labels, loc="upper right", bbox_to_anchor=(0.95, 0.98), frameon=False)

#plt.tight_layout(rect=[0, 0, 0.93, 1])
#plt.show()

#label first y-axis based on plot dictionary

In [None]:
## Identify chains to drop ...
az.rhat(ds, var_names=param_names)

In [None]:
az.ess(ds, var_names=param_names)


In [None]:
posterior_subset.mu_mb[:,:,0].plot.hist()

In [None]:
## Compute stats
az.ess(posterior_subset, var_names=param_names)

In [None]:
az.rhat(posterior_subset, var_names=param_names)

In [None]:
taus = np.array(taus_by_param)
np.max(taus, axis=1)

In [None]:
posterior_subset

In [None]:
## Thinnen and check...
# Thin posterior by keeping every 45th sample
#ds.posterior.sel(chain=best_chain_idx)
subset = posterior_subset.sel(draw=slice(5000, None)) #discard first 5000
idata_thinned = subset.sel(draw=slice(None, None, 40)) #thin by 45
idata_thinned

In [None]:
az.plot_trace(idata_thinned, var_names=param_names)

In [None]:
## create size
post_param_holder = idata_thinned["rrrfactor"].values
flat_param_holder = post_param_holder.reshape(-1)
flat_param_size = flat_param_holder.size
random_indices = np.random.choice(flat_param_size, size=300, replace=False) #generate 300 random samples
random_indices

In [None]:
## Sample n=200 and check
dic_samples = {}
for param in param_names + ["sigma_tsl_summer","sigma_alb_summer"]:
    print("Sampling ", param)
    post_param = idata_thinned[param].values
    print(post_param.shape)
    
    flat_param = post_param.reshape(-1)  # shape: (n_chains * n_samples,)
    print(flat_param.shape)

    final_samples = flat_param[random_indices]
    dic_samples[param] = final_samples
    

dic_samples

In [None]:
import pandas as pd
df = pd.DataFrame(dic_samples)
df

In [None]:
# Convert DataFrame to dict of variables (each as shape (1, 200))
df_copy = df.copy()
df_copy["chain"] = 0
df_copy["draw"] = np.arange(len(df_copy), dtype=int)
df_copy = df_copy.set_index(["chain", "draw"])
xdata = xr.Dataset.from_dataframe(df_copy)

dataset = az.InferenceData(posterior=xdata)
dataset

In [None]:
az.plot_trace(dataset, var_names=param_names)

In [None]:
column_names_for_spotpy = {'rrrfactor': 'parRRR_factor', 'albice': 'paralb_ice', 'albsnow': 'paralb_snow',
                           'albfirn': 'paralb_firn', 'albaging': 'paralbedo_aging', 'albdepth': 'paralbedo_depth',
                           'iceroughness': 'parroughness_ice'}

df = df.rename(columns=column_names_for_spotpy)
df

In [36]:
save = False
if save != False:
    if 'win' in sys.platform:
        df.to_csv("E:/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/MiscTests/LHS/final_samples_mcmc.csv", index=False)
    else:
        df.to_csv("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/MiscTests/LHS/final_samples_mcmc.csv", index=False)
        
## Store these to .csv and run the full COSIPY model with them (no emulator)

In [None]:
### Load for vali plots - seed not set, so be careful with saving###
try:
    del df
except:
    pass
df = pd.read_csv("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/COSIPY/MiscTests/LHS/final_samples_mcmc.csv")
df

In [None]:
## Load albedo data
if 'win' in sys.platform:
    albobs = xr.open_dataset("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Climate/HEF_processed_HRZ-30CC-filter_albedos.nc")
    tsla = pd.read_csv("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Climate/snowlines/HEF-snowlines-1999-2010_manual_filtered.csv")
else:
    albobs = xr.open_dataset("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Climate/HEF_processed_HRZ-30CC-filter_albedos.nc")
    tsla = pd.read_csv("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Climate/snowlines/HEF-snowlines-1999-2010_manual_filtered.csv")
albobs = albobs.sortby("time")

## Load snowlines
time_start_dt = pd.to_datetime("2000-01-01") #config starts with spinup - need to add 1year
time_end_dt = pd.to_datetime("2009-12-31")


tsla_true_obs = tsla.copy()
tsla_true_obs['LS_DATE'] = pd.to_datetime(tsla_true_obs['LS_DATE'])
print("Start date:", time_start_dt)
print("End date:", time_end_dt)
tsla_true_obs = tsla_true_obs.loc[(tsla_true_obs['LS_DATE'] > time_start_dt) & (tsla_true_obs['LS_DATE'] <= time_end_dt)]
tsla_true_obs.set_index('LS_DATE', inplace=True)
#Normalize standard deviation if necessary
tsla_true_obs['SC_stdev'] = (tsla_true_obs['SC_stdev']) / (tsla_true_obs['glacier_DEM_max'] - tsla_true_obs['glacier_DEM_min'])

thres_unc = (20) / (tsla_true_obs['glacier_DEM_max'].iloc[0] - tsla_true_obs['glacier_DEM_min'].iloc[0])
print(thres_unc)

## Set observational uncertainty where smaller to atleast model resolution (20m) and where larger keep it
sc_norm = np.where(tsla_true_obs['SC_stdev'] < thres_unc, thres_unc, tsla_true_obs['SC_stdev'])
tsla_true_obs['SC_stdev'] = sc_norm

## Load MB
rgi_id = "RGI60-11.00897"
if 'win' in sys.platform:
    geod_ref = pd.read_csv("E:/OneDrive/PhD/PhD/Data/Hugonnet_21_MB/dh_11_rgi60_pergla_rates.csv")
else:
    geod_ref = pd.read_csv("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hugonnet_21_MB/dh_11_rgi60_pergla_rates.csv")
geod_ref = geod_ref.loc[geod_ref['rgiid'] == rgi_id]
geod_ref = geod_ref.loc[geod_ref['period'] == "2000-01-01_2010-01-01"]
#geod_ref = geod_ref[['dmdtda', 'err_dmdtda']]
print(geod_ref)

In [None]:
print(tsla_true_obs['SC_stdev'].max())
print(albobs.sigma_albedo.max())

In [16]:
## Need to run the cirrus scripts before we can do this!

In [None]:
season_lookup = {
    12: "winter", 1: "winter", 2: "winter",
    3: "winter", 4: "winter", 5: "summer",
    6: "summer", 7: "summer", 8: "summer",
    9: "summer", 10: "winter", 11: "winter"
}

months = albobs["time"].dt.month
season_str = xr.DataArray([season_lookup[m.item()] for m in months], coords={"time": albobs["time"]}, dims="time")
albobs = albobs.assign_coords(season=season_str)
albobs

In [None]:
if 'win' in sys.platform:
    nc_files = sorted(pathlib.Path("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Output/albedo_files/MCMC/").glob("*.nc"))
else:
    nc_files = sorted(pathlib.Path("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Output/albedo_files/MCMC/").glob("*.nc"))

# Prepare containers
filenames = []
mb_annual_means = []
ppc_albdata = []
alb_data = []
albfull_data = []
#alb_cilower = []
#alb_ciupper = []

mb_annual_means = []
ppc_mb_means = []
mb_cumul_means = []

np.random.seed(77)

for fp in nc_files:
    print(fp)
    ds = xr.open_dataset(fp).sel(time=slice("2000-01-01","2010-01-01")) # last timestamp is 2009-12-31T23:00:00
    albfull_data.append(ds['ALBEDO_weighted'].data)
    fname = str(fp).split('MCMC-ensemble')[-1]
    # get num for lookup of sys err #
    id = int(fname.split('.nc')[0].split('num')[-1]) - 3
    sys_err_alb = df.loc[id, 'sigma_alb_summer']
    filenames.append(str(fname))

    # MASS BALANCE: mean over full time
    mb_mean = ds['MB_weighted'].groupby('time.year').sum('time').mean().item()
    daily_mb = ds['MB_weighted'].resample(time="1D").sum().data
    mb_annual_means.append(mb_mean)
    
    # Here we only use the observational uncertainty of the mean annual value
    y_pred_mean_annual_mb_i = np.random.normal(loc=mb_mean, scale=geod_ref['err_dmdtda'].item())
    ppc_mb_means.append(y_pred_mean_annual_mb_i)

    mb_cumul_means.append(np.cumsum(daily_mb))
    
    # ALBEDO: filter time period & summarize
    alb_filtered = ds['ALBEDO_weighted'].sel(time=albobs.time)
    #alb_ci975 = np.percentile(alb_filtered, 97.5)
    #alb_ci025 = np.percentile(alb_filtered, 2.5)
    mu_i = alb_filtered.data
    alb_data.append(mu_i)

    sigma_obs = albobs['sigma_albedo'].values
    is_summer = (albobs['season'].values == "summer")

    # Build a vector of systematic error sigmas aligned with the timeseries
    sigma_sys_vector_i = np.where(is_summer, sys_err_alb, 0)
    
    # Combine errors by adding variances, then taking the square root
    sigma_total_i = np.sqrt(sigma_obs**2 + sigma_sys_vector_i**2)
    
    # Draw one plausible "reality" from the predictive distribution
    y_pred_i = np.random.normal(loc=mu_i, scale=sigma_total_i)
    ppc_albdata.append(y_pred_i)
    #alb_cilower.append(alb_ci025.data)
    #alb_ciupper.append(alb_ci975.data)
    
# --- Post-Processing: Convert lists to NumPy arrays ---
# This is your parametric uncertainty
model_runs_arr = np.array(alb_data) 
# This is your total predictive uncertainty
simulated_predictions_arr = np.array(ppc_albdata)

print(f"Shape of model runs (mu) array: {model_runs_arr.shape}")
print(f"Shape of predictive ensemble (Y_pred) array: {simulated_predictions_arr.shape}")

# For VALIDATION (Posterior Predictive Check)
print("\nCalculating statistics for validation plot...")
# Central line is the median of the predictive ensemble
ppc_median = np.median(simulated_predictions_arr, axis=0)
# Uncertainty band is the 95% interval of the predictive ensemble
ppc_lower, ppc_upper = np.percentile(simulated_predictions_arr, [2.5, 97.5], axis=0)

mb_model_runs_mean_annual_arr = np.array(mb_annual_means)
mb_simulated_predictions_mean_annual_arr = np.array(ppc_mb_means)

mb_model_runs_cumulative_arr = np.array(mb_cumul_means)


In [None]:
## print modes and sd
#mb_simulated_predictions_mean_annual_arr
print(geod_ref['dmdtda'], geod_ref['err_dmdtda'])
print(np.mean(mb_model_runs_mean_annual_arr), np.std(mb_model_runs_mean_annual_arr))
# cohens d (how large and meaningufl is model bias)

In [None]:
## Do some stats testing on the MB distributions. Inspired by ChatGPT

observed_values = np.random.normal(loc=geod_ref['dmdtda'].item(), scale=geod_ref['err_dmdtda'].item(), size=300)

model_mean = np.mean(mb_model_runs_mean_annual_arr)
model_std = np.std(mb_model_runs_mean_annual_arr, ddof=1) # ddof=1 for sample std dev
model_skew = stats.skew(mb_model_runs_mean_annual_arr)
model_kurtosis = stats.kurtosis(mb_model_runs_mean_annual_arr) # 

# Calculate stats for the (synthetic) observed distribution
synth_obs_mean = np.mean(observed_values)
synth_obs_std = np.std(observed_values, ddof=1)
synth_obs_skew = stats.skew(observed_values)
synth_obs_kurtosis = stats.kurtosis(observed_values)

# Print the results in a formatted table
print(f"{'Metric':<20} {'Modelled':<15} {'Observed (Synthetic)':<20}")
print("-" * 55)
print(f"{'Mean':<20} {model_mean:<15.3f} {synth_obs_mean:<20.3f}")
print(f"{'Standard Deviation':<20} {model_std:<15.3f} {synth_obs_std:<20.3f}")
print(f"{'Skewness':<20} {model_skew:<15.3f} {synth_obs_skew:<20.3f}")
print(f"{'Kurtosis (Excess)':<20} {model_kurtosis:<15.3f} {synth_obs_kurtosis:<20.3f}\n")
print("Interpretation:")
print("- Skewness close to 0 indicates symmetry.")
print("- Excess Kurtosis close to 0 indicates a peak/tail profile similar to a normal distribution.\n")

def calculate_cohens_d(group1, group2):
    """Calculates Cohen's d for independent samples."""
    # Number of samples
    n1, n2 = len(group1), len(group2)
    # Means
    mean1, mean2 = np.mean(group1), np.mean(group2)
    # Sample standard deviations
    s1, s2 = np.std(group1, ddof=1), np.std(group2, ddof=1)
    
    # Calculate the pooled standard deviation
    pooled_std = np.sqrt(((n1 - 1) * s1**2 + (n2 - 1) * s2**2) / (n1 + n2 - 2))
    
    # Calculate Cohen's d
    d = (mean1 - mean2) / pooled_std
    return d

print("--- Quantifying Bias with Cohen's d ---")
cohen_d_value = calculate_cohens_d(mb_model_runs_mean_annual_arr, observed_values)
print(f"Cohen's d: {cohen_d_value:.3f}")

# Interpretation of Cohen's d
if abs(cohen_d_value) >= 0.8:
    print("Interpretation: This is a LARGE effect size, indicating a substantial and meaningful difference between the two means.\n")
elif abs(cohen_d_value) >= 0.5:
    print("Interpretation: This is a MEDIUM effect size.\n")
else:
    print("Interpretation: This is a SMALL effect size.\n")

print("--- Formal Shape Comparison (K-S Test) ---")
# Standardize both datasets to have a mean of 0 and std of 1.
# This removes the effect of location and scale, leaving only the shape.
modelled_standardized = (mb_model_runs_mean_annual_arr - model_mean) / model_std
observed_standardized = (observed_values - synth_obs_mean) / synth_obs_std

# Perform the two-sample Kolmogorov-Smirnov test
ks_statistic, p_value = stats.ks_2samp(modelled_standardized, observed_standardized)

print(f"K-S Statistic: {ks_statistic:.3f}")
print(f"P-value: {p_value:.3f}")

print("Interpretation:")
print("The K-S test checks if the two (standardized) samples come from the same distribution.")
if p_value < 0.05:
    print("Result: The p-value is less than 0.05. We REJECT the null hypothesis.")
    print("This suggests the underlying shapes of the distributions are statistically different.")
else:
    print("Result: The p-value is greater than 0.05. We FAIL to reject the null hypothesis.")
    print("This provides evidence that the underlying shapes are NOT statistically different, supporting the visual observation that 'the shapes match'.")


In [None]:
## Prepare data into a dataframe
mb_df = pd.DataFrame({
    'filename': filenames,
    'value': mb_annual_means
})
mb_df.value.hist()


In [None]:
time_rng = pd.date_range("2000-01-01", "2009-12-31", freq='D')
cumulative_mb_ensemble = xr.DataArray(
    data=mb_model_runs_cumulative_arr,
    dims=("simulation", "time"),
    coords={
        "simulation": np.arange(300),
        "time": time_rng
    },
    name="cumulative_mass_balance"
)

cummb_median = cumulative_mb_ensemble.median(dim='simulation')
cummb_cilower = cumulative_mb_ensemble.quantile(0.025, dim="simulation")
cummb_ciupper = cumulative_mb_ensemble.quantile(0.975, dim="simulation")

cumulative_mb_ensemble

In [50]:
tsla_true_obs['month'] = tsla_true_obs.index.month
tsla_true_obs['season'] = np.where(tsla_true_obs['month'].isin([5,6,7,8,9]), 'summer', 'winter')

In [None]:
## Load snowlines
# Get all filenames
if 'win' in sys.platform:
    csv_files = sorted(pathlib.Path("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Output/snowlines/bestfiles").glob("*.csv"))  # sort to keep order consistent
else:
    csv_files = sorted(pathlib.Path("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Output/snowlines/bestfiles").glob("*.csv"))  # sort to keep order consistent

# Make sure filenames are clean column names
#filenames = [fp.name for fp in csv_files]

# Load all CSVs into a list of DataFrames, extracting the same column (e.g., 'snowline')
# Assume 'time' column is the same in all files and you want 'snowline' values
# Prepare containers
snowline_model_runs = []
snowline_simulated_predictions = []
filenames_list = []

for i, fp in enumerate(csv_files):
    tslsub = pd.read_csv(fp, parse_dates=True, index_col='time')  # assumes time is the index
    id = int(str(fp.stem).split('num')[-1]) - 3 
    sys_err_offset = df.loc[id, 'sigma_tsl_summer']
    tslsub = tslsub.loc[tslsub.index.isin(tsla_true_obs.index)]
    
    # Store the original data Series
    mu_tsl_i = tslsub['Med_TSL'].values
    snowline_model_runs.append(mu_tsl_i)
    
    sigma_obs_tsl = tsla_true_obs['SC_stdev'].values
    is_summer_tsl = (tsla_true_obs['season'].values == "summer")
    
    sigma_sys_tsl_vector_i = np.where(is_summer_tsl, sys_err_offset, 0)
    
    # Combine errors by adding variances, then taking the square root
    sigma_total_tsl_i = np.sqrt(sigma_obs_tsl**2 + sigma_sys_tsl_vector_i**2)
    y_pred_tsl_i = np.random.normal(loc=mu_tsl_i, scale=sigma_total_tsl_i)
    snowline_simulated_predictions.append(y_pred_tsl_i)
    filenames_list.append(fp.name)

snowline_model_runs_arr = np.array(snowline_model_runs) 
# This is your total predictive uncertainty
snowline_simulated_predictions_arr = np.array(snowline_simulated_predictions)

print(f"Shape of snowline model runs (mu) array: {snowline_model_runs_arr.shape}")
print(f"Shape of snowline predictive ensemble (Y_pred) array: {snowline_simulated_predictions_arr.shape}")

# For VALIDATION (Posterior Predictive Check)
print("\nCalculating statistics for snowline validation plot...")
# Central line is the median of the predictive ensemble
ppc_tsl_median = np.median(snowline_simulated_predictions_arr, axis=0)
# Uncertainty band is the 95% interval of the predictive ensemble
ppc_tsl_lower, ppc_tsl_upper = np.percentile(snowline_simulated_predictions_arr, [2.5, 97.5], axis=0)

In [None]:
## Load WGMS for cumulative MB
if 'win' in sys.platform:
    wgms_path = "E:/OneDrive/PhD/PhD/Data/DOI-WGMS-FoG-2022-09/data/"
else:
    wgms_path = "/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/DOI-WGMS-FoG-2022-09/data/"

wgms = pd.read_csv(wgms_path+"mass_balance.csv")
wgms = wgms.loc[(wgms['NAME'] == "HINTEREIS F.") & (wgms['YEAR'] > 2001) & (wgms['YEAR'] <= 2009)]
print(wgms.NAME.iloc[0], np.unique(wgms.WGMS_ID))
wgms.drop(['POLITICAL_UNIT', 'NAME','REMARKS'], axis=1, inplace=True)
wgms = wgms.loc[wgms['LOWER_BOUND'] == 9999]
wgms

In [None]:
## Prepare data
# Define WGMS hydrological year-end as Sept 30
wgms['hydro_date'] = pd.to_datetime(wgms['YEAR'].astype(str)) + pd.DateOffset(months=8, days=30)
wgms['CUM_BALANCE'] = wgms['ANNUAL_BALANCE'].cumsum()
wgms['CUM_BALANCE'] = wgms['CUM_BALANCE'] / 1000
wgms['ANNUAL_BALANCE_UNC'] = wgms['ANNUAL_BALANCE_UNC'] / 1000

klug_etal_geod = np.array([-0.685,-2.713,-0.654,-1.028,-2.091,-1.363,-1.252,-1.209])
klug_etal_unc = np.array([0.062, 0.183, 0.063, 0.056, 0.1, 0.041, 0.046, 0.06])
wgms['klug_mb'] = klug_etal_geod
wgms['klug_unc'] = klug_etal_unc
wgms['CUM_KLUG'] = wgms['klug_mb'].cumsum()
wgms


In [None]:
"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10), dpi=300)

# 1. Histogram of Mass Balance
#axes[0, 0].hist(mb_vals, bins=20, color='skyblue', edgecolor='black')
sns.kdeplot(mb_vals, fill=True, color="skyblue", ax=axes[0,0])
#
mu_obs = geod_ref['dmdtda'].item()  # mean observed MB in m w.e.
sigma_obs = geod_ref['err_dmdtda'].item()

# Plot PDF
x = np.linspace(mu_obs - 4*sigma_obs, mu_obs + 4*sigma_obs, 200)
pdf = norm.pdf(x, loc=mu_obs, scale=sigma_obs)
axes[0, 0].plot(x, pdf, linestyle="dashed", color="black")

#axes[0, 0].axvline(x=geod_ref['dmdtda'].item(), color="black")
#axes[0, 0].axvline(x=(geod_ref['dmdtda'] - geod_ref['err_dmdtda']).item(), linestyle="dashed", color="black")
#axes[0, 0].axvline(x=(geod_ref['dmdtda'] + geod_ref['err_dmdtda']).item(), linestyle="dashed", color="black")
axes[0, 0].set_xlabel('MB (m w.e. a$^{-1}$)')
axes[0, 0].set_ylabel('Frequency')

# 2. Albedo scatter
axes[0, 1].errorbar(median_alb, albobs.median_albedo.values,
                    xerr=alb_err, yerr=albobs.sigma_albedo.values,
                    fmt='o', alpha=0.6, capsize=3, color='orange')

# Metrics
r2 = r2_score(albobs.median_albedo.values, median_alb)
rmse = root_mean_squared_error(albobs.median_albedo.values, median_alb)
mae = mean_absolute_error(albobs.median_albedo.values, median_alb)

axes[0, 1].text(0.05, 0.95, f"R² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}",
                transform=axes[0, 1].transAxes,
                verticalalignment='top',
                fontsize=16,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

axes[0, 1].axline((0, 0), slope=1, linestyle='--', color='gray')
axes[0, 1].set_xlabel('Mod. Albedo (-)')
axes[0, 1].set_ylabel('Obs. Albedo (-)')
axes[0, 1].set_xlim(0, 1)
axes[0, 1].set_ylim(0, 1)
axes[0, 1].grid()

# 3. Snowline scatter
axes[0, 2].errorbar(median_tsla, tsla_true_obs['TSL_normalized'].values,
                    xerr=tsla_err, yerr=tsla_true_obs['SC_stdev'].values,
                    fmt='o', alpha=0.6, capsize=3, color='green')

# Plot the highlighted points
#axes[0, 2].errorbar(highlighted_model_tsla, highlighted_obs_tsla,
#                    xerr=highlighted_model_sigma, yerr=highlighted_obs_sigma,
#                    fmt='o', alpha=0.6, capsize=3, color='green')
                

r2 = r2_score(tsla_true_obs['TSL_normalized'].values, median_tsla)
rmse = root_mean_squared_error(tsla_true_obs['TSL_normalized'].values, median_tsla)
mae = mean_absolute_error(tsla_true_obs['TSL_normalized'].values, median_tsla)

axes[0, 2].text(0.05, 0.95, f"R² = {r2:.2f}\nRMSE = {rmse:.2f}\nMAE = {mae:.2f}",
                transform=axes[0, 2].transAxes,
                verticalalignment='top',
                fontsize=16,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

axes[0, 2].axline((0, 0), slope=1, linestyle='--', color='gray')
axes[0, 2].set_xlabel('Mod. TSLA (-)')
axes[0, 2].set_ylabel('Obs. TSLA (-)')
axes[0, 2].set_xlim(0, 1)
axes[0, 2].set_ylim(0, 1)
axes[0, 2].grid()

# 4. Cumulative MB plot (find MB observations) #taken it times 0.77 works wonders
axes[1, 0].plot(cummb_median.time, cummb_median.values, label='Modelled', color='blue')
axes[1, 0].fill_between(cummb_median.time,
                 cummb_cilower.values,
                 cummb_ciupper.values,
                 color='blue',
                 alpha=0.3,
                 label='95% CI')
axes[1, 0].errorbar(wgms['hydro_date'], wgms['CUM_BALANCE'], yerr=wgms['ANNUAL_BALANCE_UNC'], fmt='o', color='black', label='WGMS')

# Add geod. MB (roughly)
geod_mb_rate = geod_ref['dmdtda'].item()
geod_uncertainty = geod_ref['err_dmdtda'].item()  # ±m w.e./yr
geod_start = pd.Timestamp("2000-01-01")
geod_end = pd.Timestamp("2010-01-01")
# Compute duration in years
years = (geod_end - geod_start).days / 365.25
geod_cum = geod_mb_rate * years
geod_cum_uncert = geod_uncertainty * years

# Plot as a line between start and end
axes[1, 0].plot([geod_start, geod_end], [0, geod_cum], color='black', linestyle='--', label='Obs.')

# Optional: Add uncertainty ribbon
axes[1, 0].fill_between([geod_start, geod_end],
                [0 - geod_cum_uncert, geod_cum - geod_cum_uncert],
                [0 + geod_cum_uncert, geod_cum + geod_cum_uncert],
                color='gray', alpha=0.3)


axes[1, 0].set_xlabel('Time')
axes[1, 0].set_ylabel('Cum. MB (m.w.e.)')
axes[1, 0].legend()

axes[1, 0].set_xlim(pd.to_datetime("2000-01-01"),pd.to_datetime("2010-01-01"))
axes[1, 0].xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1)))
axes[1, 0].xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{mdates.num2date(x).year}"))
axes[1, 0].xaxis.set_minor_locator(mdates.MonthLocator(bymonth=(7)))
axes[1, 0].tick_params(axis='x', rotation=30)
axes[1, 0].grid()


# 5. Albedo time series
axes[1, 1].errorbar(albobs.time.values, albobs.median_albedo.values,
                    yerr=albobs.sigma_albedo.values,
                    fmt='o', label='Observed', alpha=0.6)
axes[1, 1].plot(albobs.time.values, median_alb, label='Modelled', color='orange')
axes[1, 1].fill_between(albobs.time.values, alb_cilower, alb_ciupper, color='orange', alpha=0.4, label='Confidence Interval')
axes[1, 1].set_xlabel('Time')
axes[1, 1].set_ylabel('Albedo (-)')
#axes[1, 1].legend()
axes[1, 1].set_ylim(0, 1)

axes[1, 1].set_xlim(pd.to_datetime("2000-01-01"),pd.to_datetime("2010-01-01"))
axes[1, 1].xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1)))
axes[1, 1].xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{mdates.num2date(x).year}"))
axes[1, 1].xaxis.set_minor_locator(mdates.MonthLocator(bymonth=(7)))
axes[1, 1].tick_params(axis='x', rotation=30)
axes[1, 1].grid()

# 6. Snowline time series
axes[1, 2].errorbar(tsla_true_obs.index, tsla_true_obs['TSL_normalized'],
                    yerr=tsla_true_obs['SC_stdev'],
                    fmt='o', color="red", label="Observed", alpha=0.6)
#axes[1, 2].errorbar(filtered_tsl.index, filtered_tsl['TSL_normalized'],
#                    yerr=filtered_tsl['SC_stdev'],
#                    fmt='o', label='Observed')
axes[1, 2].plot(tsla_true_obs.index, median_tsla, label='Modelled', color='green')
axes[1, 2].fill_between(tsla_true_obs.index, tsla_cilower, tsla_ciupper, color='green', alpha=0.4, label='Confidence Interval')
axes[1, 2].set_xlabel('Time')
axes[1, 2].set_ylabel('TSLA (-)')
#axes[1, 2].legend()
axes[1, 2].set_ylim(0, 1)

axes[1, 2].set_xlim(pd.to_datetime("2000-01-01"),pd.to_datetime("2010-01-01"))
axes[1, 2].xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1)))
axes[1, 2].xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f"{mdates.num2date(x).year}"))
axes[1, 2].xaxis.set_minor_locator(mdates.MonthLocator(bymonth=(7)))
axes[1, 2].tick_params(axis='x', rotation=30)
axes[1, 2].grid()


fig.tight_layout()
if 'win' in sys.platform:
    plt.savefig("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_ensemble_eval.png", bbox_inches="tight")
else:
    plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_ensemble_eval.png", bbox_inches="tight")
"""

In [None]:
albobs['sigma_albedo'].max()

In [None]:
plt.rcParams.update({'font.size': 22})
plt.rcParams['axes.axisbelow'] = True

ppc_median = np.median(simulated_predictions_arr, axis=0)

fig, axes = plt.subplots(2, 3, figsize=(20, 12), dpi=300)

# PPC for Mean Annual Mass Balance 
# We compare the distribution of simulated mean values to the observed mean.
sns.kdeplot(mb_simulated_predictions_mean_annual_arr, fill=True, color="#D81B1B", ax=axes[0,0], label="Posterior Predictive")
mu_obs = geod_ref['dmdtda']
sigma_obs = geod_ref['err_dmdtda']
x_obs = np.linspace(mu_obs - 4*sigma_obs, mu_obs + 4*sigma_obs, 200)
pdf_obs = norm.pdf(x_obs, loc=mu_obs, scale=sigma_obs)
axes[0, 0].plot(x_obs, pdf_obs, color="black", linestyle="--", label="Observed")
axes[0, 0].set_xlabel(r'$B_{geod}$'+ ' (m w.e. a$^{-1}$)')
axes[0, 0].set_ylabel('Density')
axes[0, 0].legend()


# Albedo scatter of PPCs
x_err_alb = [ppc_median - ppc_lower, ppc_upper - ppc_median]
axes[0, 1].errorbar(ppc_median, albobs['median_albedo'].values,
                    xerr=x_err_alb, yerr=albobs['sigma_albedo'].values,
                    fmt='o', alpha=0.6, capsize=3, color='#1E80E5', ecolor='gray')

# Metrics calculated against the median of the predictive distribution.
r2 = r2_score(albobs['median_albedo'].values, ppc_median)
rmse = root_mean_squared_error(albobs['median_albedo'].values, ppc_median)
mae = mean_absolute_error(albobs['median_albedo'].values, ppc_median)

axes[0, 1].text(0.05, 0.95, f"R²={r2:.2f}\nRMSE={rmse:.2f}\nMAE={mae:.2f}",
                transform=axes[0, 1].transAxes, verticalalignment='top', fontsize=16,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

axes[0, 1].axline((0, 0), slope=1, linestyle='--', color='black')
axes[0, 1].set_xlabel(r'Modelled $\bar{\alpha}$ (-)')
axes[0, 1].set_ylabel(r'Observed $\bar{\alpha}$ (-)')
axes[0, 1].set_xlim(0, 1)
axes[0, 1].set_ylim(0, 1)
axes[0, 1].set_xticks(np.arange(0,1+0.2,0.2))
axes[0, 1].set_yticks(np.arange(0,1+0.2,0.2))
axes[0, 1].grid(True, zorder=-1)


# Snowline scatter of PPCs
x_err_tsl = [ppc_tsl_median - ppc_tsl_lower, ppc_tsl_upper - ppc_tsl_median]
axes[0, 2].errorbar(ppc_tsl_median, tsla_true_obs['TSL_normalized'].values,
                    xerr=x_err_tsl, yerr=tsla_true_obs['SC_stdev'].values,
                    fmt='o', alpha=0.6, capsize=3, color='#A5781B', ecolor='gray')

# Metrics calculated against the median of the predictive distribution.
r2 = r2_score(tsla_true_obs['TSL_normalized'].values, ppc_tsl_median)
rmse = root_mean_squared_error(tsla_true_obs['TSL_normalized'].values, ppc_tsl_median)
mae = mean_absolute_error(tsla_true_obs['TSL_normalized'].values, ppc_tsl_median)

axes[0, 2].text(0.05, 0.95, f"R²={r2:.2f}\nRMSE={rmse:.2f}\nMAE={mae:.2f}",
                transform=axes[0, 2].transAxes, verticalalignment='top', fontsize=16,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

axes[0, 2].set_xlim(-0.2, 1)
axes[0, 2].set_ylim(-0.2, 1)
axes[0, 2].axline((-0.2, -0.2), slope=1, linestyle='--', color='black')
axes[0, 2].set_xticks(np.arange(-0.2,1+0.2,0.2))
axes[0, 2].set_yticks(np.arange(-0.2,1+0.2,0.2))
axes[0, 2].set_xlabel('Modelled Norm. SLA (-)')
axes[0, 2].set_ylabel('Observed Norm. SLA (-)')
axes[0, 2].grid(True, zorder=-1)


# Cum. MB test
cummb_median = np.median(mb_model_runs_cumulative_arr, axis=0)
cummb_lower, cummb_upper = np.percentile(mb_model_runs_cumulative_arr, [2.5, 97.5], axis=0)
axes[1, 0].errorbar(wgms['hydro_date'], wgms['CUM_BALANCE'], yerr=wgms['ANNUAL_BALANCE_UNC'], fmt='o', color='darkgreen', label='WGMS')
axes[1, 0].errorbar(wgms['hydro_date'], wgms['CUM_KLUG'], yerr=wgms['klug_unc'], fmt='o', color='steelblue', label='Klug et al., 2018')
axes[1, 0].plot(time_rng, cummb_median, label='Posterior Ens. Median', color='#D81B1B')
axes[1, 0].fill_between(time_rng, cummb_lower, cummb_upper, color='#D81B1B',
                 alpha=0.3, label='95% CI')
geod_mb_rate = geod_ref['dmdtda'].item()
geod_uncertainty = geod_ref['err_dmdtda'].item()  # ±m w.e./yr
geod_start = pd.Timestamp("2000-01-01")
geod_end = pd.Timestamp("2010-01-01")
# Compute duration in years
years = (geod_end - geod_start).days / 365.25
geod_cum = geod_mb_rate * years
geod_cum_uncert = geod_uncertainty * years

# Plot as a line between start and end
axes[1, 0].plot([geod_start, geod_end], [0, geod_cum], color='black', linestyle='--', label='Obs.')

axes[1, 0].fill_between([geod_start, geod_end],
                [0 - geod_cum_uncert, geod_cum - geod_cum_uncert],
                [0 + geod_cum_uncert, geod_cum + geod_cum_uncert],
                color='gray', alpha=0.3)
axes[1, 0].set_xlabel('Time')
axes[1, 0].set_ylabel('Cumulative MB (m w.e.)')
axes[1, 0].legend(prop = { "size": 18 })
axes[1, 0].grid(True, zorder=-1)


# Albedo time series
axes[1, 1].errorbar(pd.to_datetime(albobs.time), albobs['median_albedo'].values,
                    yerr=albobs['sigma_albedo'].values,
                    fmt='o', ms=5, label='Observed', alpha=0.6, color='black', ecolor='gray')
# Plot the full predictive distribution
axes[1, 1].plot(pd.to_datetime(albobs.time), ppc_median, label='Modelled Median', color='#1E80E5', marker='o', ms=4)
axes[1, 1].fill_between(pd.to_datetime(albobs.time), ppc_lower, ppc_upper, color='#1E80E5',
                 alpha=0.4, label='95% Prediction Interval')
axes[1, 1].set_xlabel('Time')
axes[1, 1].set_ylabel(r'$\bar{\alpha}$ (-)')
axes[1, 1].set_ylim(0, 1)
axes[1, 1].grid(True, zorder=-1)


# Snowline time series
axes[1, 2].errorbar(tsla_true_obs.index, tsla_true_obs['TSL_normalized'].values,
                    yerr=tsla_true_obs['SC_stdev'].values,
                    fmt='o', ms=5, color="black", label="Observed", alpha=0.6, ecolor='gray')
# Plot the full predictive distribution
axes[1, 2].plot(tsla_true_obs.index, ppc_tsl_median, label='Modelled Median', color='#A5781B', marker='o', ms=4)
axes[1, 2].fill_between(tsla_true_obs.index, np.maximum(-0.2, ppc_tsl_lower), ppc_tsl_upper, color='#A5781B',
                 alpha=0.4, label='95% Prediction Interval')
axes[1, 2].set_xlabel('Time')
axes[1, 2].set_ylim(-0.2, 1.0)
axes[1, 2].set_yticks(np.arange(-0.2, 1+0.2, 0.2))
axes[1, 2].set_ylabel('Norm. SLA (-)')
axes[1, 2].grid(True, zorder=-1)

for ax in [axes[1, 0], axes[1, 1], axes[1, 2]]:
    ax.set_xlim(pd.to_datetime("2000-01-01"), pd.to_datetime("2010-01-01"))
    ax.xaxis.set_major_locator(mdates.YearLocator(2))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.tick_params(axis='x', rotation=30)
    
fig.text(0.01, 0.95, 'a)', transform=fig.transFigure, fontsize=24)
fig.text(0.34, 0.95, 'b)', transform=fig.transFigure, fontsize=24)
fig.text(0.66, 0.95, 'c)', transform=fig.transFigure, fontsize=24)
#
fig.text(0.01, 0.5, 'd)', transform=fig.transFigure, fontsize=24)
fig.text(0.34, 0.5, 'e)', transform=fig.transFigure, fontsize=24)
fig.text(0.66, 0.5, 'f)', transform=fig.transFigure, fontsize=24)

fig.tight_layout()

if 'win' in sys.platform:
    plt.savefig("E:/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_ensemble_eval.pdf", bbox_inches="tight")
else:
    plt.savefig("/mnt/C4AEBBABAEBB9500/OneDrive/PhD/PhD/Data/Hintereisferner/Figures/mcmc_ensemble_eval.pdf", bbox_inches="tight")
"""
"""


In [None]:
## Interpolated lines for models - need to show full time series