In [4]:
# eval_xhr.py

import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
import time

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_dir = Path("data")
ckpt_dir = Path("ckpts")
output_dir = Path("evaluation_results_xhr")
output_dir.mkdir(exist_ok=True)

model_path = ckpt_dir / "xhr_model.pth"

# ----------------------------
# Helper Functions
# ----------------------------

def load_model(model_path):
    """Load a UNet model from checkpoint."""
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    print(f"Model loaded from {model_path}")
    print(f"  Best val loss: {checkpoint['best_val_loss']:.6f}")
    print(f"  Trained for {checkpoint['epoch'] + 1} epochs")
    return model


def normalize_zscore_pixel(data, mean, std):
    """Apply pixel-wise z-score normalization."""
    return (data - mean) / (std + 1e-8)


def denormalize_zscore_pixel(data, mean, std):
    """Reverse pixel-wise z-score normalization."""
    return data * (std + 1e-8) + mean


def predict_batched(model, data, batch_size=1):
    """Run model predictions in batches."""
    n_samples = data.shape[0]
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            if i % 100 == 0:
                print(f"    Processing {i}/{n_samples}...")
            
            batch = data[i:i+batch_size]
            if batch.ndim == 3:
                batch = np.expand_dims(batch, axis=1)
            
            batch_tensor = torch.tensor(batch, dtype=torch.float32).to(device)
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    print(f"    Processing {n_samples}/{n_samples}... Done!")
    return np.concatenate(predictions, axis=0).squeeze(1)


# ----------------------------
# Main
# ----------------------------
print("=" * 80)
print("XHR G6SULFUR DOWNSCALING")
print("=" * 80)

print("\nLoading data...")
load_start = time.time()

# Load test data (G6sulfur)
test_input = xr.open_dataarray(data_dir / "cmip6/g6_test_detrend_xhr.nc")
test_cmip_interp = xr.open_dataarray(data_dir / "cmip6/g6_test_interp_xhr.nc")

# Load normalization stats
with open(data_dir / "norm_stats_xhr.pkl", 'rb') as f:
    norm_stats = pickle.load(f)

load_time = time.time() - load_start
print(f"Data loaded in {load_time:.2f}s")
print(f"  test_input (detrended): {test_input.shape}")
print(f"  test_cmip_interp: {test_cmip_interp.shape}")

# Load model
print("\n" + "-" * 60)
model = load_model(model_path)

# ----------------------------
# Run Inference
# ----------------------------
print("\n" + "=" * 60)
print("RUNNING INFERENCE (G6sulfur 2015-2100)")
print("=" * 60)

inference_start = time.time()

# Get numpy arrays
test_input_np = test_input.values
test_cmip_interp_np = test_cmip_interp.values

# Normalize input
print("\nNormalizing input...")
test_input_norm = normalize_zscore_pixel(
    test_input_np,
    norm_stats['input_detrend']['mean'],
    norm_stats['input_detrend']['std']
)

# Predict residual
print("\nPredicting residual...")
residual_pred_norm = predict_batched(model, test_input_norm, batch_size=1)

# Denormalize residual
print("\nDenormalizing predictions...")
residual_pred = denormalize_zscore_pixel(
    residual_pred_norm,
    norm_stats['residual']['mean'],
    norm_stats['residual']['std']
)

# Reconstruct: downscaled = CMIP6_interp + predicted_residual
print("\nReconstructing downscaled output...")
downscaled = test_cmip_interp_np + residual_pred

inference_time = time.time() - inference_start
print(f"\nInference completed in {inference_time:.2f}s")

# ----------------------------
# Save Results
# ----------------------------
print("\n" + "-" * 60)
print("Saving results...")

coords = {
    'time': test_input.time,
    'lat': test_input.lat,
    'lon': test_input.lon
}

ds_output = xr.Dataset({
    'tas_downscaled': xr.DataArray(downscaled, coords=coords, dims=['time', 'lat', 'lon']),
    'tas_cmip6_interp': xr.DataArray(test_cmip_interp_np, coords=coords, dims=['time', 'lat', 'lon']),
    'tas_residual_pred': xr.DataArray(residual_pred, coords=coords, dims=['time', 'lat', 'lon']),
})

output_path = output_dir / "g6sulfur_downscaled_xhr.nc"
ds_output.to_netcdf(output_path)

print(f"Saved: {output_path}")
print(f"  tas_downscaled: {downscaled.shape}")
print(f"  tas_cmip6_interp: {test_cmip_interp_np.shape}")
print(f"  tas_residual_pred: {residual_pred.shape}")

# ----------------------------
# Summary
# ----------------------------
print("\n" + "=" * 80)
print("COMPLETE!")
print("=" * 80)
print(f"Output: {output_path}")
print(f"Time range: {test_input.time.values[0]} to {test_input.time.values[-1]}")
print(f"Spatial resolution: {len(coords['lat'])} x {len(coords['lon'])}")
print(f"Total time: {inference_time:.2f}s")

Using device: cuda
XHR G6SULFUR DOWNSCALING

Loading data...
Data loaded in 0.68s
  test_input (detrended): (1020, 721, 1440)
  test_cmip_interp: (1020, 721, 1440)

------------------------------------------------------------
Model loaded from ckpts/xhr_model.pth
  Best val loss: 0.612789
  Trained for 32 epochs

RUNNING INFERENCE (G6sulfur 2015-2100)

Normalizing input...

Predicting residual...
    Processing 0/1020...
    Processing 100/1020...
    Processing 200/1020...
    Processing 300/1020...
    Processing 400/1020...
    Processing 500/1020...
    Processing 600/1020...
    Processing 700/1020...
    Processing 800/1020...
    Processing 900/1020...
    Processing 1000/1020...
    Processing 1020/1020... Done!

Denormalizing predictions...

Reconstructing downscaled output...

Inference completed in 158.36s

------------------------------------------------------------
Saving results...
Saved: evaluation_results_xhr/g6sulfur_downscaled_xhr.nc
  tas_downscaled: (1020, 721, 1440