In [2]:
import xarray as xr
import numpy as np

# Load G6solar data (now contains both tas_lr_interp and tas_hr)
ds_g6solar = xr.open_dataset("data/MPI-ESM1-2-HR-LR_G6solar_r1i1p1f1_allvars.nc")

g6solar_lr_interp = ds_g6solar['tas_lr_interp']
g6solar_hr = ds_g6solar['tas_hr']

print(f"LR interp shape: {g6solar_lr_interp.shape}")
print(f"HR shape: {g6solar_hr.shape}")
print(f"Time range: {g6solar_lr_interp.time.values[0]} to {g6solar_lr_interp.time.values[-1]}")

# Pixelwise linear detrending function
def detrend_pixelwise(data):
    """Linear detrend at each grid point."""
    data_np = data.values
    n_time, n_lat, n_lon = data_np.shape
    
    t = np.arange(n_time).astype(float)
    t_mean = np.mean(t)
    t_centered = t - t_mean
    denominator = np.sum(t_centered**2)
    
    data_reshaped = data_np.reshape(n_time, -1)
    data_mean = np.mean(data_reshaped, axis=0)
    data_centered = data_reshaped - data_mean
    
    slopes = np.sum(t_centered[:, np.newaxis] * data_centered, axis=0) / denominator
    intercepts = data_mean - slopes * t_mean
    
    trends = intercepts[np.newaxis, :] + slopes[np.newaxis, :] * t[:, np.newaxis]
    detrended = (data_reshaped - trends).reshape(n_time, n_lat, n_lon)
    
    return xr.DataArray(detrended, coords=data.coords, dims=data.dims, attrs=data.attrs)

# Detrend LR interp (for residual model input)
print("\nDetrending LR interp...")
g6solar_lr_detrend = detrend_pixelwise(g6solar_lr_interp)
print(f"LR detrended shape: {g6solar_lr_detrend.shape}")
print(f"LR detrended NaN: {g6solar_lr_detrend.isnull().sum().values / g6solar_lr_detrend.size * 100:.2f}%")

# Save all components
ds_out = xr.Dataset({
    'tas_lr_interp': g6solar_lr_interp,
    'tas_lr_detrend': g6solar_lr_detrend,
    'tas_hr': g6solar_hr
})

ds_out.to_netcdf("data/g6solar_processed.nc")
print("\nSaved: data/g6solar_processed.nc")

# Verification
print("\n=== Verification ===")
verify = xr.open_dataset("data/g6solar_processed.nc")
print(f"LR interp: {verify['tas_lr_interp'].shape}, range: [{float(verify['tas_lr_interp'].min()):.2f}, {float(verify['tas_lr_interp'].max()):.2f}]°C")
print(f"LR detrend: {verify['tas_lr_detrend'].shape}, range: [{float(verify['tas_lr_detrend'].min()):.2f}, {float(verify['tas_lr_detrend'].max()):.2f}]°C")
print(f"HR: {verify['tas_hr'].shape}, range: [{float(verify['tas_hr'].min()):.2f}, {float(verify['tas_hr'].max()):.2f}]°C")

verify.close()
ds_g6solar.close()
print("\nDone!")

LR interp shape: (1020, 192, 384)
HR shape: (1020, 192, 384)
Time range: 2015-01-16T12:00:00.000000000 to 2099-12-16T12:00:00.000000000

Detrending LR interp...
LR detrended shape: (1020, 192, 384)
LR detrended NaN: 0.00%

Saved: data/g6solar_processed.nc

=== Verification ===
LR interp: (1020, 192, 384), range: [-73.07, 46.55]°C
LR detrend: (1020, 192, 384), range: [-30.31, 30.46]°C
HR: (1020, 192, 384), range: [-72.24, 48.20]°C

Done!


In [3]:
import xarray as xr
import numpy as np
import torch
import pickle
from pathlib import Path
from unet import UNet
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# ----------------------------
# Configuration
# ----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

data_dir = Path("data")
ckpt_dir = Path("ckpts")

# ----------------------------
# Helper Functions
# ----------------------------
def load_model(model_path):
    """Load a UNet model from checkpoint."""
    model = UNet(in_channels=1, out_channels=1, initial_features=32, depth=5, dropout=0.2)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    return model

def predict_batched(model, data, batch_size=32):
    """Run model predictions in batches."""
    n_samples = len(data)
    predictions = []
    
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            batch = data[i:i+batch_size]
            batch_tensor = torch.tensor(batch, dtype=torch.float32).unsqueeze(1).to(device)
            batch_pred = model(batch_tensor)
            predictions.append(batch_pred.cpu().numpy())
    
    return np.concatenate(predictions, axis=0).squeeze(1)

# ----------------------------
# Load Data
# ----------------------------
print("Loading G6solar data...")
ds_g6solar = xr.open_dataset(data_dir / "g6solar_processed.nc")

g6solar_lr_interp = ds_g6solar['tas_lr_interp']
g6solar_lr_detrend = ds_g6solar['tas_lr_detrend']
g6solar_hr = ds_g6solar['tas_hr']

print(f"LR interp shape: {g6solar_lr_interp.shape}")
print(f"LR detrend shape: {g6solar_lr_detrend.shape}")
print(f"HR (ground truth) shape: {g6solar_hr.shape}")
print(f"Time range: {g6solar_lr_interp.time.values[0]} to {g6solar_lr_interp.time.values[-1]}")

# Load normalization statistics (from training)
print("\nLoading normalization statistics...")
with open(data_dir / "norm_stats.pkl", 'rb') as f:
    norm_stats_direct = pickle.load(f)

with open(data_dir / "norm_stats_zscore_pixel_residual_detrended.pkl", 'rb') as f:
    norm_stats_residual = pickle.load(f)

print(f"Direct norm stats keys: {list(norm_stats_direct['tas'].keys())}")
print(f"Residual norm stats keys: {list(norm_stats_residual['tas'].keys())}")

# Store coordinates for output
coords = {
    'time': g6solar_lr_interp.time,
    'lat': g6solar_lr_interp.lat,
    'lon': g6solar_lr_interp.lon
}

# ----------------------------
# Method 1: Direct (zscore_pixel, no detrending)
# Input: interpolated LR -> Predict: HR directly
# ----------------------------
print("\n" + "="*60)
print("Method 1: Direct prediction (zscore_pixel)")
print("="*60)

# Load direct model
model_direct = load_model(ckpt_dir / "tas_zscore_pixel.pth")

# Normalize input using training stats
lr_mean = norm_stats_direct['tas']['lr_interp']['pixel_mean']
lr_std = norm_stats_direct['tas']['lr_interp']['pixel_std']
g6solar_norm = (g6solar_lr_interp.values - lr_mean) / (lr_std + 1e-8)

print("Predicting...")
pred_direct_norm = predict_batched(model_direct, g6solar_norm)

# Denormalize using HR stats
hr_mean = norm_stats_direct['tas']['hr']['pixel_mean']
hr_std = norm_stats_direct['tas']['hr']['pixel_std']
pred_direct = pred_direct_norm * hr_std + hr_mean

pred_direct_da = xr.DataArray(pred_direct, coords=coords, dims=['time', 'lat', 'lon'])
print(f"Direct prediction shape: {pred_direct_da.shape}")
print(f"Direct prediction range: [{float(pred_direct_da.min()):.2f}, {float(pred_direct_da.max()):.2f}]°C")

# ----------------------------
# Method 2: Residual (zscore_pixel, with detrending)
# Input: detrended LR -> Predict: residual -> Add to interp LR
# ----------------------------
print("\n" + "="*60)
print("Method 2: Residual prediction (zscore_pixel + detrend)")
print("="*60)

# Load residual model
model_residual = load_model(ckpt_dir / "tas_lr_grid_to_residual.pth")

# Normalize detrended input using training stats
lr_detrend_mean = norm_stats_residual['tas']['lr_detrend_grid']['pixel_mean']
lr_detrend_std = norm_stats_residual['tas']['lr_detrend_grid']['pixel_std']
g6solar_detrend_norm = (g6solar_lr_detrend.values - lr_detrend_mean) / (lr_detrend_std + 1e-8)

print("Predicting residual...")
pred_residual_norm = predict_batched(model_residual, g6solar_detrend_norm)

# Denormalize residual using training stats
residual_mean = norm_stats_residual['tas']['residual']['pixel_mean']
residual_std = norm_stats_residual['tas']['residual']['pixel_std']
pred_residual = pred_residual_norm * residual_std + residual_mean

# Add residual to interpolated input to get final HR prediction
pred_final = pred_residual + g6solar_lr_interp.values

pred_residual_da = xr.DataArray(pred_final, coords=coords, dims=['time', 'lat', 'lon'])
print(f"Residual prediction shape: {pred_residual_da.shape}")
print(f"Residual prediction range: [{float(pred_residual_da.min()):.2f}, {float(pred_residual_da.max()):.2f}]°C")

# ----------------------------
# Save Results
# ----------------------------
print("\n" + "="*60)
print("Saving results...")
print("="*60)

ds_results = xr.Dataset({
    'groundtruth': g6solar_hr,
    'input': g6solar_lr_interp,
    'input_detrend': g6solar_lr_detrend,
    'pred_zscore_pixel': pred_direct_da,
    'pred_grid': pred_residual_da
})

output_path = data_dir / "g6solar_downscaled_results.nc"
ds_results.to_netcdf(output_path)
print(f"Saved: {output_path}")


Using device: cuda
Loading G6solar data...
LR interp shape: (1020, 192, 384)
LR detrend shape: (1020, 192, 384)
HR (ground truth) shape: (1020, 192, 384)
Time range: 2015-01-16T12:00:00.000000000 to 2099-12-16T12:00:00.000000000

Loading normalization statistics...
Direct norm stats keys: ['hr', 'lr_interp']
Residual norm stats keys: ['hr', 'lr_interp', 'residual', 'lr_detrend_gma', 'lr_detrend_grid', 'lr_detrend_gmt']

Method 1: Direct prediction (zscore_pixel)
Predicting...
Direct prediction shape: (1020, 192, 384)
Direct prediction range: [-68.22, 42.57]°C

Method 2: Residual prediction (zscore_pixel + detrend)
Predicting residual...
Residual prediction shape: (1020, 192, 384)
Residual prediction range: [-70.41, 48.68]°C

Saving results...
Saved: data/g6solar_downscaled_results.nc


In [4]:
import numpy as np
import xarray as xr
from pathlib import Path
from multiprocessing import Pool
import time

# ----------------------------
# Configuration
# ----------------------------
data_dir = Path("data")
output_dir = Path("evaluation_results")
output_dir.mkdir(exist_ok=True)

train_period = ('1850', '2000')
n_cpus = 16

# ----------------------------
# QDM Functions 
# ----------------------------
def apply_qdm_single_point_additive(args):
    """Apply QDM to a single grid point - ADDITIVE (for temperature)"""
    lat_idx, lon_idx, lr_train, hr_train, lr_scenario, train_months, scenario_months = args
    
    n_months_scenario = len(lr_scenario)
    result = np.zeros(n_months_scenario)
    
    for month in range(1, 13):
        lr_month_train = lr_train[train_months == month]
        hr_month_train = hr_train[train_months == month]
        lr_scenario_month = lr_scenario[scenario_months == month]
        
        if len(lr_scenario_month) == 0:
            continue
        
        n_scenario = len(lr_scenario_month)
        ranks = np.argsort(np.argsort(lr_scenario_month)) + 1
        p_values = (ranks - 0.5) / n_scenario
        p_values = np.clip(p_values, 1e-6, 1 - 1e-6)
        
        q_lr_hist = np.quantile(lr_month_train, p_values)
        q_lr_future = np.quantile(lr_scenario_month, p_values)
        q_hr_hist = np.quantile(hr_month_train, p_values)
        
        # Additive delta
        delta = q_lr_future - q_lr_hist
        corrected = q_hr_hist + delta
        
        scenario_indices = np.where(scenario_months == month)[0]
        result[scenario_indices] = corrected
    
    return lat_idx, lon_idx, result


def apply_qdm_scenario(lr_train, hr_train, lr_scenario, n_cpus=16):
    """Apply QDM with additive method for temperature"""
    
    lr_train_data = lr_train.values
    hr_train_data = hr_train.values
    lr_scenario_data = lr_scenario.values
    
    train_months = lr_train.time.dt.month.values
    scenario_months = lr_scenario.time.dt.month.values
    
    n_lat, n_lon = lr_scenario_data.shape[1], lr_scenario_data.shape[2]
    n_time = lr_scenario_data.shape[0]
    
    tasks = []
    for i in range(n_lat):
        for j in range(n_lon):
            tasks.append((
                i, j,
                lr_train_data[:, i, j],
                hr_train_data[:, i, j],
                lr_scenario_data[:, i, j],
                train_months,
                scenario_months
            ))
    
    with Pool(n_cpus) as pool:
        results = pool.map(apply_qdm_single_point_additive, tasks)
    
    output = np.zeros((n_time, n_lat, n_lon))
    for lat_idx, lon_idx, values in results:
        output[:, lat_idx, lon_idx] = values
    
    result_da = xr.DataArray(
        output,
        coords={
            'time': lr_scenario.time,
            'lat': lr_scenario.lat,
            'lon': lr_scenario.lon
        },
        dims=['time', 'lat', 'lon']
    )
    
    return result_da

# ----------------------------
# Load Data
# ----------------------------
print("Loading datasets...")

# Historical data for training
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")

# G6solar data
ds_g6solar = xr.open_dataset(data_dir / "g6solar_processed.nc")

print(f"Historical shape: {ds_hist['tas_lr_interp'].shape}")
print(f"G6solar LR interp shape: {ds_g6solar['tas_lr_interp'].shape}")
print(f"G6solar HR shape: {ds_g6solar['tas_hr'].shape}")

# ----------------------------
# Prepare Training Data
# ----------------------------
print("\nPreparing training data...")
lr_train = ds_hist['tas_lr_interp'].sel(time=slice(train_period[0], train_period[1]))
hr_train = ds_hist['tas_hr'].sel(time=slice(train_period[0], train_period[1]))

print(f"Training period: {train_period[0]} - {train_period[1]}")
print(f"LR train shape: {lr_train.shape}")
print(f"HR train shape: {hr_train.shape}")

# ----------------------------
# Apply QDM to G6solar
# ----------------------------
print("\n" + "="*60)
print("APPLYING QDM FOR G6SOLAR TAS DOWNSCALING")
print("="*60)

lr_scenario = ds_g6solar['tas_lr_interp']
hr_true = ds_g6solar['tas_hr']

print(f"G6solar time range: {lr_scenario.time.values[0]} to {lr_scenario.time.values[-1]}")

start_time = time.time()
hr_qdm = apply_qdm_scenario(lr_train, hr_train, lr_scenario, n_cpus=n_cpus)
elapsed_time = time.time() - start_time

print(f"QDM completed in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

# ----------------------------
# Save Results
# ----------------------------
print("\n" + "="*60)
print("Saving results...")
print("="*60)

ds_out = xr.Dataset({
    'groundtruth': hr_true,
    'input': lr_scenario,
    'pred_qdm': hr_qdm,
})

output_path = output_dir / "tas_evaluation_g6solar_qdm.nc"
ds_out.to_netcdf(output_path)
print(f"Saved: {output_path}")



Loading datasets...
Historical shape: (1980, 192, 384)
G6solar LR interp shape: (1020, 192, 384)
G6solar HR shape: (1020, 192, 384)

Preparing training data...
Training period: 1850 - 2000
LR train shape: (1812, 192, 384)
HR train shape: (1812, 192, 384)

APPLYING QDM FOR G6SOLAR TAS DOWNSCALING
G6solar time range: 2015-01-16T12:00:00.000000000 to 2099-12-16T12:00:00.000000000
QDM completed in 34.69 seconds (0.58 minutes)

Saving results...
Saved: evaluation_results/tas_evaluation_g6solar_qdm.nc


In [5]:
ds_out