In [1]:
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr
from joblib import Parallel, delayed

# ----------------------------
# Configuration
# ----------------------------
data_dir = Path("../data")
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")

variables = ['tas', 'pr']

scenarios_tas = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur', 'g6solar']
scenarios_pr = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur']

models = {
    'bilinear': 'BI',
    'qdm': 'QDM',
    'mld': 'MLd',
    'mls': 'MLs'
}

var_units = {
    'tas': '°C',
    'pr': 'mm/day'
}

# ----------------------------
# Latitude weights
# ----------------------------
ds_tmp = xr.open_dataset(results_dir / "tas_evaluation_historical.nc")
lats = ds_tmp['lat'].values
n_lon = ds_tmp.dims['lon']
ds_tmp.close()

lat_weights_2d = np.cos(np.deg2rad(lats))[:, np.newaxis] * np.ones((1, n_lon))
w_total = np.sum(lat_weights_2d)

# ----------------------------
# Vectorized weighted functions
# ----------------------------
def weighted_spatial_mean(data):
    """Area-weighted spatial mean: (time,) from (time, lat, lon)"""
    return np.sum(data * lat_weights_2d[np.newaxis, :, :], axis=(1, 2)) / w_total

def compute_temporal_metrics(y_true, y_pred):
    """RMSE and correlation on area-weighted temporal mean time series."""
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    """Vectorized area-weighted spatial RMSE for each time point."""
    diff_sq = (predictions - groundtruth)**2
    weighted = diff_sq * lat_weights_2d[np.newaxis, :, :]
    return np.sqrt(np.sum(weighted, axis=(1, 2)) / w_total)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    """Vectorized area-weighted spatial R² for each time point."""
    w = lat_weights_2d[np.newaxis, :, :]
    weighted_mean = np.sum(w * groundtruth, axis=(1, 2), keepdims=True) / w_total
    ss_res = np.sum(w * (groundtruth - predictions)**2, axis=(1, 2))
    ss_tot = np.sum(w * (groundtruth - weighted_mean)**2, axis=(1, 2))
    return np.where(ss_tot > 0, 1 - ss_res / ss_tot, 0.0)

# ----------------------------
# Loading and computing function for one (var, scenario)
# ----------------------------
def load_and_compute(var, scenario):
    """Load all model data and compute metrics for one (var, scenario)."""
    temporal = {}
    spatial = {}
    
    if scenario == 'g6solar':
        g6solar_main = xr.open_dataset(data_dir / "g6solar_downscaled_results.nc")
        g6solar_qdm = xr.open_dataset(results_dir / "tas_evaluation_g6solar_qdm.nc")
        
        gt_values = g6solar_main['groundtruth'].values
        gt_mean = weighted_spatial_mean(gt_values)
        
        model_map = {
            'bilinear': ('input', g6solar_main),
            'mld': ('pred_zscore_pixel', g6solar_main),
            'mls': ('pred_grid', g6solar_main),
            'qdm': ('pred_qdm', g6solar_qdm),
        }
        
        for model_key, (var_name, ds) in model_map.items():
            pred_values = ds[var_name].values
            pred_mean = weighted_spatial_mean(pred_values)
            temporal[model_key] = compute_temporal_metrics(gt_mean, pred_mean)
            spatial[model_key] = (
                compute_spatial_rmse_timeseries(pred_values, gt_values),
                compute_spatial_r2_timeseries(pred_values, gt_values)
            )
        
        g6solar_main.close()
        g6solar_qdm.close()
    
    else:
        gt_values = None
        gt_mean = None
        
        # Main file (groundtruth, bilinear, MLd)
        main_file = results_dir / f"{var}_evaluation_{scenario}.nc"
        if main_file.exists():
            ds = xr.open_dataset(main_file)
            gt_values = ds['groundtruth'].values
            gt_mean = weighted_spatial_mean(gt_values)
            
            if 'input' in ds:
                pred_values = ds['input'].values
                pred_mean = weighted_spatial_mean(pred_values)
                temporal['bilinear'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial['bilinear'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            if 'pred_zscore_pixel' in ds:
                pred_values = ds['pred_zscore_pixel'].values
                pred_mean = weighted_spatial_mean(pred_values)
                temporal['mld'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial['mld'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds.close()
        
        # QDM
        qdm_file = results_dir / f"{var}_evaluation_{scenario}_qdm.nc"
        if qdm_file.exists():
            ds = xr.open_dataset(qdm_file)
            gt_values_qdm = ds['groundtruth'].values
            gt_mean_qdm = weighted_spatial_mean(gt_values_qdm)
            
            if 'pred_qdm' in ds:
                pred_values = ds['pred_qdm'].values
                pred_mean = weighted_spatial_mean(pred_values)
                temporal['qdm'] = compute_temporal_metrics(gt_mean_qdm, pred_mean)
                spatial['qdm'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values_qdm),
                    compute_spatial_r2_timeseries(pred_values, gt_values_qdm)
                )
            
            ds.close()
        
        # MLs
        res_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
        if res_file.exists():
            ds = xr.open_dataset(res_file)
            gt_values_res = ds['groundtruth'].values
            gt_mean_res = weighted_spatial_mean(gt_values_res)
            
            if 'pred_grid' in ds:
                pred_values = ds['pred_grid'].values
                pred_mean = weighted_spatial_mean(pred_values)
                temporal['mls'] = compute_temporal_metrics(gt_mean_res, pred_mean)
                spatial['mls'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values_res),
                    compute_spatial_r2_timeseries(pred_values, gt_values_res)
                )
            
            ds.close()
    
    return var, scenario, temporal, spatial

# ----------------------------
# Build task list and run in parallel
# ----------------------------
tasks = []
for var in variables:
    scenarios = scenarios_tas if var == 'tas' else scenarios_pr
    for scenario in scenarios:
        tasks.append((var, scenario))

parallel_results = Parallel(n_jobs=min(len(tasks), 12), verbose=1)(
    delayed(load_and_compute)(var, scenario) for var, scenario in tasks
)

# Organize results
all_temporal = {var: {} for var in variables}
all_spatial = {var: {} for var in variables}

for var, scenario, temporal, spatial in parallel_results:
    all_temporal[var][scenario] = temporal
    all_spatial[var][scenario] = spatial

# ----------------------------
# Print Tables
# ----------------------------
for var in variables:
    scenarios = scenarios_tas if var == 'tas' else scenarios_pr
    
    print("\n" + "=" * 80)
    print(f"VARIABLE: {var.upper()} ({var_units[var]})")
    print("=" * 80)
    
    model_keys = list(models.keys())
    model_names = [models[m] for m in model_keys]
    
    col_width = 12
    header = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width}}" for name in model_names])
    separator = "-" * len(header)
    
    # Table 1: Temporal RMSE
    print("\n" + "=" * len(header))
    print(f"Table 1: RMSE ({var_units[var]}) - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in all_temporal[var][scenario]:
                rmse = all_temporal[var][scenario][m][0]
                row += f"{rmse:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # Table 2: Temporal Correlation
    print("\n" + "=" * len(header))
    print(f"Table 2: Correlation - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in all_temporal[var][scenario]:
                corr = all_temporal[var][scenario][m][1]
                row += f"{corr:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # Table 3: Spatial RMSE
    col_width_spatial = 22
    header_spatial = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width_spatial}}" for name in model_names])
    separator_spatial = "-" * len(header_spatial)
    
    print("\n" + "=" * len(header_spatial))
    print(f"Table 3: RMSE ({var_units[var]}) - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in all_spatial[var][scenario]:
                rmse_arr = all_spatial[var][scenario][m][0]
                mean_val = np.mean(rmse_arr)
                p5 = np.percentile(rmse_arr, 5)
                p95 = np.percentile(rmse_arr, 95)
                cell = f"{mean_val:.2f} [{p5:.2f}, {p95:.2f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)
    
    # Table 4: Spatial R²
    print("\n" + "=" * len(header_spatial))
    print(f"Table 4: R² - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in all_spatial[var][scenario]:
                r2_arr = all_spatial[var][scenario][m][1]
                mean_val = np.mean(r2_arr)
                p5 = np.percentile(r2_arr, 5)
                p95 = np.percentile(r2_arr, 95)
                cell = f"{mean_val:.4f} [{p5:.4f}, {p95:.4f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)

print("\n" + "=" * 80)
print(f"Total models compared: {len(models)}")
print(f"Variables: {', '.join(variables)}")
print("=" * 80)

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  n_lon = ds_tmp.dims['lon']
[Parallel(n_jobs=11)]: Using backend LokyBackend with 11 concurrent workers.
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core imp


VARIABLE: TAS (°C)

Table 1: RMSE (°C) - Temporal Mean
Scenario               BI          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.4365       0.1749       0.4300       0.1719 
ssp126             0.4907       0.1885       0.7331       0.1973 
ssp245             0.4299       0.1873       0.9182       0.1756 
ssp585             0.4422       0.2051       1.3235       0.1899 
g6sulfur           0.2657       0.3734       0.8442       0.2687 
g6solar            0.2551       0.3833       0.7033       0.2687 

Table 2: Correlation - Temporal Mean
Scenario               BI          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.9919       0.9924       0.9886       0.9937 
ssp126             0.9898       0.9907       0.9841       0.9941 
ssp245             0.9912       0.9917       0.9719       0.9936 
ssp585             0.9922       0.9927       0.9518 

[Parallel(n_jobs=11)]: Done  11 out of  11 | elapsed:   22.7s finished


In [2]:
import numpy as np
import xarray as xr
from pathlib import Path
from scipy.stats import pearsonr
import pandas as pd
from joblib import Parallel, delayed

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")
var = 'tas'
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585']

normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global',
                  'zscore_pixel', 'instance_zscore', 'instance_minmax']

norm_labels = {
    'none': 'None',
    'minmax_global': 'MM-G',
    'minmax_pixel': 'MM-P',
    'zscore_global': 'ZS-G',
    'zscore_pixel': 'ZS-P',
    'instance_zscore': 'Inst-ZS',
    'instance_minmax': 'Inst-MM'
}

residual_models = ['raw', 'gma', 'gmt', 'grid']
residual_labels = {
    'raw': 'LR-Res',
    'gma': 'GMA-Res',
    'gmt': 'GMT-Res',
    'grid': 'PLD-Res'
}

scenario_labels = {
    'historical': 'Historical',
    'ssp126': 'SSP1-2.6',
    'ssp245': 'SSP2-4.5',
    'ssp585': 'SSP5-8.5'
}

# ----------------------------
# Latitude weights
# ----------------------------
ds_tmp = xr.open_dataset(results_dir / f"{var}_evaluation_historical.nc")
lats = ds_tmp['lat'].values
n_lon = ds_tmp.dims['lon']
ds_tmp.close()

lat_weights_2d = np.cos(np.deg2rad(lats))[:, np.newaxis] * np.ones((1, n_lon))
w_total = np.sum(lat_weights_2d)

# ----------------------------
# Vectorized weighted functions
# ----------------------------
def weighted_spatial_mean(data):
    return np.sum(data * lat_weights_2d[np.newaxis, :, :], axis=(1, 2)) / w_total

def compute_temporal_metrics(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    diff_sq = (predictions - groundtruth)**2
    weighted = diff_sq * lat_weights_2d[np.newaxis, :, :]
    return np.sqrt(np.sum(weighted, axis=(1, 2)) / w_total)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    w = lat_weights_2d[np.newaxis, :, :]
    weighted_mean = np.sum(w * groundtruth, axis=(1, 2), keepdims=True) / w_total
    ss_res = np.sum(w * (groundtruth - predictions)**2, axis=(1, 2))
    ss_tot = np.sum(w * (groundtruth - weighted_mean)**2, axis=(1, 2))
    return np.where(ss_tot > 0, 1 - ss_res / ss_tot, 0.0)

# ----------------------------
# Load and compute for one scenario
# ----------------------------
def load_and_compute(scenario):
    results = {}

    # Normalization models
    results_file = results_dir / f"{var}_evaluation_{scenario}.nc"
    if results_file.exists():
        ds = xr.open_dataset(results_file)
        gt_data = ds['groundtruth'].values
        gt_mean = weighted_spatial_mean(gt_data)

        for norm in normalizations:
            pred_key = f'pred_{norm}'
            if pred_key not in ds:
                continue

            pred_data = ds[pred_key].values
            pred_mean = weighted_spatial_mean(pred_data)

            rmse_ts = compute_spatial_rmse_timeseries(pred_data, gt_data)
            r2_ts = compute_spatial_r2_timeseries(pred_data, gt_data)
            t_rmse, t_corr = compute_temporal_metrics(gt_mean, pred_mean)

            results[norm] = {
                'spatial_rmse': np.mean(rmse_ts),
                'spatial_r2': np.mean(r2_ts),
                'global_rmse': t_rmse,
                'global_corr': t_corr
            }
        ds.close()

    # Residual models
    residual_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
    if residual_file.exists():
        ds_res = xr.open_dataset(residual_file)
        gt_data = ds_res['groundtruth'].values
        gt_mean = weighted_spatial_mean(gt_data)

        for res_model in residual_models:
            pred_key = f'pred_{res_model}'
            if pred_key not in ds_res:
                continue

            pred_data = ds_res[pred_key].values
            pred_mean = weighted_spatial_mean(pred_data)

            rmse_ts = compute_spatial_rmse_timeseries(pred_data, gt_data)
            r2_ts = compute_spatial_r2_timeseries(pred_data, gt_data)
            t_rmse, t_corr = compute_temporal_metrics(gt_mean, pred_mean)

            results[res_model] = {
                'spatial_rmse': np.mean(rmse_ts),
                'spatial_r2': np.mean(r2_ts),
                'global_rmse': t_rmse,
                'global_corr': t_corr
            }
        ds_res.close()

    return scenario, results

# ----------------------------
# Run in parallel
# ----------------------------
parallel_results = Parallel(n_jobs=len(scenarios), verbose=1)(
    delayed(load_and_compute)(scenario) for scenario in scenarios
)

# ----------------------------
# Organize results
# ----------------------------
all_models = normalizations + residual_models
all_labels = {**norm_labels, **residual_labels}

spatial_rmse = {model: {} for model in all_models}
spatial_r2 = {model: {} for model in all_models}
global_rmse = {model: {} for model in all_models}
global_corr = {model: {} for model in all_models}

for scenario, results in parallel_results:
    for model, metrics in results.items():
        spatial_rmse[model][scenario] = metrics['spatial_rmse']
        spatial_r2[model][scenario] = metrics['spatial_r2']
        global_rmse[model][scenario] = metrics['global_rmse']
        global_corr[model][scenario] = metrics['global_corr']

# ----------------------------
# Create DataFrames
# ----------------------------
active_models = [m for m in all_models if spatial_rmse[m]]

df_spatial_rmse = pd.DataFrame({all_labels[m]: {scenario_labels[s]: spatial_rmse[m][s] for s in scenarios if s in spatial_rmse[m]} for m in active_models})
df_spatial_r2 = pd.DataFrame({all_labels[m]: {scenario_labels[s]: spatial_r2[m][s] for s in scenarios if s in spatial_r2[m]} for m in active_models})
df_global_rmse = pd.DataFrame({all_labels[m]: {scenario_labels[s]: global_rmse[m][s] for s in scenarios if s in global_rmse[m]} for m in active_models})
df_global_corr = pd.DataFrame({all_labels[m]: {scenario_labels[s]: global_corr[m][s] for s in scenarios if s in global_corr[m]} for m in active_models})

# ----------------------------
# Display Tables
# ----------------------------
print("=" * 80)
print("TABLE 1: Spatial RMSE (°C) - Mean across time")
print("=" * 80)
print(df_spatial_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 2: Spatial R² - Mean across time")
print("=" * 80)
print(df_spatial_r2.round(4))
print("\n")

print("=" * 80)
print("TABLE 3: Global Mean Timeseries RMSE (°C)")
print("=" * 80)
print(df_global_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 4: Global Mean Timeseries Correlation")
print("=" * 80)
print(df_global_corr.round(4))

  n_lon = ds_tmp.dims['lon']
[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


TABLE 1: Spatial RMSE (°C) - Mean across time
              None    MM-G    MM-P    ZS-G    ZS-P  Inst-ZS  Inst-MM  LR-Res  \
Historical  1.8247  3.4492  1.6072  2.0397  1.6000   2.2081   3.2364  1.4887   
SSP1-2.6    2.0456  3.4419  1.7779  2.1625  1.8026   2.1838   3.1394  1.5263   
SSP2-4.5    2.2627  3.5212  1.9316  2.3318  1.9769   2.1844   3.1135  1.5573   
SSP5-8.5    2.6677  3.6186  2.2402  2.6705  2.3047   2.2020   3.0804  1.6141   

            GMA-Res  GMT-Res  PLD-Res  
Historical   1.5009   1.4882   1.4812  
SSP1-2.6     1.5666   1.5371   1.4688  
SSP2-4.5     1.5953   1.5733   1.4780  
SSP5-8.5     1.6537   1.6381   1.5075  


TABLE 2: Spatial R² - Mean across time
              None    MM-G    MM-P    ZS-G    ZS-P  Inst-ZS  Inst-MM  LR-Res  \
Historical  0.9854  0.9476  0.9886  0.9818  0.9887   0.9787   0.9542  0.9903   
SSP1-2.6    0.9812  0.9465  0.9857  0.9790  0.9852   0.9786   0.9557  0.9895   
SSP2-4.5    0.9766  0.9438  0.9828  0.9753  0.9818   0.9785   0.9563  0.

[Parallel(n_jobs=4)]: Done   2 out of   4 | elapsed:   50.6s remaining:   50.6s
[Parallel(n_jobs=4)]: Done   4 out of   4 | elapsed:   50.6s finished


# old

In [5]:
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr

# ----------------------------
# Configuration
# ----------------------------
data_dir = Path("../data")
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")

variables = ['tas', 'pr']

# Scenarios for each variable (tas has g6solar, pr does not)
scenarios_tas = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur', 'g6solar']
scenarios_pr = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur']

# Selected models only
models = {
    'bilinear': 'BI',
    'qdm': 'QDM',
    'mld': 'MLd',
    'mls': 'MLs'
}

# Units for each variable
var_units = {
    'tas': '°C',
    'pr': 'mm/day'
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_spatial_mean_numpy(data):
    """Compute spatial mean for numpy array with shape (time, lat, lon)."""
    return np.mean(data, axis=(1, 2))

def compute_temporal_metrics(y_true, y_pred):
    """RMSE and correlation on temporal mean time series."""
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    """Calculate spatial RMSE for each time point."""
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    """Calculate spatial R² for each time point."""
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

# ----------------------------
# Process each variable
# ----------------------------
for var in variables:
    # Select appropriate scenarios for this variable
    scenarios = scenarios_tas if var == 'tas' else scenarios_pr
    
    print("\n" + "=" * 80)
    print(f"VARIABLE: {var.upper()} ({var_units[var]})")
    print("=" * 80)
    
    # Storage for metrics
    temporal_metrics = {scen: {} for scen in scenarios}
    spatial_metrics = {scen: {} for scen in scenarios}
    
    # ----------------------------
    # Load data and compute metrics
    # ----------------------------
    for scenario in scenarios:
        
        # Handle G6solar separately (different file structure)
        if scenario == 'g6solar':
            # Load G6solar main results
            g6solar_main = xr.open_dataset(data_dir / "g6solar_downscaled_results.nc")
            g6solar_qdm = xr.open_dataset(results_dir / "tas_evaluation_g6solar_qdm.nc")
            
            gt_values = g6solar_main['groundtruth'].values
            gt_mean = compute_spatial_mean_numpy(gt_values)
            
            # Bilinear
            input_values = g6solar_main['input'].values
            input_mean = compute_spatial_mean_numpy(input_values)
            temporal_metrics[scenario]['bilinear'] = compute_temporal_metrics(gt_mean, input_mean)
            spatial_metrics[scenario]['bilinear'] = (
                compute_spatial_rmse_timeseries(input_values, gt_values),
                compute_spatial_r2_timeseries(input_values, gt_values)
            )
            
            # MLd
            mld_values = g6solar_main['pred_zscore_pixel'].values
            mld_mean = compute_spatial_mean_numpy(mld_values)
            temporal_metrics[scenario]['mld'] = compute_temporal_metrics(gt_mean, mld_mean)
            spatial_metrics[scenario]['mld'] = (
                compute_spatial_rmse_timeseries(mld_values, gt_values),
                compute_spatial_r2_timeseries(mld_values, gt_values)
            )
            
            # MLs
            mls_values = g6solar_main['pred_grid'].values
            mls_mean = compute_spatial_mean_numpy(mls_values)
            temporal_metrics[scenario]['mls'] = compute_temporal_metrics(gt_mean, mls_mean)
            spatial_metrics[scenario]['mls'] = (
                compute_spatial_rmse_timeseries(mls_values, gt_values),
                compute_spatial_r2_timeseries(mls_values, gt_values)
            )
            
            # QDM
            qdm_values = g6solar_qdm['pred_qdm'].values
            qdm_mean = compute_spatial_mean_numpy(qdm_values)
            temporal_metrics[scenario]['qdm'] = compute_temporal_metrics(gt_mean, qdm_mean)
            spatial_metrics[scenario]['qdm'] = (
                compute_spatial_rmse_timeseries(qdm_values, gt_values),
                compute_spatial_r2_timeseries(qdm_values, gt_values)
            )
            
            g6solar_main.close()
            g6solar_qdm.close()
            
        else:
            # Load main results (for bilinear and MLd)
            main_file = results_dir / f"{var}_evaluation_{scenario}.nc"
            if main_file.exists():
                ds_main = xr.open_dataset(main_file)
                gt = ds_main['groundtruth']
                gt_values = gt.values
                gt_mean = compute_spatial_mean(gt)
                
                # Bilinear (input)
                if 'input' in ds_main:
                    input_data = ds_main['input']
                    input_values = input_data.values
                    input_mean = compute_spatial_mean(input_data)
                    temporal_metrics[scenario]['bilinear'] = compute_temporal_metrics(gt_mean, input_mean)
                    spatial_metrics[scenario]['bilinear'] = (
                        compute_spatial_rmse_timeseries(input_values, gt_values),
                        compute_spatial_r2_timeseries(input_values, gt_values)
                    )
                
                # MLd (zscore_pixel)
                if 'pred_zscore_pixel' in ds_main:
                    pred_data = ds_main['pred_zscore_pixel']
                    pred_values = pred_data.values
                    pred_mean = compute_spatial_mean(pred_data)
                    temporal_metrics[scenario]['mld'] = compute_temporal_metrics(gt_mean, pred_mean)
                    spatial_metrics[scenario]['mld'] = (
                        compute_spatial_rmse_timeseries(pred_values, gt_values),
                        compute_spatial_r2_timeseries(pred_values, gt_values)
                    )
                
                ds_main.close()
            
            # Load QDM results
            qdm_file = results_dir / f"{var}_evaluation_{scenario}_qdm.nc"
            if qdm_file.exists():
                ds_qdm = xr.open_dataset(qdm_file)
                gt = ds_qdm['groundtruth']
                gt_values = gt.values
                gt_mean = compute_spatial_mean(gt)
                
                if 'pred_qdm' in ds_qdm:
                    pred_data = ds_qdm['pred_qdm']
                    pred_values = pred_data.values
                    pred_mean = compute_spatial_mean(pred_data)
                    temporal_metrics[scenario]['qdm'] = compute_temporal_metrics(gt_mean, pred_mean)
                    spatial_metrics[scenario]['qdm'] = (
                        compute_spatial_rmse_timeseries(pred_values, gt_values),
                        compute_spatial_r2_timeseries(pred_values, gt_values)
                    )
                
                ds_qdm.close()
            
            # Load residual results (for MLs = PLD-Res)
            residual_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
            if residual_file.exists():
                ds_res = xr.open_dataset(residual_file)
                gt = ds_res['groundtruth']
                gt_values = gt.values
                gt_mean = compute_spatial_mean(gt)
                
                if 'pred_grid' in ds_res:
                    pred_data = ds_res['pred_grid']
                    pred_values = pred_data.values
                    pred_mean = compute_spatial_mean(pred_data)
                    temporal_metrics[scenario]['mls'] = compute_temporal_metrics(gt_mean, pred_mean)
                    spatial_metrics[scenario]['mls'] = (
                        compute_spatial_rmse_timeseries(pred_values, gt_values),
                        compute_spatial_r2_timeseries(pred_values, gt_values)
                    )
                
                ds_res.close()
    
    # ----------------------------
    # Print Tables for this variable
    # ----------------------------
    model_keys = list(models.keys())
    model_names = [models[m] for m in model_keys]
    
    # Header
    col_width = 12
    header = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width}}" for name in model_names])
    separator = "-" * len(header)
    
    # ----------------------------
    # Table 1: RMSE - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 1: RMSE ({var_units[var]}) - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                rmse = temporal_metrics[scenario][m][0]
                row += f"{rmse:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 2: Correlation - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 2: Correlation - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                corr = temporal_metrics[scenario][m][1]
                row += f"{corr:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 3: RMSE - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    col_width_spatial = 22
    header_spatial = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width_spatial}}" for name in model_names])
    separator_spatial = "-" * len(header_spatial)
    
    print("\n" + "=" * len(header_spatial))
    print(f"Table 3: RMSE ({var_units[var]}) - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                rmse_arr = spatial_metrics[scenario][m][0]
                mean_val = np.mean(rmse_arr)
                p5 = np.percentile(rmse_arr, 5)
                p95 = np.percentile(rmse_arr, 95)
                cell = f"{mean_val:.2f} [{p5:.2f}, {p95:.2f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)
    
    # ----------------------------
    # Table 4: R² - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    print("\n" + "=" * len(header_spatial))
    print(f"Table 4: R² - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                r2_arr = spatial_metrics[scenario][m][1]
                mean_val = np.mean(r2_arr)
                p5 = np.percentile(r2_arr, 5)
                p95 = np.percentile(r2_arr, 95)
                cell = f"{mean_val:.4f} [{p5:.4f}, {p95:.4f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)

print("\n" + "=" * 80)
print(f"Total models compared: {len(models)}")
print(f"Variables: {', '.join(variables)}")
print("=" * 80)


VARIABLE: TAS (°C)

Table 1: RMSE (°C) - Temporal Mean
Scenario               BI          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.3941       0.3056       0.5075       0.2965 
ssp126             0.4674       0.3183       0.8816       0.3417 
ssp245             0.4213       0.3280       1.0831       0.3234 
ssp585             0.4154       0.3184       1.5299       0.3229 
g6sulfur           0.5133       0.7069       0.9731       0.4908 
g6solar            0.4771       0.6722       0.8185       0.4454 

Table 2: Correlation - Temporal Mean
Scenario               BI          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.9836       0.9838       0.9824       0.9859 
ssp126             0.9784       0.9796       0.9711       0.9853 
ssp245             0.9788       0.9794       0.9573       0.9844 
ssp585             0.9857       0.9860       0.9369 

In [2]:
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")

variables = ['tas', 'pr']
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur']

# Selected models only
models = {
    'bilinear': 'Bilinear',
    'qdm': 'QDM',
    'mld': 'MLd',
    'mls': 'MLs'
}

# Units for each variable
var_units = {
    'tas': '°C',
    'pr': 'mm/day'
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_temporal_metrics(y_true, y_pred):
    """RMSE and correlation on temporal mean time series."""
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    """Calculate spatial RMSE for each time point."""
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    """Calculate spatial R² for each time point."""
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

# ----------------------------
# Process each variable
# ----------------------------
for var in variables:
    print("\n" + "=" * 80)
    print(f"VARIABLE: {var.upper()} ({var_units[var]})")
    print("=" * 80)
    
    # Storage for metrics
    temporal_metrics = {scen: {} for scen in scenarios}
    spatial_metrics = {scen: {} for scen in scenarios}
    
    # ----------------------------
    # Load data and compute metrics
    # ----------------------------
    for scenario in scenarios:
        # Load main results (for bilinear and MLd)
        main_file = results_dir / f"{var}_evaluation_{scenario}.nc"
        if main_file.exists():
            ds_main = xr.open_dataset(main_file)
            gt = ds_main['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            # Bilinear (input)
            if 'input' in ds_main:
                input_data = ds_main['input']
                input_values = input_data.values
                input_mean = compute_spatial_mean(input_data)
                temporal_metrics[scenario]['bilinear'] = compute_temporal_metrics(gt_mean, input_mean)
                spatial_metrics[scenario]['bilinear'] = (
                    compute_spatial_rmse_timeseries(input_values, gt_values),
                    compute_spatial_r2_timeseries(input_values, gt_values)
                )
            
            # MLd (zscore_pixel)
            if 'pred_zscore_pixel' in ds_main:
                pred_data = ds_main['pred_zscore_pixel']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['mld'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['mld'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_main.close()
        
        # Load QDM results
        qdm_file = results_dir / f"{var}_evaluation_{scenario}_qdm.nc"
        if qdm_file.exists():
            ds_qdm = xr.open_dataset(qdm_file)
            gt = ds_qdm['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            if 'pred_qdm' in ds_qdm:
                pred_data = ds_qdm['pred_qdm']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['qdm'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['qdm'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_qdm.close()
        
        # Load residual results (for MLs = PLD-Res)
        residual_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
        if residual_file.exists():
            ds_res = xr.open_dataset(residual_file)
            gt = ds_res['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            if 'pred_grid' in ds_res:
                pred_data = ds_res['pred_grid']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['mls'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['mls'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_res.close()
    
    # ----------------------------
    # Print Tables for this variable
    # ----------------------------
    model_keys = list(models.keys())
    model_names = [models[m] for m in model_keys]
    
    # Header
    col_width = 12
    header = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width}}" for name in model_names])
    separator = "-" * len(header)
    
    # ----------------------------
    # Table 1: RMSE - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 1: RMSE ({var_units[var]}) - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                rmse = temporal_metrics[scenario][m][0]
                row += f"{rmse:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 2: Correlation - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 2: Correlation - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                corr = temporal_metrics[scenario][m][1]
                row += f"{corr:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 3: RMSE - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    col_width_spatial = 22
    header_spatial = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width_spatial}}" for name in model_names])
    separator_spatial = "-" * len(header_spatial)
    
    print("\n" + "=" * len(header_spatial))
    print(f"Table 3: RMSE ({var_units[var]}) - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                rmse_arr = spatial_metrics[scenario][m][0]
                mean_val = np.mean(rmse_arr)
                p5 = np.percentile(rmse_arr, 5)
                p95 = np.percentile(rmse_arr, 95)
                cell = f"{mean_val:.2f} [{p5:.2f}, {p95:.2f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)
    
    # ----------------------------
    # Table 4: R² - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    print("\n" + "=" * len(header_spatial))
    print(f"Table 4: R² - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                r2_arr = spatial_metrics[scenario][m][1]
                mean_val = np.mean(r2_arr)
                p5 = np.percentile(r2_arr, 5)
                p95 = np.percentile(r2_arr, 95)
                cell = f"{mean_val:.4f} [{p5:.4f}, {p95:.4f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)

print("\n" + "=" * 80)
print(f"Total models compared: {len(models)}")
print(f"Variables: {', '.join(variables)}")
print("=" * 80)


VARIABLE: TAS (°C)

Table 1: RMSE (°C) - Temporal Mean
Scenario         Bilinear          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.3941       0.3056       0.5075       0.2965 
ssp126             0.4674       0.3183       0.8816       0.3417 
ssp245             0.4213       0.3280       1.0831       0.3234 
ssp585             0.4154       0.3184       1.5299       0.3229 
g6sulfur           0.5133       0.7069       0.9731       0.4908 

Table 2: Correlation - Temporal Mean
Scenario         Bilinear          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.9836       0.9838       0.9824       0.9859 
ssp126             0.9784       0.9796       0.9711       0.9853 
ssp245             0.9788       0.9794       0.9573       0.9844 
ssp585             0.9857       0.9860       0.9369       0.9890 
g6sulfur           0.9662       0.9656       0.9372 

In [3]:
import numpy as np
import xarray as xr
from pathlib import Path
from scipy.stats import pearsonr
import pandas as pd

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
var = 'tas'
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585']
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                  'zscore_pixel', 'instance_zscore', 'instance_minmax']

norm_labels = {
    'none': 'Raw',
    'minmax_global': 'MM-G',
    'minmax_pixel': 'MM-P',
    'zscore_global': 'ZS-G',
    'zscore_pixel': 'ZS-P',
    'instance_zscore': 'Inst-ZS',
    'instance_minmax': 'Inst-MM'
}

scenario_labels = {
    'historical': 'Historical',
    'ssp126': 'SSP1-2.6',
    'ssp245': 'SSP2-4.5',
    'ssp585': 'SSP5-8.5'
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_rmse_timeseries(predictions, groundtruth):
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_global_metrics(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

# ----------------------------
# Compute all metrics
# ----------------------------
spatial_rmse = {norm: {} for norm in normalizations}
spatial_r2 = {norm: {} for norm in normalizations}
global_rmse = {norm: {} for norm in normalizations}
global_corr = {norm: {} for norm in normalizations}

for scenario in scenarios:
    results_file = results_dir / f"{var}_evaluation_{scenario}.nc"
    ds = xr.open_dataset(results_file)
    gt_data = ds['groundtruth'].values
    gt_spatial_mean = compute_spatial_mean(ds['groundtruth'])
    
    for norm in normalizations:
        pred_key = f'pred_{norm}'
        if pred_key not in ds:
            continue
        
        pred_data = ds[pred_key].values
        pred_spatial_mean = compute_spatial_mean(ds[pred_key])
        
        # Spatial metrics (mean across time)
        rmse_ts = compute_spatial_rmse_timeseries(pred_data, gt_data)
        r2_ts = compute_spatial_r2_timeseries(pred_data, gt_data)
        spatial_rmse[norm][scenario] = np.mean(rmse_ts)
        spatial_r2[norm][scenario] = np.mean(r2_ts)
        
        # Global mean timeseries metrics
        rmse, corr = compute_global_metrics(gt_spatial_mean, pred_spatial_mean)
        global_rmse[norm][scenario] = rmse
        global_corr[norm][scenario] = corr

# ----------------------------
# Create DataFrames
# ----------------------------
df_spatial_rmse = pd.DataFrame(spatial_rmse).T
df_spatial_rmse.columns = [scenario_labels[s] for s in scenarios]
df_spatial_rmse.index = [norm_labels[n] for n in normalizations]

df_spatial_r2 = pd.DataFrame(spatial_r2).T
df_spatial_r2.columns = [scenario_labels[s] for s in scenarios]
df_spatial_r2.index = [norm_labels[n] for n in normalizations]

df_global_rmse = pd.DataFrame(global_rmse).T
df_global_rmse.columns = [scenario_labels[s] for s in scenarios]
df_global_rmse.index = [norm_labels[n] for n in normalizations]

df_global_corr = pd.DataFrame(global_corr).T
df_global_corr.columns = [scenario_labels[s] for s in scenarios]
df_global_corr.index = [norm_labels[n] for n in normalizations]

# ----------------------------
# Display Tables
# ----------------------------
print("=" * 80)
print("TABLE 1: Spatial RMSE (°C) - Mean across time")
print("=" * 80)
print(df_spatial_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 2: Spatial R² - Mean across time")
print("=" * 80)
print(df_spatial_r2.round(4))
print("\n")

print("=" * 80)
print("TABLE 3: Global Mean Timeseries RMSE (°C)")
print("=" * 80)
print(df_global_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 4: Global Mean Timeseries Correlation")
print("=" * 80)
print(df_global_corr.round(4))

TABLE 1: Spatial RMSE (°C) - Mean across time
         Historical  SSP1-2.6  SSP2-4.5  SSP5-8.5
Raw          2.6181    2.7092    2.8388    3.1587
MM-G         4.4511    4.3072    4.3325    4.3183
MM-P         2.1215    2.3391    2.4934    2.8214
ZS-G         2.6656    2.6307    2.7132    2.9547
ZS-P         2.0433    2.2696    2.4441    2.7790
Inst-ZS      3.1534    3.0709    3.0573    3.0370
Inst-MM      4.1309    3.9698    3.9580    3.8915


TABLE 2: Spatial R² - Mean across time
         Historical  SSP1-2.6  SSP2-4.5  SSP5-8.5
Raw          0.9845    0.9828    0.9808    0.9744
MM-G         0.9553    0.9569    0.9560    0.9555
MM-P         0.9897    0.9869    0.9849    0.9793
ZS-G         0.9839    0.9838    0.9825    0.9782
ZS-P         0.9904    0.9876    0.9854    0.9798
Inst-ZS      0.9774    0.9779    0.9778    0.9776
Inst-MM      0.9613    0.9633    0.9631    0.9637


TABLE 3: Global Mean Timeseries RMSE (°C)
         Historical  SSP1-2.6  SSP2-4.5  SSP5-8.5
Raw          0.3088

In [4]:
import numpy as np
import xarray as xr
from pathlib import Path
from scipy.stats import pearsonr
import pandas as pd

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")
var = 'tas'
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585']

# Normalization methods
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                  'zscore_pixel', 'instance_zscore', 'instance_minmax']

norm_labels = {
    'none': 'None',
    'minmax_global': 'MM-G',
    'minmax_pixel': 'MM-P',
    'zscore_global': 'ZS-G',
    'zscore_pixel': 'ZS-P',
    'instance_zscore': 'Inst-ZS',
    'instance_minmax': 'Inst-MM'
}

# Residual models
residual_models = ['raw', 'gma', 'gmt', 'grid']
residual_labels = {
    'raw': 'LR-Res',
    'gma': 'GMA-Res',
    'gmt': 'GMT-Res',
    'grid': 'PLD-Res'
}

scenario_labels = {
    'historical': 'Historical',
    'ssp126': 'SSP1-2.6',
    'ssp245': 'SSP2-4.5',
    'ssp585': 'SSP5-8.5'
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_rmse_timeseries(predictions, groundtruth):
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_global_metrics(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

# ----------------------------
# Compute all metrics
# ----------------------------
all_models = list(normalizations) + residual_models
spatial_rmse = {model: {} for model in all_models}
spatial_r2 = {model: {} for model in all_models}
global_rmse = {model: {} for model in all_models}
global_corr = {model: {} for model in all_models}

for scenario in scenarios:
    # Load normalization results
    results_file = results_dir / f"{var}_evaluation_{scenario}.nc"
    if results_file.exists():
        ds = xr.open_dataset(results_file)
        gt_data = ds['groundtruth'].values
        gt_spatial_mean = compute_spatial_mean(ds['groundtruth'])
        
        for norm in normalizations:
            pred_key = f'pred_{norm}'
            if pred_key not in ds:
                continue
            
            pred_data = ds[pred_key].values
            pred_spatial_mean = compute_spatial_mean(ds[pred_key])
            
            # Spatial metrics (mean across time)
            rmse_ts = compute_spatial_rmse_timeseries(pred_data, gt_data)
            r2_ts = compute_spatial_r2_timeseries(pred_data, gt_data)
            spatial_rmse[norm][scenario] = np.mean(rmse_ts)
            spatial_r2[norm][scenario] = np.mean(r2_ts)
            
            # Global mean timeseries metrics
            rmse, corr = compute_global_metrics(gt_spatial_mean, pred_spatial_mean)
            global_rmse[norm][scenario] = rmse
            global_corr[norm][scenario] = corr
    
    # Load residual model results
    residual_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
    if residual_file.exists():
        ds_residual = xr.open_dataset(residual_file)
        gt_data = ds_residual['groundtruth'].values
        gt_spatial_mean = compute_spatial_mean(ds_residual['groundtruth'])
        
        for res_model in residual_models:
            pred_key = f'pred_{res_model}'
            if pred_key not in ds_residual:
                continue
            
            pred_data = ds_residual[pred_key].values
            pred_spatial_mean = compute_spatial_mean(ds_residual[pred_key])
            
            # Spatial metrics (mean across time)
            rmse_ts = compute_spatial_rmse_timeseries(pred_data, gt_data)
            r2_ts = compute_spatial_r2_timeseries(pred_data, gt_data)
            spatial_rmse[res_model][scenario] = np.mean(rmse_ts)
            spatial_r2[res_model][scenario] = np.mean(r2_ts)
            
            # Global mean timeseries metrics
            rmse, corr = compute_global_metrics(gt_spatial_mean, pred_spatial_mean)
            global_rmse[res_model][scenario] = rmse
            global_corr[res_model][scenario] = corr

# ----------------------------
# Create DataFrames with combined models
# ----------------------------
# Combine labels
all_labels = {**norm_labels, **residual_labels}

df_spatial_rmse = pd.DataFrame(spatial_rmse).T
df_spatial_rmse.columns = [scenario_labels[s] for s in scenarios]
df_spatial_rmse.index = [all_labels[m] for m in all_models if m in spatial_rmse and spatial_rmse[m]]

df_spatial_r2 = pd.DataFrame(spatial_r2).T
df_spatial_r2.columns = [scenario_labels[s] for s in scenarios]
df_spatial_r2.index = [all_labels[m] for m in all_models if m in spatial_r2 and spatial_r2[m]]

df_global_rmse = pd.DataFrame(global_rmse).T
df_global_rmse.columns = [scenario_labels[s] for s in scenarios]
df_global_rmse.index = [all_labels[m] for m in all_models if m in global_rmse and global_rmse[m]]

df_global_corr = pd.DataFrame(global_corr).T
df_global_corr.columns = [scenario_labels[s] for s in scenarios]
df_global_corr.index = [all_labels[m] for m in all_models if m in global_corr and global_corr[m]]

# ----------------------------
# Display Tables
# ----------------------------
print("=" * 80)
print("TABLE 1: Spatial RMSE (°C) - Mean across time")
print("=" * 80)
print(df_spatial_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 2: Spatial R² - Mean across time")
print("=" * 80)
print(df_spatial_r2.round(4))
print("\n")

print("=" * 80)
print("TABLE 3: Global Mean Timeseries RMSE (°C)")
print("=" * 80)
print(df_global_rmse.round(4))
print("\n")

print("=" * 80)
print("TABLE 4: Global Mean Timeseries Correlation")
print("=" * 80)
print(df_global_corr.round(4))

TABLE 1: Spatial RMSE (°C) - Mean across time
         Historical  SSP1-2.6  SSP2-4.5  SSP5-8.5
None         2.6181    2.7092    2.8388    3.1587
MM-G         4.4511    4.3072    4.3325    4.3183
MM-P         2.1215    2.3391    2.4934    2.8214
ZS-G         2.6656    2.6307    2.7132    2.9547
ZS-P         2.0433    2.2696    2.4441    2.7790
Inst-ZS      3.1534    3.0709    3.0573    3.0370
Inst-MM      4.1309    3.9698    3.9580    3.8915
LR-Res       1.9032    1.9535    1.9869    2.0281
GMA-Res      1.9250    2.0612    2.1028    2.1756
GMT-Res      1.9004    1.9905    2.0337    2.1038
PLD-Res      1.9007    1.8924    1.9080    1.9336


TABLE 2: Spatial R² - Mean across time
         Historical  SSP1-2.6  SSP2-4.5  SSP5-8.5
None         0.9845    0.9828    0.9808    0.9744
MM-G         0.9553    0.9569    0.9560    0.9555
MM-P         0.9897    0.9869    0.9849    0.9793
ZS-G         0.9839    0.9838    0.9825    0.9782
ZS-P         0.9904    0.9876    0.9854    0.9798
Inst-ZS      