In [1]:
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")

var = 'tas'
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur']

# All models
models = {
    # Normalization models
    'none': 'None',
    'minmax_global': 'MM-G',
    'minmax_pixel': 'MM-P',
    'zscore_global': 'ZS-G',
    'zscore_pixel': 'ZS-P',
    'instance_zscore': 'Inst-ZS',
    'instance_minmax': 'Inst-MM',
    # Residual models
    'raw_res': 'Raw-Res',
    'gma_res': 'GMA-Res',
    'gmt_res': 'GMT-Res',
    'pld_res': 'PLD-Res',
    # Baselines
    'qdm': 'QDM',
    'bilinear': 'Bilinear',
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_temporal_metrics(y_true, y_pred):
    """RMSE and correlation on temporal mean time series."""
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    """Calculate spatial RMSE for each time point."""
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    """Calculate spatial R² for each time point."""
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

# ----------------------------
# Load data and compute metrics
# ----------------------------
# Temporal metrics: (rmse, corr)
temporal_metrics = {scen: {} for scen in scenarios}
# Spatial metrics: (rmse_array, r2_array)
spatial_metrics = {scen: {} for scen in scenarios}

for scenario in scenarios:
    # ---------------------------
    # Load main results (normalization models + bilinear)
    # ---------------------------
    main_file = results_dir / f"tas_evaluation_{scenario}.nc"
    if main_file.exists():
        ds_main = xr.open_dataset(main_file)
        gt = ds_main['groundtruth']
        gt_values = gt.values
        gt_mean = compute_spatial_mean(gt)
        
        # Bilinear (input)
        if 'input' in ds_main:
            input_data = ds_main['input']
            input_values = input_data.values
            input_mean = compute_spatial_mean(input_data)
            temporal_metrics[scenario]['bilinear'] = compute_temporal_metrics(gt_mean, input_mean)
            spatial_metrics[scenario]['bilinear'] = (
                compute_spatial_rmse_timeseries(input_values, gt_values),
                compute_spatial_r2_timeseries(input_values, gt_values)
            )
        
        # Normalization models
        norm_models = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                       'zscore_pixel', 'instance_zscore', 'instance_minmax']
        for norm in norm_models:
            pred_key = f'pred_{norm}'
            if pred_key in ds_main:
                pred_data = ds_main[pred_key]
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario][norm] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario][norm] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
    
    # ---------------------------
    # Load QDM results
    # ---------------------------
    qdm_file = results_dir / f"tas_evaluation_{scenario}_qdm.nc"
    if qdm_file.exists():
        ds_qdm = xr.open_dataset(qdm_file)
        gt = ds_qdm['groundtruth']
        gt_values = gt.values
        gt_mean = compute_spatial_mean(gt)
        
        if 'pred_qdm' in ds_qdm:
            pred_data = ds_qdm['pred_qdm']
            pred_values = pred_data.values
            pred_mean = compute_spatial_mean(pred_data)
            temporal_metrics[scenario]['qdm'] = compute_temporal_metrics(gt_mean, pred_mean)
            spatial_metrics[scenario]['qdm'] = (
                compute_spatial_rmse_timeseries(pred_values, gt_values),
                compute_spatial_r2_timeseries(pred_values, gt_values)
            )
    
    # ---------------------------
    # Load residual results
    # ---------------------------
    residual_file = results_residual_dir / f"tas_evaluation_{scenario}.nc"
    if residual_file.exists():
        ds_res = xr.open_dataset(residual_file)
        gt = ds_res['groundtruth']
        gt_values = gt.values
        gt_mean = compute_spatial_mean(gt)
        
        # Residual model mapping
        res_mapping = {
            'raw_res': 'pred_raw',
            'gma_res': 'pred_gma',
            'gmt_res': 'pred_gmt',
            'pld_res': 'pred_grid'
        }
        
        for model_key, data_key in res_mapping.items():
            if data_key in ds_res:
                pred_data = ds_res[data_key]
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario][model_key] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario][model_key] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )

# ----------------------------
# Print Tables
# ----------------------------
model_keys = list(models.keys())
model_names = [models[m] for m in model_keys]

# Header
col_width = 12
header = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width}}" for name in model_names])
separator = "-" * len(header)

# ----------------------------
# Table 1: RMSE - Temporal Mean
# ----------------------------
print("\n" + "=" * len(header))
print("Table 1: RMSE (°C) - Temporal Mean")
print("=" * len(header))
print(header)
print(separator)

for scenario in scenarios:
    row = f"{scenario:<12} "
    for m in model_keys:
        if m in temporal_metrics[scenario]:
            rmse = temporal_metrics[scenario][m][0]
            row += f"{rmse:>{col_width}.4f} "
        else:
            row += f"{'N/A':>{col_width}} "
    print(row)

# ----------------------------
# Table 2: Correlation - Temporal Mean
# ----------------------------
print("\n" + "=" * len(header))
print("Table 2: Correlation - Temporal Mean")
print("=" * len(header))
print(header)
print(separator)

for scenario in scenarios:
    row = f"{scenario:<12} "
    for m in model_keys:
        if m in temporal_metrics[scenario]:
            corr = temporal_metrics[scenario][m][1]
            row += f"{corr:>{col_width}.4f} "
        else:
            row += f"{'N/A':>{col_width}} "
    print(row)

# ----------------------------
# Table 3: RMSE - Spatial (mean [5th, 95th percentile])
# ----------------------------
col_width_spatial = 22
header_spatial = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width_spatial}}" for name in model_names])
separator_spatial = "-" * len(header_spatial)

print("\n" + "=" * len(header_spatial))
print("Table 3: RMSE (°C) - Spatial (mean [5th, 95th])")
print("=" * len(header_spatial))
print(header_spatial)
print(separator_spatial)

for scenario in scenarios:
    row = f"{scenario:<12} "
    for m in model_keys:
        if m in spatial_metrics[scenario]:
            rmse_arr = spatial_metrics[scenario][m][0]
            mean_val = np.mean(rmse_arr)
            p5 = np.percentile(rmse_arr, 5)
            p95 = np.percentile(rmse_arr, 95)
            cell = f"{mean_val:.2f} [{p5:.2f}, {p95:.2f}]"
            row += f"{cell:>{col_width_spatial}} "
        else:
            row += f"{'N/A':>{col_width_spatial}} "
    print(row)

# ----------------------------
# Table 4: R² - Spatial (mean [5th, 95th percentile])
# ----------------------------
print("\n" + "=" * len(header_spatial))
print("Table 4: R² - Spatial (mean [5th, 95th])")
print("=" * len(header_spatial))
print(header_spatial)
print(separator_spatial)

for scenario in scenarios:
    row = f"{scenario:<12} "
    for m in model_keys:
        if m in spatial_metrics[scenario]:
            r2_arr = spatial_metrics[scenario][m][1]
            mean_val = np.mean(r2_arr)
            p5 = np.percentile(r2_arr, 5)
            p95 = np.percentile(r2_arr, 95)
            cell = f"{mean_val:.4f} [{p5:.4f}, {p95:.4f}]"
            row += f"{cell:>{col_width_spatial}} "
        else:
            row += f"{'N/A':>{col_width_spatial}} "
    print(row)

print("\n" + "=" * 50)
print(f"Total models compared: {len(models)}")
print("=" * 50)

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (



Table 1: RMSE (°C) - Temporal Mean
Scenario             None         MM-G         MM-P         ZS-G         ZS-P      Inst-ZS      Inst-MM      Raw-Res      GMA-Res      GMT-Res      PLD-Res          QDM     Bilinear
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
historical         0.3088       1.2254       0.5159       0.3061       0.5075       0.7443       1.1327       0.2779       0.3331       0.2887       0.2965       0.3056       0.3941 
ssp126             0.8384       1.0173       0.9116       0.5086       0.8816       0.6094       1.0525       0.3421       0.5226       0.3616       0.3417       0.3183       0.4674 
ssp245             1.1747       0.9952       1.1130       0.7956       1.0831       0.6429       1.1308       0.3407       0.5127       0.3543       0.3234       0.3280       0.4213 
ssp585             1.8075       0.8753       1.5493