In [2]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
from scipy import stats

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# Variables and input types to evaluate
variables = ['pr', 'tas']
input_types = ['raw', 'gma', 'grid', 'gmt']  # Simplified names

# Load datasets
print("Loading datasets...")
datasets = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_residual_detrended.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_residual_detrended.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_residual_detrended.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_residual_detrended.nc")
}

# Load original datasets for undetrended versions
datasets_original = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

def evaluate_model(model_path, var, input_type, dataset, dataset_original, test_period):
    """Evaluate a single model on a dataset."""
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Get variable names based on input type
    if input_type == 'raw':
        input_var = f"{var}_lr_interp"
        input_key = 'lr_interp'
    else:
        # Map simplified names back to full detrend names
        detrend_map = {'gma': 'detrend_gma', 'grid': 'detrend_grid', 'gmt': 'detrend_gmt'}
        input_var = f"{var}_lr_{detrend_map[input_type]}"
        input_key = f'lr_{detrend_map[input_type]}'
    
    # Extract test data
    lr_data = dataset[input_var].sel(time=slice(test_period[0], test_period[1])).values
    residual_true = dataset[f"{var}_residual"].sel(time=slice(test_period[0], test_period[1])).values
    lr_original = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1])).values
    hr_original = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).values
    
    # Apply zscore_pixel normalization to input
    lr_mean = norm_stats[var][input_key]['pixel_mean']
    lr_std = norm_stats[var][input_key]['pixel_std']
    lr_normalized = (lr_data - lr_mean) / (lr_std + 1e-8)
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_normalized)
    predictions_norm = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_normalized[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions_norm.append(batch_pred.cpu().numpy())
    
    predictions_norm = np.concatenate(predictions_norm, axis=0).squeeze(1)
    
    # Denormalize residual predictions
    residual_mean = norm_stats[var]['residual']['pixel_mean']
    residual_std = norm_stats[var]['residual']['pixel_std']
    residual_pred = predictions_norm * residual_std + residual_mean
    
    # Reconstruct full HR prediction by adding back LR_interp
    hr_pred = residual_pred + lr_original
    
    return hr_pred, hr_original, lr_original, residual_pred, residual_true

# Main evaluation loop
print("\n" + "="*80)
print("EVALUATING RESIDUAL MODELS")
print("="*80)

all_results = {}

for scenario_name in datasets.keys():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    dataset = datasets[scenario_name]
    dataset_original = datasets_original[scenario_name]
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    for var in variables:
        print(f"\n  Variable: {var}")
        var_results = {}
        
        # Get coordinates for creating xarray DataArrays
        time_coords = dataset[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).time
        lat_coords = dataset[f"{var}_hr"].lat
        lon_coords = dataset[f"{var}_hr"].lon
        
        # Store ground truth (original HR, not residual)
        hr_true = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1]))
        lr_input = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1]))
        
        var_results['groundtruth'] = hr_true
        var_results['input'] = lr_input
        
        for input_type in input_types:
            # Build model filename based on your naming convention
            if input_type == 'raw':
                model_filename = f"{var}_lr_to_residual.pth"
            else:
                model_filename = f"{var}_lr_{input_type}_to_residual.pth"
            
            model_path = ckpt_dir / model_filename
            
            if not model_path.exists():
                print(f"    Model not found: {model_path}")
                continue
            
            try:
                print(f"    Evaluating {input_type}...", end=" ")
                
                hr_pred, hr_true_np, lr_orig, residual_pred, residual_true = evaluate_model(
                    model_path, var, input_type, dataset, dataset_original, test_period
                )
                
                # Create xarray DataArray for predictions
                pred_da = xr.DataArray(
                    hr_pred,
                    coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                    dims=['time', 'lat', 'lon'],
                    name=f'{var}_pred_{input_type}'
                )
                
                var_results[f'pred_{input_type}'] = pred_da
                print("Success")
                
            except Exception as e:
                print(f"Error: {e}")
                continue
        
        scenario_results[var] = var_results
    
    all_results[scenario_name] = scenario_results

# Save results
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results_residual")
output_dir.mkdir(exist_ok=True)

for scenario_name, scenario_results in all_results.items():
    for var, var_results in scenario_results.items():
        ds_result = xr.Dataset()
        
        # Store ground truth and input
        ds_result['groundtruth'] = var_results['groundtruth']
        ds_result['input'] = var_results['input']
        
        # Store predictions
        for input_type in input_types:
            key = f'pred_{input_type}'
            if key in var_results:
                ds_result[key] = var_results[key]
        
        # Save to netCDF
        output_path = output_dir / f"{var}_evaluation_{scenario_name}.nc"
        ds_result.to_netcdf(output_path)
        print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...

EVALUATING RESIDUAL MODELS

HISTORICAL Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

SSP126 Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

SSP245 Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    

In [3]:
# evaluate_pr.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Variable configuration
variable = 'pr_hr'
var_base = 'pr'
var_lr = 'pr_lr_interp'

# Models to evaluate
normalizations = ['none', 'minmax_global', 'minmax_pixel', 'zscore_global', 
                 'zscore_pixel', 'instance_zscore', 'instance_minmax']

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# ----------------------------
# Load data 
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

datasets = {
    'historical': ds_hist,
    'ssp126': ds_ssp126,
    'ssp245': ds_ssp245,
    'ssp585': ds_ssp585
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

print("Normalization stats loaded for variables:", list(norm_stats.keys()))

# ----------------------------
# Normalization helpers 
# ----------------------------
def apply_normalization(data, method, norm_stats, var_base, resolution='lr_interp'):
    """Apply normalization method to data."""
    if method == 'none':
        return data
    
    elif method == 'minmax_global':
        data_min = norm_stats[var_base][resolution]['global_min']
        data_max = norm_stats[var_base][resolution]['global_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'minmax_pixel':
        data_min = norm_stats[var_base][resolution]['pixel_min']
        data_max = norm_stats[var_base][resolution]['pixel_max']
        return 2 * (data - data_min) / (data_max - data_min + 1e-8) - 1
    
    elif method == 'zscore_global':
        data_mean = norm_stats[var_base][resolution]['global_mean']
        data_std = norm_stats[var_base][resolution]['global_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'zscore_pixel':
        data_mean = norm_stats[var_base][resolution]['pixel_mean']
        data_std = norm_stats[var_base][resolution]['pixel_std']
        return (data - data_mean) / (data_std + 1e-8)
    
    elif method == 'instance_zscore':
        data_np = data.values if hasattr(data, 'values') else data
        data_mean = np.mean(data_np, axis=(1, 2), keepdims=True)
        data_std = np.std(data_np, axis=(1, 2), keepdims=True)
        data_normalized = (data_np - data_mean) / (data_std + 1e-8)
        return data_normalized
    
    elif method == 'instance_minmax':
        data_np = data.values if hasattr(data, 'values') else data
        data_min = np.min(data_np, axis=(1, 2), keepdims=True)
        data_max = np.max(data_np, axis=(1, 2), keepdims=True)
        data_normalized = 2 * (data_np - data_min) / (data_max - data_min + 1e-8) - 1
        return data_normalized

def denormalize_predictions(predictions, method, norm_stats, var_base, input_data=None):
    """Denormalize predictions back to original scale."""
    if method == 'none':
        return predictions
    
    elif method == 'minmax_global':
        hr_min = norm_stats[var_base]['hr']['global_min']
        hr_max = norm_stats[var_base]['hr']['global_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'minmax_pixel':
        hr_min = norm_stats[var_base]['hr']['pixel_min']
        hr_max = norm_stats[var_base]['hr']['pixel_max']
        return ((predictions + 1) / 2) * (hr_max - hr_min) + hr_min
    
    elif method == 'zscore_global':
        hr_mean = norm_stats[var_base]['hr']['global_mean']
        hr_std = norm_stats[var_base]['hr']['global_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'zscore_pixel':
        hr_mean = norm_stats[var_base]['hr']['pixel_mean']
        hr_std = norm_stats[var_base]['hr']['pixel_std']
        return predictions * hr_std + hr_mean
    
    elif method == 'instance_zscore':
        input_mean = np.mean(input_data, axis=(1, 2), keepdims=True)
        input_std = np.std(input_data, axis=(1, 2), keepdims=True)
        return predictions * input_std + input_mean
    
    elif method == 'instance_minmax':
        input_min = np.min(input_data, axis=(1, 2), keepdims=True)
        input_max = np.max(input_data, axis=(1, 2), keepdims=True)
        return ((predictions + 1) / 2) * (input_max - input_min) + input_min

# ----------------------------
# Model evaluation
# ----------------------------
def evaluate_model(model_path, norm_method, dataset, test_period):
    """Evaluate a single model on a dataset."""
    print(f"    Evaluating {norm_method}...")
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Extract test data
    lr_data_xr = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    hr_data_xr = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    
    lr_data_raw = lr_data_xr.values
    hr_data_raw = hr_data_xr.values
    
    # Apply normalization to input
    lr_data_norm = apply_normalization(lr_data_xr, norm_method, norm_stats, var_base)
    
    # Ensure it's numpy
    if hasattr(lr_data_norm, 'values'):
        lr_data_norm = lr_data_norm.values
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_data_norm)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_data_norm[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0).squeeze(1)
    
    # Denormalize predictions
    predictions_denorm = denormalize_predictions(
        predictions, norm_method, norm_stats, var_base, lr_data_raw
    )
    
    return predictions_denorm, hr_data_raw, lr_data_raw

# ----------------------------
# Main evaluation loop
# ----------------------------
print("\n" + "="*80)
print("EVALUATING PRECIPITATION (PR) MODELS")
print("="*80)

all_results = {}

for scenario_name, dataset in datasets.items():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    # Get ground truth
    hr_data = dataset[variable].sel(time=slice(test_period[0], test_period[1]))
    lr_data = dataset[var_lr].sel(time=slice(test_period[0], test_period[1]))
    time_coords = hr_data.time
    lat_coords = hr_data.lat
    lon_coords = hr_data.lon
    
    # Store ground truth and input
    scenario_results['groundtruth'] = hr_data
    scenario_results['input'] = lr_data
    
    # Evaluate each model
    for norm_method in normalizations:
        model_path = ckpt_dir / f"pr_{norm_method}.pth"
        
        if not model_path.exists():
            print(f"Model not found: {model_path}")
            continue
        
        try:
            predictions, groundtruth, inputs = evaluate_model(
                model_path, norm_method, dataset, test_period
            )
            
            pred_da = xr.DataArray(
                predictions,
                coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                dims=['time', 'lat', 'lon'],
                name=f'pr_pred_{norm_method}'
            )
            
            scenario_results[f'pred_{norm_method}'] = pred_da
            print(f"Success")
            
        except Exception as e:
            print(f"Error: {e}")
            continue
    
    all_results[scenario_name] = scenario_results

# ----------------------------
# Save results
# ----------------------------
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

for scenario_name, results in all_results.items():
    ds_result = xr.Dataset()
    
    ds_result['groundtruth'] = results['groundtruth']
    ds_result['input'] = results['input']
    
    for norm_method in normalizations:
        key = f'pred_{norm_method}'
        if key in results:
            ds_result[key] = results[key]
    
    output_path = output_dir / f"pr_evaluation_{scenario_name}.nc"
    ds_result.to_netcdf(output_path)
    print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...
Normalization stats loaded for variables: ['pr', 'tas', 'hurs', 'sfcWind']

EVALUATING PRECIPITATION (PR) MODELS

HISTORICAL Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP126 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Success
    Evaluating zscore_global...
Success
    Evaluating zscore_pixel...
Success
    Evaluating instance_zscore...
Success
    Evaluating instance_minmax...
Success

SSP245 Scenario
----------------------------------------
    Evaluating none...
Success
    Evaluating minmax_global...
Success
    Evaluating minmax_pixel...
Succe