In [2]:
import xarray as xr
import numpy as np
from pathlib import Path
from scipy.stats import pearsonr

# ----------------------------
# Configuration
# ----------------------------
results_dir = Path("../evaluation_results")
results_residual_dir = Path("../evaluation_results_residual")

variables = ['tas', 'pr']
scenarios = ['historical', 'ssp126', 'ssp245', 'ssp585', 'g6sulfur']

# Selected models only
models = {
    'bilinear': 'Bilinear',
    'qdm': 'QDM',
    'mld': 'MLd',
    'mls': 'MLs'
}

# Units for each variable
var_units = {
    'tas': '°C',
    'pr': 'mm/day'
}

# ----------------------------
# Functions
# ----------------------------
def compute_spatial_mean(data):
    return data.mean(dim=['lat', 'lon']).values

def compute_temporal_metrics(y_true, y_pred):
    """RMSE and correlation on temporal mean time series."""
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    corr, _ = pearsonr(y_true, y_pred)
    return rmse, corr

def compute_spatial_rmse_timeseries(predictions, groundtruth):
    """Calculate spatial RMSE for each time point."""
    n_time = predictions.shape[0]
    rmse_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        rmse = np.sqrt(np.mean((pred_2d - gt_2d)**2))
        rmse_values.append(rmse)
    return np.array(rmse_values)

def compute_spatial_r2_timeseries(predictions, groundtruth):
    """Calculate spatial R² for each time point."""
    n_time = predictions.shape[0]
    r2_values = []
    for t in range(n_time):
        pred_2d = predictions[t].flatten()
        gt_2d = groundtruth[t].flatten()
        ss_res = np.sum((gt_2d - pred_2d)**2)
        ss_tot = np.sum((gt_2d - np.mean(gt_2d))**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        r2_values.append(r2)
    return np.array(r2_values)

# ----------------------------
# Process each variable
# ----------------------------
for var in variables:
    print("\n" + "=" * 80)
    print(f"VARIABLE: {var.upper()} ({var_units[var]})")
    print("=" * 80)
    
    # Storage for metrics
    temporal_metrics = {scen: {} for scen in scenarios}
    spatial_metrics = {scen: {} for scen in scenarios}
    
    # ----------------------------
    # Load data and compute metrics
    # ----------------------------
    for scenario in scenarios:
        # Load main results (for bilinear and MLd)
        main_file = results_dir / f"{var}_evaluation_{scenario}.nc"
        if main_file.exists():
            ds_main = xr.open_dataset(main_file)
            gt = ds_main['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            # Bilinear (input)
            if 'input' in ds_main:
                input_data = ds_main['input']
                input_values = input_data.values
                input_mean = compute_spatial_mean(input_data)
                temporal_metrics[scenario]['bilinear'] = compute_temporal_metrics(gt_mean, input_mean)
                spatial_metrics[scenario]['bilinear'] = (
                    compute_spatial_rmse_timeseries(input_values, gt_values),
                    compute_spatial_r2_timeseries(input_values, gt_values)
                )
            
            # MLd (zscore_pixel)
            if 'pred_zscore_pixel' in ds_main:
                pred_data = ds_main['pred_zscore_pixel']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['mld'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['mld'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_main.close()
        
        # Load QDM results
        qdm_file = results_dir / f"{var}_evaluation_{scenario}_qdm.nc"
        if qdm_file.exists():
            ds_qdm = xr.open_dataset(qdm_file)
            gt = ds_qdm['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            if 'pred_qdm' in ds_qdm:
                pred_data = ds_qdm['pred_qdm']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['qdm'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['qdm'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_qdm.close()
        
        # Load residual results (for MLs = PLD-Res)
        residual_file = results_residual_dir / f"{var}_evaluation_{scenario}.nc"
        if residual_file.exists():
            ds_res = xr.open_dataset(residual_file)
            gt = ds_res['groundtruth']
            gt_values = gt.values
            gt_mean = compute_spatial_mean(gt)
            
            if 'pred_grid' in ds_res:
                pred_data = ds_res['pred_grid']
                pred_values = pred_data.values
                pred_mean = compute_spatial_mean(pred_data)
                temporal_metrics[scenario]['mls'] = compute_temporal_metrics(gt_mean, pred_mean)
                spatial_metrics[scenario]['mls'] = (
                    compute_spatial_rmse_timeseries(pred_values, gt_values),
                    compute_spatial_r2_timeseries(pred_values, gt_values)
                )
            
            ds_res.close()
    
    # ----------------------------
    # Print Tables for this variable
    # ----------------------------
    model_keys = list(models.keys())
    model_names = [models[m] for m in model_keys]
    
    # Header
    col_width = 12
    header = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width}}" for name in model_names])
    separator = "-" * len(header)
    
    # ----------------------------
    # Table 1: RMSE - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 1: RMSE ({var_units[var]}) - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                rmse = temporal_metrics[scenario][m][0]
                row += f"{rmse:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 2: Correlation - Temporal Mean
    # ----------------------------
    print("\n" + "=" * len(header))
    print(f"Table 2: Correlation - Temporal Mean")
    print("=" * len(header))
    print(header)
    print(separator)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in temporal_metrics[scenario]:
                corr = temporal_metrics[scenario][m][1]
                row += f"{corr:>{col_width}.4f} "
            else:
                row += f"{'N/A':>{col_width}} "
        print(row)
    
    # ----------------------------
    # Table 3: RMSE - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    col_width_spatial = 22
    header_spatial = f"{'Scenario':<12} " + " ".join([f"{name:>{col_width_spatial}}" for name in model_names])
    separator_spatial = "-" * len(header_spatial)
    
    print("\n" + "=" * len(header_spatial))
    print(f"Table 3: RMSE ({var_units[var]}) - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                rmse_arr = spatial_metrics[scenario][m][0]
                mean_val = np.mean(rmse_arr)
                p5 = np.percentile(rmse_arr, 5)
                p95 = np.percentile(rmse_arr, 95)
                cell = f"{mean_val:.2f} [{p5:.2f}, {p95:.2f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)
    
    # ----------------------------
    # Table 4: R² - Spatial (mean [5th, 95th percentile])
    # ----------------------------
    print("\n" + "=" * len(header_spatial))
    print(f"Table 4: R² - Spatial (mean [5th, 95th])")
    print("=" * len(header_spatial))
    print(header_spatial)
    print(separator_spatial)
    
    for scenario in scenarios:
        row = f"{scenario:<12} "
        for m in model_keys:
            if m in spatial_metrics[scenario]:
                r2_arr = spatial_metrics[scenario][m][1]
                mean_val = np.mean(r2_arr)
                p5 = np.percentile(r2_arr, 5)
                p95 = np.percentile(r2_arr, 95)
                cell = f"{mean_val:.4f} [{p5:.4f}, {p95:.4f}]"
                row += f"{cell:>{col_width_spatial}} "
            else:
                row += f"{'N/A':>{col_width_spatial}} "
        print(row)

print("\n" + "=" * 80)
print(f"Total models compared: {len(models)}")
print(f"Variables: {', '.join(variables)}")
print("=" * 80)


VARIABLE: TAS (°C)

Table 1: RMSE (°C) - Temporal Mean
Scenario         Bilinear          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.3941       0.3056       0.5075       0.2965 
ssp126             0.4674       0.3183       0.8816       0.3417 
ssp245             0.4213       0.3280       1.0831       0.3234 
ssp585             0.4154       0.3184       1.5299       0.3229 
g6sulfur           0.5133       0.7069       0.9731       0.4908 

Table 2: Correlation - Temporal Mean
Scenario         Bilinear          QDM          MLd          MLs
----------------------------------------------------------------
historical         0.9836       0.9838       0.9824       0.9859 
ssp126             0.9784       0.9796       0.9711       0.9853 
ssp245             0.9788       0.9794       0.9573       0.9844 
ssp585             0.9857       0.9860       0.9369       0.9890 
g6sulfur           0.9662       0.9656       0.9372 