In [21]:
import xarray as xr
import numpy as np

# Define data directory
data_dir = "data"

# Load the original datasets (without log transform)
print("Loading original datasets...")
ds_hist = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc")
ds_ssp126 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp245 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc")
ds_ssp585 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")

def compute_normalization_stats(target_train_np, input_train_np):
    """
    Compute all normalization statistics from training data using numpy arrays
    
    Parameters:
    -----------
    target_train_np : numpy.ndarray - Training data for target (HR), shape (time, lat, lon)
    input_train_np : numpy.ndarray - Training data for input (LR), shape (time, lat, lon)
    
    Returns:
    --------
    stats : dict with structure:
        stats['hr']['global_mean'], stats['hr']['global_std'], etc.
        stats['lr_interp']['global_mean'], stats['lr_interp']['global_std'], etc.
    """
    stats = {
        'hr': {},
        'lr_interp': {}
    }
    
    # HR statistics - all using numpy
    stats['hr']['global_mean'] = float(np.mean(target_train_np))
    stats['hr']['global_std'] = float(np.std(target_train_np))
    stats['hr']['global_min'] = float(np.min(target_train_np))
    stats['hr']['global_max'] = float(np.max(target_train_np))
    stats['hr']['pixel_mean'] = np.mean(target_train_np, axis=0)  # Average over time
    stats['hr']['pixel_std'] = np.std(target_train_np, axis=0)    # Std over time
    stats['hr']['pixel_min'] = np.min(target_train_np, axis=0)    # Min over time
    stats['hr']['pixel_max'] = np.max(target_train_np, axis=0)    # Max over time
    
    # LR_interp statistics - all using numpy
    stats['lr_interp']['global_mean'] = float(np.mean(input_train_np))
    stats['lr_interp']['global_std'] = float(np.std(input_train_np))
    stats['lr_interp']['global_min'] = float(np.min(input_train_np))
    stats['lr_interp']['global_max'] = float(np.max(input_train_np))
    stats['lr_interp']['pixel_mean'] = np.mean(input_train_np, axis=0)  # Average over time
    stats['lr_interp']['pixel_std'] = np.std(input_train_np, axis=0)    # Std over time
    stats['lr_interp']['pixel_min'] = np.min(input_train_np, axis=0)    # Min over time
    stats['lr_interp']['pixel_max'] = np.max(input_train_np, axis=0)    # Max over time
    
    return stats

# Compute normalization statistics for all variables
variables = ['pr_hr', 'tas_hr', 'hurs_hr', 'sfcWind_hr']
train_start = '1850'
train_end = '1980'

print(f"Computing normalization statistics for training period: {train_start}-{train_end}\n")

norm_stats = {}
for var in variables:
    print(f"Processing {var}...")
    
    # Get base variable name (without _hr)
    var_base = var.replace('_hr', '')
    
    # Get corresponding LR variable name
    var_lr = var.replace('_hr', '_lr_interp')
    
    # Extract training data and IMMEDIATELY convert to numpy
    target_train_xr = ds_hist[var].sel(time=slice(train_start, train_end))
    input_train_xr = ds_hist[var_lr].sel(time=slice(train_start, train_end))
    
    # Convert to numpy arrays
    target_train_np = target_train_xr.values
    input_train_np = input_train_xr.values
    
    print(f"  Shape check - HR: {target_train_np.shape}, LR: {input_train_np.shape}")
    
    # Compute statistics using numpy arrays
    stats = compute_normalization_stats(target_train_np, input_train_np)
    
    # Store with base variable name
    norm_stats[var_base] = stats
    
    print(f"  {var_base} HR - Global mean: {stats['hr']['global_mean']:.6f}, std: {stats['hr']['global_std']:.6f}")
    print(f"  {var_base} LR - Global mean: {stats['lr_interp']['global_mean']:.6f}, std: {stats['lr_interp']['global_std']:.6f}")
    print()

print("Normalization statistics computed successfully using numpy arrays!")
print("\nSummary of global statistics:")
for var_base in norm_stats.keys():
    hr_std = norm_stats[var_base]['hr']['global_std']
    lr_std = norm_stats[var_base]['lr_interp']['global_std']
    ratio = hr_std / lr_std if lr_std > 0 else 0
    print(f"{var_base:10s} - HR std: {hr_std:8.4f}, LR std: {lr_std:8.4f}, Ratio: {ratio:.2f}")

Loading original datasets...
Computing normalization statistics for training period: 1850-1980

Processing pr_hr...
  Shape check - HR: (1572, 192, 384), LR: (1572, 192, 384)
  pr HR - Global mean: 2.373345, std: 2.978923
  pr LR - Global mean: 2.328078, std: 2.809266

Processing tas_hr...
  Shape check - HR: (1572, 192, 384), LR: (1572, 192, 384)
  tas HR - Global mean: 4.989362, std: 21.373560
  tas LR - Global mean: 4.677408, std: 21.151808

Processing hurs_hr...
  Shape check - HR: (1572, 192, 384), LR: (1572, 192, 384)
  hurs HR - Global mean: 81.109680, std: 18.955059
  hurs LR - Global mean: 81.973465, std: 17.777142

Processing sfcWind_hr...
  Shape check - HR: (1572, 192, 384), LR: (1572, 192, 384)
  sfcWind HR - Global mean: 6.178439, std: 2.727153
  sfcWind LR - Global mean: 6.312948, std: 2.687389

Normalization statistics computed successfully using numpy arrays!

Summary of global statistics:
pr         - HR std:   2.9789, LR std:   2.8093, Ratio: 1.06
tas        - HR std

In [23]:
for var in norm_stats:
    print(f"\n{var}:")
    for res in norm_stats[var]:
        print(f"  {res}:")
        for stat in norm_stats[var][res]:
            val = norm_stats[var][res][stat]
            if hasattr(val, 'shape'):
                print(f"    {stat}: shape{val.shape}")
            else:
                print(f"    {stat}: {val:.4f}")


pr:
  hr:
    global_mean: 2.3733
    global_std: 2.9789
    global_min: 0.0000
    global_max: 72.3470
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 2.3281
    global_std: 2.8093
    global_min: 0.0000
    global_max: 48.3795
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

tas:
  hr:
    global_mean: 4.9894
    global_std: 21.3736
    global_min: -76.4118
    global_max: 44.7484
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 4.6774
    global_std: 21.1518
    global_min: -75.3701
    global_max: 44.2455
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

hurs:
  hr:
    global_mean: 81.1097
    global_std: 18.9551
    glo

In [9]:
for var in norm_stats:
    print(f"\n{var}:")
    for res in norm_stats[var]:
        print(f"  {res}:")
        for stat in norm_stats[var][res]:
            val = norm_stats[var][res][stat]
            if hasattr(val, 'shape'):
                print(f"    {stat}: shape{val.shape}")
            else:
                print(f"    {stat}: {val:.4f}")


pr:
  hr:
    global_mean: 2.3733
    global_std: 2.9789
    global_min: 0.0000
    global_max: 72.3470
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 2.3281
    global_std: 2.8093
    global_min: 0.0000
    global_max: 48.3795
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

tas:
  hr:
    global_mean: 4.9894
    global_std: 17.4904
    global_min: -76.4118
    global_max: 44.7484
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 4.6774
    global_std: 21.1518
    global_min: -75.3701
    global_max: 44.2455
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

hurs:
  hr:
    global_mean: 81.1097
    global_std: 42.0933
    glo

In [24]:
import pickle
import numpy as np

# Save normalization statistics
print("Saving normalization statistics...")

# Check what we're saving
print("\nData types in norm_stats:")
for var in norm_stats:
    for res in ['hr', 'lr_interp']:
        for stat in norm_stats[var][res]:
            val = norm_stats[var][res][stat]
            print(f"  {var}.{res}.{stat}: {type(val)}")
            break  # Just check first one
        break
    break

# Save with pickle (pickle handles numpy arrays fine)
with open('data/norm_stats.pkl', 'wb') as f:
    pickle.dump(norm_stats, f)

print("✓ Saved to data/norm_stats.pkl")

# Verify by loading it back
print("\nVerifying saved file...")
with open('data/norm_stats.pkl', 'rb') as f:
    loaded_stats = pickle.load(f)

# Check that it loaded correctly
print("Loaded successfully!")
print(f"Variables: {list(loaded_stats.keys())}")
print(f"\nExample - pr HR global_mean: {loaded_stats['pr']['hr']['global_mean']:.4f}")
print(f"Example - pr HR pixel_mean shape: {loaded_stats['pr']['hr']['pixel_mean'].shape}")

# Verify numpy arrays are preserved
assert isinstance(loaded_stats['pr']['hr']['pixel_mean'], np.ndarray), "pixel_mean should be numpy array"
assert loaded_stats['pr']['hr']['pixel_mean'].shape == (192, 384), "Shape should be preserved"
print("\n✓ All checks passed!")

Saving normalization statistics...

Data types in norm_stats:
  pr.hr.global_mean: <class 'float'>
✓ Saved to data/norm_stats.pkl

Verifying saved file...
Loaded successfully!
Variables: ['pr', 'tas', 'hurs', 'sfcWind']

Example - pr HR global_mean: 2.3733
Example - pr HR pixel_mean shape: (192, 384)

✓ All checks passed!


# old draft

In [3]:
import xarray as xr
import numpy as np

# Define data directory
data_dir = "data"

# Apply log transform to precipitation variables
eps = 1e-4

print("Applying log transform to precipitation variables...")

file_info = [
    ('historical', f"{data_dir}/MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars.nc"),
    ('ssp126', f"{data_dir}/MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars.nc"),
    ('ssp245', f"{data_dir}/MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars.nc"),
    ('ssp585', f"{data_dir}/MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars.nc")
]

for ds_name, input_filepath in file_info:
    print(f"  Processing {ds_name}...")
    
    # Open dataset
    ds = xr.open_dataset(input_filepath)
    
    # Transform pr_hr
    ds['pr_hr'] = np.log(ds['pr_hr'] + eps) - np.log(eps)
    ds['pr_hr'].attrs['log_transform_formula'] = f'log(x + {eps}) - log({eps})'
    
    # Transform pr_lr_interp
    ds['pr_lr_interp'] = np.log(ds['pr_lr_interp'] + eps) - np.log(eps)
    ds['pr_lr_interp'].attrs['log_transform_formula'] = f'log(x + {eps}) - log({eps})'
    
    # Create output filepath
    output_filepath = input_filepath.replace('.nc', '_logpr.nc')
    
    # Save to new file
    print(f"    Saving to {output_filepath}...")
    ds.to_netcdf(output_filepath)
    ds.close()
    print(f"    ✓ Saved")

print("\nLog transform complete! New files saved with '_logpr' suffix.")
print(f"To inverse transform: x_original = np.exp(x_transformed + np.log({eps})) - {eps}")

# # Reload the new datasets with log-transformed precipitation
ds_hist = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_historical_r1i1p1f1_1850_2014_allvars_logpr.nc")
ds_ssp126 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp126_r1i1p1f1_2015_2100_allvars_logpr.nc")
ds_ssp245 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp245_r1i1p1f1_2015_2100_allvars_logpr.nc")
ds_ssp585 = xr.open_dataset(f"{data_dir}/MPI-ESM1-2-HR-LR_ssp585_r1i1p1f1_2015_2100_allvars_logpr.nc")

In [43]:
ds_hist

In [4]:
def compute_normalization_stats(target_train, input_train):
    """
    Compute all normalization statistics from training data
    
    Parameters:
    -----------
    target_train : xarray.DataArray - Training data for target (HR)
    input_train : xarray.DataArray - Training data for input (LR)
    
    Returns:
    --------
    stats : dict with structure:
        stats['hr']['global_mean'], stats['hr']['global_std'], etc.
        stats['lr_interp']['global_mean'], stats['lr_interp']['global_std'], etc.
    """
    stats = {
        'hr': {},
        'lr_interp': {}
    }
    
    # HR statistics
    stats['hr']['global_mean'] = float(target_train.mean().values)
    stats['hr']['global_std'] = float(target_train.std().values)
    stats['hr']['global_min'] = float(target_train.min().values)
    stats['hr']['global_max'] = float(target_train.max().values)
    stats['hr']['pixel_mean'] = target_train.mean(dim='time')
    stats['hr']['pixel_std'] = target_train.std(dim='time')
    stats['hr']['pixel_min'] = target_train.min(dim='time')
    stats['hr']['pixel_max'] = target_train.max(dim='time')
    
    # LR_interp statistics
    stats['lr_interp']['global_mean'] = float(input_train.mean().values)
    stats['lr_interp']['global_std'] = float(input_train.std().values)
    stats['lr_interp']['global_min'] = float(input_train.min().values)
    stats['lr_interp']['global_max'] = float(input_train.max().values)
    stats['lr_interp']['pixel_mean'] = input_train.mean(dim='time')
    stats['lr_interp']['pixel_std'] = input_train.std(dim='time')
    stats['lr_interp']['pixel_min'] = input_train.min(dim='time')
    stats['lr_interp']['pixel_max'] = input_train.max(dim='time')
    
    return stats


# Compute normalization statistics for all variables
variables = ['pr_hr', 'tas_hr', 'hurs_hr', 'sfcWind_hr']
train_start = '1850'
train_end = '1980'

print(f"Computing normalization statistics for training period: {train_start}-{train_end}\n")

norm_stats = {}

for var in variables:
    print(f"Processing {var}...")
    
    # Get base variable name (without _hr)
    var_base = var.replace('_hr', '')
    
    # Get corresponding LR variable name
    var_lr = var.replace('_hr', '_lr_interp')
    
    # Extract training data
    target_train = ds_hist[var].sel(time=slice(train_start, train_end))
    input_train = ds_hist[var_lr].sel(time=slice(train_start, train_end))
    
    # Compute statistics
    stats = compute_normalization_stats(target_train, input_train)
    
    # Store with base variable name
    norm_stats[var_base] = stats
    

Computing normalization statistics for training period: 1850-1980

Processing pr_hr...
Processing tas_hr...
Processing hurs_hr...
Processing sfcWind_hr...


In [5]:
for var in norm_stats:
    print(f"\n{var}:")
    for res in norm_stats[var]:
        print(f"  {res}:")
        for stat in norm_stats[var][res]:
            val = norm_stats[var][res][stat]
            if hasattr(val, 'shape'):
                print(f"    {stat}: shape{val.shape}")
            else:
                print(f"    {stat}: {val:.4f}")


pr:
  hr:
    global_mean: 9.0621
    global_std: 2.0507
    global_min: 0.0000
    global_max: 13.4918
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 9.1318
    global_std: 1.9459
    global_min: 0.0000
    global_max: 13.0894
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

tas:
  hr:
    global_mean: 4.9894
    global_std: 17.4904
    global_min: -76.4118
    global_max: 44.7484
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)
  lr_interp:
    global_mean: 4.6774
    global_std: 21.1518
    global_min: -75.3701
    global_max: 44.2455
    pixel_mean: shape(192, 384)
    pixel_std: shape(192, 384)
    pixel_min: shape(192, 384)
    pixel_max: shape(192, 384)

hurs:
  hr:
    global_mean: 81.1097
    global_std: 42.0933
    glo

In [6]:
import pickle

# Save
with open('data/norm_stats.pkl', 'wb') as f:
    pickle.dump(norm_stats, f)


In [7]:
with open('data/norm_stats.pkl', 'rb') as f:
    norm_stats = pickle.load(f)