In [2]:
import numpy as np
import xarray as xr
from pathlib import Path
from multiprocessing import Pool
import time

# ----------------------------
# Configuration
# ----------------------------
data_dir = Path("data")
output_dir = Path("evaluation_results")

train_period = ('1850', '2000')
n_cpus = 16

# Variables to process
variables = {
    'tas': 'additive',      # Temperature: absolute changes
    'pr': 'multiplicative'  # Precipitation: relative changes
}

# ----------------------------
# QDM Functions 
# ----------------------------
def apply_qdm_single_point_additive(args):
    """Apply QDM to a single grid point - ADDITIVE (for temperature)"""
    lat_idx, lon_idx, lr_train, hr_train, lr_scenario, train_months, scenario_months = args
    
    n_months_scenario = len(lr_scenario)
    result = np.zeros(n_months_scenario)
    
    for month in range(1, 13):
        lr_month_train = lr_train[train_months == month]
        hr_month_train = hr_train[train_months == month]
        lr_scenario_month = lr_scenario[scenario_months == month]
        
        if len(lr_scenario_month) == 0:
            continue
        
        n_scenario = len(lr_scenario_month)
        ranks = np.argsort(np.argsort(lr_scenario_month)) + 1
        p_values = (ranks - 0.5) / n_scenario
        p_values = np.clip(p_values, 1e-6, 1 - 1e-6)
        
        q_lr_hist = np.quantile(lr_month_train, p_values)
        q_lr_future = np.quantile(lr_scenario_month, p_values)
        q_hr_hist = np.quantile(hr_month_train, p_values)
        
        # Additive delta
        delta = q_lr_future - q_lr_hist
        corrected = q_hr_hist + delta
        
        scenario_indices = np.where(scenario_months == month)[0]
        result[scenario_indices] = corrected
    
    return lat_idx, lon_idx, result


def apply_qdm_single_point_multiplicative(args):
    """Apply QDM to a single grid point - MULTIPLICATIVE (for precipitation)"""
    lat_idx, lon_idx, lr_train, hr_train, lr_scenario, train_months, scenario_months = args
    
    # Trace threshold for precipitation
    trace = 0.05
    epsilon = 1e-10
    
    # Handle zeros: replace with small random values
    lr_train_adj = lr_train.copy()
    hr_train_adj = hr_train.copy()
    lr_scenario_adj = lr_scenario.copy()
    
    # Replace values below trace with small random values
    mask_lr_train = lr_train_adj < trace
    mask_hr_train = hr_train_adj < trace
    mask_lr_scenario = lr_scenario_adj < trace
    
    lr_train_adj[mask_lr_train] = np.random.uniform(epsilon, trace, size=np.sum(mask_lr_train))
    hr_train_adj[mask_hr_train] = np.random.uniform(epsilon, trace, size=np.sum(mask_hr_train))
    lr_scenario_adj[mask_lr_scenario] = np.random.uniform(epsilon, trace, size=np.sum(mask_lr_scenario))
    
    n_months_scenario = len(lr_scenario)
    result = np.zeros(n_months_scenario)
    
    for month in range(1, 13):
        lr_month_train = lr_train_adj[train_months == month]
        hr_month_train = hr_train_adj[train_months == month]
        lr_scenario_month = lr_scenario_adj[scenario_months == month]
        
        if len(lr_scenario_month) == 0:
            continue
        
        n_scenario = len(lr_scenario_month)
        ranks = np.argsort(np.argsort(lr_scenario_month)) + 1
        p_values = (ranks - 0.5) / n_scenario
        p_values = np.clip(p_values, 1e-6, 1 - 1e-6)
        
        q_lr_hist = np.quantile(lr_month_train, p_values)
        q_lr_future = np.quantile(lr_scenario_month, p_values)
        q_hr_hist = np.quantile(hr_month_train, p_values)
        
        # Multiplicative delta (relative change)
        delta = q_lr_future / (q_lr_hist + epsilon)
        corrected = q_hr_hist * delta
        
        scenario_indices = np.where(scenario_months == month)[0]
        result[scenario_indices] = corrected
    
    # Set values below trace threshold back to zero
    result[result < trace] = 0.0
    
    return lat_idx, lon_idx, result


def apply_qdm_scenario(lr_train, hr_train, lr_scenario, method='additive', n_cpus=16):
    """Apply QDM - supports both additive and multiplicative methods"""
    
    lr_train_data = lr_train.values
    hr_train_data = hr_train.values
    lr_scenario_data = lr_scenario.values
    
    train_months = lr_train.time.dt.month.values
    scenario_months = lr_scenario.time.dt.month.values
    
    n_lat, n_lon = lr_scenario_data.shape[1], lr_scenario_data.shape[2]
    n_time = lr_scenario_data.shape[0]
    
    tasks = []
    for i in range(n_lat):
        for j in range(n_lon):
            tasks.append((
                i, j,
                lr_train_data[:, i, j],
                hr_train_data[:, i, j],
                lr_scenario_data[:, i, j],
                train_months,
                scenario_months
            ))
    
    # Select the appropriate function based on method
    if method == 'additive':
        qdm_func = apply_qdm_single_point_additive
    else:
        qdm_func = apply_qdm_single_point_multiplicative
    
    with Pool(n_cpus) as pool:
        results = pool.map(qdm_func, tasks)
    
    output = np.zeros((n_time, n_lat, n_lon))
    for lat_idx, lon_idx, values in results:
        output[:, lat_idx, lon_idx] = values
    
    result_da = xr.DataArray(
        output,
        coords={
            'time': lr_scenario.time,
            'lat': lr_scenario.lat,
            'lon': lr_scenario.lon
        },
        dims=['time', 'lat', 'lon']
    )
    
    return result_da

# ----------------------------
# Load Data
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")
ds_g6 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")

# ----------------------------
# Define Scenarios
# ----------------------------
scenarios = {
    'historical': (ds_hist, '2001', '2014'),
    'ssp126': (ds_ssp126, '2015', '2100'),
    'ssp245': (ds_ssp245, '2015', '2100'),
    'ssp585': (ds_ssp585, '2015', '2100'),
    'g6sulfur': (ds_g6, '2020', '2099'),
}

# ----------------------------
# Apply QDM and Save Results
# ----------------------------
timing_results = {}

for var, method in variables.items():
    print("\n" + "="*60)
    print(f"APPLYING QDM FOR {var.upper()} DOWNSCALING ({method.upper()})")
    print("="*60)
    
    # Prepare Training Data
    lr_train = ds_hist[f'{var}_lr_interp'].sel(time=slice(train_period[0], train_period[1]))
    hr_train = ds_hist[f'{var}_hr'].sel(time=slice(train_period[0], train_period[1]))
    
    timing_results[var] = {}
    
    for scenario_name, (ds, start, end) in scenarios.items():
        print(f"\nProcessing {scenario_name} ({start}-{end})...")
        
        # Get LR input and HR ground truth
        lr_scenario = ds[f'{var}_lr_interp'].sel(time=slice(start, end))
        hr_true = ds[f'{var}_hr'].sel(time=slice(start, end))
        
        # Time the QDM application
        start_time = time.time()
        hr_qdm = apply_qdm_scenario(lr_train, hr_train, lr_scenario, method=method, n_cpus=n_cpus)
        end_time = time.time()
        
        elapsed_time = end_time - start_time
        timing_results[var][scenario_name] = elapsed_time
        
        print(f"  QDM completed in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
        
        # Create output dataset
        ds_out = xr.Dataset({
            'groundtruth': hr_true,
            'input': lr_scenario,
            'pred_qdm': hr_qdm,
        })
        
        # Save
        output_path = output_dir / f"{var}_evaluation_{scenario_name}_qdm.nc"
        ds_out.to_netcdf(output_path)
        print(f"  Saved: {output_path}")

# ----------------------------
# Print Timing Summary
# ----------------------------
print("\n" + "="*60)
print("TIMING SUMMARY")
print("="*60)

for var in variables.keys():
    print(f"\n{var.upper()}:")
    print(f"{'Scenario':<15} {'Time (s)':<12} {'Time (min)':<12}")
    print("-"*39)
    for scenario_name, elapsed in timing_results[var].items():
        print(f"{scenario_name:<15} {elapsed:<12.2f} {elapsed/60:<12.2f}")
    print("-"*39)
    total_time = sum(timing_results[var].values())
    print(f"{'TOTAL':<15} {total_time:<12.2f} {total_time/60:<12.2f}")

grand_total = sum(sum(v.values()) for v in timing_results.values())
print(f"\n{'GRAND TOTAL':<15} {grand_total:<12.2f} {grand_total/60:<12.2f}")

print("\nQDM evaluation complete!")

Loading datasets...

APPLYING QDM FOR TAS DOWNSCALING (ADDITIVE)

Processing historical (2001-2014)...
  QDM completed in 25.51 seconds (0.43 minutes)
  Saved: evaluation_results/tas_evaluation_historical_qdm.nc

Processing ssp126 (2015-2100)...
  QDM completed in 31.64 seconds (0.53 minutes)
  Saved: evaluation_results/tas_evaluation_ssp126_qdm.nc

Processing ssp245 (2015-2100)...
  QDM completed in 33.05 seconds (0.55 minutes)
  Saved: evaluation_results/tas_evaluation_ssp245_qdm.nc

Processing ssp585 (2015-2100)...
  QDM completed in 32.41 seconds (0.54 minutes)
  Saved: evaluation_results/tas_evaluation_ssp585_qdm.nc

Processing g6sulfur (2020-2099)...
  QDM completed in 32.48 seconds (0.54 minutes)
  Saved: evaluation_results/tas_evaluation_g6sulfur_qdm.nc

APPLYING QDM FOR PR DOWNSCALING (MULTIPLICATIVE)

Processing historical (2001-2014)...
  QDM completed in 25.95 seconds (0.43 minutes)
  Saved: evaluation_results/pr_evaluation_historical_qdm.nc

Processing ssp126 (2015-2100)..