In [1]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
import time

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variables to evaluate
variables = ['tas', 'pr']

# Normalization methods for direct models
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                  'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Input types for residual models
residual_input_types = ['raw', 'gma', 'grid', 'gmt']

# Test periods for all scenarios
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'),
    'ssp585': ('2015', '2100'),
    'g6sulfur': ('2020', '2099')
}

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


Using device: cuda


In [2]:
print("Loading datasets...")
load_start = time.time()

# Original datasets (for direct models and ground truth)
datasets_original = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc"),
    'g6sulfur': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")
}

# Residual/detrended datasets (for residual models)
datasets_residual = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_residual_detrended.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_residual_detrended.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_residual_detrended.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_residual_detrended.nc"),
    'g6sulfur': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_residual_detrended.nc")
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats_direct = pickle.load(f)

with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats_residual = pickle.load(f)

load_time = time.time() - load_start
print(f"Direct norm stats loaded for: {list(norm_stats_direct.keys())}")
print(f"Residual norm stats loaded for: {list(norm_stats_residual.keys())}")
print(f"\nDatasets loaded in {load_time:.2f} seconds")

Loading datasets...
Loading normalization statistics...
Direct norm stats loaded for: ['pr', 'tas', 'hurs', 'sfcWind']
Residual norm stats loaded for: ['pr', 'tas']

Datasets loaded in 0.86 seconds


In [3]:
# Cell 3: Helper Functions

def apply_normalization(data, method, norm_stats, var, resolution='lr_interp'):
    """Apply normalization method to data."""
    data_np = data.values if hasattr(data, 'values') else data
    
    if method == 'none':
        return data_np
    elif method == 'minmax_global':
        data_min = norm_stats[var][resolution]['global_min']
        data_max = norm_stats[var][resolution]['global_max']
        return 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
    elif method == 'minmax_pixel':
        data_min = norm_stats[var][resolution]['pixel_min']
        data_max = norm_stats[var][resolution]['pixel_max']
        return 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
    elif method == 'zscore_global':
        data_mean = norm_stats[var][resolution]['global_mean']
        data_std = norm_stats[var][resolution]['global_std']
        return (data_np - data_mean) / (data_std + 1e-8)
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var][resolution]['pixel_mean']
        data_std = norm_stats[var][resolution]['pixel_std']
        return (data_np - data_mean) / (data_std + 1e-8)
    elif method == 'instance_zscore':
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        return (data_np - data_mean) / (data_std + 1e-8)
    elif method == 'instance_minmax':
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        return 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
    else:
        raise ValueError(f"Unknown normalization method: {method}")


def denormalize_predictions(predictions, method, norm_stats, var, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    elif method == 'minmax_global':
        hr_min = norm_stats[var]['hr']['global_min']
        hr_max = norm_stats[var]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var]['hr']['pixel_min']
        hr_max = norm_stats[var]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    elif method == 'zscore_global':
        hr_mean = norm_stats[var]['hr']['global_mean']
        hr_std = norm_stats[var]['hr']['global_std']
        return predictions * hr_std + hr_mean
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var]['hr']['pixel_mean']
        hr_std = norm_stats[var]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min
    else:
        raise ValueError(f"Unknown normalization method: {method}")


def load_model(model_path):
    """Load a UNet model from checkpoint."""
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    return model


def predict_batched(model, data, batch_size=32):
    """Run model predictions in batches."""
    n_samples = len(data)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = data[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    return np.concatenate(predictions, axis=0).squeeze(1)


print("Helper functions defined.")

Helper functions defined.


In [4]:
# Cell 4: Evaluation Functions

def evaluate_direct_models(var, scenario_name, dataset, test_period, timing_results):
    """Evaluate all direct (normalization-based) models for a variable and scenario."""
    results = {}
    
    var_lr = f'{var}_lr_interp'
    var_hr = f'{var}_hr'
    
    # Extract data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[var_hr].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    
    # Store coordinates
    coords = {
        'time': hr_data_xr.time,
        'lat': hr_data_xr.lat,
        'lon': hr_data_xr.lon
    }
    
    # Store ground truth and input
    results['groundtruth'] = hr_data_xr
    results['input'] = lr_data_xr
    
    # Evaluate each normalization method
    for norm_method in normalizations:
        model_path = ckpt_dir / f"{var}_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"        Model not found: {model_path}")
            continue
        
        try:
            model_start = time.time()
            
            # Load model
            model = load_model(model_path)
            
            # Normalize input
            lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats_direct, var)
            
            # Predict
            predictions = predict_batched(model, lr_data_norm)
            
            # Denormalize
            predictions_denorm = denormalize_predictions(
                predictions, norm_method, norm_stats_direct, var, lr_data_raw
            )
            
            # Create DataArray
            pred_da = xr.DataArray(
                predictions_denorm,
                coords=coords,
                dims=['time', 'lat', 'lon']
            )
            
            results[f'pred_{norm_method}'] = pred_da
            
            model_time = time.time() - model_start
            timing_results['direct'][var][scenario_name][norm_method] = model_time
            print(f"        {norm_method}: {model_time:.2f}s")
            
        except Exception as e:
            print(f"        {norm_method}: Error - {e}")
            continue
    
    return results


def evaluate_residual_models(var, scenario_name, dataset_residual, dataset_original, test_period, timing_results):
    """Evaluate all residual models for a variable and scenario."""
    results = {}
    
    var_lr = f'{var}_lr_interp'
    var_hr = f'{var}_hr'
    
    # Extract original data
    lr_original = dataset_original[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_original = dataset_original[var_hr].sel(time=slice(test_period[0], test_period[1]))
    
    # Store coordinates
    coords = {
        'time': hr_original.time,
        'lat': hr_original.lat,
        'lon': hr_original.lon
    }
    
    # Store ground truth and input
    results['groundtruth'] = hr_original
    results['input'] = lr_original
    
    lr_original_np = lr_original.values
    
    # Evaluate each input type
    for input_type in residual_input_types:
        # Determine input variable and key
        if input_type == 'raw':
            input_var = f"{var}_lr_interp"
            input_key = 'lr_interp'
            model_filename = f"{var}_lr_to_residual.pth"
        else:
            detrend_map = {'gma': 'detrend_gma', 'grid': 'detrend_grid', 'gmt': 'detrend_gmt'}
            input_var = f"{var}_lr_{detrend_map[input_type]}"
            input_key = f'lr_{detrend_map[input_type]}'
            model_filename = f"{var}_lr_{input_type}_to_residual.pth"
        
        model_path = ckpt_dir / model_filename
        
        if not model_path.exists():
            print(f"        Model not found: {model_path}")
            continue
        
        try:
            model_start = time.time()
            
            # Load model
            model = load_model(model_path)
            
            # Get input data
            lr_data = dataset_residual[input_var].sel(time=slice(test_period[0], test_period[1])).values
            
            # Apply zscore_pixel normalization
            lr_mean = norm_stats_residual[var][input_key]['pixel_mean']
            lr_std = norm_stats_residual[var][input_key]['pixel_std']
            lr_normalized = (lr_data - lr_mean) / (lr_std + 1e-8)
            
            # Predict
            predictions_norm = predict_batched(model, lr_normalized)
            
            # Denormalize residual predictions
            residual_mean = norm_stats_residual[var]['residual']['pixel_mean']
            residual_std = norm_stats_residual[var]['residual']['pixel_std']
            residual_pred = predictions_norm * residual_std + residual_mean
            
            # Reconstruct full HR prediction
            hr_pred = residual_pred + lr_original_np
            
            # Create DataArray
            pred_da = xr.DataArray(
                hr_pred,
                coords=coords,
                dims=['time', 'lat', 'lon']
            )
            
            results[f'pred_{input_type}'] = pred_da
            
            model_time = time.time() - model_start
            timing_results['residual'][var][scenario_name][input_type] = model_time
            print(f"        {input_type}: {model_time:.2f}s")
            
        except Exception as e:
            print(f"        {input_type}: Error - {e}")
            continue
    
    return results


print("Evaluation functions defined.")

Evaluation functions defined.


In [6]:
print("="*80)
print("EVALUATING DIRECT MODELS")
print("="*80)

# Initialize timing results
timing_results = {
    'direct': {var: {scen: {} for scen in test_periods.keys()} for var in variables},
    'residual': {var: {scen: {} for scen in test_periods.keys()} for var in variables}
}

output_dir_direct = Path("evaluation_results")
output_dir_direct.mkdir(exist_ok=True)

total_start = time.time()

for var in variables:
    print(f"\n{'='*60}")
    print(f"Variable: {var.upper()}")
    print("="*60)
    
    for scenario_name in test_periods.keys():
        print(f"\n    Scenario: {scenario_name}")
        print(f"    {'-'*40}")
        
        test_period = test_periods[scenario_name]
        scenario_start = time.time()
        
        try:
            results = evaluate_direct_models(
                var, scenario_name, 
                datasets_original[scenario_name], 
                test_period,
                timing_results
            )
            
            # Save results
            ds_result = xr.Dataset()
            for key, value in results.items():
                ds_result[key] = value
            
            output_path = output_dir_direct / f"{var}_evaluation_{scenario_name}.nc"
            ds_result.to_netcdf(output_path)
            
            scenario_time = time.time() - scenario_start
            print(f"    Scenario total: {scenario_time:.2f}s | Saved: {output_path}")
            
        except Exception as e:
            print(f"    Error processing {scenario_name}: {e}")
            continue

direct_total_time = time.time() - total_start
print(f"\n{'='*60}")
print(f"Direct evaluation complete! Total time: {direct_total_time:.2f}s ({direct_total_time/60:.2f} min)")
print("="*60)

EVALUATING DIRECT MODELS

Variable: TAS

    Scenario: historical
    ----------------------------------------
        none: 1.15s
        minmax_global: 1.24s
        minmax_pixel: 1.61s
        zscore_global: 1.41s
        zscore_pixel: 1.46s
        instance_zscore: 1.70s
        instance_minmax: 1.59s
    Scenario total: 10.99s | Saved: evaluation_results/tas_evaluation_historical.nc

    Scenario: ssp126
    ----------------------------------------
        none: 4.59s
        minmax_global: 5.12s
        minmax_pixel: 5.94s
        zscore_global: 5.12s
        zscore_pixel: 5.49s
        instance_zscore: 7.02s
        instance_minmax: 6.07s
    Scenario total: 43.23s | Saved: evaluation_results/tas_evaluation_ssp126.nc

    Scenario: ssp245
    ----------------------------------------
        none: 3.85s
        minmax_global: 4.68s
        minmax_pixel: 5.31s
        zscore_global: 4.67s
        zscore_pixel: 5.29s
        instance_zscore: 7.69s
        instance_minmax: 6.08s
   

In [9]:
print("="*80)
print("EVALUATING RESIDUAL MODELS")
print("="*80)

output_dir_residual = Path("evaluation_results_residual")
output_dir_residual.mkdir(exist_ok=True)

total_start = time.time()

for var in variables:
    print(f"\n{'='*60}")
    print(f"Variable: {var.upper()}")
    print("="*60)
    
    for scenario_name in test_periods.keys():
        print(f"\n    Scenario: {scenario_name}")
        print(f"    {'-'*40}")
        
        test_period = test_periods[scenario_name]
        scenario_start = time.time()
        
        try:
            results = evaluate_residual_models(
                var, scenario_name,
                datasets_residual[scenario_name],
                datasets_original[scenario_name],
                test_period,
                timing_results
            )
            
            # Save results
            ds_result = xr.Dataset()
            for key, value in results.items():
                ds_result[key] = value
            
            output_path = output_dir_residual / f"{var}_evaluation_{scenario_name}.nc"
            ds_result.to_netcdf(output_path)
            
            scenario_time = time.time() - scenario_start
            print(f"    Scenario total: {scenario_time:.2f}s | Saved: {output_path}")
            
        except Exception as e:
            print(f"    Error processing {scenario_name}: {e}")
            continue

residual_total_time = time.time() - total_start
print(f"\n{'='*60}")
print(f"Residual evaluation complete! Total time: {residual_total_time:.2f}s ({residual_total_time/60:.2f} min)")
print("="*60)

EVALUATING RESIDUAL MODELS

Variable: TAS

    Scenario: historical
    ----------------------------------------
        raw: 1.41s
        gma: 1.51s
        grid: 1.59s
        gmt: 1.54s
    Scenario total: 6.93s | Saved: evaluation_results_residual/tas_evaluation_historical.nc

    Scenario: ssp126
    ----------------------------------------
        raw: 6.12s
        gma: 6.66s
        grid: 6.90s
        gmt: 6.47s
    Scenario total: 31.14s | Saved: evaluation_results_residual/tas_evaluation_ssp126.nc

    Scenario: ssp245
    ----------------------------------------
        raw: 6.66s
        gma: 6.74s
        grid: 7.12s
        gmt: 6.93s
    Scenario total: 33.59s | Saved: evaluation_results_residual/tas_evaluation_ssp245.nc

    Scenario: ssp585
    ----------------------------------------
        raw: 7.28s
        gma: 6.99s
        grid: 6.75s
        gmt: 6.40s
    Scenario total: 32.45s | Saved: evaluation_results_residual/tas_evaluation_ssp585.nc

    Scenario: g6su

In [11]:
# Cell: Re-run residual evaluation for g6sulfur ONLY

print("="*60)
print("RE-EVALUATING RESIDUAL MODELS FOR G6SULFUR")
print("="*60)

# Reload the updated g6sulfur residual dataset
datasets_residual['g6sulfur'] = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_residual_detrended.nc")

scenario_name = 'g6sulfur'
test_period = test_periods[scenario_name]

for var in variables:
    print(f"\n  Variable: {var}")
    print(f"  {'-'*40}")
    
    scenario_start = time.time()
    
    results = evaluate_residual_models(
        var, scenario_name,
        datasets_residual[scenario_name],
        datasets_original[scenario_name],
        test_period,
        timing_results
    )
    
    # Save results
    ds_result = xr.Dataset()
    for key, value in results.items():
        ds_result[key] = value
    
    output_path = output_dir_residual / f"{var}_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    
    scenario_time = time.time() - scenario_start
    print(f"  Saved: {output_path} ({scenario_time:.2f}s)")

print("\n" + "="*60)
print("G6sulfur residual evaluation complete!")
print("="*60)

RE-EVALUATING RESIDUAL MODELS FOR G6SULFUR

  Variable: tas
  ----------------------------------------
        raw: 6.10s
        gma: 6.23s
        grid: 6.18s
        gmt: 6.24s
  Saved: evaluation_results_residual/tas_evaluation_g6sulfur.nc (29.03s)

  Variable: pr
  ----------------------------------------
        raw: 6.15s
        gma: 5.98s
        grid: 6.84s
        gmt: 6.57s
  Saved: evaluation_results_residual/pr_evaluation_g6sulfur.nc (30.59s)

G6sulfur residual evaluation complete!


In [10]:
# Cell 7: Print Timing Summary

print("="*80)
print("TIMING SUMMARY")
print("="*80)

# Direct Models Summary
print("\n" + "-"*60)
print("DIRECT MODELS")
print("-"*60)

for var in variables:
    print(f"\n{var.upper()}:")
    print(f"{'Scenario':<15} " + " ".join([f"{n[:8]:<10}" for n in normalizations]))
    print("-" * (15 + 10 * len(normalizations)))
    
    for scenario_name in test_periods.keys():
        row = f"{scenario_name:<15} "
        for norm in normalizations:
            if norm in timing_results['direct'][var][scenario_name]:
                t = timing_results['direct'][var][scenario_name][norm]
                row += f"{t:<10.2f} "
            else:
                row += f"{'N/A':<10} "
        print(row)
    
    # Calculate totals per normalization
    print("-" * (15 + 10 * len(normalizations)))
    row = f"{'TOTAL':<15} "
    for norm in normalizations:
        total = sum(timing_results['direct'][var][scen].get(norm, 0) for scen in test_periods.keys())
        row += f"{total:<10.2f} "
    print(row)

# Residual Models Summary
print("\n" + "-"*60)
print("RESIDUAL MODELS")
print("-"*60)

for var in variables:
    print(f"\n{var.upper()}:")
    print(f"{'Scenario':<15} " + " ".join([f"{t:<10}" for t in residual_input_types]))
    print("-" * (15 + 10 * len(residual_input_types)))
    
    for scenario_name in test_periods.keys():
        row = f"{scenario_name:<15} "
        for input_type in residual_input_types:
            if input_type in timing_results['residual'][var][scenario_name]:
                t = timing_results['residual'][var][scenario_name][input_type]
                row += f"{t:<10.2f} "
            else:
                row += f"{'N/A':<10} "
        print(row)
    
    # Calculate totals per input type
    print("-" * (15 + 10 * len(residual_input_types)))
    row = f"{'TOTAL':<15} "
    for input_type in residual_input_types:
        total = sum(timing_results['residual'][var][scen].get(input_type, 0) for scen in test_periods.keys())
        row += f"{total:<10.2f} "
    print(row)

# Grand Total
print("\n" + "="*60)
direct_total = sum(
    timing_results['direct'][var][scen].get(norm, 0)
    for var in variables
    for scen in test_periods.keys()
    for norm in normalizations
)
residual_total = sum(
    timing_results['residual'][var][scen].get(input_type, 0)
    for var in variables
    for scen in test_periods.keys()
    for input_type in residual_input_types
)
grand_total = direct_total + residual_total

print(f"Direct models total:   {direct_total:>10.2f}s ({direct_total/60:.2f} min)")
print(f"Residual models total: {residual_total:>10.2f}s ({residual_total/60:.2f} min)")
print(f"{'='*40}")
print(f"GRAND TOTAL:           {grand_total:>10.2f}s ({grand_total/60:.2f} min)")
print("="*60)

TIMING SUMMARY

------------------------------------------------------------
DIRECT MODELS
------------------------------------------------------------

TAS:
Scenario        none       minmax_g   minmax_p   zscore_g   zscore_p   instance   instance  
-------------------------------------------------------------------------------------
historical      1.15       1.24       1.61       1.41       1.46       1.70       1.59       
ssp126          4.59       5.12       5.94       5.12       5.49       7.02       6.07       
ssp245          3.85       4.68       5.31       4.67       5.29       7.69       6.08       
ssp585          4.32       4.71       5.83       4.80       5.10       6.73       6.23       
g6sulfur        3.52       4.32       4.89       4.40       4.90       6.58       5.74       
-------------------------------------------------------------------------------------
TOTAL           17.44      20.07      23.58      20.39      22.24      29.72      25.71      

PR:
Scenario