# 🌎 Welcome to the CSE151B Spring 2025 Climate Emulation Competition!

Thank you for participating in this exciting challenge focused on building machine learning models to emulate complex climate systems.  
This notebook is provided as a **starter template** to help you:

- Understand how to load and preprocess the dataset  
- Construct a baseline model  
- Train and evaluate predictions using a PyTorch Lightning pipeline  
- Format your predictions for submission to the leaderboard  

You're encouraged to:
- Build on this structure or replace it entirely
- Try more advanced models and training strategies
- Incorporate your own ideas to push the boundaries of what's possible

If you're interested in developing within a repository structure and/or use helpful tools like configuration management (based on Hydra) and logging (with Weights & Biases), we recommend checking out the following Github repo. Such a structure can be useful when running multiple experiments and trying various research ideas.

👉 [https://github.com/salvaRC/cse151b-spring2025-competition](https://github.com/salvaRC/cse151b-spring2025-competition)

Good luck, have fun, and we hope you learn a lot through this process!


### 📦 Install Required Libraries
We install the necessary Python packages for data loading, deep learning, and visualization.


In [1]:
!pip install xarray zarr dask lightning matplotlib wandb cftime einops --quiet

import os
from datetime import datetime
import numpy as np
import xarray as xr
import dask.array as da
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import lightning.pytorch as pl


### ⚙️ Configuration Setup  
Define all model, data, and training hyperparameters in one place for easy control and reproducibility.

### 📊 Data Configuration

We define the dataset settings used for training and evaluation. This includes:

- **`path`**: Path to the `.zarr` dataset containing monthly climate variables from CMIP6 simulations.
- **`input_vars`**: Climate forcing variables (e.g., CO₂, CH₄) used as model inputs.
- **`output_vars`**: Target variables to predict — surface air temperature (`tas`) and precipitation (`pr`).
- **`target_member_id`**: Ensemble member to use from the simulations (each SSP has 3) for target variables.
- **`train_ssps`**: SSP scenarios used for training (low to high emissions).
- **`test_ssp`**: Scenario held out for evaluation (Must be set to SSP245).
- **`test_months`**: Number of months to include in the test split (Must be set to 120).
- **`batch_size`** and **`num_workers`**: Data loading parameters for PyTorch training.

These settings reflect how the challenge is structured: models must learn from some emission scenarios and generalize to unseen ones.

> ⚠️ **Important:** Do **not modify** the following test settings:
>
> - `test_ssp` must remain **`ssp245`**, which is the held-out evaluation scenario.
> - `test_months` must be **`120`**, corresponding to the last 10 years (monthly resolution) of the scenario.



In [2]:
%pwd

'/home/etflores/teams/kaggle-group-30/CSE151B_Milestone'

In [3]:
#NOTE Change the data directory according to where you have your zarr files stored
config = {
    "data": {
        "path": "../processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr",
        # "path": "/kaggle/input/cse151b-spring2025-competition/processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "target_member_id": 0,
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp": "ssp245",
        "test_months": 360,
        "batch_size": 64,
        "num_workers": 4,
    },
    "model": {
        "type": "unet_cnn",
        # "kernel_size": 3,
        "init_dim": 64,
        # "depth": 4,
        "dropout_rate": 0.1,
    },
    "training": {
        "lr": 1e-3,
        "weight_decay": 1e-4,    # ← add this line
    },
    "trainer": {
        "max_epochs": 100,
        "accelerator": "auto",
        "devices": "auto",
        "precision": 32,
        "deterministic": True,
        "num_sanity_val_steps": 0,
    },
    "seed": 42,
}
pl.seed_everything(config["seed"])  # Set seed for reproducibility

Seed set to 42


42

### 📊 Data Configuration

In [4]:
data_path = "../processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr"

ds = xr.open_zarr(data_path, consolidated=True)
ds

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,107.68 MiB,648.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 107.68 MiB 648.00 kiB Shape (4, 1021, 48, 72) (1, 24, 48, 72) Dask graph 172 chunks in 2 graph layers Data type float64 numpy.ndarray",4  1  72  48  1021,

Unnamed: 0,Array,Chunk
Bytes,107.68 MiB,648.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,31.91 kiB,192 B
Shape,"(4, 1021)","(1, 24)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 31.91 kiB 192 B Shape (4, 1021) (1, 24) Dask graph 172 chunks in 2 graph layers Data type float64 numpy.ndarray",1021  4,

Unnamed: 0,Array,Chunk
Bytes,31.91 kiB,192 B
Shape,"(4, 1021)","(1, 24)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,31.91 kiB,192 B
Shape,"(4, 1021)","(1, 24)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 31.91 kiB 192 B Shape (4, 1021) (1, 24) Dask graph 172 chunks in 2 graph layers Data type float64 numpy.ndarray",1021  4,

Unnamed: 0,Array,Chunk
Bytes,31.91 kiB,192 B
Shape,"(4, 1021)","(1, 24)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,107.68 MiB,648.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 107.68 MiB 648.00 kiB Shape (4, 1021, 48, 72) (1, 24, 48, 72) Dask graph 172 chunks in 2 graph layers Data type float64 numpy.ndarray",4  1  72  48  1021,

Unnamed: 0,Array,Chunk
Bytes,107.68 MiB,648.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,161.53 MiB,324.00 kiB
Shape,"(4, 1021, 3, 48, 72)","(1, 24, 1, 48, 72)"
Dask graph,516 chunks in 2 graph layers,516 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 161.53 MiB 324.00 kiB Shape (4, 1021, 3, 48, 72) (1, 24, 1, 48, 72) Dask graph 516 chunks in 2 graph layers Data type float32 numpy.ndarray",1021  4  72  48  3,

Unnamed: 0,Array,Chunk
Bytes,161.53 MiB,324.00 kiB
Shape,"(4, 1021, 3, 48, 72)","(1, 24, 1, 48, 72)"
Dask graph,516 chunks in 2 graph layers,516 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,53.84 MiB,324.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 53.84 MiB 324.00 kiB Shape (4, 1021, 48, 72) (1, 24, 48, 72) Dask graph 172 chunks in 2 graph layers Data type float32 numpy.ndarray",4  1  72  48  1021,

Unnamed: 0,Array,Chunk
Bytes,53.84 MiB,324.00 kiB
Shape,"(4, 1021, 48, 72)","(1, 24, 48, 72)"
Dask graph,172 chunks in 2 graph layers,172 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,161.53 MiB,324.00 kiB
Shape,"(4, 1021, 3, 48, 72)","(1, 24, 1, 48, 72)"
Dask graph,516 chunks in 2 graph layers,516 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 161.53 MiB 324.00 kiB Shape (4, 1021, 3, 48, 72) (1, 24, 1, 48, 72) Dask graph 516 chunks in 2 graph layers Data type float32 numpy.ndarray",1021  4  72  48  3,

Unnamed: 0,Array,Chunk
Bytes,161.53 MiB,324.00 kiB
Shape,"(4, 1021, 3, 48, 72)","(1, 24, 1, 48, 72)"
Dask graph,516 chunks in 2 graph layers,516 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


### Data Exploration

**Spatial Dimensions, Size of Train/Validate/Test Data**

We know there are 1021 time steps per SSP (1021 months).

In [5]:
ssp_train = ['ssp126', 'ssp585']
ssp_val = 'ssp370'
ssp_test = 'ssp245'

n_members = ds.sizes['member_id']
n_time_total = ds.sizes['time']
# print(n_time_total)
lat = ds.sizes['latitude']
lon = ds.sizes['longitude']

# TRAINING = full ssp126 + ssp585 + first 901 months of ssp370
# (since last 10 years / 120 months of SSP370 is the validation set)
n_train_ssp126 = n_time_total * n_members
n_train_ssp585 = n_time_total * n_members
n_train_ssp370 = (n_time_total - 120) * n_members  # time 0 to 900 inclusive
n_train_samples = n_train_ssp126 + n_train_ssp585 + n_train_ssp370

# VALIDATION = last 120 months of ssp370, member_id = 0
n_val_samples = 120

# TEST = last 360 months of ssp245, member_id = 0
n_test_samples = 360

print(f"Spatial Dimensions per sample: (lat, lon) =, {(lat, lon)} --> 3456 grid points")
print("Training samples:", n_train_samples)
print("Validation samples:", n_val_samples)
print("Test samples:", n_test_samples)
n_train_samples + n_val_samples + n_test_samples

Spatial Dimensions per sample: (lat, lon) =, (48, 72) --> 3456 grid points
Training samples: 8829
Validation samples: 120
Test samples: 360


9309

In [6]:
lat_dim = ds.sizes["latitude"]
lon_dim = ds.sizes["longitude"]
n_spatial = lat_dim * lon_dim
print(f"Spatial dimensions: {lat_dim} x {lon_dim} = {n_spatial} grid points")

Spatial dimensions: 48 x 72 = 3456 grid points


**1b)** The dataset consists of 8,829 training samples drawn from ssp126, ssp370 (first 901 months), and ssp585 across all 3 ensemble members. Validation includes 120 samples from the last 10 years (120 months) of ssp370 using only member_id = 0. The test set contains 360 samples from the last 360 months of ssp245, also using only member_id = 0. We only use member_id = 0 because the competition evaluation is performed on a single ensemble realization (member_id = 0), and using multiple ensemble members for validation or testing would misrepresent model performance (data leakage from internal climate variability across ensemble runs).

There are 2 spatial dimensions (latitude and longitude) forming a 48 × 72 global grid. Each sample covers the full grid, so each sample contains 3,456 spatial points.

**Distribution of Target Variables, Input Data**

In [7]:
ds["tas"]
ds["tas"].sel(ssp="ssp245", member_id=0)


Unnamed: 0,Array,Chunk
Bytes,13.46 MiB,324.00 kiB
Shape,"(1021, 48, 72)","(24, 48, 72)"
Dask graph,43 chunks in 3 graph layers,43 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 13.46 MiB 324.00 kiB Shape (1021, 48, 72) (24, 48, 72) Dask graph 43 chunks in 3 graph layers Data type float32 numpy.ndarray",72  48  1021,

Unnamed: 0,Array,Chunk
Bytes,13.46 MiB,324.00 kiB
Shape,"(1021, 48, 72)","(24, 48, 72)"
Dask graph,43 chunks in 3 graph layers,43 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [8]:
ds["tas"].sel(ssp="ssp245", member_id=0).isel(time=0)


Unnamed: 0,Array,Chunk
Bytes,13.50 kiB,13.50 kiB
Shape,"(48, 72)","(48, 72)"
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 13.50 kiB 13.50 kiB Shape (48, 72) (48, 72) Dask graph 1 chunks in 4 graph layers Data type float32 numpy.ndarray",72  48,

Unnamed: 0,Array,Chunk
Bytes,13.50 kiB,13.50 kiB
Shape,"(48, 72)","(48, 72)"
Dask graph,1 chunks in 4 graph layers,1 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 27.00 kiB 27.00 kiB Shape (72, 48) (72, 48) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",48  72,

Unnamed: 0,Array,Chunk
Bytes,27.00 kiB,27.00 kiB
Shape,"(72, 48)","(72, 48)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [9]:
ssp_train = ["ssp126", "ssp370", "ssp585"]
input_vars = ["CO2", "CH4", "SO2", "BC", "rsdt"]
output_vars = ["tas", "pr"]

def clean_flat(arr):
    return arr.compute().values.flatten()[~np.isnan(arr.compute().values.flatten())]

def print_stats(values):
    zero_count = (values == 0).sum()
    print(
        f"count = {len(values):>6} | "
        f"mean = {np.mean(values)} | "
        f"std = {np.std(values)} | "
        f"median = {np.median(values)} | "
        f"min = {np.min(values)} | "
        f"max = {np.max(values)} | "
        f"5th–95th pct = {np.percentile(values, 5)} – {np.percentile(values, 95)} | "
        f"zeros = {zero_count} ({100 * zero_count / len(values)}%)"
    )

# input data
for var in input_vars:
    print(f"\n--------- INPUT VARIABLE: {var} ---------")
    for ssp in ssp_train:
        print(f"SSP: {ssp}")
        arr = ds[var].sel(ssp=ssp)
        if "member_id" in arr.dims:
            arr = arr.sel(member_id=0)
        values = clean_flat(arr)
        
        print_stats(values)

        plt.hist(values, bins=100)
        plt.title(f"{var} Distribution – {ssp}")
        plt.xlabel(var)
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"figures/{var}_{ssp}_dist.png", dpi=300)
        plt.close()
        # plt.show()

# target vars
for var in output_vars:
    print(f"\n--------- TARGET VARIABLE: {var} ---------")
    for ssp in ssp_train:
        print(f"SSP: {ssp}")
        arr = ds[var].sel(ssp=ssp, member_id=0)
        values = clean_flat(arr)

        print_stats(values)

        plt.hist(values, bins=100)
        plt.title(f"{var} Distribution – {ssp}")
        plt.xlabel(var)
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"figures/{var}_{ssp}_dist.png", dpi=300)
        plt.close()
        # plt.show()


--------- INPUT VARIABLE: CO2 ---------
SSP: ssp126
count =   1021 | mean = 2532.318713977056 | std = 401.69992064797714 | median = 2731.6626772312566 | min = 1536.0722224547292 | max = 2891.7034108052826 | 5th–95th pct = 1689.6271652350188 – 2889.948210871908 | zeros = 0 (0.0%)
SSP: ssp370
count =   1021 | mean = 3874.4328704822187 | std = 1528.9858993886194 | median = 3751.897249353455 | min = 1536.0722224547292 | max = 6763.858968521363 | 5th–95th pct = 1699.3051121573612 – 6426.770585143383 | zeros = 0 (0.0%)
SSP: ssp585
count =   1021 | mean = 4653.006907184912 | std = 2343.882160473454 | median = 4182.662019889618 | min = 1536.0722224547292 | max = 9362.59353129993 | 5th–95th pct = 1696.662145353775 – 8816.81862612959 | zeros = 0 (0.0%)

--------- INPUT VARIABLE: CH4 ---------
SSP: ssp126
count =   1021 | mean = 0.1990533398174867 | std = 0.07149415080379436 | median = 0.18533224375025475 | min = 0.11364517919757215 | max = 0.37373672410518455 | 5th–95th pct = 0.1164790331243431

In [10]:
def plot_distribution_across_ssps(ds, var, time_idx, ssp_list=["ssp126", "ssp370", "ssp585"], member_id=0, bins=100):
    os.makedirs("figures", exist_ok=True)
    plt.figure(figsize=(8, 8))

    for ssp in ssp_list:
        arr = ds[var].sel(ssp=ssp, member_id=member_id).isel(time=time_idx)
        values = arr.compute().values.flatten()
        values = values[~np.isnan(values)]

        plt.hist(values, bins=bins, alpha=0.3, label=ssp)

    try:
        time_str = str(ds["time"].isel(time=time_idx).values)[:10].replace("-", "")
    except:
        time_str = f"t{time_idx}"

    plt.title(f"Distribution of {var} at {time_str} Across SSPs")
    plt.xlabel(f"{var} value")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    # Save instead of showing
    filename = f"figures/{var}_ssp_comparison_{time_str}.png"
    plt.savefig(filename, dpi=300)
    plt.close()
    print(f"Saved: {filename}")
    # plt.show()


In [11]:
plot_distribution_across_ssps(ds, var="tas", time_idx=0)
plot_distribution_across_ssps(ds, var="pr", time_idx=0)
plot_distribution_across_ssps(ds, var="tas", time_idx=1020)
plot_distribution_across_ssps(ds, var="pr", time_idx=1020)

Saved: figures/tas_ssp_comparison_20150115.png
Saved: figures/pr_ssp_comparison_20150115.png
Saved: figures/tas_ssp_comparison_21000115.png
Saved: figures/pr_ssp_comparison_21000115.png


In [12]:
def plot_distribution_over_time(ds, ssp, var, time_indices, member_id=0, bins=100):
    os.makedirs("figures", exist_ok=True)
    plt.figure(figsize=(8, 5))

    label_parts = []

    for time_idx in time_indices:
        arr = ds[var].sel(ssp=ssp, member_id=member_id).isel(time=time_idx)
        values = arr.compute().values.flatten()
        values = values[~np.isnan(values)]

        try:
            time_label = str(ds["time"].isel(time=time_idx).values)[:10]
            label_parts.append(time_label.replace("-", ""))
        except:
            time_label = f"t{time_idx}"
            label_parts.append(time_label)

        plt.hist(values, bins=bins, alpha=0.5, label=time_label)

    label_str = "_".join(label_parts)
    plt.title(f"{var.upper()} Distribution Over Time – {ssp}")
    plt.xlabel(f"{var} value")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    filename = f"figures/{var}_{ssp}_overtime_{label_str}.png"
    plt.savefig(filename, dpi=300)
    plt.close()
    print(f"Saved: {filename}")
    # plt.show()


In [13]:
plot_distribution_over_time(ds, "ssp126", "tas", [0, 1020])
plot_distribution_over_time(ds, "ssp126", "pr", [0, 1020])
plot_distribution_over_time(ds, "ssp585", "tas", [0, 1020])
plot_distribution_over_time(ds, "ssp585", "pr", [0, 1020])

Saved: figures/tas_ssp126_overtime_20150115_21000115.png
Saved: figures/pr_ssp126_overtime_20150115_21000115.png
Saved: figures/tas_ssp585_overtime_20150115_21000115.png
Saved: figures/pr_ssp585_overtime_20150115_21000115.png


### 🔧 Spatial Weighting Utility Function

This cell sets up utility functions for reproducibility and spatial weighting:

- **`get_lat_weights(latitude_values)`**: Computes cosine-based area weights for each latitude, accounting for the Earth's curvature. This is critical for evaluating global climate metrics fairly — grid cells near the equator represent larger surface areas than those near the poles.


In [14]:
def get_lat_weights(latitude_values):
    lat_rad = np.deg2rad(latitude_values)
    weights = np.cos(lat_rad)
    return weights / np.mean(weights)

### 🧠 SimpleCNN: A Residual Convolutional Baseline

This is a lightweight baseline model designed to capture spatial patterns in global climate data using convolutional layers.

- The architecture starts with a **convolution + batch norm + ReLU** block to process the input channels.
- It then applies a series of **residual blocks** to extract increasingly abstract spatial features. These help preserve gradient flow during training.
- Finally, a few convolutional layers reduce the feature maps down to the desired number of output channels (`tas` and `pr`).

This model only serves as a **simple baseline for climate emulation**. 

We encourage you to build and experiment with your own models and ideas.


In [15]:
##### UNet CNN
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=kernel_size // 2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNetCNN(nn.Module):
    def __init__(self, n_input_channels, init_dim=64, dropout_rate=0.2):
        super().__init__()

        # Encoder
        self.enc1 = UNetBlock(n_input_channels, init_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = UNetBlock(init_dim, init_dim * 2)
        self.pool2 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = UNetBlock(init_dim * 2, init_dim * 4)

        # Decoder
        self.up2 = nn.ConvTranspose2d(init_dim * 4, init_dim * 2, kernel_size=2, stride=2)
        self.dec2 = UNetBlock(init_dim * 4, init_dim * 2)
        self.up1 = nn.ConvTranspose2d(init_dim * 2, init_dim, kernel_size=2, stride=2)
        self.dec1 = UNetBlock(init_dim * 2, init_dim)

        # Dropout
        self.dropout = nn.Dropout2d(dropout_rate)

        # Dual output heads: one for tas, one for pr
        self.tas_head = nn.Conv2d(init_dim, 1, kernel_size=1)
        self.pr_head  = nn.Conv2d(init_dim, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        bottleneck = self.bottleneck(self.pool2(enc2))

        # Decoder
        dec2 = self.up2(bottleneck)
        dec2 = self.dec2(torch.cat([dec2, enc2], dim=1))
        dec1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat([dec1, enc1], dim=1))

        dec1 = self.dropout(dec1)

        # Forward through both heads
        tas_out = self.tas_head(dec1)
        pr_out  = self.pr_head(dec1)

        return torch.cat([tas_out, pr_out], dim=1)  # Output: [B, 2, 48, 72]


### 📐 Normalizer: Z-Score Scaling for Climate Inputs & Outputs

This class handles **Z-score normalization**, a crucial preprocessing step for stable and efficient neural network training:

- **`set_input_statistics(mean, std)` / `set_output_statistics(...)`**: Store the mean and standard deviation computed from the training data for later use.
- **`normalize(data, data_type)`**: Standardizes the data using `(x - mean) / std`. This is applied separately to inputs and outputs.
- **`inverse_transform_output(data)`**: Converts model predictions back to the original physical units (e.g., Kelvin for temperature, mm/day for precipitation).

Normalizing the data ensures the model sees inputs with similar dynamic ranges and avoids biases caused by different variable scales.


In [16]:
class Normalizer:
    def __init__(self):
        self.mean_in, self.std_in = None, None
        self.mean_out, self.std_out = None, None

    def set_input_statistics(self, mean, std):
        self.mean_in = mean
        self.std_in = std

    def set_output_statistics(self, mean, std):
        self.mean_out = mean
        self.std_out = std

    def normalize(self, data, data_type):
        if data_type == "input":
            return (data - self.mean_in) / self.std_in
        elif data_type == "output":
            return (data - self.mean_out) / self.std_out

    def inverse_transform_output(self, data):
        return data * self.std_out + self.mean_out


### 🌍 Data Module: Loading, Normalization, and Splitting

This section handles the entire data pipeline, from loading the `.zarr` dataset to preparing PyTorch-ready DataLoaders.

#### `ClimateDataset`
- A simple PyTorch `Dataset` wrapper that preloads the entire (normalized) dataset into memory using Dask.
- Converts the data to PyTorch tensors and handles any `NaN` checks up front.

#### `ClimateDataModule`
A PyTorch Lightning `DataModule` that handles:
- ✅ **Loading data** from different SSP scenarios and ensemble members
- ✅ **Broadcasting non-spatial inputs** (like CO₂) to match spatial grid size
- ✅ **Normalization** using mean/std computed from training data only
- ✅ **Splitting** into training, validation, and test sets:
  - Training: All months from selected SSPs (except last 10 years of SSP370)
  - Validation: Last 10 years (120 months) of SSP370
  - Test: Last 10 years of SSP245 (unseen scenario)
- ✅ **Batching** and parallelized data loading via PyTorch `DataLoader`s
- ✅ **Latitude-based area weighting** for fair climate metric evaluation
- Shape of the inputs are Batch_Size X 5 (num_input_variables) X 48 X 72
- Shape of ouputputs are Batch_Size X 2 (num_output_variables) X 48 X 72

> ℹ️ **Note:** You likely won’t need to modify this class but feel free to make modifications if you want to inlcude different ensemble mebers to feed more data to your models


In [17]:
class ClimateDataset(Dataset):
    def __init__(self, inputs_dask, outputs_dask, output_is_normalized=True):
        self.size = inputs_dask.shape[0]
        print(f"Creating dataset with {self.size} samples...")

        inputs_np = inputs_dask.compute()
        outputs_np = outputs_dask.compute()

        self.inputs = torch.from_numpy(inputs_np).float()
        self.outputs = torch.from_numpy(outputs_np).float()

        if torch.isnan(self.inputs).any() or torch.isnan(self.outputs).any():
            raise ValueError("NaNs found in dataset")

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]


class ClimateDataModule(pl.LightningDataModule):
    def __init__(
        self,
        path,
        input_vars,
        output_vars,
        train_ssps,
        test_ssp,
        target_member_id,
        val_split=0.1,
        test_months=360,
        batch_size=32,
        num_workers=0,
        seed=42,
    ):
        super().__init__()
        self.path = path
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.train_ssps = train_ssps
        self.test_ssp = test_ssp
        self.target_member_id = target_member_id
        self.val_split = val_split
        self.test_months = test_months
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.normalizer = Normalizer()

    def prepare_data(self):
        assert os.path.exists(self.path), f"Data path not found: {self.path}"

    def setup(self, stage=None):
        ds = xr.open_zarr(self.path, consolidated=False, chunks={"time": 24})
        spatial_template = ds["rsdt"].isel(time=0, ssp=0, drop=True)

        def load_ssp(ssp):
            input_dask, output_dask = [], []
            for var in self.input_vars:
                da_var = ds[var].sel(ssp=ssp)
                if "latitude" in da_var.dims:
                    da_var = da_var.rename({"latitude": "y", "longitude": "x"})
                if "member_id" in da_var.dims:
                    da_var = da_var.sel(member_id=self.target_member_id)
                if set(da_var.dims) == {"time"}:
                    da_var = da_var.broadcast_like(spatial_template).transpose("time", "y", "x")
                input_dask.append(da_var.data)

            for var in self.output_vars:
                da_out = ds[var].sel(ssp=ssp, member_id=self.target_member_id)
                if "latitude" in da_out.dims:
                    da_out = da_out.rename({"latitude": "y", "longitude": "x"})
                output_dask.append(da_out.data)

            return da.stack(input_dask, axis=1), da.stack(output_dask, axis=1)

        train_input, train_output, val_input, val_output = [], [], None, None

        for ssp in self.train_ssps:
            x, y = load_ssp(ssp)
            if ssp == "ssp370":
                val_input = x[-self.test_months:]
                val_output = y[-self.test_months:]
                train_input.append(x[:-self.test_months])
                train_output.append(y[:-self.test_months])
            else:
                train_input.append(x)
                train_output.append(y)

        train_input = da.concatenate(train_input, axis=0)
        train_output = da.concatenate(train_output, axis=0)

        self.normalizer.set_input_statistics(
            mean=da.nanmean(train_input, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_input, axis=(0, 2, 3), keepdims=True).compute(),
        )
        self.normalizer.set_output_statistics(
            mean=da.nanmean(train_output, axis=(0, 2, 3), keepdims=True).compute(),
            std=da.nanstd(train_output, axis=(0, 2, 3), keepdims=True).compute(),
        )

        train_input_norm = self.normalizer.normalize(train_input, "input")
        train_output_norm = self.normalizer.normalize(train_output, "output")
        val_input_norm = self.normalizer.normalize(val_input, "input")
        val_output_norm = self.normalizer.normalize(val_output, "output")

        test_input, test_output = load_ssp(self.test_ssp)
        test_input = test_input[-self.test_months:]
        test_output = test_output[-self.test_months:]
        test_input_norm = self.normalizer.normalize(test_input, "input")

        self.train_dataset = ClimateDataset(train_input_norm, train_output_norm)
        self.val_dataset = ClimateDataset(val_input_norm, val_output_norm)
        self.test_dataset = ClimateDataset(test_input_norm, test_output, output_is_normalized=False)

        self.lat = spatial_template.y.values
        self.lon = spatial_template.x.values
        self.area_weights = xr.DataArray(get_lat_weights(self.lat), dims=["y"], coords={"y": self.lat})

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False,
                          num_workers=self.num_workers, pin_memory=True)

    def get_lat_weights(self):
        return self.area_weights

    def get_coords(self):
        return self.lat, self.lon

### ⚡ ClimateEmulationModule: Lightning Wrapper for Climate Model Emulation

This is the core model wrapper built with **PyTorch Lightning**, which organizes the training, validation, and testing logic for the climate emulation task. Lightning abstracts away much of the boilerplate code in PyTorch-based deep learning workflows, making it easier to scale models.

#### ✅ Key Features

- **`training_step` / `validation_step` / `test_step`**: Standard Lightning hooks for computing loss and predictions at each stage. The loss used is **Mean Squared Error (MSE)**.

- **Normalization-aware outputs**:
  - During validation and testing, predictions and targets are denormalized before evaluation using stored mean/std statistics.
  - This ensures evaluation is done in real-world units (Kelvin and mm/day).

- **Metric Evaluation** via `_evaluate()`:
  For each variable (`tas`, `pr`), it calculates:
  - **Monthly Area-Weighted RMSE**
  - **Time-Mean RMSE** (RMSE on 10-year average's)
  - **Time-Stddev MAE** (MAE on 10-year standard deviation; a measure of temporal variability)
    
  These metrics reflect the competition's evaluation criteria and are logged and printed.

- **Kaggle Submission Writer**:
  After testing, predictions are saved to a `.csv` file in the required Kaggle format via `_save_submission()`.

- **Saving Predictions for Visualization**:
  - Validation predictions are saved tao `val_preds.npy` and `val_trues.npy`
  - These can be loaded later for visual inspection of the model's performance.

 🔧 **Feel free to modify any part of this module** (loss functions, evaluation, training logic) to better suit your model or training pipeline / Use pure PyTorch etc.

⚠️ The **final submission `.csv` file must strictly follow the format and naming convention used in `_save_submission()`**, as these `ID`s are used to match predictions to the hidden test set during evaluation.



In [18]:
import pandas as pd

class ClimateEmulationModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-4, weight_decay=0.0):
        super().__init__()
        self.model = model
        # Save both lr & weight_decay into self.hparams
        self.save_hyperparameters("learning_rate", "weight_decay", ignore=["model"])
        # self.criterion = nn.MSELoss()
        self.criterion_tas = nn.MSELoss()
        self.criterion_pr  = nn.MSELoss()
        self.normalizer = None
        self.val_preds, self.val_targets = [], []
        self.test_preds, self.test_targets = [], []

    def forward(self, x):
        return self.model(x)

    def on_fit_start(self):
        self.normalizer = self.trainer.datamodule.normalizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # loss = self.criterion(y_hat, y)
        tas_hat, pr_hat = y_hat[:, 0:1], y_hat[:, 1:2]
        tas_true, pr_true = y[:, 0:1], y[:, 1:2]
        loss_tas = self.criterion_tas(tas_hat, tas_true)
        loss_pr  = self.criterion_pr(pr_hat, pr_true)
        loss = loss_tas + loss_pr
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        # loss = self.criterion(y_hat, y)
        tas_hat, pr_hat = y_hat[:, 0:1], y_hat[:, 1:2]
        tas_true, pr_true = y[:, 0:1], y[:, 1:2]
        loss_tas = self.criterion_tas(tas_hat, tas_true)
        loss_pr  = self.criterion_pr(pr_hat, pr_true)
        loss = loss_tas + loss_pr
        self.log("val/loss", loss)

        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np     = self.normalizer.inverse_transform_output(y.detach().cpu().numpy())
        self.val_preds.append(y_hat_np)
        self.val_targets.append(y_np)
        return loss

    def on_validation_epoch_end(self):
        preds = np.concatenate(self.val_preds, axis=0)
        trues = np.concatenate(self.val_targets, axis=0)
        self._evaluate(preds, trues, phase="val")
        np.save("val_preds.npy", preds)
        np.save("val_trues.npy", trues)
        self.val_preds.clear()
        self.val_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        y_hat_np = self.normalizer.inverse_transform_output(y_hat.detach().cpu().numpy())
        y_np     = y.detach().cpu().numpy()
        self.test_preds.append(y_hat_np)
        self.test_targets.append(y_np)

    def on_test_epoch_end(self):
        preds = np.concatenate(self.test_preds, axis=0)
        trues = np.concatenate(self.test_targets, axis=0)
        self._evaluate(preds, trues, phase="test")
        self._save_submission(preds)
        self.test_preds.clear()
        self.test_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.trainer.max_epochs,
            eta_min=self.hparams.learning_rate * 0.01,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "monitor": "val/tas/rmse",
            }
        }

    def _evaluate(self, preds, trues, phase="val"):
        datamodule = self.trainer.datamodule
        area_weights = datamodule.get_lat_weights()
        lat, lon = datamodule.get_coords()
        time = np.arange(preds.shape[0])
        output_vars = datamodule.output_vars

        for i, var in enumerate(output_vars):
            p = preds[:, i]
            t = trues[:, i]
            p_xr = xr.DataArray(p, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
            t_xr = xr.DataArray(t, dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

            # RMSE
            rmse = np.sqrt(((p_xr - t_xr) ** 2).weighted(area_weights).mean(("time", "y", "x")).item())
            # RMSE of time-mean
            mean_rmse = np.sqrt(((p_xr.mean("time") - t_xr.mean("time")) ** 2).weighted(area_weights).mean(("y", "x")).item())
            # MAE of time-stddev
            std_mae = np.abs(p_xr.std("time") - t_xr.std("time")).weighted(area_weights).mean(("y", "x")).item()

            print(f"[{phase.upper()}] {var}: RMSE={rmse:.4f}, Time-Mean RMSE={mean_rmse:.4f}, Time-Stddev MAE={std_mae:.4f}")
            self.log_dict({
                f"{phase}/{var}/rmse": rmse,
                f"{phase}/{var}/time_mean_rmse": mean_rmse,
                f"{phase}/{var}/time_std_mae": std_mae,
            })

    def _save_submission(self, predictions):
        datamodule = self.trainer.datamodule
        lat, lon = datamodule.get_coords()
        output_vars = datamodule.output_vars
        time = np.arange(predictions.shape[0])

        rows = []
        for t_idx, t in enumerate(time):
            for var_idx, var in enumerate(output_vars):
                for y_idx, y in enumerate(lat):
                    for x_idx, x in enumerate(lon):
                        row_id = f"t{t_idx:03d}_{var}_{y:.2f}_{x:.2f}"
                        pred = predictions[t_idx, var_idx, y_idx, x_idx]
                        rows.append({"ID": row_id, "Prediction": pred})

        df = pd.DataFrame(rows)
        os.makedirs("submissions", exist_ok=True)
        filepath = f"submissions/kaggle_submission_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        df.to_csv(filepath, index=False)
        print(f"✅ Submission saved to: {filepath}")

### ⚡ Training & Evaluation with PyTorch Lightning

This block sets up and runs the training and testing pipeline using **PyTorch Lightning’s `Trainer`**, which abstracts away much of the boilerplate in deep learning workflows.

- **Modular Setup**:
  - `datamodule`: Handles loading, normalization, and batching of climate data.
  - `model`: A convolutional neural network that maps climate forcings to predicted outputs.
  - `lightning_module`: Wraps the model with training/validation/test logic and metric evaluation.

- **Trainer Flexibility**:
  The `Trainer` accepts a wide range of configuration options from `config["trainer"]`, including:
  - Number of epochs
  - Precision (e.g., 16-bit or 32-bit)
  - Device configuration (CPU, GPU, or TPU)
  - Determinism, logging, callbacks, and more

In [19]:
datamodule = ClimateDataModule(**config["data"])

model = UNetCNN(
    n_input_channels=len(config["data"]["input_vars"]),
    init_dim=config["model"]["init_dim"],
    dropout_rate=config["model"]["dropout_rate"]
)

lightning_module = ClimateEmulationModule(
    model,
    learning_rate=config["training"]["lr"],
    weight_decay=config["training"]["weight_decay"],
)


In [20]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

# Define callbacks
early_stop = EarlyStopping(
    monitor="val/tas/rmse",
    patience=20,
    mode="min"
)
checkpoint = ModelCheckpoint(
    monitor="val/tas/rmse",
    mode="min",
    save_top_k=1,
    filename="unet-best-{epoch:02d}-{val/tas/rmse:.4f}"
)

# Updated Trainer with callbacks
trainer = pl.Trainer(
    **config["trainer"],
    callbacks=[early_stop, checkpoint]
)

trainer.fit(lightning_module, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
2025-05-22 10:44:31.798631: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747910671.815810   15732 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747910671.821531   15732 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-22 10:44:31.839680: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Creating dataset with 2703 samples...
Creating dataset with 360 samples...
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type    | Params | Mode 
--------------------------------------------------
0 | model         | UNetCNN | 1.9 M  | train
1 | criterion_tas | MSELoss | 0      | train
2 | criterion_pr  | MSELoss | 0      | train
--------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.467     Total estimated model params size (MB)
50        Modules in train mode
0         Modules in eval mode
/home/etflores/.local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (43) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.9082, Time-Mean RMSE=3.4892, Time-Stddev MAE=1.2025
[VAL] pr: RMSE=2.8018, Time-Mean RMSE=0.8934, Time-Stddev MAE=1.6556


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.2986, Time-Mean RMSE=2.7052, Time-Stddev MAE=1.2526
[VAL] pr: RMSE=2.6824, Time-Mean RMSE=0.7685, Time-Stddev MAE=1.4587


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.7273, Time-Mean RMSE=1.8557, Time-Stddev MAE=1.1723
[VAL] pr: RMSE=2.6067, Time-Mean RMSE=0.5999, Time-Stddev MAE=1.4597


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.6402, Time-Mean RMSE=1.6375, Time-Stddev MAE=1.0878
[VAL] pr: RMSE=2.6212, Time-Mean RMSE=0.6079, Time-Stddev MAE=1.4173


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.6265, Time-Mean RMSE=1.9915, Time-Stddev MAE=1.1629
[VAL] pr: RMSE=2.5206, Time-Mean RMSE=0.7579, Time-Stddev MAE=1.2521


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.3923, Time-Mean RMSE=2.0672, Time-Stddev MAE=1.0134
[VAL] pr: RMSE=2.2127, Time-Mean RMSE=0.5552, Time-Stddev MAE=1.0037


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.2943, Time-Mean RMSE=1.9825, Time-Stddev MAE=1.0434
[VAL] pr: RMSE=2.1666, Time-Mean RMSE=0.5565, Time-Stddev MAE=0.8906


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.5606, Time-Mean RMSE=1.4691, Time-Stddev MAE=0.6186
[VAL] pr: RMSE=2.0556, Time-Mean RMSE=0.4131, Time-Stddev MAE=0.8605


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.5265, Time-Mean RMSE=1.3306, Time-Stddev MAE=0.7768
[VAL] pr: RMSE=2.0649, Time-Mean RMSE=0.4835, Time-Stddev MAE=0.8979


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.4556, Time-Mean RMSE=1.3933, Time-Stddev MAE=0.6489
[VAL] pr: RMSE=2.0398, Time-Mean RMSE=0.3728, Time-Stddev MAE=0.8239


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.3921, Time-Mean RMSE=1.3850, Time-Stddev MAE=0.5872
[VAL] pr: RMSE=2.0039, Time-Mean RMSE=0.3765, Time-Stddev MAE=0.8060


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0270, Time-Mean RMSE=0.9551, Time-Stddev MAE=0.4856
[VAL] pr: RMSE=2.0080, Time-Mean RMSE=0.4355, Time-Stddev MAE=0.8277


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2397, Time-Mean RMSE=1.3438, Time-Stddev MAE=0.4911
[VAL] pr: RMSE=1.9817, Time-Mean RMSE=0.3357, Time-Stddev MAE=0.8200


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0535, Time-Mean RMSE=1.0637, Time-Stddev MAE=0.6612
[VAL] pr: RMSE=1.9758, Time-Mean RMSE=0.3437, Time-Stddev MAE=0.8140


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0534, Time-Mean RMSE=1.1214, Time-Stddev MAE=0.5560
[VAL] pr: RMSE=1.9739, Time-Mean RMSE=0.3286, Time-Stddev MAE=0.7597


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.4635, Time-Mean RMSE=1.3544, Time-Stddev MAE=0.5945
[VAL] pr: RMSE=2.1159, Time-Mean RMSE=0.4716, Time-Stddev MAE=0.8810


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2355, Time-Mean RMSE=1.3955, Time-Stddev MAE=0.5744
[VAL] pr: RMSE=1.9805, Time-Mean RMSE=0.3385, Time-Stddev MAE=0.7711


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0790, Time-Mean RMSE=1.1880, Time-Stddev MAE=0.5531
[VAL] pr: RMSE=1.9612, Time-Mean RMSE=0.3114, Time-Stddev MAE=0.8149


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8113, Time-Mean RMSE=0.9013, Time-Stddev MAE=0.4200
[VAL] pr: RMSE=1.9679, Time-Mean RMSE=0.3395, Time-Stddev MAE=0.7636


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8371, Time-Mean RMSE=0.9278, Time-Stddev MAE=0.5288
[VAL] pr: RMSE=1.9729, Time-Mean RMSE=0.3554, Time-Stddev MAE=0.8184


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9465, Time-Mean RMSE=1.1256, Time-Stddev MAE=0.4363
[VAL] pr: RMSE=1.9899, Time-Mean RMSE=0.3566, Time-Stddev MAE=0.8149


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9616, Time-Mean RMSE=1.1121, Time-Stddev MAE=0.4409
[VAL] pr: RMSE=1.9621, Time-Mean RMSE=0.3213, Time-Stddev MAE=0.8060


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7865, Time-Mean RMSE=0.9933, Time-Stddev MAE=0.4107
[VAL] pr: RMSE=1.9549, Time-Mean RMSE=0.3018, Time-Stddev MAE=0.7864


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7670, Time-Mean RMSE=0.9199, Time-Stddev MAE=0.4052
[VAL] pr: RMSE=1.9688, Time-Mean RMSE=0.3339, Time-Stddev MAE=0.8258


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8476, Time-Mean RMSE=1.0797, Time-Stddev MAE=0.4387
[VAL] pr: RMSE=1.9842, Time-Mean RMSE=0.3945, Time-Stddev MAE=0.7948


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8147, Time-Mean RMSE=1.0255, Time-Stddev MAE=0.4315
[VAL] pr: RMSE=1.9675, Time-Mean RMSE=0.3125, Time-Stddev MAE=0.7639


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7092, Time-Mean RMSE=0.7235, Time-Stddev MAE=0.4368
[VAL] pr: RMSE=1.9654, Time-Mean RMSE=0.3275, Time-Stddev MAE=0.8102


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6972, Time-Mean RMSE=0.8191, Time-Stddev MAE=0.3855
[VAL] pr: RMSE=1.9578, Time-Mean RMSE=0.2996, Time-Stddev MAE=0.7738


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7383, Time-Mean RMSE=0.8564, Time-Stddev MAE=0.4580
[VAL] pr: RMSE=1.9544, Time-Mean RMSE=0.2969, Time-Stddev MAE=0.7572


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7781, Time-Mean RMSE=0.9926, Time-Stddev MAE=0.4048
[VAL] pr: RMSE=1.9604, Time-Mean RMSE=0.3213, Time-Stddev MAE=0.8185


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6878, Time-Mean RMSE=0.8373, Time-Stddev MAE=0.3731
[VAL] pr: RMSE=1.9573, Time-Mean RMSE=0.3214, Time-Stddev MAE=0.7825


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6543, Time-Mean RMSE=0.8279, Time-Stddev MAE=0.3677
[VAL] pr: RMSE=1.9537, Time-Mean RMSE=0.2905, Time-Stddev MAE=0.7749


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6517, Time-Mean RMSE=0.7872, Time-Stddev MAE=0.3804
[VAL] pr: RMSE=1.9547, Time-Mean RMSE=0.2978, Time-Stddev MAE=0.7744


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6075, Time-Mean RMSE=0.6919, Time-Stddev MAE=0.3888
[VAL] pr: RMSE=1.9630, Time-Mean RMSE=0.2962, Time-Stddev MAE=0.7973


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7637, Time-Mean RMSE=0.8964, Time-Stddev MAE=0.4736
[VAL] pr: RMSE=1.9494, Time-Mean RMSE=0.2864, Time-Stddev MAE=0.7623


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6955, Time-Mean RMSE=0.8363, Time-Stddev MAE=0.3692
[VAL] pr: RMSE=1.9670, Time-Mean RMSE=0.3540, Time-Stddev MAE=0.7925


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7471, Time-Mean RMSE=0.8315, Time-Stddev MAE=0.4544
[VAL] pr: RMSE=1.9711, Time-Mean RMSE=0.3593, Time-Stddev MAE=0.7299


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6354, Time-Mean RMSE=0.8211, Time-Stddev MAE=0.3773
[VAL] pr: RMSE=1.9591, Time-Mean RMSE=0.2870, Time-Stddev MAE=0.7803


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6447, Time-Mean RMSE=0.7917, Time-Stddev MAE=0.4126
[VAL] pr: RMSE=1.9829, Time-Mean RMSE=0.2965, Time-Stddev MAE=0.8023


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7237, Time-Mean RMSE=0.8814, Time-Stddev MAE=0.4722
[VAL] pr: RMSE=1.9531, Time-Mean RMSE=0.2756, Time-Stddev MAE=0.7440


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7863, Time-Mean RMSE=0.9047, Time-Stddev MAE=0.5413
[VAL] pr: RMSE=1.9535, Time-Mean RMSE=0.2755, Time-Stddev MAE=0.7658


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6107, Time-Mean RMSE=0.7510, Time-Stddev MAE=0.3944
[VAL] pr: RMSE=1.9598, Time-Mean RMSE=0.2913, Time-Stddev MAE=0.8310


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6505, Time-Mean RMSE=0.7862, Time-Stddev MAE=0.3985
[VAL] pr: RMSE=1.9691, Time-Mean RMSE=0.3282, Time-Stddev MAE=0.8566


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6288, Time-Mean RMSE=0.7782, Time-Stddev MAE=0.3390
[VAL] pr: RMSE=1.9593, Time-Mean RMSE=0.3142, Time-Stddev MAE=0.7768


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5768, Time-Mean RMSE=0.7598, Time-Stddev MAE=0.3362
[VAL] pr: RMSE=1.9501, Time-Mean RMSE=0.2955, Time-Stddev MAE=0.7610


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5517, Time-Mean RMSE=0.7057, Time-Stddev MAE=0.3244
[VAL] pr: RMSE=1.9570, Time-Mean RMSE=0.3041, Time-Stddev MAE=0.7625


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6362, Time-Mean RMSE=0.7355, Time-Stddev MAE=0.3341
[VAL] pr: RMSE=1.9608, Time-Mean RMSE=0.3190, Time-Stddev MAE=0.8046


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6985, Time-Mean RMSE=0.9267, Time-Stddev MAE=0.4490
[VAL] pr: RMSE=1.9598, Time-Mean RMSE=0.3127, Time-Stddev MAE=0.7998


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6274, Time-Mean RMSE=0.8218, Time-Stddev MAE=0.3754
[VAL] pr: RMSE=1.9504, Time-Mean RMSE=0.2529, Time-Stddev MAE=0.8087


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5371, Time-Mean RMSE=0.7187, Time-Stddev MAE=0.3614
[VAL] pr: RMSE=1.9590, Time-Mean RMSE=0.3282, Time-Stddev MAE=0.8012


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5758, Time-Mean RMSE=0.7333, Time-Stddev MAE=0.3436
[VAL] pr: RMSE=1.9555, Time-Mean RMSE=0.3023, Time-Stddev MAE=0.7813


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5779, Time-Mean RMSE=0.6677, Time-Stddev MAE=0.4728
[VAL] pr: RMSE=1.9537, Time-Mean RMSE=0.2719, Time-Stddev MAE=0.8222


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5171, Time-Mean RMSE=0.7028, Time-Stddev MAE=0.3622
[VAL] pr: RMSE=1.9561, Time-Mean RMSE=0.3038, Time-Stddev MAE=0.8090


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4878, Time-Mean RMSE=0.6269, Time-Stddev MAE=0.3107
[VAL] pr: RMSE=1.9442, Time-Mean RMSE=0.2468, Time-Stddev MAE=0.8049


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7956, Time-Mean RMSE=1.1376, Time-Stddev MAE=0.4103
[VAL] pr: RMSE=1.9526, Time-Mean RMSE=0.2795, Time-Stddev MAE=0.8058


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5567, Time-Mean RMSE=0.8098, Time-Stddev MAE=0.3214
[VAL] pr: RMSE=1.9485, Time-Mean RMSE=0.2647, Time-Stddev MAE=0.8105


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5146, Time-Mean RMSE=0.6672, Time-Stddev MAE=0.3764
[VAL] pr: RMSE=1.9428, Time-Mean RMSE=0.2542, Time-Stddev MAE=0.7759


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4780, Time-Mean RMSE=0.6545, Time-Stddev MAE=0.4076
[VAL] pr: RMSE=1.9545, Time-Mean RMSE=0.3010, Time-Stddev MAE=0.7798


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5517, Time-Mean RMSE=0.7878, Time-Stddev MAE=0.4016
[VAL] pr: RMSE=1.9530, Time-Mean RMSE=0.2822, Time-Stddev MAE=0.8106


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4833, Time-Mean RMSE=0.6239, Time-Stddev MAE=0.3312
[VAL] pr: RMSE=1.9536, Time-Mean RMSE=0.2807, Time-Stddev MAE=0.8147


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6349, Time-Mean RMSE=0.8791, Time-Stddev MAE=0.4052
[VAL] pr: RMSE=1.9567, Time-Mean RMSE=0.2863, Time-Stddev MAE=0.7944


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5132, Time-Mean RMSE=0.6672, Time-Stddev MAE=0.3831
[VAL] pr: RMSE=1.9443, Time-Mean RMSE=0.2408, Time-Stddev MAE=0.7895


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4767, Time-Mean RMSE=0.6888, Time-Stddev MAE=0.3650
[VAL] pr: RMSE=1.9499, Time-Mean RMSE=0.2653, Time-Stddev MAE=0.8087


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5038, Time-Mean RMSE=0.6667, Time-Stddev MAE=0.3997
[VAL] pr: RMSE=1.9479, Time-Mean RMSE=0.2848, Time-Stddev MAE=0.7644


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5242, Time-Mean RMSE=0.7204, Time-Stddev MAE=0.3279
[VAL] pr: RMSE=1.9552, Time-Mean RMSE=0.3152, Time-Stddev MAE=0.7953


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4693, Time-Mean RMSE=0.6007, Time-Stddev MAE=0.3238
[VAL] pr: RMSE=1.9493, Time-Mean RMSE=0.2574, Time-Stddev MAE=0.8068


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4960, Time-Mean RMSE=0.6743, Time-Stddev MAE=0.3077
[VAL] pr: RMSE=1.9499, Time-Mean RMSE=0.2804, Time-Stddev MAE=0.7913


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4578, Time-Mean RMSE=0.6034, Time-Stddev MAE=0.3090
[VAL] pr: RMSE=1.9478, Time-Mean RMSE=0.2810, Time-Stddev MAE=0.7847


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.5068, Time-Mean RMSE=0.7775, Time-Stddev MAE=0.3089
[VAL] pr: RMSE=1.9564, Time-Mean RMSE=0.3019, Time-Stddev MAE=0.7543


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4179, Time-Mean RMSE=0.5937, Time-Stddev MAE=0.2921
[VAL] pr: RMSE=1.9536, Time-Mean RMSE=0.2893, Time-Stddev MAE=0.7977


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4739, Time-Mean RMSE=0.6184, Time-Stddev MAE=0.3607
[VAL] pr: RMSE=1.9611, Time-Mean RMSE=0.3263, Time-Stddev MAE=0.7936


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4501, Time-Mean RMSE=0.6357, Time-Stddev MAE=0.3269
[VAL] pr: RMSE=1.9528, Time-Mean RMSE=0.2791, Time-Stddev MAE=0.8194


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3733, Time-Mean RMSE=0.5008, Time-Stddev MAE=0.3343
[VAL] pr: RMSE=1.9486, Time-Mean RMSE=0.2710, Time-Stddev MAE=0.7876


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4331, Time-Mean RMSE=0.6116, Time-Stddev MAE=0.2772
[VAL] pr: RMSE=1.9599, Time-Mean RMSE=0.2957, Time-Stddev MAE=0.8240


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4648, Time-Mean RMSE=0.6033, Time-Stddev MAE=0.3088
[VAL] pr: RMSE=1.9576, Time-Mean RMSE=0.3101, Time-Stddev MAE=0.7862


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3944, Time-Mean RMSE=0.5210, Time-Stddev MAE=0.2892
[VAL] pr: RMSE=1.9550, Time-Mean RMSE=0.2863, Time-Stddev MAE=0.8163


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4491, Time-Mean RMSE=0.6683, Time-Stddev MAE=0.3358
[VAL] pr: RMSE=1.9564, Time-Mean RMSE=0.2988, Time-Stddev MAE=0.8156


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4103, Time-Mean RMSE=0.5630, Time-Stddev MAE=0.3096
[VAL] pr: RMSE=1.9562, Time-Mean RMSE=0.2999, Time-Stddev MAE=0.8038


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3955, Time-Mean RMSE=0.5067, Time-Stddev MAE=0.2697
[VAL] pr: RMSE=1.9531, Time-Mean RMSE=0.2825, Time-Stddev MAE=0.7905


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4219, Time-Mean RMSE=0.5782, Time-Stddev MAE=0.3033
[VAL] pr: RMSE=1.9628, Time-Mean RMSE=0.3174, Time-Stddev MAE=0.8153


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4062, Time-Mean RMSE=0.5799, Time-Stddev MAE=0.3065
[VAL] pr: RMSE=1.9533, Time-Mean RMSE=0.2873, Time-Stddev MAE=0.7794


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3712, Time-Mean RMSE=0.4769, Time-Stddev MAE=0.2965
[VAL] pr: RMSE=1.9619, Time-Mean RMSE=0.3103, Time-Stddev MAE=0.7990


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4638, Time-Mean RMSE=0.6672, Time-Stddev MAE=0.3168
[VAL] pr: RMSE=1.9657, Time-Mean RMSE=0.3233, Time-Stddev MAE=0.8195


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3913, Time-Mean RMSE=0.5283, Time-Stddev MAE=0.2958
[VAL] pr: RMSE=1.9576, Time-Mean RMSE=0.2952, Time-Stddev MAE=0.7961


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4291, Time-Mean RMSE=0.6160, Time-Stddev MAE=0.3216
[VAL] pr: RMSE=1.9575, Time-Mean RMSE=0.2929, Time-Stddev MAE=0.7940


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3639, Time-Mean RMSE=0.4894, Time-Stddev MAE=0.2926
[VAL] pr: RMSE=1.9569, Time-Mean RMSE=0.2946, Time-Stddev MAE=0.7907


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3831, Time-Mean RMSE=0.5222, Time-Stddev MAE=0.3046
[VAL] pr: RMSE=1.9533, Time-Mean RMSE=0.2748, Time-Stddev MAE=0.7913


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4207, Time-Mean RMSE=0.6040, Time-Stddev MAE=0.2781
[VAL] pr: RMSE=1.9612, Time-Mean RMSE=0.3045, Time-Stddev MAE=0.7921


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3949, Time-Mean RMSE=0.5316, Time-Stddev MAE=0.3019
[VAL] pr: RMSE=1.9610, Time-Mean RMSE=0.3058, Time-Stddev MAE=0.7885


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3967, Time-Mean RMSE=0.5106, Time-Stddev MAE=0.2840
[VAL] pr: RMSE=1.9576, Time-Mean RMSE=0.2910, Time-Stddev MAE=0.7973


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4150, Time-Mean RMSE=0.5913, Time-Stddev MAE=0.2946
[VAL] pr: RMSE=1.9636, Time-Mean RMSE=0.3168, Time-Stddev MAE=0.7925


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3960, Time-Mean RMSE=0.5485, Time-Stddev MAE=0.3146
[VAL] pr: RMSE=1.9626, Time-Mean RMSE=0.3119, Time-Stddev MAE=0.7996


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3862, Time-Mean RMSE=0.5288, Time-Stddev MAE=0.3053
[VAL] pr: RMSE=1.9625, Time-Mean RMSE=0.3112, Time-Stddev MAE=0.7892


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3779, Time-Mean RMSE=0.5255, Time-Stddev MAE=0.2835
[VAL] pr: RMSE=1.9592, Time-Mean RMSE=0.2985, Time-Stddev MAE=0.7931


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4139, Time-Mean RMSE=0.5846, Time-Stddev MAE=0.3000
[VAL] pr: RMSE=1.9622, Time-Mean RMSE=0.3122, Time-Stddev MAE=0.7892


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3987, Time-Mean RMSE=0.5615, Time-Stddev MAE=0.3014
[VAL] pr: RMSE=1.9618, Time-Mean RMSE=0.3109, Time-Stddev MAE=0.7908


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3876, Time-Mean RMSE=0.5264, Time-Stddev MAE=0.2971
[VAL] pr: RMSE=1.9645, Time-Mean RMSE=0.3157, Time-Stddev MAE=0.7924


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.3864, Time-Mean RMSE=0.5232, Time-Stddev MAE=0.3058
[VAL] pr: RMSE=1.9644, Time-Mean RMSE=0.3155, Time-Stddev MAE=0.7996


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.4028, Time-Mean RMSE=0.5713, Time-Stddev MAE=0.2953
[VAL] pr: RMSE=1.9651, Time-Mean RMSE=0.3193, Time-Stddev MAE=0.7912


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


[VAL] tas: RMSE=1.3870, Time-Mean RMSE=0.5087, Time-Stddev MAE=0.2986
[VAL] pr: RMSE=1.9626, Time-Mean RMSE=0.3074, Time-Stddev MAE=0.7975


# Test model

**IMPORTANT:** Please note that the test metrics will be bad because the test targets have been corrupted on the public Kaggle dataset.
The purpose of testing below is to generate the Kaggle submission file based on your model's predictions, which you can submit to the competition.

In [None]:
trainer.test(lightning_module, datamodule=datamodule) 

### Plotting Utils


In [None]:
def plot_comparison(true_xr, pred_xr, title, cmap='viridis', diff_cmap='RdBu_r', metric=None):
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    vmin = min(true_xr.min().item(), pred_xr.min().item())
    vmax = max(true_xr.max().item(), pred_xr.max().item())

    # Ground truth
    true_xr.plot(ax=axs[0], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[0].set_title(f"{title} (Ground Truth)")

    # Prediction
    pred_xr.plot(ax=axs[1], cmap=cmap, vmin=vmin, vmax=vmax, add_colorbar=True)
    axs[1].set_title(f"{title} (Prediction)")

    # Difference
    diff = pred_xr - true_xr
    abs_max = np.max(np.abs(diff))
    diff.plot(ax=axs[2], cmap=diff_cmap, vmin=-abs_max, vmax=abs_max, add_colorbar=True)
    axs[2].set_title(f"{title} (Difference) {f'- {metric:.4f}' if metric else ''}")

    plt.tight_layout()
    plt.show()


### 🖼️ Visualizing Validation Predictions

This cell loads saved validation predictions and compares them to the ground truth using spatial plots. These visualizations help you qualitatively assess your model's performance.

For each output variable (`tas`, `pr`), we visualize:

- **📈 Time-Mean Map**: The 10-year average spatial pattern for both prediction and ground truth. Helps identify long-term biases or spatial shifts.
- **📊 Time-Stddev Map**: Shows the standard deviation across time for each grid cell — useful for assessing how well the model captures **temporal variability** at each location.
- **🕓 Random Timestep Sample**: Visual comparison of prediction vs ground truth for a single month. Useful for spotting fine-grained anomalies or errors in specific months.

> These plots provide intuition beyond metrics and are useful for debugging spatial or temporal model failures.


In [None]:
# Load validation predictions
# make sure to have run the validation loop at least once
val_preds = np.load("val_preds.npy")
val_trues = np.load("val_trues.npy")

lat, lon = datamodule.get_coords()
output_vars = datamodule.output_vars
time = np.arange(val_preds.shape[0])

for i, var in enumerate(output_vars):
    pred_xr = xr.DataArray(val_preds[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})
    true_xr = xr.DataArray(val_trues[:, i], dims=["time", "y", "x"], coords={"time": time, "y": lat, "x": lon})

    # --- Time Mean ---
    plot_comparison(true_xr.mean("time"), pred_xr.mean("time"), f"{var} Val Time-Mean")

    # --- Time Stddev ---
    plot_comparison(true_xr.std("time"), pred_xr.std("time"), f"{var} Val Time-Stddev", cmap="plasma")

    # --- Random timestep ---
    t_idx = np.random.randint(0, len(time))
    plot_comparison(true_xr.isel(time=t_idx), pred_xr.isel(time=t_idx), f"{var} Val Sample Timestep {t_idx}")


## 🧪 Final Notes

This notebook is meant to serve as a **baseline template** — a starting point to help you get up and running quickly with the climate emulation challenge.

You are **not** required to stick to this exact setup. In fact, we **encourage** you to:

- 🔁 Build on top of the provided `DataModule`. 
- 🧠 Use your own model architectures or training pipelines that you’re more comfortable with 
- ⚗️ Experiment with ideas  
- 🥇 Compete creatively to climb the Kaggle leaderboard  
- 🙌 Most importantly: **have fun** and **learn as much as you can** along the way

This challenge simulates a real-world scientific problem, and there’s no single "correct" approach — so be curious, experiment boldly, and make it your own!


In [None]:
pd.read_csv('submissions/kaggle_submission_20250522_062724.csv')