# Relative Humidity

In [1]:
# evaluate_hurs.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'hurs_hr'
var_base = 'hurs'
var_lr = 'hurs_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# ----------------------------
# Load data 
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

datasets = {
    'historical': ds_hist,
    'ssp126': ds_ssp126,
    'ssp245': ds_ssp245,
    'ssp585': ds_ssp585
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

print("Normalization stats loaded for variables:", list(norm_stats.keys()))

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        # CHANGED: pixel stats are already numpy arrays
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        # CHANGED: pixel stats are already numpy arrays
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        # Per-sample normalization - convert to numpy immediately
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized  # Return numpy array
    
    elif method == 'instance_minmax':
        # Per-sample min-max to [-1, 1] - convert to numpy immediately
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized  # Return numpy array

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        # Need input statistics for denormalization
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        # Need input statistics for denormalization
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data - CONVERT TO NUMPY IMMEDIATELY
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input (works with numpy or xarray)
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            # Get predictions
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation loop
# ----------------------------
print("\n" + "="*80)
print("EVALUATING RELATIVE HUMIDITY (HURS) MODELS")
print("="*80)

# Store all results
all_results = {}

for scenario_name, dataset in datasets.items():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    # Get ground truth (same for all models)
    hr_data = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    lr_data = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    time_coords = hr_data.time
    lat_coords = hr_data.lat
    lon_coords = hr_data.lon
    
    # Store ground truth and input
    scenario_results['groundtruth'] = hr_data
    scenario_results['input'] = lr_data
    
    # Evaluate each model
    for norm_method in normalizations:
        model_path = ckpt_dir / f"hurs_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue
        
        try:
            predictions, groundtruth, inputs = evaluate_model(
                model_path, norm_method, dataset, test_period
            )
            
            # Create xarray DataArray for predictions
            pred_da = xr.DataArray(
                predictions,
                coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                dims=['time', 'lat', 'lon'],
                name=f'hurs_pred_{norm_method}'
            )
            
            scenario_results[f'pred_{norm_method}'] = pred_da
            print(f"Success")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    all_results[scenario_name] = scenario_results

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

for scenario_name, results in all_results.items():
    # Create dataset with all predictions
    ds_result = xr.Dataset()
    
    # Add ground truth and input
    ds_result['groundtruth'] = results['groundtruth']
    ds_result['input'] = results['input']
    
    # Add all predictions
    for norm_method in normalizations:
        key = f'pred_{norm_method}'
        if key in results:
            ds_result[key] = results[key]
    
    # Save to NetCDF
    output_path = output_dir / f"hurs_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    print(f" Saved: {output_path}")


  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


Loading datasets...
Loading normalization statistics...
Normalization stats loaded for variables: ['pr', 'tas', 'hurs', 'sfcWind']

EVALUATING RELATIVE HUMIDITY (HURS) MODELS

HISTORICAL Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP126 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP245 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...

# Temperature

In [2]:
# evaluate_tas.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'tas_hr'
var_base = 'tas'
var_lr = 'tas_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# ----------------------------
# Load data 
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

datasets = {
    'historical': ds_hist,
    'ssp126': ds_ssp126,
    'ssp245': ds_ssp245,
    'ssp585': ds_ssp585
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

print("Normalization stats loaded for variables:", list(norm_stats.keys()))

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized
    
    elif method == 'instance_minmax':
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation loop
# ----------------------------
print("\n" + "="*80)
print("EVALUATING TEMPERATURE (TAS) MODELS")
print("="*80)

all_results = {}

for scenario_name, dataset in datasets.items():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    # Get ground truth
    hr_data = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    lr_data = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    time_coords = hr_data.time
    lat_coords = hr_data.lat
    lon_coords = hr_data.lon
    
    # Store ground truth and input
    scenario_results['groundtruth'] = hr_data
    scenario_results['input'] = lr_data
    
    # Evaluate each model
    for norm_method in normalizations:
        model_path = ckpt_dir / f"tas_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue
        
        try:
            predictions, groundtruth, inputs = evaluate_model(
                model_path, norm_method, dataset, test_period
            )
            
            pred_da = xr.DataArray(
                predictions,
                coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                dims=['time', 'lat', 'lon'],
                name=f'tas_pred_{norm_method}'
            )
            
            scenario_results[f'pred_{norm_method}'] = pred_da
            print(f"Success")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    all_results[scenario_name] = scenario_results

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

for scenario_name, results in all_results.items():
    ds_result = xr.Dataset()
    
    ds_result['groundtruth'] = results['groundtruth']
    ds_result['input'] = results['input']
    
    for norm_method in normalizations:
        key = f'pred_{norm_method}'
        if key in results:
            ds_result[key] = results[key]
    
    output_path = output_dir / f"tas_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...
Normalization stats loaded for variables: ['pr', 'tas', 'hurs', 'sfcWind']

EVALUATING TEMPERATURE (TAS) MODELS

HISTORICAL Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP126 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP245 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Succes

In [1]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'tas_hr'
var_base = 'tas'
var_lr = 'tas_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test period for G6sulfur
test_period_g6 = ('2020', '2099')

# ----------------------------
# Load data 
# ----------------------------
print("Loading G6sulfur dataset...")
ds_g6sulfur = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized
    
    elif method == 'instance_minmax':
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation for G6sulfur
# ----------------------------
print("\n" + "="*80)
print("EVALUATING G6SULFUR TEMPERATURE (TAS) MODELS")
print("="*80)

g6_results = {}

# Get ground truth
hr_data = ds_g6sulfur[variable].sel(time=slice(test_period_g6[0], test_period_g6[1]))
lr_data = ds_g6sulfur[var_lr].sel(time=slice(test_period_g6[0], test_period_g6[1]))
time_coords = hr_data.time
lat_coords = hr_data.lat
lon_coords = hr_data.lon

# Store ground truth and input
g6_results['groundtruth'] = hr_data
g6_results['input'] = lr_data

# Evaluate each model
for norm_method in normalizations:
    model_path = ckpt_dir / f"tas_{norm_method}.pth"
    
    if not model_path.exists():
        print(f"Model not found: {model_path}")
        continue
    
    try:
        predictions, groundtruth, inputs = evaluate_model(
            model_path, norm_method, ds_g6sulfur, test_period_g6
        )
        
        pred_da = xr.DataArray(
            predictions,
            coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
            dims=['time', 'lat', 'lon'],
            name=f'tas_pred_{norm_method}'
        )
        
        g6_results[f'pred_{norm_method}'] = pred_da
        print(f"    {norm_method}: Success")
        
    except Exception as e:
        print(f"    {norm_method}: Error - {e}")
        continue

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING G6SULFUR RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

ds_result = xr.Dataset()
ds_result['groundtruth'] = g6_results['groundtruth']
ds_result['input'] = g6_results['input']

for norm_method in normalizations:
    key = f'pred_{norm_method}'
    if key in g6_results:
        ds_result[key] = g6_results[key]

output_path = output_dir / "tas_evaluation_g6sulfur.nc"
ds_result.to_netcdf(output_path)
print(f"   Saved: {output_path}")
print(f"   Variables: {list(ds_result.data_vars)}")
print(f"   Time range: {test_period_g6[0]} to {test_period_g6[1]}")

print("\nG6sulfur evaluation complete!")

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


Loading G6sulfur dataset...
Loading normalization statistics...

EVALUATING G6SULFUR TEMPERATURE (TAS) MODELS
    Evaluating none...
    none: Success
    Evaluating minmax_global...
    minmax_global: Success
    Evaluating minmax_pixel...
    minmax_pixel: Success
    Evaluating zscore_global...
    zscore_global: Success
    Evaluating zscore_pixel...
    zscore_pixel: Success
    Evaluating instance_zscore...
    instance_zscore: Success
    Evaluating instance_minmax...
    instance_minmax: Success

SAVING G6SULFUR RESULTS
   Saved: evaluation_results/tas_evaluation_g6sulfur.nc
   Variables: ['groundtruth', 'input', 'pred_none', 'pred_minmax_global', 'pred_minmax_pixel', 'pred_zscore_global', 'pred_zscore_pixel', 'pred_instance_zscore', 'pred_instance_minmax']
   Time range: 2020 to 2099

G6sulfur evaluation complete!


# Precipitation

In [3]:
# evaluate_pr.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'pr_hr'
var_base = 'pr'
var_lr = 'pr_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# ----------------------------
# Load data 
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

datasets = {
    'historical': ds_hist,
    'ssp126': ds_ssp126,
    'ssp245': ds_ssp245,
    'ssp585': ds_ssp585
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

print("Normalization stats loaded for variables:", list(norm_stats.keys()))

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized
    
    elif method == 'instance_minmax':
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation loop
# ----------------------------
print("\n" + "="*80)
print("EVALUATING PRECIPITATION (PR) MODELS")
print("="*80)

all_results = {}

for scenario_name, dataset in datasets.items():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    # Get ground truth
    hr_data = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    lr_data = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    time_coords = hr_data.time
    lat_coords = hr_data.lat
    lon_coords = hr_data.lon
    
    # Store ground truth and input
    scenario_results['groundtruth'] = hr_data
    scenario_results['input'] = lr_data
    
    # Evaluate each model
    for norm_method in normalizations:
        model_path = ckpt_dir / f"pr_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue
        
        try:
            predictions, groundtruth, inputs = evaluate_model(
                model_path, norm_method, dataset, test_period
            )
            
            pred_da = xr.DataArray(
                predictions,
                coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                dims=['time', 'lat', 'lon'],
                name=f'pr_pred_{norm_method}'
            )
            
            scenario_results[f'pred_{norm_method}'] = pred_da
            print(f"Success")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    all_results[scenario_name] = scenario_results

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

for scenario_name, results in all_results.items():
    ds_result = xr.Dataset()
    
    ds_result['groundtruth'] = results['groundtruth']
    ds_result['input'] = results['input']
    
    for norm_method in normalizations:
        key = f'pred_{norm_method}'
        if key in results:
            ds_result[key] = results[key]
    
    output_path = output_dir / f"pr_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...
Normalization stats loaded for variables: ['pr', 'tas', 'hurs', 'sfcWind']

EVALUATING PRECIPITATION (PR) MODELS

HISTORICAL Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP126 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP245 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Succe

# sfcWind

In [4]:
# evaluate_sfcWind.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'sfcWind_hr'
var_base = 'sfcWind'
var_lr = 'sfcWind_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# ----------------------------
# Load data 
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

datasets = {
    'historical': ds_hist,
    'ssp126': ds_ssp126,
    'ssp245': ds_ssp245,
    'ssp585': ds_ssp585
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

print("Normalization stats loaded for variables:", list(norm_stats.keys()))

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized
    
    elif method == 'instance_minmax':
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation loop
# ----------------------------
print("\n" + "="*80)
print("EVALUATING WIND SPEED (SFCWIND) MODELS")
print("="*80)

all_results = {}

for scenario_name, dataset in datasets.items():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    # Get ground truth
    hr_data = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    lr_data = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    time_coords = hr_data.time
    lat_coords = hr_data.lat
    lon_coords = hr_data.lon
    
    # Store ground truth and input
    scenario_results['groundtruth'] = hr_data
    scenario_results['input'] = lr_data
    
    # Evaluate each model
    for norm_method in normalizations:
        model_path = ckpt_dir / f"sfcWind_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue
        
        try:
            predictions, groundtruth, inputs = evaluate_model(
                model_path, norm_method, dataset, test_period
            )
            
            pred_da = xr.DataArray(
                predictions,
                coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                dims=['time', 'lat', 'lon'],
                name=f'sfcWind_pred_{norm_method}'
            )
            
            scenario_results[f'pred_{norm_method}'] = pred_da
            print(f"Success")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    all_results[scenario_name] = scenario_results

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

for scenario_name, results in all_results.items():
    ds_result = xr.Dataset()
    
    ds_result['groundtruth'] = results['groundtruth']
    ds_result['input'] = results['input']
    
    for norm_method in normalizations:
        key = f'pred_{norm_method}'
        if key in results:
            ds_result[key] = results[key]
    
    output_path = output_dir / f"sfcWind_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...
Normalization stats loaded for variables: ['pr', 'tas', 'hurs', 'sfcWind']

EVALUATING WIND SPEED (SFCWIND) MODELS

HISTORICAL Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP126 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP245 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Suc