In [None]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
from scipy import stats

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Test periods - ADDED g6sulfur
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100'),
    'g6sulfur': ('2020', '2099')  # Added
}

# Variables and input types to evaluate
variables = ['pr', 'tas']
input_types = ['raw', 'gma', 'grid', 'gmt']  # Simplified names

# Load datasets - ADDED g6sulfur
print("Loading datasets...")
datasets = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_residual_detrended.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_residual_detrended.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_residual_detrended.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_residual_detrended.nc"),
    'g6sulfur': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_residual_detrended.nc")  # Added
}

# Load original datasets for undetrended versions - ADDED g6sulfur
datasets_original = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc"),
    'g6sulfur': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")  # Added
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

def evaluate_model(model_path, var, input_type, dataset, dataset_original, test_period):
    """Evaluate a single model on a dataset."""
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Get variable names based on input type
    if input_type == 'raw':
        input_var = f"{var}_lr_interp"
        input_key = 'lr_interp'
    else:
        # Map simplified names back to full detrend names
        detrend_map = {'gma': 'detrend_gma', 'grid': 'detrend_grid', 'gmt': 'detrend_gmt'}
        input_var = f"{var}_lr_{detrend_map[input_type]}"
        input_key = f'lr_{detrend_map[input_type]}'
    
    # Extract test data
    lr_data = dataset[input_var].sel(time=slice(test_period[0], test_period[1])).values
    residual_true = dataset[f"{var}_residual"].sel(time=slice(test_period[0], test_period[1])).values
    lr_original = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1])).values
    hr_original = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).values
    
    # Apply zscore_pixel normalization to input
    lr_mean = norm_stats[var][input_key]['pixel_mean']
    lr_std = norm_stats[var][input_key]['pixel_std']
    lr_normalized = (lr_data - lr_mean) / (lr_std + 1e-8)
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_normalized)
    predictions_norm = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_normalized[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions_norm.append(batch_pred.cpu().numpy())
    
    predictions_norm = np.concatenate(predictions_norm, axis=0).squeeze(1)
    
    # Denormalize residual predictions
    residual_mean = norm_stats[var]['residual']['pixel_mean']
    residual_std = norm_stats[var]['residual']['pixel_std']
    residual_pred = predictions_norm * residual_std + residual_mean
    
    # Reconstruct full HR prediction by adding back LR_interp
    hr_pred = residual_pred + lr_original
    
    return hr_pred, hr_original, lr_original, residual_pred, residual_true

# Main evaluation loop
print("\n" + "="*80)
print("EVALUATING RESIDUAL MODELS")
print("="*80)

all_results = {}

for scenario_name in datasets.keys():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    dataset = datasets[scenario_name]
    dataset_original = datasets_original[scenario_name]
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    for var in variables:
        print(f"\n  Variable: {var}")
        var_results = {}
        
        # Get coordinates for creating xarray DataArrays
        time_coords = dataset[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).time
        lat_coords = dataset[f"{var}_hr"].lat
        lon_coords = dataset[f"{var}_hr"].lon
        
        # Store ground truth (original HR, not residual)
        hr_true = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1]))
        lr_input = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1]))
        
        var_results['groundtruth'] = hr_true
        var_results['input'] = lr_input
        
        for input_type in input_types:
            # Build model filename based on your naming convention
            if input_type == 'raw':
                model_filename = f"{var}_lr_to_residual.pth"
            else:
                model_filename = f"{var}_lr_{input_type}_to_residual.pth"
            
            model_path = ckpt_dir / model_filename
            
            if not model_path.exists():
                print(f"    Model not found: {model_path}")
                continue
            
            try:
                print(f"    Evaluating {input_type}...", end=" ")
                
                hr_pred, hr_true_np, lr_orig, residual_pred, residual_true = evaluate_model(
                    model_path, var, input_type, dataset, dataset_original, test_period
                )
                
                # Create xarray DataArray for predictions
                pred_da = xr.DataArray(
                    hr_pred,
                    coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                    dims=['time', 'lat', 'lon'],
                    name=f'{var}_pred_{input_type}'
                )
                
                var_results[f'pred_{input_type}'] = pred_da
                print("Success")
                
            except Exception as e:
                print(f"Error: {e}")
                continue
        
        scenario_results[var] = var_results
    
    all_results[scenario_name] = scenario_results

# Save results
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results_residual")
output_dir.mkdir(exist_ok=True)

for scenario_name, scenario_results in all_results.items():
    for var, var_results in scenario_results.items():
        ds_result = xr.Dataset()
        
        # Store ground truth and input
        ds_result['groundtruth'] = var_results['groundtruth']
        ds_result['input'] = var_results['input']
        
        # Store predictions
        for input_type in input_types:
            key = f'pred_{input_type}'
            if key in var_results:
                ds_result[key] = var_results[key]
        
        # Save to netCDF
        output_path = output_dir / f"{var}_evaluation_{scenario_name}.nc"
        ds_result.to_netcdf(output_path)
        print(f"Saved: {output_path}")

print("\nEvaluation complete!")

In [2]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
from scipy import stats

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Test periods
test_periods = {
    'historical': ('2001', '2014'),
    'ssp126': ('2015', '2100'),
    'ssp245': ('2015', '2100'), 
    'ssp585': ('2015', '2100')
}

# Variables and input types to evaluate
variables = ['pr', 'tas']
input_types = ['raw', 'gma', 'grid', 'gmt']  # Simplified names

# Load datasets
print("Loading datasets...")
datasets = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_residual_detrended.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_residual_detrended.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_residual_detrended.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_residual_detrended.nc")
}

# Load original datasets for undetrended versions
datasets_original = {
    'historical': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc"),
    'ssp126': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp245': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc"),
    'ssp585': xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")
}

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

def evaluate_model(model_path, var, input_type, dataset, dataset_original, test_period):
    """Evaluate a single model on a dataset."""
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Get variable names based on input type
    if input_type == 'raw':
        input_var = f"{var}_lr_interp"
        input_key = 'lr_interp'
    else:
        # Map simplified names back to full detrend names
        detrend_map = {'gma': 'detrend_gma', 'grid': 'detrend_grid', 'gmt': 'detrend_gmt'}
        input_var = f"{var}_lr_{detrend_map[input_type]}"
        input_key = f'lr_{detrend_map[input_type]}'
    
    # Extract test data
    lr_data = dataset[input_var].sel(time=slice(test_period[0], test_period[1])).values
    residual_true = dataset[f"{var}_residual"].sel(time=slice(test_period[0], test_period[1])).values
    lr_original = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1])).values
    hr_original = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).values
    
    # Apply zscore_pixel normalization to input
    lr_mean = norm_stats[var][input_key]['pixel_mean']
    lr_std = norm_stats[var][input_key]['pixel_std']
    lr_normalized = (lr_data - lr_mean) / (lr_std + 1e-8)
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_normalized)
    predictions_norm = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_normalized[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions_norm.append(batch_pred.cpu().numpy())
    
    predictions_norm = np.concatenate(predictions_norm, axis=0).squeeze(1)
    
    # Denormalize residual predictions
    residual_mean = norm_stats[var]['residual']['pixel_mean']
    residual_std = norm_stats[var]['residual']['pixel_std']
    residual_pred = predictions_norm * residual_std + residual_mean
    
    # Reconstruct full HR prediction by adding back LR_interp
    hr_pred = residual_pred + lr_original
    
    return hr_pred, hr_original, lr_original, residual_pred, residual_true

# Main evaluation loop
print("\n" + "="*80)
print("EVALUATING RESIDUAL MODELS")
print("="*80)

all_results = {}

for scenario_name in datasets.keys():
    print(f"\n{scenario_name.upper()} Scenario")
    print("-"*40)
    
    dataset = datasets[scenario_name]
    dataset_original = datasets_original[scenario_name]
    test_period = test_periods[scenario_name]
    scenario_results = {}
    
    for var in variables:
        print(f"\n  Variable: {var}")
        var_results = {}
        
        # Get coordinates for creating xarray DataArrays
        time_coords = dataset[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).time
        lat_coords = dataset[f"{var}_hr"].lat
        lon_coords = dataset[f"{var}_hr"].lon
        
        # Store ground truth (original HR, not residual)
        hr_true = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1]))
        lr_input = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1]))
        
        var_results['groundtruth'] = hr_true
        var_results['input'] = lr_input
        
        for input_type in input_types:
            # Build model filename based on your naming convention
            if input_type == 'raw':
                model_filename = f"{var}_lr_to_residual.pth"
            else:
                model_filename = f"{var}_lr_{input_type}_to_residual.pth"
            
            model_path = ckpt_dir / model_filename
            
            if not model_path.exists():
                print(f"    Model not found: {model_path}")
                continue
            
            try:
                print(f"    Evaluating {input_type}...", end=" ")
                
                hr_pred, hr_true_np, lr_orig, residual_pred, residual_true = evaluate_model(
                    model_path, var, input_type, dataset, dataset_original, test_period
                )
                
                # Create xarray DataArray for predictions
                pred_da = xr.DataArray(
                    hr_pred,
                    coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
                    dims=['time', 'lat', 'lon'],
                    name=f'{var}_pred_{input_type}'
                )
                
                var_results[f'pred_{input_type}'] = pred_da
                print("Success")
                
            except Exception as e:
                print(f"Error: {e}")
                continue
        
        scenario_results[var] = var_results
    
    all_results[scenario_name] = scenario_results

# Save results
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)

output_dir = Path("evaluation_results_residual")
output_dir.mkdir(exist_ok=True)

for scenario_name, scenario_results in all_results.items():
    for var, var_results in scenario_results.items():
        ds_result = xr.Dataset()
        
        # Store ground truth and input
        ds_result['groundtruth'] = var_results['groundtruth']
        ds_result['input'] = var_results['input']
        
        # Store predictions
        for input_type in input_types:
            key = f'pred_{input_type}'
            if key in var_results:
                ds_result[key] = var_results[key]
        
        # Save to netCDF
        output_path = output_dir / f"{var}_evaluation_{scenario_name}.nc"
        ds_result.to_netcdf(output_path)
        print(f"Saved: {output_path}")

print("\nEvaluation complete!")

Loading datasets...
Loading normalization statistics...

EVALUATING RESIDUAL MODELS

HISTORICAL Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

SSP126 Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

SSP245 Scenario
----------------------------------------

  Variable: pr
    Evaluating raw... Success
    Evaluating gma... Success
    Evaluating grid... Success
    Evaluating gmt... Success

  Variable: tas
    Evaluating raw... Success
    

In [3]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_dir = Path("data")
ckpt_dir = Path("ckpts")

# Test period for G6sulfur
test_period_g6 = ('2020', '2099')

# Variable and input types
var = 'tas'
input_types = ['raw', 'gma', 'grid', 'gmt']

# Load datasets
print("Loading G6sulfur datasets...")
dataset = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_residual_detrended.nc")
dataset_original = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")

# Load normalization statistics
print("Loading normalization statistics...")
with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

def evaluate_model(model_path, var, input_type, dataset, dataset_original, test_period):
    """Evaluate a single model on a dataset."""
    
    # Load model
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Get variable names based on input type
    if input_type == 'raw':
        input_var = f"{var}_lr_interp"
        input_key = 'lr_interp'
    else:
        # Map simplified names back to full detrend names
        detrend_map = {'gma': 'detrend_gma', 'grid': 'detrend_grid', 'gmt': 'detrend_gmt'}
        input_var = f"{var}_lr_{detrend_map[input_type]}"
        input_key = f'lr_{detrend_map[input_type]}'
    
    # Extract test data
    lr_data = dataset[input_var].sel(time=slice(test_period[0], test_period[1])).values
    residual_true = dataset[f"{var}_residual"].sel(time=slice(test_period[0], test_period[1])).values
    lr_original = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period[0], test_period[1])).values
    hr_original = dataset_original[f"{var}_hr"].sel(time=slice(test_period[0], test_period[1])).values
    
    # Apply zscore_pixel normalization to input
    lr_mean = norm_stats[var][input_key]['pixel_mean']
    lr_std = norm_stats[var][input_key]['pixel_std']
    lr_normalized = (lr_data - lr_mean) / (lr_std + 1e-8)
    
    # Predict in batches
    batch_size = 32
    n_samples = len(lr_normalized)
    predictions_norm = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = lr_normalized[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions_norm.append(batch_pred.cpu().numpy())
    
    predictions_norm = np.concatenate(predictions_norm, axis=0).squeeze(1)
    
    # Denormalize residual predictions
    residual_mean = norm_stats[var]['residual']['pixel_mean']
    residual_std = norm_stats[var]['residual']['pixel_std']
    residual_pred = predictions_norm * residual_std + residual_mean
    
    # Reconstruct full HR prediction by adding back LR_interp
    hr_pred = residual_pred + lr_original
    
    return hr_pred, hr_original, lr_original, residual_pred, residual_true

# Main evaluation for G6sulfur
print("\n" + "="*80)
print("EVALUATING G6SULFUR RESIDUAL MODELS")
print("="*80)

g6_results = {}

# Get coordinates for creating xarray DataArrays
time_coords = dataset[f"{var}_hr"].sel(time=slice(test_period_g6[0], test_period_g6[1])).time
lat_coords = dataset[f"{var}_hr"].lat
lon_coords = dataset[f"{var}_hr"].lon

# Store ground truth (original HR, not residual)
hr_true = dataset_original[f"{var}_hr"].sel(time=slice(test_period_g6[0], test_period_g6[1]))
lr_input = dataset_original[f"{var}_lr_interp"].sel(time=slice(test_period_g6[0], test_period_g6[1]))

g6_results['groundtruth'] = hr_true
g6_results['input'] = lr_input

for input_type in input_types:
    # Build model filename
    if input_type == 'raw':
        model_filename = f"{var}_lr_to_residual.pth"
    else:
        model_filename = f"{var}_lr_{input_type}_to_residual.pth"
    
    model_path = ckpt_dir / model_filename
    
    if not model_path.exists():
        print(f"Model not found: {model_path}")
        continue
    
    try:
        print(f"Evaluating {input_type}...", end=" ")
        
        hr_pred, hr_true_np, lr_orig, residual_pred, residual_true = evaluate_model(
            model_path, var, input_type, dataset, dataset_original, test_period_g6
        )
        
        # Create xarray DataArray for predictions
        pred_da = xr.DataArray(
            hr_pred,
            coords={'time': time_coords, 'lat': lat_coords, 'lon': lon_coords},
            dims=['time', 'lat', 'lon'],
            name=f'{var}_pred_{input_type}'
        )
        
        g6_results[f'pred_{input_type}'] = pred_da
        print("Success")
        
    except Exception as e:
        print(f"Error: {e}")
        continue

# Save results
print("\n" + "="*80)
print("SAVING G6SULFUR RESULTS")
print("="*80)

output_dir = Path("evaluation_results_residual")
output_dir.mkdir(exist_ok=True)

ds_result = xr.Dataset()

# Store ground truth and input
ds_result['groundtruth'] = g6_results['groundtruth']
ds_result['input'] = g6_results['input']

# Store predictions
for input_type in input_types:
    key = f'pred_{input_type}'
    if key in g6_results:
        ds_result[key] = g6_results[key]

# Save to netCDF
output_path = output_dir / f"{var}_evaluation_g6sulfur.nc"
ds_result.to_netcdf(output_path)
print(f"   Saved: {output_path}")
print(f"   Variables: {list(ds_result.data_vars)}")
print(f"   Time range: {test_period_g6[0]} to {test_period_g6[1]}")

print("\nG6sulfur residual evaluation complete!")

Loading G6sulfur datasets...
Loading normalization statistics...

EVALUATING G6SULFUR RESIDUAL MODELS
Evaluating raw... Success
Evaluating gma... Success
Evaluating grid... Success
Evaluating gmt... Success

SAVING G6SULFUR RESULTS
   Saved: evaluation_results_residual/tas_evaluation_g6sulfur.nc
   Variables: ['groundtruth', 'input', 'pred_raw', 'pred_gma', 'pred_grid', 'pred_gmt']
   Time range: 2020 to 2099

G6sulfur residual evaluation complete!
