In [1]:
import numpy as np
import xarray as xr
from pathlib import Path
from multiprocessing import Pool

# ----------------------------
# Configuration
# ----------------------------
data_dir = Path("data")
output_dir = Path("evaluation_results")

var = 'tas'
train_period = ('1850', '2000')
n_cpus = 16

# ----------------------------
# QDM Functions 
# ----------------------------
def apply_qdm_single_point(args):
    """Apply QDM to a single grid point"""
    lat_idx, lon_idx, lr_train, hr_train, lr_scenario, train_months, scenario_months = args
    
    n_months_scenario = len(lr_scenario)
    result = np.zeros(n_months_scenario)
    
    for month in range(1, 13):
        lr_month_train = lr_train[train_months == month]
        hr_month_train = hr_train[train_months == month]
        lr_scenario_month = lr_scenario[scenario_months == month]
        
        if len(lr_scenario_month) == 0:
            continue
        
        n_scenario = len(lr_scenario_month)
        ranks = np.argsort(np.argsort(lr_scenario_month)) + 1
        p_values = (ranks - 0.5) / n_scenario
        p_values = np.clip(p_values, 1e-6, 1 - 1e-6)
        
        q_lr_hist = np.quantile(lr_month_train, p_values)
        q_lr_future = np.quantile(lr_scenario_month, p_values)
        q_hr_hist = np.quantile(hr_month_train, p_values)
        
        delta_lr = q_lr_future - q_lr_hist
        corrected = q_hr_hist + delta_lr
        
        scenario_indices = np.where(scenario_months == month)[0]
        result[scenario_indices] = corrected
    
    return lat_idx, lon_idx, result

def apply_qdm_scenario(lr_train, hr_train, lr_scenario, n_cpus=16):
    """Apply QDM - now all inputs are DataArrays"""
    
    lr_train_data = lr_train.values
    hr_train_data = hr_train.values
    lr_scenario_data = lr_scenario.values
    
    train_months = lr_train.time.dt.month.values
    scenario_months = lr_scenario.time.dt.month.values
    
    n_lat, n_lon = lr_scenario_data.shape[1], lr_scenario_data.shape[2]
    n_time = lr_scenario_data.shape[0]
    
    tasks = []
    for i in range(n_lat):
        for j in range(n_lon):
            tasks.append((
                i, j,
                lr_train_data[:, i, j],
                hr_train_data[:, i, j],
                lr_scenario_data[:, i, j],
                train_months,
                scenario_months
            ))
    
    with Pool(n_cpus) as pool:
        results = pool.map(apply_qdm_single_point, tasks)
    
    output = np.zeros((n_time, n_lat, n_lon))
    for lat_idx, lon_idx, values in results:
        output[:, lat_idx, lon_idx] = values
    
    result_da = xr.DataArray(
        output,
        coords={
            'time': lr_scenario.time,
            'lat': lr_scenario.lat,
            'lon': lr_scenario.lon
        },
        dims=['time', 'lat', 'lon']
    )
    
    return result_da

# ----------------------------
# Load Data
# ----------------------------
print("Loading datasets...")
ds_hist = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")
ds_g6 = xr.open_dataset(data_dir / "MPI-ESM1-2-HR-LR_g6sulfur_r1i1p1f1_2020_2099_allvars.nc")

# ----------------------------
# Prepare Training Data
# ----------------------------
lr_train = ds_hist[f'{var}_lr_interp'].sel(time=slice(train_period[0], train_period[1]))
hr_train = ds_hist[f'{var}_hr'].sel(time=slice(train_period[0], train_period[1]))

# ----------------------------
# Define Scenarios
# ----------------------------
scenarios = {
    'historical': (ds_hist, '2001', '2014'),
    'ssp126': (ds_ssp126, '2015', '2100'),
    'ssp245': (ds_ssp245, '2015', '2100'),
    'ssp585': (ds_ssp585, '2015', '2100'),
    'g6sulfur': (ds_g6, '2020', '2099'),
}

# ----------------------------
# Apply QDM and Save Results
# ----------------------------
print("\n" + "="*60)
print("APPLYING QDM FOR TAS DOWNSCALING")
print("="*60)

for scenario_name, (ds, start, end) in scenarios.items():
    print(f"\nProcessing {scenario_name} ({start}-{end})...")
    
    # Get LR input and HR ground truth
    lr_scenario = ds[f'{var}_lr_interp'].sel(time=slice(start, end))
    hr_true = ds[f'{var}_hr'].sel(time=slice(start, end))
    
    # Apply QDM
    hr_qdm = apply_qdm_scenario(lr_train, hr_train, lr_scenario, n_cpus=n_cpus)
    
    # Create output dataset
    ds_out = xr.Dataset({
        'groundtruth': hr_true,
        'input': lr_scenario,
        'pred_qdm': hr_qdm,
    })
    
    # Save
    output_path = output_dir / f"{var}_evaluation_{scenario_name}_qdm.nc"
    ds_out.to_netcdf(output_path)
    print(f"Saved: {output_path}")

print("\nQDM evaluation complete!")

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


Loading datasets...

APPLYING QDM FOR TAS DOWNSCALING

Processing historical (2001-2014)...
Saved: evaluation_results/tas_evaluation_historical_qdm.nc

Processing ssp126 (2015-2100)...
Saved: evaluation_results/tas_evaluation_ssp126_qdm.nc

Processing ssp245 (2015-2100)...
Saved: evaluation_results/tas_evaluation_ssp245_qdm.nc

Processing ssp585 (2015-2100)...
Saved: evaluation_results/tas_evaluation_ssp585_qdm.nc

Processing g6sulfur (2020-2099)...
Saved: evaluation_results/tas_evaluation_g6sulfur_qdm.nc

QDM evaluation complete!
