# Data

In [2]:
import xarray as xr

# Open datasets
era5 = xr.open_dataset("data/era5/era5_t2m_1940-2014_monthly.nc")
hist = xr.open_dataset("data/cmip6/hist_lr_tas_monthly_1940_2014.nc")
g6sulfur = xr.open_dataset("data/cmip6/g6sulfur_lr_tas_monthly_2015_2100.nc")

# Extract DataArrays
era5_t2m = era5['t2m']
hist_tas = hist['tas']
g6_tas = g6sulfur['tas']

# --- Clean ERA5 ---
era5_t2m = (era5_t2m
    .rename({'valid_time': 'time', 'latitude': 'lat', 'longitude': 'lon'})
    .drop_vars(['number', 'expver'], errors='ignore')
)

# --- Clean HIST ---
hist_tas = (hist_tas
    .drop_vars(['height', 'variant_label', 'sub_experiment_id'], errors='ignore')
)

# --- Clean G6SULFUR ---
g6_tas = (g6_tas
    .drop_vars(['height', 'variant_label', 'sub_experiment_id'], errors='ignore')
)

# --- Align time: use hist's time for era5 ---
era5_t2m = era5_t2m.assign_coords(time=hist_tas.time)

# Verify
print("ERA5:    ", era5_t2m.dims, era5_t2m.shape)
print("HIST:    ", hist_tas.dims, hist_tas.shape)
print("G6:      ", g6_tas.dims, g6_tas.shape)

ERA5:     ('time', 'lat', 'lon') (900, 721, 1440)
HIST:     ('time', 'lat', 'lon') (900, 96, 192)
G6:       ('time', 'lat', 'lon') (1032, 96, 192)


In [5]:
import dask
from dask.diagnostics import ProgressBar

# Enable parallel computation
dask.config.set(scheduler='threads', num_workers=64)

# --- Bilinear interpolation (parallel) ---
def interpolate_lr_to_hr_grid(lr_data, hr_lat, hr_lon, time_chunk_size=100):
    """Parallel bilinear interpolation to HR grid."""
    print(f"  Interpolating {lr_data.shape} to ({len(hr_lat)}, {len(hr_lon)})...")
    
    # Chunk along time dimension
    lr_chunked = lr_data.chunk({'time': time_chunk_size})
    
    # Interpolate (lazy)
    interpolated = lr_chunked.interp(
        lat=hr_lat, 
        lon=hr_lon, 
        method='linear', 
        kwargs={'fill_value': 'extrapolate'}
    )
    
    # Compute with progress bar
    with ProgressBar():
        result = interpolated.compute()
    
    print(f"  Result shape: {result.shape}")
    return result

# Get ERA5 target coordinates
hr_lat = era5_t2m.lat
hr_lon = era5_t2m.lon

# Interpolate
print("\nInterpolating HIST...")
hist_interp = interpolate_lr_to_hr_grid(hist_tas, hr_lat, hr_lon)

print("\nInterpolating G6SULFUR...")
g6_interp = interpolate_lr_to_hr_grid(g6_tas, hr_lat, hr_lon)

# --- Calculate residual ---
print("\nCalculating residual (ERA5 - HIST_interp)...")
residual = era5_t2m - hist_interp

# Final summary
print("\n=== Final shapes ===")
print(f"ERA5:        {era5_t2m.shape}")
print(f"HIST_interp: {hist_interp.shape}")
print(f"G6_interp:   {g6_interp.shape}")
print(f"Residual:    {residual.shape}")

# Residual stats
print("\n=== Residual stats ===")
print(f"Mean: {float(residual.mean()):.3f} K")
print(f"Std:  {float(residual.std()):.3f} K")


Interpolating HIST...
  Interpolating (900, 96, 192) to (721, 1440)...
[########################################] | 100% Completed | 10.85 ss
  Result shape: (900, 721, 1440)

Interpolating G6SULFUR...
  Interpolating (1032, 96, 192) to (721, 1440)...
[########################################] | 100% Completed | 12.71 ss
  Result shape: (1032, 721, 1440)

Calculating residual (ERA5 - HIST_interp)...

=== Final shapes ===
ERA5:        (900, 721, 1440)
HIST_interp: (900, 721, 1440)
G6_interp:   (1032, 721, 1440)
Residual:    (900, 721, 1440)

=== Residual stats ===
Mean: -0.059 K
Std:  3.262 K


In [13]:
import numpy as np
import pickle

# --- Define time splits ---
train_slice = slice('1940', '2000')  # 1940-2000
val_slice = slice('2001', '2014')    # 2001-2014

# --- Split data ---
print("Splitting data...")

# Inputs (to be detrended)
hist_train = hist_interp.sel(time=train_slice)
hist_val = hist_interp.sel(time=val_slice)
g6_test = g6_interp  # already 2015-2100

# Targets (NO detrending)
residual_train = residual.sel(time=train_slice)
residual_val = residual.sel(time=val_slice)

# ERA5 (for reference)
era5_train = era5_t2m.sel(time=train_slice)
era5_val = era5_t2m.sel(time=val_slice)

print(f"Train period: {hist_train.time.values[0]} to {hist_train.time.values[-1]}")
print(f"Val period:   {hist_val.time.values[0]} to {hist_val.time.values[-1]}")
print(f"Test period:  {g6_test.time.values[0]} to {g6_test.time.values[-1]}")

# --- Detrend function (pixel-wise linear) ---
def detrend_pixelwise(data):
    """Linear detrend at each grid point (following your example)."""
    data_np = data.values
    n_time, n_lat, n_lon = data_np.shape
    
    t = np.arange(n_time).astype(float)
    t_mean = np.mean(t)
    t_centered = t - t_mean
    denominator = np.sum(t_centered**2)
    
    data_reshaped = data_np.reshape(n_time, -1)
    data_mean = np.mean(data_reshaped, axis=0)
    data_centered = data_reshaped - data_mean
    
    slopes = np.sum(t_centered[:, np.newaxis] * data_centered, axis=0) / denominator
    intercepts = data_mean - slopes * t_mean
    
    trends = intercepts[np.newaxis, :] + slopes[np.newaxis, :] * t[:, np.newaxis]
    detrended = (data_reshaped - trends).reshape(n_time, n_lat, n_lon)
    
    detrended_da = xr.DataArray(detrended, coords=data.coords, dims=data.dims)
    
    return detrended_da

# --- Detrend inputs (separately for each split) ---
print("\nDetrending inputs...")

print("  Train input...")
hist_train_detrend = detrend_pixelwise(hist_train)

print("  Val input...")
hist_val_detrend = detrend_pixelwise(hist_val)

print("  Test input (G6)...")
g6_test_detrend = detrend_pixelwise(g6_test)

# --- Compute normalization stats from TRAINING set only ---
print("\nComputing normalization stats from training set...")

def compute_pixel_stats(data):
    """Compute pixel-wise mean and std (over time)."""
    return {
        'mean': data.mean(dim='time').values,
        'std': data.std(dim='time').values
    }

norm_stats_xhr = {
    'input_detrend': compute_pixel_stats(hist_train_detrend),
    'residual': compute_pixel_stats(residual_train),
    'era5': compute_pixel_stats(era5_train),
}

print(f"  Input detrend - mean range: [{norm_stats_xhr['input_detrend']['mean'].min():.2f}, {norm_stats_xhr['input_detrend']['mean'].max():.2f}]")
print(f"  Input detrend - std range:  [{norm_stats_xhr['input_detrend']['std'].min():.4f}, {norm_stats_xhr['input_detrend']['std'].max():.4f}]")
print(f"  Residual - mean range: [{norm_stats_xhr['residual']['mean'].min():.2f}, {norm_stats_xhr['residual']['mean'].max():.2f}]")
print(f"  Residual - std range:  [{norm_stats_xhr['residual']['std'].min():.4f}, {norm_stats_xhr['residual']['std'].max():.4f}]")

# --- Save normalization stats ---
print("\nSaving normalization stats...")
with open('data/norm_stats_xhr.pkl', 'wb') as f:
    pickle.dump(norm_stats_xhr, f)

# --- Save preprocessed data ---
print("\nSaving preprocessed data...")

# CMIP6 folder - detrended inputs
hist_train_detrend.to_netcdf('data/cmip6/hist_train_detrend_xhr.nc')
hist_val_detrend.to_netcdf('data/cmip6/hist_val_detrend_xhr.nc')
g6_test_detrend.to_netcdf('data/cmip6/g6_test_detrend_xhr.nc')

# CMIP6 folder - interpolated (non-detrended)
hist_train.to_netcdf('data/cmip6/hist_train_interp_xhr.nc')
hist_val.to_netcdf('data/cmip6/hist_val_interp_xhr.nc')
g6_test.to_netcdf('data/cmip6/g6_test_interp_xhr.nc')

# ERA5 folder
era5_train.to_netcdf('data/era5/era5_train_xhr.nc')
era5_val.to_netcdf('data/era5/era5_val_xhr.nc')

# Residuals
residual_train.to_netcdf('data/era5/residual_train_xhr.nc')
residual_val.to_netcdf('data/era5/residual_val_xhr.nc')

# --- Summary ---
print("\n=== Final shapes ===")
print(f"hist_train_detrend: {hist_train_detrend.shape}")
print(f"hist_val_detrend:   {hist_val_detrend.shape}")
print(f"g6_test_detrend:    {g6_test_detrend.shape}")
print(f"residual_train:     {residual_train.shape}")
print(f"residual_val:       {residual_val.shape}")
print(f"era5_train:         {era5_train.shape}")
print(f"era5_val:           {era5_val.shape}")

print("\nDone! All files saved.")

Splitting data...
Train period: 1940-01-16T12:00:00.000000000 to 2000-12-16T12:00:00.000000000
Val period:   2001-01-16T12:00:00.000000000 to 2014-12-16T12:00:00.000000000
Test period:  2015-01-16T12:00:00.000000000 to 2100-12-16T12:00:00.000000000

Detrending inputs...
  Train input...
  Val input...
  Test input (G6)...

Computing normalization stats from training set...
  Input detrend - mean range: [-0.00, 0.00]
  Input detrend - std range:  [0.2839, 17.5948]
  Residual - mean range: [-18.58, 13.12]
  Residual - std range:  [0.3473, 7.6617]

Saving normalization stats...

Saving preprocessed data...

=== Final shapes ===
hist_train_detrend: (732, 721, 1440)
hist_val_detrend:   (168, 721, 1440)
g6_test_detrend:    (1032, 721, 1440)
residual_train:     (732, 721, 1440)
residual_val:       (168, 721, 1440)
era5_train:         (732, 721, 1440)
era5_val:           (168, 721, 1440)

Done! All files saved.
