# Set up

In [None]:
import torch
import torch.nn as nn
import numpy as np
import xarray as xr
from torch.utils.data import TensorDataset, DataLoader
from train import apply_inverse_zca_whitening_4d_torch
from unet import UNet  
from utils import *

setup_random_seeds(42)
device = get_device()

base_path = "gs://leap-persistent/YueWang/SSH/data"
storage_opts = {"token": "cloud", "asynchronous": False}

train = open_zarr(f"{base_path}/train_80_sst.zarr", storage_opts)
test = open_zarr(f"{base_path}/test_80_sst.zarr", storage_opts)
zca = open_zarr(f"{base_path}/zca_80.zarr", storage_opts)

Vt = torch.from_numpy(zca.ubm_Vt.values).float().to(device)
scale = torch.from_numpy(zca.ubm_scale.values).float().to(device)
mean = torch.from_numpy(zca.ubm_mean.values).float().to(device)

# Model 1: ZCA NLL Loss with SSH+SST input

# Prepare training data for normalization statistics (SSH+SST)
x_train_ssh = torch.from_numpy(train.ssh.values).float().unsqueeze(1).to(device)
x_train_sst = torch.from_numpy(train.sst.values).float().unsqueeze(1).to(device)
x_train = torch.cat([x_train_ssh, x_train_sst], dim=1)
x_train_normalized, min_vals_sst, max_vals_sst = min_max_normalize(x_train)

# Prepare test data (SSH+SST) 
x_test_ssh_original = torch.from_numpy(test.ssh.values).float().unsqueeze(1).to(device)
x_test_sst_original = torch.from_numpy(test.sst.values).float().unsqueeze(1).to(device)
x_test_original = torch.cat([x_test_ssh_original, x_test_sst_original], dim=1)

# Normalize test data for model input
x_test_normalized, _, _ = min_max_normalize(x_test_original, min_vals_sst, max_vals_sst)

# Prepare test targets
y_test_physical = torch.from_numpy(test.ubm.values).float().unsqueeze(1).to(device)
y_test_zca = torch.from_numpy(test.zca_ubm.values).float().unsqueeze(1).to(device)
y_test = torch.cat([y_test_physical, y_test_zca], dim=1)

# Create test dataset and loader
test_dataset_sst = TensorDataset(x_test_normalized, y_test)
test_loader_sst = DataLoader(test_dataset_sst, batch_size=32, shuffle=False)

# Load model
model_sst_ssh = UNet(in_channels=2, out_channels=2, initial_features=32, depth=4)
model_sst_ssh.to(device)

checkpoint = torch.load('/home/jovyan/GRL_ssh/checkpoints/sst_ssh.pth', map_location=device)
model_sst_ssh.load_state_dict(checkpoint['model_state_dict'])

# Evaluate model
model_sst_ssh.eval()
results_sst_ssh = {
    'ssh': [], 'sst': [], 'ubm_true': [], 'bm_true': [],
    'ubm_pred_mu': [], 'bm_pred_mu': [],
    'ubm_pred_ensembles': [], 'bm_pred_ensembles': []
}

sample_indices = []

with torch.no_grad():
    for i, (batch_x, batch_y) in enumerate(test_loader_sst):
        
        batch_start = i * test_loader_sst.batch_size
        batch_end = min(batch_start + test_loader_sst.batch_size, len(test_dataset_sst))
        current_batch_indices = list(range(batch_start, batch_end))
        sample_indices.extend(current_batch_indices)
        
        batch_x = batch_x.to(device)
        batch_y_physical = batch_y[:, 0:1, ...].to(device)

        outputs = model_sst_ssh(batch_x)
        
        ssh_batch_original = x_test_ssh_original[current_batch_indices]
        sst_batch_original = x_test_sst_original[current_batch_indices]
        
        ubm_true = batch_y_physical
        bm_true = ssh_batch_original - ubm_true
        
        # Predicted mean in physical space
        mu_zca = outputs[:, 0, ...]
        log_sigma_zca = outputs[:, 1, ...]
        mu_zca_expanded = mu_zca.unsqueeze(1)
        ubm_pred_mu = apply_inverse_zca_whitening_4d_torch(mu_zca_expanded, Vt, scale, mean)
        
        bm_pred_mu = ssh_batch_original - ubm_pred_mu
        
        # Generate ensemble samples
        zca_samples = generate_gaussian_samples(mu_zca, log_sigma_zca, n_samples=30)
        B, n_samples, H, W = zca_samples.shape
        zca_samples_flat = zca_samples.reshape(B * n_samples, 1, H, W)
        ubm_samples_flat = apply_inverse_zca_whitening_4d_torch(zca_samples_flat, Vt, scale, mean)
        ubm_samples = ubm_samples_flat.reshape(B, n_samples, 1, H, W)
        
        # Use original SSH for ensemble BM calculation
        ssh_expanded = ssh_batch_original.unsqueeze(1).expand(-1, n_samples, -1, -1, -1)
        bm_samples = ssh_expanded - ubm_samples
        
        # Store results 
        results_sst_ssh['ssh'].append(ssh_batch_original.cpu().numpy())
        results_sst_ssh['sst'].append(sst_batch_original.cpu().numpy())
        results_sst_ssh['ubm_true'].append(ubm_true.cpu().numpy())
        results_sst_ssh['bm_true'].append(bm_true.cpu().numpy())
        results_sst_ssh['ubm_pred_mu'].append(ubm_pred_mu.cpu().numpy())
        results_sst_ssh['bm_pred_mu'].append(bm_pred_mu.cpu().numpy())
        results_sst_ssh['ubm_pred_ensembles'].append(ubm_samples.cpu().numpy())
        results_sst_ssh['bm_pred_ensembles'].append(bm_samples.cpu().numpy())

for key in results_sst_ssh:
    results_sst_ssh[key] = np.concatenate(results_sst_ssh[key], axis=0)

print("Model 1 evaluation complete!")

# Model 2: SSH input only

# Prepare training data for normalization statistics (SSH only)
x_train_ssh_only = torch.from_numpy(train.ssh.values).float().unsqueeze(1).to(device)
x_train_normalized_ssh, min_vals_ssh, max_vals_ssh = min_max_normalize(x_train_ssh_only)

# Prepare test data (SSH only)
x_test_ssh_only_original = torch.from_numpy(test.ssh.values).float().unsqueeze(1).to(device)
x_test_normalized_ssh, _, _ = min_max_normalize(x_test_ssh_only_original, min_vals_ssh, max_vals_ssh)

# Create test dataset and loader
test_dataset_ssh = TensorDataset(x_test_normalized_ssh, y_test)
test_loader_ssh = DataLoader(test_dataset_ssh, batch_size=32, shuffle=False)

# Load model
model_ssh_only = UNet(in_channels=1, out_channels=2, initial_features=32, depth=4)
model_ssh_only.to(device)

checkpoint = torch.load('/home/jovyan/GRL_ssh/checkpoints/ssh_input_only.pth', map_location=device)
model_ssh_only.load_state_dict(checkpoint['model_state_dict'])

# Evaluate model
model_ssh_only.eval()
results_ssh_only = {
    'ssh': [], 'ubm_true': [], 'bm_true': [],
    'ubm_pred_mu': [], 'bm_pred_mu': [],
    'ubm_pred_ensembles': [], 'bm_pred_ensembles': []
}

sample_indices_ssh = []

with torch.no_grad():
    for i, (batch_x, batch_y) in enumerate(test_loader_ssh):
        
        batch_start = i * test_loader_ssh.batch_size
        batch_end = min(batch_start + test_loader_ssh.batch_size, len(test_dataset_ssh))
        current_batch_indices = list(range(batch_start, batch_end))
        sample_indices_ssh.extend(current_batch_indices)
            
        batch_x = batch_x.to(device)
        batch_y_physical = batch_y[:, 0:1, ...].to(device)
        
        outputs = model_ssh_only(batch_x)
        
        ssh_batch_original = x_test_ssh_only_original[current_batch_indices]
        
        ubm_true = batch_y_physical
        bm_true = ssh_batch_original - ubm_true
        
        # Predicted mean in physical space
        mu_zca = outputs[:, 0, ...]
        log_sigma_zca = outputs[:, 1, ...]
        mu_zca_expanded = mu_zca.unsqueeze(1)
        ubm_pred_mu = apply_inverse_zca_whitening_4d_torch(mu_zca_expanded, Vt, scale, mean)
        # BM prediction using original SSH scale
        bm_pred_mu = ssh_batch_original - ubm_pred_mu
        
        # Generate ensemble samples
        zca_samples = generate_gaussian_samples(mu_zca, log_sigma_zca, n_samples=30)
        B, n_samples, H, W = zca_samples.shape
        zca_samples_flat = zca_samples.reshape(B * n_samples, 1, H, W)
        ubm_samples_flat = apply_inverse_zca_whitening_4d_torch(zca_samples_flat, Vt, scale, mean)
        ubm_samples = ubm_samples_flat.reshape(B, n_samples, 1, H, W)
        
        # Use original SSH for ensemble BM calculation
        ssh_expanded = ssh_batch_original.unsqueeze(1).expand(-1, n_samples, -1, -1, -1)
        bm_samples = ssh_expanded - ubm_samples
        
        # Store results 
        results_ssh_only['ssh'].append(ssh_batch_original.cpu().numpy())
        results_ssh_only['ubm_true'].append(ubm_true.cpu().numpy())
        results_ssh_only['bm_true'].append(bm_true.cpu().numpy())
        results_ssh_only['ubm_pred_mu'].append(ubm_pred_mu.cpu().numpy())
        results_ssh_only['bm_pred_mu'].append(bm_pred_mu.cpu().numpy())
        results_ssh_only['ubm_pred_ensembles'].append(ubm_samples.cpu().numpy())
        results_ssh_only['bm_pred_ensembles'].append(bm_samples.cpu().numpy())

for key in results_ssh_only:
    results_ssh_only[key] = np.concatenate(results_ssh_only[key], axis=0)

print("Model 2 evaluation complete!")

# Model 3: MSE loss only
model_mse_only = UNet(in_channels=2, out_channels=2, initial_features=32, depth=4)
model_mse_only.to(device)

checkpoint = torch.load('/home/jovyan/GRL_ssh/checkpoints/mse_loss_only.pth', map_location=device)
model_mse_only.load_state_dict(checkpoint['model_state_dict'])

# Evaluate model
model_mse_only.eval()
results_mse_only = {
    'ssh': [], 'sst': [], 'ubm_true': [], 'bm_true': [],
    'ubm_pred_mu': [], 'bm_pred_mu': []
}

sample_indices_mse = []

with torch.no_grad():
    for i, (batch_x, batch_y) in enumerate(test_loader_sst):  
        
        batch_start = i * test_loader_sst.batch_size
        batch_end = min(batch_start + test_loader_sst.batch_size, len(test_dataset_sst))
        current_batch_indices = list(range(batch_start, batch_end))
        sample_indices_mse.extend(current_batch_indices)
            
        batch_x = batch_x.to(device)
        batch_y_physical = batch_y[:, 0:1, ...].to(device)
        
        outputs = model_mse_only(batch_x)
        
        # Use original scale SSH and SST for BM calculation
        ssh_batch_original = x_test_ssh_original[current_batch_indices]
        sst_batch_original = x_test_sst_original[current_batch_indices]
        
        ubm_true = batch_y_physical
        bm_true = ssh_batch_original - ubm_true
        
        # For MSE model, only use mean prediction (no sampling)
        mu_zca_expanded = outputs[:, 0:1, ...]  # Use first channel as mean
        ubm_pred_mu = apply_inverse_zca_whitening_4d_torch(mu_zca_expanded, Vt, scale, mean)
        # BM prediction using original SSH scale
        bm_pred_mu = ssh_batch_original - ubm_pred_mu
        
        # Store results
        results_mse_only['ssh'].append(ssh_batch_original.cpu().numpy())
        results_mse_only['sst'].append(sst_batch_original.cpu().numpy())
        results_mse_only['ubm_true'].append(ubm_true.cpu().numpy())
        results_mse_only['bm_true'].append(bm_true.cpu().numpy())
        results_mse_only['ubm_pred_mu'].append(ubm_pred_mu.cpu().numpy())
        results_mse_only['bm_pred_mu'].append(bm_pred_mu.cpu().numpy())

for key in results_mse_only:
    results_mse_only[key] = np.concatenate(results_mse_only[key], axis=0)

print("Model 3 evaluation complete!")

# Create xarray datasets and save results
models_results = [
    ('sst_ssh', results_sst_ssh, True, True),
    ('ssh_only', results_ssh_only, True, False), 
    ('mse_only', results_mse_only, False, True)
]

# Store all datasets
eval_datasets = {}

for model_name, results, has_ensembles, has_sst in models_results:
    print(f"Creating dataset for {model_name}...")
    
    eval_dataset = create_evaluation_dataset(results, model_name, has_ensembles, has_sst)
    
    # Store the dataset
    eval_datasets[model_name] = eval_dataset


Using device: cpu


  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  return cls(fs=fs, path=path, read_only=read_only, allowed_exceptions=allowed_exceptions)
  checkpoint = torch.load('/home/jovyan/GRL_ssh/checkpoints/sst_ssh.pth', map_location=device)


# Table

In [2]:
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score

# -------------------------------------------------
# constants & grid info
# -------------------------------------------------
g      = 9.81       # m s⁻²
dx     = 1_500.0    # m   (1.5 km grid)
dy     = 1_500.0    # m
f_cor  = -8.6e-5    # s⁻¹ (Agulhas region)

# -------------------------------------------------
# clean-sample mask (no NaNs in truth UBM)
# -------------------------------------------------
clean_mask = ~test.ubm.isnull().any(dim=("i", "j")).values
clean_idx  = np.where(clean_mask)[0]
print(f"Total samples={len(clean_mask)}, clean={len(clean_idx)}, skipped={len(clean_mask)-len(clean_idx)}")

# -------------------------------------------------
# helpers
# -------------------------------------------------
def flatten_clean(da):
    """Return (Nclean, P) NumPy array of flattened spatial dims."""
    return da.isel(sample=clean_idx).stack(pixels=("i", "j")).values

def r2_corr(truth, pred):
    """Per-sample R² & corr, with mean, 5th and 95th percentiles over clean samples."""
    r2_vals, c_vals = [], []
    for t, p in zip(truth, pred):
        m = np.isfinite(t) & np.isfinite(p)
        if m.sum() < 2:
            r2_vals.append(np.nan)
            c_vals.append(np.nan)
        else:
            r2_vals.append(r2_score(t[m], p[m]))
            c_vals.append(np.corrcoef(t[m], p[m])[0, 1])
    
    # Calculate mean and percentiles
    r2_mean = np.nanmean(r2_vals)
    r2_p05 = np.nanpercentile(r2_vals, 5)
    r2_p95 = np.nanpercentile(r2_vals, 95)
    
    c_mean = np.nanmean(c_vals)
    c_p05 = np.nanpercentile(c_vals, 5)
    c_p95 = np.nanpercentile(c_vals, 95)
    
    return (r2_mean, r2_p05, r2_p95), (c_mean, c_p05, c_p95)

def geostrophic_vel(field_2d_array):
    """Calculate geostrophic velocities from 2D SSH field."""
    dη_dy = np.gradient(field_2d_array, dy, axis=0, edge_order=2)
    dη_dx = np.gradient(field_2d_array, dx, axis=1, edge_order=2)
    u = -g / f_cor * dη_dy
    v =  g / f_cor * dη_dx
    return u, v

def geostrophic_vel_xarray(da):
    """Calculate geostrophic velocities from xarray DataArray."""
    # Apply geostrophic_vel to each sample
    u_list = []
    v_list = []
    for i in range(da.shape[0]):
        u, v = geostrophic_vel(da.isel(sample=i).values)
        u_list.append(u)
        v_list.append(v)
    
    u_array = np.stack(u_list, axis=0)
    v_array = np.stack(v_list, axis=0)
    
    # Create xarray DataArrays with same structure as input
    u_da = da.copy()
    u_da.values = u_array
    v_da = da.copy() 
    v_da.values = v_array
    
    return u_da, v_da

# Get true velocities from BM
bm_true = test.bm
u_true, v_true = geostrophic_vel_xarray(bm_true)

# -------------------------------------------------
# models to evaluate from eval_datasets
# -------------------------------------------------
models = {
    "sst_ssh": eval_datasets['sst_ssh'],
    "ssh_only": eval_datasets['ssh_only'], 
    "mse_only": eval_datasets['mse_only']
}

# -------------------------------------------------
# main loop
# -------------------------------------------------
records = []
records_p05 = []  # For 5th percentile
records_p95 = []  # For 95th percentile

for name, dataset in models.items():
    print(f"Processing model: {name}")
    
    # Get UBM predictions
    ubm_pred = dataset.ubm_pred_mean
    
    # UBM metrics
    (ubm_r2, ubm_r2_p05, ubm_r2_p95), (ubm_corr, ubm_corr_p05, ubm_corr_p95) = r2_corr(
        flatten_clean(test.ubm), flatten_clean(ubm_pred))
    
    # BM metrics - calculate BM from SSH - UBM
    bm_pred = test.ssh - ubm_pred
    (bm_r2, bm_r2_p05, bm_r2_p95), (bm_corr, bm_corr_p05, bm_corr_p95) = r2_corr(
        flatten_clean(bm_true), flatten_clean(bm_pred))
    
    # velocity metrics
    u_pred, v_pred = geostrophic_vel_xarray(bm_pred)
    (u_r2, u_r2_p05, u_r2_p95), (u_corr, u_corr_p05, u_corr_p95) = r2_corr(
        flatten_clean(u_true), flatten_clean(u_pred))
    (v_r2, v_r2_p05, v_r2_p95), (v_corr, v_corr_p05, v_corr_p95) = r2_corr(
        flatten_clean(v_true), flatten_clean(v_pred))
    
    # average of all eight mean numbers
    avg_all = np.nanmean([ubm_r2, ubm_corr, bm_r2, bm_corr,
                          u_r2, u_corr, v_r2, v_corr])
    
    # average of all eight p05 numbers
    avg_all_p05 = np.nanmean([ubm_r2_p05, ubm_corr_p05, bm_r2_p05, bm_corr_p05,
                              u_r2_p05, u_corr_p05, v_r2_p05, v_corr_p05])
    
    # average of all eight p95 numbers
    avg_all_p95 = np.nanmean([ubm_r2_p95, ubm_corr_p95, bm_r2_p95, bm_corr_p95,
                              u_r2_p95, u_corr_p95, v_r2_p95, v_corr_p95])
    
    # Mean metrics
    records.append(dict(model=name,
                        UBM_R2=ubm_r2, UBM_corr=ubm_corr,
                        BM_R2=bm_r2, BM_corr=bm_corr,
                        u_R2=u_r2, u_corr=u_corr,
                        v_R2=v_r2, v_corr=v_corr,
                        avg_all=avg_all))
    
    # 5th percentile metrics
    records_p05.append(dict(model=name,
                            UBM_R2=ubm_r2_p05, UBM_corr=ubm_corr_p05,
                            BM_R2=bm_r2_p05, BM_corr=bm_corr_p05,
                            u_R2=u_r2_p05, u_corr=u_corr_p05,
                            v_R2=v_r2_p05, v_corr=v_corr_p05,
                            avg_all=avg_all_p05))
    
    # 95th percentile metrics
    records_p95.append(dict(model=name,
                            UBM_R2=ubm_r2_p95, UBM_corr=ubm_corr_p95,
                            BM_R2=bm_r2_p95, BM_corr=bm_corr_p95,
                            u_R2=u_r2_p95, u_corr=u_corr_p95,
                            v_R2=v_r2_p95, v_corr=v_corr_p95,
                            avg_all=avg_all_p95))

# Create DataFrames for mean, p05, and p95
metrics_all = (pd.DataFrame(records)
               .set_index("model")
               .round(3)
               .sort_values("avg_all", ascending=False))

metrics_p05 = (pd.DataFrame(records_p05)
               .set_index("model")
               .round(3)
               .sort_values("avg_all", ascending=False))

metrics_p95 = (pd.DataFrame(records_p95)
               .set_index("model")
               .round(3)
               .sort_values("avg_all", ascending=False))


print("\n=== Combined Summary (Mean ± Range) ===")
combined_summary = pd.DataFrame(index=metrics_all.index)

for col in ['UBM_R2', 'UBM_corr', 'BM_R2', 'BM_corr', 'u_R2', 'u_corr', 'v_R2', 'v_corr', 'avg_all']:
    combined_summary[col] = (metrics_all[col].round(3).astype(str) + 
                            ' (' + metrics_p05[col].round(3).astype(str) + 
                            ', ' + metrics_p95[col].round(3).astype(str) + ')')

print(combined_summary)

Total samples=3645, clean=2907, skipped=738
Processing model: sst_ssh
Processing model: ssh_only
Processing model: mse_only

=== Summary metrics (mean over clean samples) ===
          UBM_R2  UBM_corr  BM_R2  BM_corr   u_R2  u_corr   v_R2  v_corr  \
model                                                                      
sst_ssh    0.010     0.478  0.969    0.987  0.848   0.925  0.849   0.924   
ssh_only  -0.176     0.401  0.955    0.980  0.812   0.904  0.813   0.904   
mse_only  -0.254     0.389  0.958    0.981  0.761   0.885  0.754   0.883   

          avg_all  
model              
sst_ssh     0.749  
ssh_only    0.699  
mse_only    0.670  

=== Summary metrics (5th percentile over clean samples) ===
          UBM_R2  UBM_corr  BM_R2  BM_corr   u_R2  u_corr   v_R2  v_corr  \
model                                                                      
sst_ssh   -0.938     0.010  0.869    0.942  0.567   0.778  0.582   0.774   
ssh_only  -1.349    -0.107  0.800    0.909  0.445   0.7

# SI Figures

## Gaussin Filter

In [13]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import xarray as xr
import cmocean  # Added import for cmocean
from matplotlib import gridspec
from matplotlib.colors import Normalize
from scipy.ndimage import gaussian_filter

best_sample_idx = extreme_samples['sst_ssh']['max']['sample_idx']
best_r2_value = extreme_samples['sst_ssh']['max']['r2']

# Get the true and predicted BM data for the best sample
best_bm_true_sample = eval_datasets['sst_ssh'].bm_truth.isel(sample=best_sample_idx).values
best_bm_pred_sample = eval_datasets['sst_ssh'].bm_pred_mean.isel(sample=best_sample_idx).values

plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Arial', 'DejaVu Sans', 'Helvetica', 'sans-serif'],
        'mathtext.fontset': 'stix',  
        'axes.grid': False,
        'figure.facecolor': 'white',
        'axes.facecolor': 'white'
    })

# Set title parameters
title_fontsize = 40 
title_pad = 8

# Set colorbar tick fontsize
colorbar_tick_fontsize = 35

# ----------------------
# Define plotting parameters
# ----------------------
# Column 1: Original data
vmin_orig, vmax_orig = -0.08, 0.08

# Column 2: Features < 60km 
vmin_60, vmax_60 = -0.02, 0.02

# Column 3: Features < 30km 
vmin_30, vmax_30 = -0.02, 0.02

# Column 4: Features < 10km 
vmin_10, vmax_10 = -0.002, 0.002

# Define parameters for normalized error
# Using appropriate scales for each column 
vmin_norm_orig, vmax_norm_orig = 0, 2.0  # Normalized absolute error for original data
vmin_norm_60, vmax_norm_60 = 0, 2.0  # Normalized absolute error for 60km features
vmin_norm_30, vmax_norm_30 = 0, 2.0  # Normalized absolute error for 30km features
vmin_norm_10, vmax_norm_10 = 0, 2.0  # Normalized absolute error for 10km features

# Define tick values in actual data units 
ticks_orig_data = np.array([-0.08, -0.04, 0.0, 0.04, 0.08])
ticks_60_data = np.array([-0.02, -0.01, 0.0, 0.01, 0.02])  
ticks_30_data = np.array([-0.02, -0.01, 0.0, 0.01, 0.02])  
ticks_10_data = np.array([-0.002, -0.001, 0.0, 0.001, 0.002])  

# Define ticks for normalized absolute error (dimensionless)
ticks_norm_orig_data = np.array([0, 0.5, 1.0, 1.5, 2.0])
ticks_norm_60_data = np.array([0, 0.5, 1.0, 1.5, 2.0])
ticks_norm_30_data = np.array([0, 0.5, 1.0, 1.5, 2.0])
ticks_norm_10_data = np.array([0, 0.5, 1.0, 1.5, 2.0])

# Create the size of arrays
size = best_bm_true_sample.shape[0]
i = np.arange(size)
j = np.arange(size)

# Calculate Gaussian filter parameters based on physical distance
pixel_size = 1.5  # km per grid point

# For features smaller than 60km
scale_60km = 60 / pixel_size  # Convert to grid units
sigma_60km = scale_60km / np.sqrt(12)  # Convert to Gaussian sigma using np.sqrt(12)

# For features smaller than 30km
scale_30km = 30 / pixel_size  # Convert to grid units
sigma_30km = scale_30km / np.sqrt(12)  # Convert to Gaussian sigma using np.sqrt(12)

# For features smaller than 10km
scale_10km = 10 / pixel_size  # Convert to grid units
sigma_10km = scale_10km / np.sqrt(12)  # Convert to Gaussian sigma using np.sqrt(12)

# Apply Gaussian filter to get high-pass filtered versions
# For 60km features
true_low_pass_60km = gaussian_filter(best_bm_true_sample, sigma=sigma_60km)
true_high_pass_60km = best_bm_true_sample - true_low_pass_60km
pred_low_pass_60km = gaussian_filter(best_bm_pred_sample, sigma=sigma_60km)
pred_high_pass_60km = best_bm_pred_sample - pred_low_pass_60km

# For 30km features
true_low_pass_30km = gaussian_filter(best_bm_true_sample, sigma=sigma_30km)
true_high_pass_30km = best_bm_true_sample - true_low_pass_30km
pred_low_pass_30km = gaussian_filter(best_bm_pred_sample, sigma=sigma_30km)
pred_high_pass_30km = best_bm_pred_sample - pred_low_pass_30km

# For 10km features
true_low_pass_10km = gaussian_filter(best_bm_true_sample, sigma=sigma_10km)
true_high_pass_10km = best_bm_true_sample - true_low_pass_10km
pred_low_pass_10km = gaussian_filter(best_bm_pred_sample, sigma=sigma_10km)
pred_high_pass_10km = best_bm_pred_sample - pred_low_pass_10km

# Crop the arrays to remove boundary effects
# Define crop width 
crop_width = 2

# Crop for 60km features
true_high_pass_60km_cropped = true_high_pass_60km[crop_width:-crop_width, crop_width:-crop_width]
pred_high_pass_60km_cropped = pred_high_pass_60km[crop_width:-crop_width, crop_width:-crop_width]

# Crop for 30km features
true_high_pass_30km_cropped = true_high_pass_30km[crop_width:-crop_width, crop_width:-crop_width]
pred_high_pass_30km_cropped = pred_high_pass_30km[crop_width:-crop_width, crop_width:-crop_width]

# Crop for 10km features
true_high_pass_10km_cropped = true_high_pass_10km[crop_width:-crop_width, crop_width:-crop_width]
pred_high_pass_10km_cropped = pred_high_pass_10km[crop_width:-crop_width, crop_width:-crop_width]

# Also crop the original data for consistency
best_bm_true_sample_cropped = best_bm_true_sample[crop_width:-crop_width, crop_width:-crop_width]
best_bm_pred_sample_cropped = best_bm_pred_sample[crop_width:-crop_width, crop_width:-crop_width]

# Calculate normalized absolute errors 
# Normalized absolute error for original data
squared_error_orig = np.abs(best_bm_pred_sample_cropped - best_bm_true_sample_cropped)
std_orig = np.std(best_bm_true_sample_cropped)
norm_absolute_error_orig = squared_error_orig / std_orig

# Normalized absolute error for 60km features
squared_error_60km = np.abs(pred_high_pass_60km_cropped - true_high_pass_60km_cropped)
std_60km = np.std(true_high_pass_60km_cropped)
norm_absolute_error_60km = squared_error_60km / std_60km

# Normalized absolute error for 30km features
squared_error_30km = np.abs(pred_high_pass_30km_cropped - true_high_pass_30km_cropped)
std_30km = np.std(true_high_pass_30km_cropped)
norm_absolute_error_30km = squared_error_30km / std_30km

# Normalized absolute error for 10km features
squared_error_10km = np.abs(pred_high_pass_10km_cropped - true_high_pass_10km_cropped)
std_10km = np.std(true_high_pass_10km_cropped)
norm_absolute_error_10km = squared_error_10km / std_10km

# Create figure with subplots 
fig = plt.figure(figsize=(38, 20))  

# Create a grid with proper alignment - 3 rows, 15 columns
gs = gridspec.GridSpec(3, 15,  
                      width_ratios=[0.94, 0.001, 0.08, 0.3, 0.94, 0.001, 0.08, 0.3, 0.94, 0.001, 0.08, 0.3, 0.94, 0.001, 0.08],  
                      height_ratios=[1, 1, 1],
                      wspace=-0.05, hspace=0.12)

# Define colormap as cmocean thermal 
cmap_bm = cmocean.cm.ice
cmap_norm_error = cmocean.cm.amp

# Helper function to format the axes with centered titles
def format_ax(ax, title):
    # Remove all ticks
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Set title
    ax.set_title(title, fontsize=title_fontsize, pad=title_pad, loc='center')
    

# Helper function to create colorbar with scientific notation header
def create_colorbar_with_scientific(im, cax, ticks, unit, exponent, extend_type='both'):
    cbar = plt.colorbar(im, cax=cax, extend=extend_type, ticks=ticks)
    
    # Set the ylabel with units
    cbar.ax.set_ylabel(unit, rotation=270, labelpad=25, fontsize=30)
    cbar.ax.yaxis.tick_right()
    cbar.ax.yaxis.set_label_position('right')
    
    # Format tick labels as simple numbers 
    tick_labels = [f'{x * (10**(-exponent)):.1f}' for x in ticks]
    cbar.ax.set_yticklabels(tick_labels)
    cbar.ax.tick_params(labelsize=colorbar_tick_fontsize, length=8, width=2)
    
    # Add scientific notation multiplier at the top center
    if exponent != 0:
        multiplier_text = f'×10$^{{{exponent}}}$'
        cbar.ax.text(0.75, 1.04, multiplier_text, transform=cbar.ax.transAxes, 
                    ha='center', va='bottom', fontsize=28)
    
    return cbar

# Helper function for normalized error colorbars 
def create_normalized_colorbar(im, cax, ticks, extend_type='max'):
    cbar = plt.colorbar(im, cax=cax, extend=extend_type, ticks=ticks)
    cbar.ax.set_ylabel('', rotation=270, labelpad=25, fontsize=30)
    cbar.ax.yaxis.tick_right()
    cbar.ax.yaxis.set_label_position('right')
    cbar.ax.set_yticklabels([f'{x:.1f}' for x in ticks])
    cbar.ax.tick_params(labelsize=colorbar_tick_fontsize, length=8, width=2)
    return cbar

# Row 1: True Data
# Column 1: Original data
ax1 = fig.add_subplot(gs[0, 0])
im1 = ax1.imshow(best_bm_true_sample_cropped, cmap=cmap_bm, vmin=vmin_orig, vmax=vmax_orig)
format_ax(ax1, '<120 km')

# Column 2: Features < 60km 
ax1_5 = fig.add_subplot(gs[0, 4])
im1_5 = ax1_5.imshow(true_high_pass_60km_cropped, cmap=cmap_bm, vmin=vmin_60, vmax=vmax_60)
format_ax(ax1_5, '<60 km')

# Column 3: Features < 30km
ax2 = fig.add_subplot(gs[0, 8])
im2 = ax2.imshow(true_high_pass_30km_cropped, cmap=cmap_bm, vmin=vmin_30, vmax=vmax_30)
format_ax(ax2, '<30 km')

# Column 4: Features < 10km
ax3 = fig.add_subplot(gs[0, 12])
im3 = ax3.imshow(true_high_pass_10km_cropped, cmap=cmap_bm, vmin=vmin_10, vmax=vmax_10)
format_ax(ax3, '<10 km')

# Add row label for first row
ax1.text(-0.06, 0.5, 'True', transform=ax1.transAxes, rotation=90, 
         va='center', ha='center', fontsize=title_fontsize)

# Row 2: Predicted Data
# Column 1: Original data
ax4 = fig.add_subplot(gs[1, 0])
im4 = ax4.imshow(best_bm_pred_sample_cropped, cmap=cmap_bm, vmin=vmin_orig, vmax=vmax_orig)
format_ax(ax4, '')

# Column 2: Features < 60km
ax4_5 = fig.add_subplot(gs[1, 4])
im4_5 = ax4_5.imshow(pred_high_pass_60km_cropped, cmap=cmap_bm, vmin=vmin_60, vmax=vmax_60)
format_ax(ax4_5, '')

# Column 3: Features < 30km
ax5 = fig.add_subplot(gs[1, 8])
im5 = ax5.imshow(pred_high_pass_30km_cropped, cmap=cmap_bm, vmin=vmin_30, vmax=vmax_30)
format_ax(ax5, '')

# Column 4: Features < 10km 
ax6 = fig.add_subplot(gs[1, 12])
im6 = ax6.imshow(pred_high_pass_10km_cropped, cmap=cmap_bm, vmin=vmin_10, vmax=vmax_10)
format_ax(ax6, '')

# Add row label for second row
ax4.text(-0.06, 0.5, 'ZCA+SST', transform=ax4.transAxes, rotation=90, 
         va='center', ha='center', fontsize=title_fontsize)

# Row 3: Normalized absolute Error 
# Column 1: Normalized absolute error for original data
ax7 = fig.add_subplot(gs[2, 0])
im7 = ax7.imshow(norm_absolute_error_orig, cmap=cmap_norm_error, vmin=0, vmax=vmax_norm_orig)
format_ax(ax7, '')

# Column 2: Normalized absolute error for 60km features 
ax7_5 = fig.add_subplot(gs[2, 4])
im7_5 = ax7_5.imshow(norm_absolute_error_60km, cmap=cmap_norm_error, vmin=0, vmax=vmax_norm_60)
format_ax(ax7_5, '')

# Column 3: Normalized absolute error for 30km features 
ax8 = fig.add_subplot(gs[2, 8])
im8 = ax8.imshow(norm_absolute_error_30km, cmap=cmap_norm_error, vmin=0, vmax=vmax_norm_30)
format_ax(ax8, '')

# Column 4: Normalized absolute error for 10km features
ax9 = fig.add_subplot(gs[2, 12])
im9 = ax9.imshow(norm_absolute_error_10km, cmap=cmap_norm_error, vmin=0, vmax=vmax_norm_10)
format_ax(ax9, '')

# Add row label for third row
ax7.text(-0.06, 0.5, 'Error', transform=ax7.transAxes, rotation=90, 
         va='center', ha='center', fontsize=title_fontsize)

# Create colorbars for each panel with improved scientific notation
# Row 1 colorbars (True data)
# Colorbar for A1 (original data, -2 exponent)
cax1 = fig.add_subplot(gs[0, 2])
cbar1 = create_colorbar_with_scientific(im1, cax1, ticks_orig_data, '(m)', -2)

# Colorbar for A2 (60km features, -2 exponent)
cax1_5 = fig.add_subplot(gs[0, 6])
cbar1_5 = create_colorbar_with_scientific(im1_5, cax1_5, ticks_60_data, '(m)', -2)

# Colorbar for A3 (30km features, -2 exponent)
cax2 = fig.add_subplot(gs[0, 10])
cbar2 = create_colorbar_with_scientific(im2, cax2, ticks_30_data, '(m)', -2)

# Colorbar for A4 (10km features, -3 exponent)
cax3 = fig.add_subplot(gs[0, 14])
cbar3 = create_colorbar_with_scientific(im3, cax3, ticks_10_data, '(m)', -3)

# Row 2 colorbars (Predicted data)
# Colorbar for B1
cax4 = fig.add_subplot(gs[1, 2])
cbar4 = create_colorbar_with_scientific(im4, cax4, ticks_orig_data, '(m)', -2)

# Colorbar for B2
cax4_5 = fig.add_subplot(gs[1, 6])
cbar4_5 = create_colorbar_with_scientific(im4_5, cax4_5, ticks_60_data, '(m)', -2)

# Colorbar for B3
cax5 = fig.add_subplot(gs[1, 10])
cbar5 = create_colorbar_with_scientific(im5, cax5, ticks_30_data, '(m)', -2)

# Colorbar for B4
cax6 = fig.add_subplot(gs[1, 14])
cbar6 = create_colorbar_with_scientific(im6, cax6, ticks_10_data, '(m)', -3)

# Row 3 colorbars (Normalized Absolute Error)
# Colorbar for C1
cax7 = fig.add_subplot(gs[2, 2])
cbar7 = create_normalized_colorbar(im7, cax7, ticks_norm_orig_data)

# Colorbar for C2
cax7_5 = fig.add_subplot(gs[2, 6])
cbar7_5 = create_normalized_colorbar(im7_5, cax7_5, ticks_norm_60_data)

# Colorbar for C3
cax8 = fig.add_subplot(gs[2, 10])
cbar8 = create_normalized_colorbar(im8, cax8, ticks_norm_30_data)

# Colorbar for C4
cax9 = fig.add_subplot(gs[2, 14])
cbar9 = create_normalized_colorbar(im9, cax9, ticks_norm_10_data)

# Save figure
plt.savefig('/home/jovyan/GRL_ssh/figures/SI/scales_bm.png', bbox_inches='tight', dpi=300, transparent=True)
plt.savefig('/home/jovyan/GRL_ssh/figures/SI/scales_bm.pdf', bbox_inches='tight', dpi=300, transparent=True)
plt.show()

## Scatter plot

In [11]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score

plt.style.use('default')
plt.rcParams.update({
    'font.size': 18,
    'font.family': 'serif',
    'axes.linewidth': 1.2,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'xtick.major.width': 1.2,
    'ytick.major.width': 1.2,
    'xtick.minor.width': 0.8,
    'ytick.minor.width': 0.8,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'grid.linewidth': 0.8,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'axes.labelsize': 18,
    'axes.titlesize': 18
})


models = ['mse_only', 'ssh_only', 'sst_ssh']
model_titles = ['MSE+SST', 'ZCA', 'ZCA+SST']

# Define outlier thresholds
R2_MIN = 0.001  # Minimum R^2 to include (removes very negative values)
R2_MAX = 1.0
REL_MAG_MIN = 0.001  # Minimum relative magnitude
REL_MAG_MAX = 1   # Maximum relative magnitude

fig, axes = plt.subplots(2, 3, figsize=(22, 12))


for col, model in enumerate(models):

    ubm_truth = eval_datasets[model].ubm_truth
    bm_truth = eval_datasets[model].bm_truth
    ubm_pred = eval_datasets[model].ubm_pred_mean
    bm_pred = eval_datasets[model].bm_pred_mean
    
    # Get number of samples
    n_samples = ubm_truth.shape[0]
    
    # Initialize arrays to store metrics for each sample
    bm_r2_values = []
    ubm_r2_values = []
    relative_magnitudes = []
    
    # Calculate metrics for each sample
    for sample_idx in range(n_samples):
        # Get 2D fields for current sample
        ubm_true_2d = ubm_truth[sample_idx, :, :].values.flatten()
        bm_true_2d = bm_truth[sample_idx, :, :].values.flatten()
        ubm_pred_2d = ubm_pred[sample_idx, :, :].values.flatten()
        bm_pred_2d = bm_pred[sample_idx, :, :].values.flatten()
        
        # Remove NaN values if any
        valid_mask_ubm = ~(np.isnan(ubm_true_2d) | np.isnan(ubm_pred_2d))
        valid_mask_bm = ~(np.isnan(bm_true_2d) | np.isnan(bm_pred_2d))
        
        # Calculate R^2 for UBM and BM
        if np.sum(valid_mask_ubm) > 1:
            ubm_r2 = r2_score(ubm_true_2d[valid_mask_ubm], ubm_pred_2d[valid_mask_ubm])
        else:
            ubm_r2 = np.nan
            
        if np.sum(valid_mask_bm) > 1:
            bm_r2 = r2_score(bm_true_2d[valid_mask_bm], bm_pred_2d[valid_mask_bm])
        else:
            bm_r2 = np.nan
        
        # Calculate relative magnitude (mean absolute values to avoid division issues)
        # Check if we have valid data before calculating magnitudes
        if np.sum(~np.isnan(ubm_true_2d)) > 0:
            ubm_magnitude = np.nanmean(np.abs(ubm_true_2d))
        else:
            ubm_magnitude = np.nan
            
        if np.sum(~np.isnan(bm_true_2d)) > 0:
            bm_magnitude = np.nanmean(np.abs(bm_true_2d))
        else:
            bm_magnitude = np.nan
        
        if not np.isnan(bm_magnitude) and bm_magnitude != 0 and not np.isnan(ubm_magnitude):
            rel_magnitude = ubm_magnitude / bm_magnitude
        else:
            rel_magnitude = np.nan
        
        # Store values
        bm_r2_values.append(bm_r2)
        ubm_r2_values.append(ubm_r2)
        relative_magnitudes.append(rel_magnitude)
    
    # Convert to numpy arrays and remove NaN values for plotting
    bm_r2_values = np.array(bm_r2_values)
    ubm_r2_values = np.array(ubm_r2_values)
    relative_magnitudes = np.array(relative_magnitudes)
    
    # Plot BM (top row)
    valid_bm = ~(np.isnan(bm_r2_values) | np.isnan(relative_magnitudes))
    if np.sum(valid_bm) > 0:
        # Apply outlier filtering
        bm_r2_filt = bm_r2_values[valid_bm]
        rel_mag_filt = relative_magnitudes[valid_bm]
        
        # Filter outliers
        outlier_mask = ((bm_r2_filt >= R2_MIN) & (bm_r2_filt <= R2_MAX) & 
                       (rel_mag_filt >= REL_MAG_MIN) & (rel_mag_filt <= REL_MAG_MAX))
        
        if np.sum(outlier_mask) > 0:
            hb = axes[0, col].hexbin(rel_mag_filt[outlier_mask], bm_r2_filt[outlier_mask],
                                   gridsize=60, cmap='Reds', mincnt=1, 
                                   xscale='log', yscale='log', vmin=0, vmax=3)
    
    axes[0, col].set_title(f'{model_titles[col]} (BM)', fontsize=18, fontweight='bold', pad=20)
    # Only add x-axis label for bottom row
    if col == 0:  # Only label y-axis for leftmost column
        axes[0, col].set_ylabel('R²', fontsize=16)
    axes[0, col].set_xscale('log')
    axes[0, col].set_yscale('log')
    axes[0, col].set_xlim(1e-3, 1e0)
    axes[0, col].set_ylim(1e-3, 1e0)
    axes[0, col].tick_params(which='both', direction='in', top=True, right=True)
    
    # Plot UBM (bottom row)
    valid_ubm = ~(np.isnan(ubm_r2_values) | np.isnan(relative_magnitudes))
    if np.sum(valid_ubm) > 0:
        # Apply outlier filtering
        ubm_r2_filt = ubm_r2_values[valid_ubm]
        rel_mag_filt = relative_magnitudes[valid_ubm]
        
        # Filter outliers
        outlier_mask = ((ubm_r2_filt >= R2_MIN) & (ubm_r2_filt <= R2_MAX) & 
                       (rel_mag_filt >= REL_MAG_MIN) & (rel_mag_filt <= REL_MAG_MAX))
        
        if np.sum(outlier_mask) > 0:
            hb = axes[1, col].hexbin(rel_mag_filt[outlier_mask], ubm_r2_filt[outlier_mask],
                                   gridsize=60, cmap='Blues', mincnt=1, 
                                   xscale='log', yscale='log', vmin=0, vmax=3)
    
    axes[1, col].set_title(f'{model_titles[col]} (UBM)', fontsize=18, fontweight='bold', pad=20)
    axes[1, col].set_xlabel('Relative Magnitude (UBM/BM)', fontsize=16)  # x-axis label for bottom row
    if col == 0:  # Only label y-axis for leftmost column
        axes[1, col].set_ylabel('R²', fontsize=16)
    axes[1, col].set_xscale('log')
    axes[1, col].set_yscale('log')
    axes[1, col].set_xlim(1e-3, 1e0)
    axes[1, col].set_ylim(1e-3, 1e0)
    axes[1, col].tick_params(which='both', direction='in', top=True, right=True)


plt.tight_layout(rect=[0, 0, 1, 0.93]) 
plt.subplots_adjust(hspace=0.35, wspace=0.15)  

plt.savefig('figures/SI/scatter.png', dpi=300, bbox_inches='tight', 
            facecolor='white', edgecolor='none')
plt.savefig('figures/SI/scatter.pdf', dpi=300, bbox_inches='tight', 
            facecolor='white', edgecolor='none')

## R^2 & Std Map

In [16]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from sklearn.metrics import r2_score
import xarray as xr
import os

def calculate_spatial_r2(true_data, pred_data):
    """
    Calculate R^2 for each spatial location across the time dimension.
    
    Args:
        true_data: Array of shape (n_samples, H, W)
        pred_data: Array of shape (n_samples, H, W)
    
    Returns:
        r2_map: Array of shape (H, W) with R^2 values for each spatial location
    """
    n_samples, H, W = true_data.shape
    r2_map = np.zeros((H, W))
    
    for i in range(H):
        for j in range(W):
            # Extract time series for this spatial location
            true_ts = true_data[:, i, j]
            pred_ts = pred_data[:, i, j]
            
            # Filter out NaN values
            valid_mask = ~(np.isnan(true_ts) | np.isnan(pred_ts))
            true_ts_valid = true_ts[valid_mask]
            pred_ts_valid = pred_ts[valid_mask]
            
            if len(true_ts_valid) < 2 or np.var(true_ts_valid) == 0:
                r2_map[i, j] = np.nan
            else:
                r2_map[i, j] = r2_score(true_ts_valid, pred_ts_valid)
    
    return r2_map

def create_discrete_colormap_and_norm(data, cmap_name, n_levels=10, percentile_range=(5, 95)):
    """
    Create discrete colormap and normalization based on data distribution.
    
    Args:
        data: Input data array
        cmap_name: Name of the colormap
        n_levels: Number of discrete levels
        percentile_range: Tuple of (low, high) percentiles for range
    
    Returns:
        cmap, norm, levels
    """
    # Get data range based on percentiles
    vmin, vmax = np.nanpercentile(data, percentile_range)
    
    # Create discrete levels
    levels = np.linspace(vmin, vmax, n_levels + 1)
    
    # Create discrete colormap and normalization
    cmap = plt.cm.get_cmap(cmap_name, n_levels)
    norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
    
    return cmap, norm, levels

def plot_r2_and_uncertainty_maps(eval_dataset, figsize=(12, 6), dpi=300, save_path="figures/", 
                                 n_levels_r2=10, n_levels_unc=10):

    plt.style.use('default')  
    
    plt.rcParams.update({
        'font.size': 11,
        'font.family': 'sans-serif',
        'axes.linewidth': 1.2,
        'axes.spines.top': True,
        'axes.spines.right': True,
        'axes.spines.bottom': True,
        'axes.spines.left': True,
        'axes.grid': False,
        'xtick.direction': 'in',
        'ytick.direction': 'in',
        'xtick.major.size': 4,
        'ytick.major.size': 4,
        'legend.frameon': True,
        'legend.fancybox': False,
        'legend.shadow': False,
        'legend.edgecolor': 'black',
        'savefig.dpi': dpi,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1
    })
    

    # Extract the data 
    ubm_true = eval_dataset['ubm_truth'].values
    ubm_pred = eval_dataset['ubm_pred_mean'].values
    
    # Calculate R^2 map
    ubm_r2_map = calculate_spatial_r2(ubm_true, ubm_pred)
    ubm_r2_map = np.clip(ubm_r2_map, -1.0, 1.0)
    
    # Uncertainty calculation
    # Extract ensemble predictions
    ubm_ensemble = eval_dataset['ubm_pred_samples'].values

    # Calculate std for each sample
    ubm_ensemble_std_per_sample = np.std(ubm_ensemble, axis=1)  # Shape: (n_test_samples, H, W)
    
    # Geometric mean via log-space averaging
    ubm_uncertainty = np.exp(np.mean(np.log(ubm_ensemble_std_per_sample + 1e-8), axis=0))  # Shape: (H, W)
    
    # Create discrete colormaps and normalizations
    r2_cmap, r2_norm, r2_levels = create_discrete_colormap_and_norm(
        ubm_r2_map, 'RdYlBu_r', n_levels_r2, (5, 95)
    )
    
    unc_cmap, unc_norm, unc_levels = create_discrete_colormap_and_norm(
        ubm_uncertainty, 'plasma', n_levels_unc, (5, 95)
    )
    
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=figsize, facecolor='white')
    fig.patch.set_facecolor('white')
    
    # Set up distance ticks
    size = 80
    tick_positions = np.linspace(0, size-1, 9)
    tick_positions = np.round(tick_positions).astype(int)
    km_ticks = tick_positions * 1.5  
    
    x_labels = []
    y_labels = []
    for i, km in enumerate(km_ticks):
        if i % 2 == 0: 
            x_labels.append(f'{km:.0f}')
            y_labels.append(f'{km:.0f}')
        else:
            x_labels.append('')
            y_labels.append('')
    
    # First column
    im1 = axes[0].imshow(ubm_r2_map, cmap=r2_cmap, norm=r2_norm, 
                          origin='lower', interpolation='bilinear')
    
    # Format axis with distance ticks 
    axes[0].set_xticks(tick_positions)
    axes[0].set_yticks(tick_positions)
    axes[0].set_xticklabels(x_labels)
    axes[0].set_yticklabels(y_labels)
    axes[0].tick_params(axis='both', which='major', length=7, width=1.5, labelsize=14)
    axes[0].set_xlabel('Distance (km)', fontsize=18)
    axes[0].set_ylabel('Distance (km)', fontsize=18)
    axes[0].text(0.5, 1.05, '(a)', transform=axes[0].transAxes, 
             fontsize=26, fontweight='bold', ha='center', va='bottom')
    
    for spine in axes[0].spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1.2)
    axes[0].grid(False)
    
    # Add discrete colorbar 
    cbar1 = fig.colorbar(im1, ax=axes[0], shrink=0.62, aspect=20, extend='both', 
                         boundaries=r2_levels, ticks=r2_levels)
    cbar1.set_label(r'$R^2$', rotation=270, labelpad=25, fontsize=20)
    cbar1.ax.tick_params(labelsize=14)
    
    # Format colorbar ticks
    if np.max(np.abs(r2_levels)) < 1e-2:  # Use scientific notation
        cbar1.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.2e}'))
    else:
        cbar1.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.2f}'))
    
    # Second column: Uncertainty plot with discrete colorbar
    im2 = axes[1].imshow(ubm_uncertainty, cmap=unc_cmap, norm=unc_norm, 
                          origin='lower', interpolation='bilinear')
    
    # Format axis with distance ticks
    axes[1].set_xticks(tick_positions)
    axes[1].set_yticks(tick_positions)
    axes[1].set_xticklabels(x_labels)
    axes[1].set_yticklabels(y_labels)
    axes[1].tick_params(axis='both', which='major', length=7, width=1.5, labelsize=14)
    axes[1].set_xlabel('Distance (km)', fontsize=18)
    axes[1].set_ylabel('Distance (km)', fontsize=18)
    axes[1].text(0.5, 1.05, '(b)', transform=axes[1].transAxes,
             fontsize=26, fontweight='bold', ha='center', va='bottom')
    
    for spine in axes[1].spines.values():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1.2)
    axes[1].grid(False)
    
    # Add discrete colorbar for uncertainty
    cbar2 = fig.colorbar(im2, ax=axes[1], shrink=0.62, aspect=20, extend='both',
                         boundaries=unc_levels, ticks=unc_levels)
    cbar2.set_label(r'$\sigma_{geo}$', rotation=270, labelpad=30, fontsize=22)
    cbar2.ax.tick_params(labelsize=14)
    
    # Format uncertainty colorbar ticks
    cbar2.ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x:.2e}'))
    
    # Adjust layout with proper spacing
    plt.tight_layout(pad=2.0, w_pad=3.0, h_pad=3.0)
    
    # Save figure
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        full_path = os.path.join(save_path, 'performance_maps.png')
        plt.savefig(full_path, dpi=dpi, bbox_inches='tight', 
                   facecolor='white', edgecolor='none')
        print(f"Figure saved to: {full_path}")
    
    plt.show()
    
    return fig, ubm_r2_map, ubm_uncertainty


fig, ubm_r2_map, ubm_uncertainty = plot_r2_and_uncertainty_maps(
    eval_datasets['sst_ssh'], 
    figsize=(12, 6), 
    dpi=300,
    save_path='figures/SI',
    n_levels_r2=10,      
    n_levels_unc=10    
)

Figure saved to: figures/SI/performance_maps.png


## 3 PSD

In [12]:
import warnings
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import xarray as xr
import xrft
from sklearn.metrics import r2_score

# Suppress warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore", category=FutureWarning, module='xrft')

def get_sample_psd_data(sample_idx, eval_data):
    """Get PSD data for a specific sample using ensemble mean PSDs"""
    # Get data for the sample
    ubm_true = eval_data.ubm_truth.isel(sample=sample_idx).values
    bm_true = eval_data.bm_truth.isel(sample=sample_idx).values
    
    # Get ensemble data
    ubm_ensemble = eval_data.ubm_pred_samples.isel(sample=sample_idx).values
    bm_ensemble = eval_data.bm_pred_samples.isel(sample=sample_idx).values
    
    # Calculate PSDs for truth
    psd_ubm_true = calculate_psd_km(ubm_true)
    psd_bm_true = calculate_psd_km(bm_true)
    
    # Calculate ensemble PSDs
    ubm_ensemble_psds = []
    bm_ensemble_psds = []
    
    for ens in range(30):
        if not np.all(np.isnan(ubm_ensemble[ens, :, :])):
            ubm_psd = calculate_psd_km(ubm_ensemble[ens, :, :])
            ubm_ensemble_psds.append(ubm_psd.values)
            
        if not np.all(np.isnan(bm_ensemble[ens, :, :])):
            bm_psd = calculate_psd_km(bm_ensemble[ens, :, :])
            bm_ensemble_psds.append(bm_psd.values)
    
    # Calculate envelope bounds (no mean needed)
    ubm_ensemble_psds = np.array(ubm_ensemble_psds)
    bm_ensemble_psds = np.array(bm_ensemble_psds)
    
    ubm_psd_05 = np.nanpercentile(ubm_ensemble_psds, 5, axis=0)
    ubm_psd_95 = np.nanpercentile(ubm_ensemble_psds, 95, axis=0)
    bm_psd_05 = np.nanpercentile(bm_ensemble_psds, 5, axis=0)
    bm_psd_95 = np.nanpercentile(bm_ensemble_psds, 95, axis=0)
    
    return {
        'freq': psd_ubm_true.freq_r.values,
        'ubm_true': psd_ubm_true.values,
        'bm_true': psd_bm_true.values,
        'ubm_05': ubm_psd_05,
        'ubm_95': ubm_psd_95,
        'bm_05': bm_psd_05,
        'bm_95': bm_psd_95
    }

plt.rcParams.update({
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'DejaVu Sans', 'Helvetica', 'sans-serif'],
    'mathtext.fontset': 'stix',  
    'axes.grid': False,
    'figure.facecolor': 'white',
    'axes.facecolor': 'white', 
    'text.usetex': False,
})


sample_types = ['best', 'median', 'worst']
sample_keys = ['max', 'median', 'min']
sample_indices = []

for sample_key in sample_keys:
    sample_idx = extreme_samples['sst_ssh'][sample_key]['sample_idx']
    sample_indices.append(sample_idx)


# Load the ZCA+SST evaluation dataset
sst_ssh_data = eval_datasets['sst_ssh']

# Create figure with 1x3 subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

colors = {
    'ubm_true': '#d62728',      # Red
    'bm_true': '#1f77b4',       # Blue
    'ubm_envelope': '#ff7f0e',  # Orange
    'bm_envelope': '#2ca02c',   # Green
}

# Per-panel R^2 for UBM prediction (best, median, worst)
r2_vals = [0.9127, 0.0866, -6.2663]

# Loop through each sample and create plots
for i, (sample_idx, sample_type) in enumerate(zip(sample_indices, sample_types)):
    ax = axes[i]
    
    # Get PSD data for this sample
    print(f"Computing PSDs for {sample_type} sample #{sample_idx}...")
    psd_data = get_sample_psd_data(sample_idx, sst_ssh_data)
    
    # Plot true values (solid lines)
    ax.loglog(psd_data['freq'], psd_data['ubm_true'], 
             color=colors['ubm_true'], linewidth=3.5, label='UBM True')
    ax.loglog(psd_data['freq'], psd_data['bm_true'], 
             color=colors['bm_true'], linewidth=3.5, label='BM True')
    
    # Add uncertainty envelopes (no mean lines)
    ax.fill_between(psd_data['freq'], psd_data['ubm_05'], psd_data['ubm_95'], 
                    alpha=0.6, color=colors['ubm_envelope'], label='UBM 90% CI', 
                    edgecolor='none')
    ax.fill_between(psd_data['freq'], psd_data['bm_05'], psd_data['bm_95'], 
                    alpha=0.6, color=colors['bm_envelope'], label='BM 90% CI', 
                    edgecolor='none')
    
    # Customize plot appearance
    ax.set_xlim(8e-3, 5e-1)
    ax.set_ylim(1e-12, 1e1)
    
    # Set custom tick locations for y-axis
    y_ticks_labeled = [10**j for j in range(-10, 1, 2)]
    ax.yaxis.set_major_locator(ticker.FixedLocator(y_ticks_labeled))
    ax.yaxis.set_major_formatter(ticker.LogFormatterMathtext())
    
    # Professional tick styling
    ax.tick_params(axis='both', which='major', labelsize=22, length=8, 
                   width=1.5, direction='in') 
    ax.tick_params(axis='both', which='minor', length=4, 
                   width=0.8, direction='in')

    # Hide y-axis labels for middle and right subplots
    if i > 0:
        ax.tick_params(axis='y', labelleft=False)

    # Panel labels
    letters = ['a', 'b', 'c']
    ax.text(0.02, 0.98, f'{letters[i]})', transform=ax.transAxes, 
            fontsize=24, fontweight='bold', va='top', ha='left')
    
    # X-axis label for all subplots
    ax.set_xlabel(r'Wavenumber (cpkm)', fontsize=24, fontweight='normal')

    # Y-label only on leftmost subplot
    if i == 0:
        ax.set_ylabel(r'PSD (m$^2$ cpkm$^{-1}$)', fontsize=24, fontweight='normal')
    
    # Annotation inside each panel (top-right corner)
    ax.text(
        0.98, 0.95, f"R$^2$ (UBM pred) = {r2_vals[i]:.4f}",
        transform=ax.transAxes, ha='right', va='top',
        fontsize=18
    )



handles, labels = axes[0].get_legend_handles_labels()
fig.legend(
    handles, labels,
    loc='lower center', bbox_to_anchor=(0.5, -0.15),
    ncol=4, frameon=False, fontsize=22, columnspacing=2.5
)

plt.tight_layout()
plt.subplots_adjust(bottom=0.2) 
plt.show()

plt.savefig('/home/jovyan/GRL_ssh/figures/SI/3_psd.png', 
            bbox_inches='tight', dpi=300, facecolor='white', 
            edgecolor='none', format='png')
plt.savefig('/home/jovyan/GRL_ssh/figures/SI/3_psd.pdf', 
            bbox_inches='tight', facecolor='white', 
            edgecolor='none', format='pdf')


Computing PSDs for best sample #3041...
Computing PSDs for median sample #519...
Computing PSDs for worst sample #3323...
