In [2]:
import sys, os
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Added project root to sys.path: {project_root}")

In [3]:
INPUT_LEN = 12  # Number of input time steps
FORECAST_LEN = 6  # Number of forecast time steps
PATH = "gs://weatherbench2/datasets/era5/1959-2023_01_10-wb13-6h-1440x721_with_derived_variables.zarr"
VARIABLES = [
    'total_precipitation_6hr',
    '2m_temperature', '2m_dewpoint_temperature', 'surface_pressure',
    'mean_sea_level_pressure', '10m_u_component_of_wind', '10m_v_component_of_wind',
    '10m_wind_speed', 'u_component_of_wind', 'v_component_of_wind',
    'total_column_water_vapour', 'integrated_vapor_transport', 'boundary_layer_height',
    'specific_humidity', 'total_cloud_cover',
    'mean_surface_net_short_wave_radiation_flux',
    'mean_surface_latent_heat_flux', 'mean_surface_sensible_heat_flux',
    'snow_depth', 'sea_surface_temperature', 'volumetric_soil_water_layer_1',
    'mean_vertically_integrated_moisture_divergence', 'eddy_kinetic_energy',
    'land_sea_mask'
]
TARGET_VARIABLE = 'total_precipitation_6hr'

In [10]:
from src.GriddedClimateDataset import GriddedClimateDataset
dataset = GriddedClimateDataset(PATH, input_len=INPUT_LEN, forecast_len=FORECAST_LEN, 
                                  variables=VARIABLES, target_variable=TARGET_VARIABLE,time_slice=slice("1959", "2023"))

KeyboardInterrupt: 

In [7]:
ds = dataset.ds

In [9]:
from sklearn.preprocessing import MinMaxScaler
import numpy as np

input_vars = VARIABLES # ← your config['variables']
mins, maxs = [], []
from tqdm import tqdm
for var in tqdm(input_vars, desc="Calculating min/max for scaling"):
    da = ds[var]
    if "level" in da.dims:
        da = da.mean(dim="level")  # Collapse level
    if set(da.dims) == {"latitude", "longitude"}:
        da = da.expand_dims(time=ds.time)  # Broadcast static vars
    mins.append(da.min().compute())
    maxs.append(da.max().compute())

x_data_min = np.array([float(m) for m in mins])
x_data_max = np.array([float(m) for m in maxs])

scaler_x = MinMaxScaler(feature_range=(-1, 1))
scaler_x.min_ = -x_data_min * (2.0 / (x_data_max - x_data_min + 1e-8))
scaler_x.scale_ = 2.0 / (x_data_max - x_data_min + 1e-8)
scaler_x.data_min_ = x_data_min
scaler_x.data_max_ = x_data_max


Calculating min/max for scaling:   0%|          | 0/24 [03:33<?, ?it/s]


KeyboardInterrupt: 

In [None]:
target_var = TARGET_VARIABLE  # or your config['target_variable']
target = ds[target_var]
if "level" in target.dims:
    target = target.mean(dim="level")
target_min = float(target.min().compute())
target_max = float(target.max().compute())

scaler_y = MinMaxScaler(feature_range=(-1, 1))
scaler_y.min_ = -target_min * (2.0 / (target_max - target_min + 1e-8))
scaler_y.scale_ = 2.0 / (target_max - target_min + 1e-8)
scaler_y.data_min_ = np.array([target_min])
scaler_y.data_max_ = np.array([target_max])
