In [1]:
# Cell 1: Imports
# Basic OS, date/time, and numerical libraries
import os
from datetime import datetime
import numpy as np

# Libraries for data handling and scientific computing
import xarray as xr
import dask.array as da # For handling large arrays that don't fit in memory

# PyTorch libraries for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F # For functions like F.pad
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# PyTorch Lightning for streamlining training
import lightning.pytorch as pl

# Plotting library
import matplotlib.pyplot as plt

# Pandas for data manipulation (e.g., creating submission CSV)
import pandas as pd

In [2]:
# --------------------------------
# Cell 2 – Configuration (ResNet)
# --------------------------------
config = {
    "data": {
        "path": "processed_data_cse151b_v2_corrupted_ssp245/"
                "processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars":  ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "target_member_id": 0,
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp":    "ssp245",
        "test_months": 120,
        "batch_size":  64,
        "num_workers": 4,
    },

    # -------------  Model block -------------
    "model_resnet": {                 # ← change key
        "type": "resnet",             # ← will be parsed in your model-factory
        "arch": "resnet18",           # resnet18 / 34 / 50 / 101 …
        "in_channels": 5,             # len(input_vars)
        "out_channels": 2,            # len(output_vars)
        "pretrained": False,          # True if using ImageNet weights
        "replace_first_conv": True,   # set to True if in_channels ≠ 3
        "fc_hidden": 256,             # a small FC head before output layer
        "dropout": 0.1,               # dropout prob in the head
    },

    # --------  Optimizer / training ---------
    "training": {
        "optimizer": "AdamW",         # AdamW works well with ResNets
        "lr": 3e-4,                   # start a bit lower than 1e-3
        "weight_decay": 1e-4,
        "lr_scheduler": {
            "name": "CosineAnnealingLR",
            "T_max": 10,              # epochs for one cosine cycle
            "eta_min": 1e-5,          # minimum LR
        },
        "grad_clip": 1.0,             # clip to stabilise large grads
    },

    # -------------  Trainer block -----------
    "trainer": {
        "max_epochs": 10,
        "accelerator": "auto",
        "devices":     "auto",
        "precision":   16,            # try mixed precision for speed
        "deterministic": True,
        "num_sanity_val_steps": 0,
        # Example callbacks (uncomment if you have them in code):
        # "callbacks": [
        #     {"class_path": "pl.callbacks.ModelCheckpoint",
        #      "init_args": {"save_top_k": 2, "monitor": "val_loss"}},
        #     {"class_path": "pl.callbacks.EarlyStopping",
        #      "init_args": {"monitor": "val_loss", "patience": 5}}
        # ],
    },

    "seed": 42,
}

# Set seeds, torch float-32 matmul precision hint
pl.seed_everything(config["seed"], workers=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7:
    torch.set_float32_matmul_precision("medium")
    print("Tensor Core utilisation enabled (‘medium’).")


Seed set to 42


Tensor Core utilisation enabled (‘medium’).


In [15]:
# --------------------------------
# Cell 2  – Updated Configuration
# --------------------------------
config = {
    # ---------- DATA ----------
    "data": {
        "path": "processed_data_cse151b_v2_corrupted_ssp245/"
                "processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars":  ["CO2", "SO2", "CH4", "BC", "rsdt"],
        "output_vars": ["tas", "pr"],
        "target_member_id": 0,
        "train_ssps": ["ssp126", "ssp370", "ssp585"],
        "test_ssp":    "ssp245",
        "test_months": 120,
        "batch_size":  64,
        "num_workers": 4,
    },

    # ---------- MODEL ----------
    "model_resnet": {
        "type":              "resnet",
        "arch":              "resnet18",   # 18 / 34 / 50 / 101 / …
        "in_channels":        5,           # len(input_vars)
        "out_channels":       2,           # len(output_vars)
        "pretrained":         False,
        "replace_first_conv": True,
        "fc_hidden":         256,
        "dropout":           0.10,
    },

    # ---------- OPTIMISER & SCHEDULER ----------
    "training": {
        "optimizer":      "AdamW",
        "optim_args": {                  # ← anything here is passed to the optimiser
            "lr":            3e-4,       # 🡐 change learning-rate here
            "weight_decay":  1e-4,       # 🡐 change weight-decay here
            "betas":         (0.9, 0.999)
        },

        # Cosine-annealing schedule (easy to swap)
        "lr_scheduler": {
            "name":   "CosineAnnealingLR",
            "T_max":  10,                # one full cosine cycle over 10 epochs
            "eta_min": 1e-5,
        },

        "grad_clip": 1.0,                # global-norm clip
    },

    # ---------- TRAINER ----------
    "trainer": {
        "max_epochs": 20,
        "accelerator": "auto",
        "devices":     "auto",
        "precision":   16,               # AMP
        "deterministic": True,
        "num_sanity_val_steps": 0,
        "log_every_n_steps": 10,

        # Uncomment if you already import the callbacks elsewhere
        # "callbacks": [
        #     {
        #         "class_path": "lightning.pytorch.callbacks.ModelCheckpoint",
        #         "init_args":  {"save_top_k": 2, "monitor": "val/loss", "mode": "min"}
        #     },
        #     {
        #         "class_path": "lightning.pytorch.callbacks.EarlyStopping",
        #         "init_args":  {"monitor": "val/loss", "patience": 4, "mode": "min"}
        #     }
        # ],
    },

    "seed": 42,
}

# ------------- misc initialisation -------------
pl.seed_everything(config["seed"], workers=True)

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7:
    torch.set_float32_matmul_precision("medium")
    print("Tensor Core utilisation enabled ('medium').")


Seed set to 42


Tensor Core utilisation enabled ('medium').


In [16]:
# Cell 3: Latitude Weights Utility

def get_lat_weights(latitude_values):
    """
    Computes cosine-based area weights for each latitude.
    This accounts for the Earth's curvature, giving more weight to
    grid cells near the equator for global metrics.

    Args:
        latitude_values (np.array): Array of latitude values in degrees.

    Returns:
        np.array: Normalized latitude weights.
    """
    lat_rad = np.deg2rad(latitude_values) # Convert degrees to radians
    weights = np.cos(lat_rad)             # Cosine of latitude
    return weights / np.mean(weights)     # Normalize by the mean weight


In [17]:
# Cell 4: Normalizer Class

class Normalizer:
    """
    Handles Z-score normalization for input and output data.
    (data - mean) / std
    """
    def __init__(self):
        self.mean_in, self.std_in = None, None   # Statistics for input data
        self.mean_out, self.std_out = None, None # Statistics for output data

    def set_input_statistics(self, mean, std):
        """Sets the mean and standard deviation for input data."""
        self.mean_in = mean
        self.std_in = std

    def set_output_statistics(self, mean, std):
        """Sets the mean and standard deviation for output data."""
        self.mean_out = mean
        self.std_out = std

    def normalize(self, data, data_type):
        """
        Normalizes the data using pre-computed statistics.

        Args:
            data (np.array or dask.array): Data to normalize.
            data_type (str): "input" or "output", to use appropriate statistics.

        Returns:
            Normalized data.
        
        Raises:
            ValueError: If statistics for the specified data_type are not set.
        """
        if data_type == "input":
            if self.mean_in is None or self.std_in is None:
                raise ValueError("Input statistics not set in Normalizer.")
            # Add a small epsilon to std to prevent division by zero if std is very small or zero
            return (data - self.mean_in) / (self.std_in + 1e-8) 
        elif data_type == "output":
            if self.mean_out is None or self.std_out is None:
                raise ValueError("Output statistics not set in Normalizer.")
            return (data - self.mean_out) / (self.std_out + 1e-8)
        else:
            raise ValueError(f"Invalid data_type '{data_type}'. Must be 'input' or 'output'.")

    def inverse_transform_output(self, data):
        """
        Applies inverse normalization to output data (predictions).

        Args:
            data (torch.Tensor or np.array): Normalized output data.

        Returns:
            Data in original physical units.

        Raises:
            ValueError: If output statistics are not set.
        """
        if self.mean_out is None or self.std_out is None:
            raise ValueError("Output statistics not set in Normalizer for inverse transform.")
        return data * (self.std_out + 1e-8) + self.mean_out


In [18]:
# ----- Dense-output ResNet-FCN ----------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F

# ─── building blocks ─────────────────────────────────────────────────────────
class BasicBlock(nn.Module):
    expansion: int = 1
    def __init__(self, in_c: int, out_c: int, stride: int = 1,
                 downsample: nn.Module | None = None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, stride, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.down  = downsample
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x if self.down is None else self.down(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + identity)

class Bottleneck(nn.Module):
    expansion: int = 4
    def __init__(self, in_c: int, out_c: int, stride: int = 1,
                 downsample: nn.Module | None = None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, stride, 1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_c)
        self.conv3 = nn.Conv2d(out_c, out_c * self.expansion, 1, bias=False)
        self.bn3   = nn.BatchNorm2d(out_c * self.expansion)
        self.down  = downsample
        self.relu  = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x if self.down is None else self.down(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        return self.relu(out + identity)

# ─── ResNet-FCN backbone ─────────────────────────────────────────────────────
class ResNet(nn.Module):
    """
    Returns dense (B, C_out, H, W) maps.
    Works for input H, W divisible by 8 (48×72, 96×144, …).
    """
    def __init__(self,
                 depth: int              = 18,
                 n_input_channels: int   = 5,
                 n_output_classes: int   = 2):
        super().__init__()

        cfg = {18:  (BasicBlock,  [2, 2, 2, 2]),
               34:  (BasicBlock,  [3, 4, 6, 3]),
               50:  (Bottleneck,  [3, 4, 6, 3]),
               101: (Bottleneck,  [3, 4, 23, 3]),
               152: (Bottleneck,  [3, 8, 36, 3])}[depth]

        block, layers = cfg
        self.in_c = 64

        # ── stem (/2) ────────────────────────────────────────────────────────
        self.stem = nn.Sequential(
            nn.Conv2d(n_input_channels, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1)   # /2
        )

        # ── residual stages ──────────────────────────────────────────────────
        self.layer1 = self._make_layer(block,  64, layers[0], stride=1)  # /2
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)  # /4
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)  # /8
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1)  # keep /8

        # ── prediction head ─────────────────────────────────────────────────
        self.head = nn.Conv2d(512 * block.expansion,
                              n_output_classes, kernel_size=1)

        # init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def _make_layer(self, block, out_c, blocks, stride):
        down = None
        if stride != 1 or self.in_c != out_c * block.expansion:
            down = nn.Sequential(
                nn.Conv2d(self.in_c, out_c * block.expansion,
                          1, stride, bias=False),
                nn.BatchNorm2d(out_c * block.expansion)
            )
        layers = [block(self.in_c, out_c, stride, down)]
        self.in_c = out_c * block.expansion
        layers += [block(self.in_c, out_c) for _ in range(1, blocks)]
        return nn.Sequential(*layers)

    # ── forward ─────────────────────────────────────────────────────────────
    def forward(self, x):
        H, W = x.shape[-2:]
        x = self.stem(x)
        x = self.layer1(x); x = self.layer2(x)
        x = self.layer3(x); x = self.layer4(x)
        x = self.head(x)                       # (B, C_out, H/8, W/8)
        return F.interpolate(x, (H, W),
                             mode='bilinear',
                             align_corners=False)


In [19]:
# Cell 6: ClimateDataset and ClimateDataModule

class ClimateDataset(Dataset):
    def __init__(self, inputs_dask, outputs_dask, output_is_normalized=True):
        """
        PyTorch Dataset for climate data.

        Args:
            inputs_dask (dask.array): Dask array of input features.
            outputs_dask (dask.array): Dask array of output targets.
            output_is_normalized (bool): Flag indicating if outputs_dask is already normalized.
                                         Used for the test set where targets are not pre-normalized.
        """
        self.size = inputs_dask.shape[0]
        print(f"Creating dataset with {self.size} samples...")

        inputs_np = inputs_dask.compute()
        outputs_np = outputs_dask.compute()

        self.inputs = torch.from_numpy(inputs_np).float()
        self.outputs = torch.from_numpy(outputs_np).float()

        if torch.isnan(self.inputs).any():
            raise ValueError("NaNs found in input dataset after converting to tensor.")
        if torch.isnan(self.outputs).any():
            raise ValueError("NaNs found in output dataset after converting to tensor.")
        
        print(f"Dataset created. Input shape: {self.inputs.shape}, Output shape: {self.outputs.shape}")


    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]


class ClimateDataModule(pl.LightningDataModule):
    def __init__(
        self,
        path,
        input_vars,
        output_vars,
        train_ssps,
        test_ssp,
        target_member_id,
        test_months=120,
        batch_size=32,
        num_workers=0,
        seed=42, 
    ):
        super().__init__()
        self.save_hyperparameters() 
        self.normalizer = Normalizer()

    def prepare_data(self):
        if not os.path.exists(self.hparams.path):
            raise FileNotFoundError(f"Data path not found: {self.hparams.path}. Please check config['data']['path'].")

    def setup(self, stage=None):
        ds = xr.open_zarr(self.hparams.path, consolidated=False, chunks={"time": 24})
        
        # --- FIX for spatial_template ---
        # The 'rsdt' variable might not have 'member_id'. Handle this conditionally.
        rsdt_var_for_template = ds["rsdt"]
        if "member_id" in rsdt_var_for_template.dims:
            spatial_template = rsdt_var_for_template.isel(time=0, ssp=0, member_id=0, drop=True)
        else:
            # If 'member_id' is not present, select without it.
            # This assumes 'rsdt' is consistent across members or doesn't have that dimension.
            spatial_template = rsdt_var_for_template.isel(time=0, ssp=0, drop=True)
        # --- END FIX ---

        def load_ssp(ssp_name):
            input_dask_list, output_dask_list = [], []
            
            for var_name in self.hparams.input_vars:
                da_var = ds[var_name].sel(ssp=ssp_name)
                if "latitude" in da_var.dims: 
                    da_var = da_var.rename({"latitude": "y", "longitude": "x"})
                # For input variables, if member_id exists, select the target_member_id.
                # If it doesn't exist (e.g. for some forcing data), this sel will be a no-op if strict=False,
                # or we can check existence. Xarray's sel is usually robust if the dim doesn't exist.
                # However, to be safe, let's check.
                if "member_id" in da_var.dims:
                    da_var = da_var.sel(member_id=self.hparams.target_member_id)
                
                if set(da_var.dims) == {"time"}: 
                    da_var = da_var.broadcast_like(spatial_template).transpose("time", "y", "x")
                input_dask_list.append(da_var.data)
            
            for var_name in self.hparams.output_vars:
                # Output variables are always selected by target_member_id
                da_out = ds[var_name].sel(ssp=ssp_name, member_id=self.hparams.target_member_id)
                if "latitude" in da_out.dims: 
                    da_out = da_out.rename({"latitude": "y", "longitude": "x"})
                output_dask_list.append(da_out.data)

            return da.stack(input_dask_list, axis=1), da.stack(output_dask_list, axis=1)

        train_input_list, train_output_list = [], []
        val_input_ssp370, val_output_ssp370 = None, None

        for ssp in self.hparams.train_ssps:
            x_ssp, y_ssp = load_ssp(ssp)
            if ssp == "ssp370": 
                val_input_ssp370 = x_ssp[-self.hparams.test_months:]
                val_output_ssp370 = y_ssp[-self.hparams.test_months:]
                train_input_list.append(x_ssp[:-self.hparams.test_months])
                train_output_list.append(y_ssp[:-self.hparams.test_months])
            else:
                train_input_list.append(x_ssp)
                train_output_list.append(y_ssp)
        
        train_input_all_ssp = da.concatenate(train_input_list, axis=0)
        train_output_all_ssp = da.concatenate(train_output_list, axis=0)

        input_mean = da.nanmean(train_input_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        input_std = da.nanstd(train_input_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        self.normalizer.set_input_statistics(mean=input_mean, std=input_std)

        output_mean = da.nanmean(train_output_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        output_std = da.nanstd(train_output_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        self.normalizer.set_output_statistics(mean=output_mean, std=output_std)

        train_input_norm = self.normalizer.normalize(train_input_all_ssp, "input")
        train_output_norm = self.normalizer.normalize(train_output_all_ssp, "output")
        
        val_input_norm = self.normalizer.normalize(val_input_ssp370, "input")
        val_output_norm = self.normalizer.normalize(val_output_ssp370, "output")

        test_input_ssp, test_output_ssp = load_ssp(self.hparams.test_ssp)
        test_input_ssp = test_input_ssp[-self.hparams.test_months:] 
        test_output_ssp = test_output_ssp[-self.hparams.test_months:]
        test_input_norm = self.normalizer.normalize(test_input_ssp, "input")

        if stage == "fit" or stage is None:
            self.train_dataset = ClimateDataset(train_input_norm, train_output_norm)
            self.val_dataset = ClimateDataset(val_input_norm, val_output_norm)
        if stage == "test" or stage is None:
            self.test_dataset = ClimateDataset(test_input_norm, test_output_ssp, output_is_normalized=False)
        
        self.lat = spatial_template.y.values
        self.lon = spatial_template.x.values
        self.area_weights = xr.DataArray(get_lat_weights(self.lat), dims=["y"], coords={"y": self.lat})

        ds.close()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, shuffle=False,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def get_lat_weights(self):
        return self.area_weights

    def get_coords(self):
        return self.lat, self.lon


In [20]:
# Cell 7: ClimateEmulationModule (PyTorch Lightning)

class ClimateEmulationModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model 
        self.save_hyperparameters(ignore=['model']) 
        
        self.criterion = nn.MSELoss() 
        self.normalizer = None 
        
        self.val_preds, self.val_targets = [], []
        self.test_preds, self.test_targets = [], []

    def forward(self, x):
        return self.model(x)

    def _get_normalizer_from_datamodule(self):
        """Helper to safely get normalizer from datamodule."""
        if hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None and hasattr(self.trainer.datamodule, 'normalizer'):
            return self.trainer.datamodule.normalizer
        else:
            # Fallback if trainer.datamodule is not set up (e.g. direct call to test without fit)
            # This requires 'config' to be globally accessible or passed differently.
            print("Warning: Normalizer not found via self.trainer.datamodule. Attempting fallback initialization.")
            temp_dm = ClimateDataModule(**config["data"]) 
            temp_dm.prepare_data()
            temp_dm.setup(stage="test") # Or appropriate stage to ensure normalizer stats are computed
            return temp_dm.normalizer


    def on_fit_start(self):
        """Called at the beginning of training."""
        self.normalizer = self._get_normalizer_from_datamodule()

    def on_test_start(self):
        """Called at the beginning of testing."""
        if self.normalizer is None: # Ensure normalizer is available
            self.normalizer = self._get_normalizer_from_datamodule()


    def training_step(self, batch, batch_idx):
        x, y_norm = batch 
        y_hat_norm = self(x)   
        loss = self.criterion(y_hat_norm, y_norm)
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_norm = batch
        y_hat_norm = self(x)
        loss = self.criterion(y_hat_norm, y_norm)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()

        y_hat_denorm = self.normalizer.inverse_transform_output(y_hat_norm.detach().cpu().numpy())
        y_denorm = self.normalizer.inverse_transform_output(y_norm.detach().cpu().numpy())
        
        self.val_preds.append(y_hat_denorm)
        self.val_targets.append(y_denorm)
        return loss 

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking: # Skip during sanity check
            if not self.val_preds or not self.val_targets: 
                return

            preds_epoch = np.concatenate(self.val_preds, axis=0)
            trues_epoch = np.concatenate(self.val_targets, axis=0)
            
            if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()
            
            self._evaluate(preds_epoch, trues_epoch, phase="val")
            
            np.save("val_preds.npy", preds_epoch)
            np.save("val_trues.npy", trues_epoch)
            
            self.val_preds.clear() 
            self.val_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y_true_denorm = batch 
        y_hat_norm = self(x)    

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()
        
        y_hat_denorm = self.normalizer.inverse_transform_output(y_hat_norm.detach().cpu().numpy())
        
        self.test_preds.append(y_hat_denorm)
        self.test_targets.append(y_true_denorm.detach().cpu().numpy()) 

    def on_test_epoch_end(self):
        if not self.test_preds or not self.test_targets: 
            return

        preds_epoch = np.concatenate(self.test_preds, axis=0)
        trues_epoch = np.concatenate(self.test_targets, axis=0)

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()

        self._evaluate(preds_epoch, trues_epoch, phase="test")
        self._save_submission(preds_epoch) 
        
        self.test_preds.clear()
        self.test_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

    def _evaluate(self, preds_np, trues_np, phase="val"):
        """Calculates and logs evaluation metrics."""
        # This check is important for when _evaluate might be called outside trainer.fit/test context
        # or if datamodule is not correctly propagated.
        if self.trainer.datamodule is None or not hasattr(self.trainer.datamodule, 'get_lat_weights'):
            print("Warning: self.trainer.datamodule not fully available in _evaluate. Using fallback for coords/weights.")
            dm_eval = ClimateDataModule(**config["data"]) # Re-init for coords, assumes global config
            dm_eval.prepare_data()
            dm_eval.setup(stage=phase) # Setup for the correct stage
            area_weights = dm_eval.get_lat_weights()
            lat, lon = dm_eval.get_coords()
            output_vars = dm_eval.hparams.output_vars
        else:
            area_weights = self.trainer.datamodule.get_lat_weights()
            lat, lon = self.trainer.datamodule.get_coords()
            output_vars = self.trainer.datamodule.hparams.output_vars


        time_coords = np.arange(preds_np.shape[0])
        metrics_summary = {}

        for i, var_name in enumerate(output_vars):
            p_var = preds_np[:, i] 
            t_var = trues_np[:, i] 
            
            p_xr = xr.DataArray(p_var, dims=["time", "y", "x"], coords={"time": time_coords, "y": lat, "x": lon})
            t_xr = xr.DataArray(t_var, dims=["time", "y", "x"], coords={"time": time_coords, "y": lat, "x": lon})

            rmse = np.sqrt(((p_xr - t_xr) ** 2).weighted(area_weights).mean()).item()
            mean_rmse = np.sqrt(((p_xr.mean("time") - t_xr.mean("time")) ** 2).weighted(area_weights).mean()).item()
            std_mae = np.abs(p_xr.std("time") - t_xr.std("time")).weighted(area_weights).mean().item()

            print(f"[{phase.upper()}] {var_name}: RMSE={rmse:.4f}, Time-Mean RMSE={mean_rmse:.4f}, Time-Stddev MAE={std_mae:.4f}")
            
            metrics_summary[f"{phase}/{var_name}/rmse"] = rmse
            metrics_summary[f"{phase}/{var_name}/time_mean_rmse"] = mean_rmse
            metrics_summary[f"{phase}/{var_name}/time_std_mae"] = std_mae
        
        self.log_dict(metrics_summary, logger=True)

    def _save_submission(self, predictions_np):
        """Saves model predictions to a CSV file in Kaggle submission format."""
        if self.trainer.datamodule is None or not hasattr(self.trainer.datamodule, 'get_coords'):
            print("Warning: self.trainer.datamodule not fully available in _save_submission. Using fallback.")
            dm_submission = ClimateDataModule(**config["data"])
            dm_submission.prepare_data()
            dm_submission.setup(stage="test") # Ensure coords are loaded
            lat, lon = dm_submission.get_coords()
            output_vars = dm_submission.hparams.output_vars
        else:
            lat, lon = self.trainer.datamodule.get_coords()
            output_vars = self.trainer.datamodule.hparams.output_vars
            
        time_coords_submission = np.arange(predictions_np.shape[0])

        rows = []
        for t_idx, t_val in enumerate(time_coords_submission):
            for var_idx, var_name in enumerate(output_vars):
                for y_idx, y_val in enumerate(lat):
                    for x_idx, x_val in enumerate(lon):
                        row_id = f"t{t_idx:03d}_{var_name}_{y_val:.2f}_{x_val:.2f}"
                        pred_value = predictions_np[t_idx, var_idx, y_idx, x_idx]
                        rows.append({"ID": row_id, "Prediction": pred_value})

        submission_df = pd.DataFrame(rows)
        submission_dir = "submissions"
        os.makedirs(submission_dir, exist_ok=True)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filepath = os.path.join(submission_dir, f"kaggle_submission_unet_{timestamp}.csv")
        submission_df.to_csv(filepath, index=False)
        print(f"Submission saved to: {filepath}")


In [12]:
# Cell 8: Training and Evaluation Script

# --- Instantiate DataModule ---
datamodule = ClimateDataModule(**config["data"])
# datamodule.prepare_data() # Called by Trainer when .fit() or .test() is called
# datamodule.setup()      # Called by Trainer when .fit() or .test() is called

# --- Instantiate U-Net Model ---
n_inputs = len(config["data"]["input_vars"])
n_outputs = len(config["data"]["output_vars"])

unet_config_params = config.get("model_unet", {}) 
init_features = unet_config_params.get("init_features", 64)
bilinear_upsampling = unet_config_params.get("bilinear", True)

resnet_model = ResNet(depth        = 18,
                      n_input_channels = n_inputs,
                      n_output_classes = n_outputs)

# --- Instantiate Lightning Module ---
learning_rate = config["training"]["lr"]
lightning_module = ClimateEmulationModule(resnet_model, learning_rate=learning_rate)

# --- Setup Trainer ---
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss", 
    mode="min",         
    filename="unet-best-{epoch:02d}-{val/loss:.2f}", 
    save_top_k=1,       
    verbose=True
)

early_stop_callback = EarlyStopping(
    monitor="val/loss",
    patience=5, 
    verbose=True,
    mode="min"
)

trainer_params = {**config["trainer"]} 
trainer_params["callbacks"] = [checkpoint_callback, early_stop_callback]
# Optional: Add logger
# from lightning.pytorch.loggers import TensorBoardLogger
# logger = TensorBoardLogger("tb_logs", name="unet_climate_emulation")
# trainer_params["logger"] = logger

trainer = pl.Trainer(**trainer_params)

# --- Train the Model ---
print("Starting U-Net model training...")
trainer.fit(lightning_module, datamodule=datamodule)
print("Training finished.")

# --- Test the Model ---
print("Starting U-Net model testing using the best checkpoint...")
# trainer.test will use the checkpoint_callback's best_model_path by default if available
# or you can specify ckpt_path="best"
test_results = trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best") 
print("Testing finished.")
print("Test Results:", test_results)


/home/ruz039/.local/lib/python3.11/site-packages/lightning/fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Starting U-Net model training...


2025-05-17 04:55:57.909004: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-17 04:55:57.949198: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-17 04:55:57.949245: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-17 04:55:57.950294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-17 04:55:57.957371: I tensorflow/core/platform/cpu_feature_guar

Creating dataset with 2943 samples...
Dataset created. Input shape: torch.Size([2943, 5, 48, 72]), Output shape: torch.Size([2943, 2, 48, 72])
Creating dataset with 120 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Dataset created. Input shape: torch.Size([120, 5, 48, 72]), Output shape: torch.Size([120, 2, 48, 72])



  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | ResNet  | 11.2 M | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.684    Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode
/home/ruz039/.local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (46) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved. New best score: 82.380
Epoch 0, global step 46: 'val/loss' reached 82.37978 (best 82.37978), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=00-val/loss=82.38.ckpt' as top 1


[VAL] tas: RMSE=175.5129, Time-Mean RMSE=96.2518, Time-Stddev MAE=128.9028
[VAL] pr: RMSE=31.8142, Time-Mean RMSE=26.1904, Time-Stddev MAE=14.3366


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 61.658 >= min_delta = 0.0. New best score: 20.722
Epoch 1, global step 92: 'val/loss' reached 20.72213 (best 20.72213), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=01-val/loss=20.72.ckpt' as top 1


[VAL] tas: RMSE=96.0668, Time-Mean RMSE=58.5966, Time-Stddev MAE=67.0241
[VAL] pr: RMSE=14.8811, Time-Mean RMSE=9.2109, Time-Stddev MAE=8.3847


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 6.616 >= min_delta = 0.0. New best score: 14.106
Epoch 2, global step 138: 'val/loss' reached 14.10643 (best 14.10643), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=02-val/loss=14.11.ckpt' as top 1


[VAL] tas: RMSE=64.6294, Time-Mean RMSE=25.7538, Time-Stddev MAE=53.9040
[VAL] pr: RMSE=14.2202, Time-Mean RMSE=9.9704, Time-Stddev MAE=6.9772


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.245 >= min_delta = 0.0. New best score: 13.861
Epoch 3, global step 184: 'val/loss' reached 13.86138 (best 13.86138), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=03-val/loss=13.86.ckpt' as top 1


[VAL] tas: RMSE=69.7610, Time-Mean RMSE=50.4383, Time-Stddev MAE=43.5849
[VAL] pr: RMSE=13.8794, Time-Mean RMSE=9.3984, Time-Stddev MAE=6.9822


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 6.226 >= min_delta = 0.0. New best score: 7.636
Epoch 4, global step 230: 'val/loss' reached 7.63560 (best 7.63560), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=04-val/loss=7.64.ckpt' as top 1


[VAL] tas: RMSE=51.4174, Time-Mean RMSE=28.6953, Time-Stddev MAE=38.1924
[VAL] pr: RMSE=9.8416, Time-Mean RMSE=4.7257, Time-Stddev MAE=5.5358


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 1.621 >= min_delta = 0.0. New best score: 6.014
Epoch 5, global step 276: 'val/loss' reached 6.01449 (best 6.01449), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=05-val/loss=6.01.ckpt' as top 1


[VAL] tas: RMSE=47.3929, Time-Mean RMSE=24.5678, Time-Stddev MAE=36.3817
[VAL] pr: RMSE=8.4873, Time-Mean RMSE=3.4949, Time-Stddev MAE=4.6664


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 1.082 >= min_delta = 0.0. New best score: 4.933
Epoch 6, global step 322: 'val/loss' reached 4.93287 (best 4.93287), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=06-val/loss=4.93.ckpt' as top 1


[VAL] tas: RMSE=43.0287, Time-Mean RMSE=16.1971, Time-Stddev MAE=35.4112
[VAL] pr: RMSE=7.2959, Time-Mean RMSE=2.8425, Time-Stddev MAE=3.7610


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 1.030 >= min_delta = 0.0. New best score: 3.903
Epoch 7, global step 368: 'val/loss' reached 3.90281 (best 3.90281), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=07-val/loss=3.90.ckpt' as top 1


[VAL] tas: RMSE=39.0149, Time-Mean RMSE=15.5968, Time-Stddev MAE=31.4980
[VAL] pr: RMSE=6.6639, Time-Mean RMSE=2.4375, Time-Stddev MAE=3.2626


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.292 >= min_delta = 0.0. New best score: 3.611
Epoch 8, global step 414: 'val/loss' reached 3.61073 (best 3.61073), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=08-val/loss=3.61.ckpt' as top 1


[VAL] tas: RMSE=38.4802, Time-Mean RMSE=22.0491, Time-Stddev MAE=27.3383
[VAL] pr: RMSE=6.3529, Time-Mean RMSE=2.3403, Time-Stddev MAE=3.0168


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.338 >= min_delta = 0.0. New best score: 3.272
Epoch 9, global step 460: 'val/loss' reached 3.27225 (best 3.27225), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=09-val/loss=3.27.ckpt' as top 1


[VAL] tas: RMSE=33.1929, Time-Mean RMSE=13.4367, Time-Stddev MAE=26.3828
[VAL] pr: RMSE=6.3840, Time-Mean RMSE=2.9815, Time-Stddev MAE=2.8440


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.526 >= min_delta = 0.0. New best score: 2.746
Epoch 10, global step 506: 'val/loss' reached 2.74612 (best 2.74612), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=10-val/loss=2.75.ckpt' as top 1


[VAL] tas: RMSE=30.3918, Time-Mean RMSE=15.8568, Time-Stddev MAE=22.1174
[VAL] pr: RMSE=5.8608, Time-Mean RMSE=3.0451, Time-Stddev MAE=2.4114


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 552: 'val/loss' was not in top 1


[VAL] tas: RMSE=32.3497, Time-Mean RMSE=14.6957, Time-Stddev MAE=24.5223
[VAL] pr: RMSE=5.6914, Time-Mean RMSE=2.1534, Time-Stddev MAE=2.5816


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 598: 'val/loss' was not in top 1


[VAL] tas: RMSE=32.1158, Time-Mean RMSE=16.8699, Time-Stddev MAE=23.1382
[VAL] pr: RMSE=6.1809, Time-Mean RMSE=3.0482, Time-Stddev MAE=2.5092


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.203 >= min_delta = 0.0. New best score: 2.543
Epoch 13, global step 644: 'val/loss' reached 2.54271 (best 2.54271), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=13-val/loss=2.54.ckpt' as top 1


[VAL] tas: RMSE=32.1272, Time-Mean RMSE=15.3960, Time-Stddev MAE=23.8279
[VAL] pr: RMSE=5.3746, Time-Mean RMSE=2.5733, Time-Stddev MAE=2.0959


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.168 >= min_delta = 0.0. New best score: 2.375
Epoch 14, global step 690: 'val/loss' reached 2.37461 (best 2.37461), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=14-val/loss=2.37.ckpt' as top 1


[VAL] tas: RMSE=27.1474, Time-Mean RMSE=10.8584, Time-Stddev MAE=21.2413
[VAL] pr: RMSE=5.6638, Time-Mean RMSE=3.3242, Time-Stddev MAE=1.9112


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 736: 'val/loss' was not in top 1


[VAL] tas: RMSE=33.0648, Time-Mean RMSE=18.6756, Time-Stddev MAE=23.2095
[VAL] pr: RMSE=5.0511, Time-Mean RMSE=2.4356, Time-Stddev MAE=1.9398


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.542 >= min_delta = 0.0. New best score: 1.833
Epoch 16, global step 782: 'val/loss' reached 1.83282 (best 1.83282), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=16-val/loss=1.83.ckpt' as top 1


[VAL] tas: RMSE=26.2800, Time-Mean RMSE=12.7070, Time-Stddev MAE=18.9646
[VAL] pr: RMSE=4.7693, Time-Mean RMSE=2.0581, Time-Stddev MAE=1.9058


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.027 >= min_delta = 0.0. New best score: 1.806
Epoch 17, global step 828: 'val/loss' reached 1.80576 (best 1.80576), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=17-val/loss=1.81.ckpt' as top 1


[VAL] tas: RMSE=24.1528, Time-Mean RMSE=10.6463, Time-Stddev MAE=17.7292
[VAL] pr: RMSE=4.7550, Time-Mean RMSE=1.9804, Time-Stddev MAE=1.8772


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.083 >= min_delta = 0.0. New best score: 1.723
Epoch 18, global step 874: 'val/loss' reached 1.72292 (best 1.72292), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=18-val/loss=1.72.ckpt' as top 1


[VAL] tas: RMSE=25.0657, Time-Mean RMSE=12.7579, Time-Stddev MAE=17.1890
[VAL] pr: RMSE=4.5996, Time-Mean RMSE=1.9392, Time-Stddev MAE=1.8118


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 920: 'val/loss' was not in top 1


[VAL] tas: RMSE=24.0107, Time-Mean RMSE=13.0061, Time-Stddev MAE=16.2700
[VAL] pr: RMSE=4.8181, Time-Mean RMSE=2.6075, Time-Stddev MAE=1.7232


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.209 >= min_delta = 0.0. New best score: 1.514
Epoch 20, global step 966: 'val/loss' reached 1.51357 (best 1.51357), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=20-val/loss=1.51.ckpt' as top 1


[VAL] tas: RMSE=22.9894, Time-Mean RMSE=12.4624, Time-Stddev MAE=15.6717
[VAL] pr: RMSE=4.3443, Time-Mean RMSE=2.0835, Time-Stddev MAE=1.6128


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 1012: 'val/loss' was not in top 1


[VAL] tas: RMSE=33.0565, Time-Mean RMSE=20.0313, Time-Stddev MAE=19.5996
[VAL] pr: RMSE=4.6948, Time-Mean RMSE=2.4165, Time-Stddev MAE=1.7833


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.234 >= min_delta = 0.0. New best score: 1.279
Epoch 22, global step 1058: 'val/loss' reached 1.27941 (best 1.27941), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=22-val/loss=1.28.ckpt' as top 1


[VAL] tas: RMSE=20.9029, Time-Mean RMSE=11.0936, Time-Stddev MAE=13.5406
[VAL] pr: RMSE=4.1901, Time-Mean RMSE=1.8432, Time-Stddev MAE=1.5721


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 1104: 'val/loss' was not in top 1


[VAL] tas: RMSE=22.4695, Time-Mean RMSE=12.0208, Time-Stddev MAE=14.7622
[VAL] pr: RMSE=4.4507, Time-Mean RMSE=2.2024, Time-Stddev MAE=1.6501


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 1150: 'val/loss' was not in top 1


[VAL] tas: RMSE=21.0052, Time-Mean RMSE=10.3462, Time-Stddev MAE=14.3447
[VAL] pr: RMSE=4.3527, Time-Mean RMSE=2.1096, Time-Stddev MAE=1.4988


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 1196: 'val/loss' was not in top 1


[VAL] tas: RMSE=19.3907, Time-Mean RMSE=9.0492, Time-Stddev MAE=12.8699
[VAL] pr: RMSE=4.4462, Time-Mean RMSE=2.2979, Time-Stddev MAE=1.6927


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 1242: 'val/loss' was not in top 1


[VAL] tas: RMSE=20.4186, Time-Mean RMSE=10.7820, Time-Stddev MAE=12.9301
[VAL] pr: RMSE=4.4944, Time-Mean RMSE=2.5007, Time-Stddev MAE=1.5346


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.099 >= min_delta = 0.0. New best score: 1.180
Epoch 27, global step 1288: 'val/loss' reached 1.17992 (best 1.17992), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=27-val/loss=1.18.ckpt' as top 1


[VAL] tas: RMSE=18.6279, Time-Mean RMSE=9.1363, Time-Stddev MAE=12.7494
[VAL] pr: RMSE=4.1363, Time-Mean RMSE=2.0824, Time-Stddev MAE=1.4988


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.011 >= min_delta = 0.0. New best score: 1.168
Epoch 28, global step 1334: 'val/loss' reached 1.16847 (best 1.16847), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=28-val/loss=1.17.ckpt' as top 1


[VAL] tas: RMSE=17.6856, Time-Mean RMSE=7.9444, Time-Stddev MAE=11.9060
[VAL] pr: RMSE=4.2563, Time-Mean RMSE=2.0761, Time-Stddev MAE=1.4392


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.013 >= min_delta = 0.0. New best score: 1.155
Epoch 29, global step 1380: 'val/loss' reached 1.15498 (best 1.15498), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=29-val/loss=1.15.ckpt' as top 1


[VAL] tas: RMSE=18.7086, Time-Mean RMSE=10.0719, Time-Stddev MAE=11.6510
[VAL] pr: RMSE=4.1701, Time-Mean RMSE=2.0551, Time-Stddev MAE=1.4993


`Trainer.fit` stopped: `max_epochs=30` reached.


Training finished.
Starting U-Net model testing using the best checkpoint...
Creating dataset with 120 samples...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=29-val/loss=1.15.ckpt


Dataset created. Input shape: torch.Size([120, 5, 48, 72]), Output shape: torch.Size([120, 2, 48, 72])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/version_1/checkpoints/unet-best-epoch=29-val/loss=1.15.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[TEST] tas: RMSE=285.8640, Time-Mean RMSE=284.3571, Time-Stddev MAE=27.7675
[TEST] pr: RMSE=7.8771, Time-Mean RMSE=4.3515, Time-Stddev MAE=5.8766
Submission saved to: submissions/kaggle_submission_unet_20250517_045746.csv
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test/pr/rmse           7.877082347869873
 test/pr/time_mean_rmse      4.351463317871094
  test/pr/time_std_mae       5.876616477966309
      test/tas/rmse          285.8639831542969
 test/tas/time_mean_rmse     284.3570861816406
  test/tas/time_std_mae     27.767528533935547
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Testing finished.
Test Results: [{'test/tas/rmse': 285.8639831542969, 'test/tas/time_mean

In [13]:
# Cell 9: Plotting Utilities (Optional)
# Ensure matplotlib, numpy, and xarray are imported (usually in Cell 1)

def plot_comparison(true_xr, pred_xr, title, cmap='viridis', diff_cmap='RdBu_r', metric_val=None, metric_name="Metric"):
    """
    Plots a comparison between ground truth, prediction, and their difference.
    Includes calculation and display of a spatial mean metric (e.g., RMSE).
    """
    fig, axs = plt.subplots(1, 3, figsize=(18, 5)) 
    fig.suptitle(title, fontsize=16) 

    common_min = min(true_xr.min().item(), pred_xr.min().item())
    common_max = max(true_xr.max().item(), pred_xr.max().item())

    true_xr.plot(ax=axs[0], cmap=cmap, vmin=common_min, vmax=common_max, add_colorbar=True, cbar_kwargs={'label': true_xr.name or 'Value'})
    axs[0].set_title(f"Ground Truth")

    pred_xr.plot(ax=axs[1], cmap=cmap, vmin=common_min, vmax=common_max, add_colorbar=True, cbar_kwargs={'label': pred_xr.name or 'Value'})
    axs[1].set_title(f"Prediction")

    diff = pred_xr - true_xr
    abs_max_diff = np.max(np.abs(diff.data)) if diff.size > 0 else 0.1 
    
    diff_plot_params = {'cmap': diff_cmap, 'add_colorbar': True, 'cbar_kwargs': {'label': 'Difference'}}
    if abs_max_diff > 0: 
        diff_plot_params['vmin'] = -abs_max_diff
        diff_plot_params['vmax'] = abs_max_diff
        
    diff.plot(ax=axs[2], **diff_plot_params)
    
    title_suffix = ""
    if metric_val is not None:
        title_suffix = f" ({metric_name}: {metric_val:.4f})"
    axs[2].set_title(f"Difference (Pred - Truth){title_suffix}")

    plt.tight_layout(rect=[0, 0, 1, 0.96]) 
    plt.show()


In [None]:
# Cell 10: Visualization Script (Optional)

try:
    val_preds_loaded = np.load("val_preds.npy")
    val_trues_loaded = np.load("val_trues.npy")

    if not hasattr(datamodule, 'lat') or datamodule.lat is None:
        print("Datamodule not fully set up for visualization. Setting it up...")
        # datamodule.prepare_data() # Should have been called by trainer
        datamodule.setup(stage="fit") # Ensure lat, lon, etc. are available

    lat, lon = datamodule.get_coords()
    output_vars = config["data"]["output_vars"] 
    area_weights_vis = datamodule.get_lat_weights() 
    
    time_val_coords = np.arange(val_preds_loaded.shape[0])

    print(f"\n--- Visualizing Validation Predictions for U-Net ---")
    for i, var_name in enumerate(output_vars):
        pred_xr = xr.DataArray(val_preds_loaded[:, i], dims=["time", "y", "x"], 
                               coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)
        true_xr = xr.DataArray(val_trues_loaded[:, i], dims=["time", "y", "x"], 
                               coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)

        pred_mean = pred_xr.mean("time")
        true_mean = true_xr.mean("time")
        mean_rmse_var = np.sqrt(((pred_mean - true_mean) ** 2).weighted(area_weights_vis).mean()).item()
        plot_comparison(true_mean, pred_mean, 
                        f"U-Net: {var_name.upper()} - Validation Time-Mean",
                        metric_val=mean_rmse_var, metric_name="Time-Mean RMSE")

        pred_std = pred_xr.std("time")
        true_std = true_xr.std("time")
        std_mae_var = np.abs(pred_std - true_std).weighted(area_weights_vis).mean().item()
        plot_comparison(true_std, pred_std, 
                        f"U-Net: {var_name.upper()} - Validation Time-StdDev", cmap="plasma",
                        metric_val=std_mae_var, metric_name="Time-StdDev MAE")

        if len(time_val_coords) > 0:
            t_idx_random = np.random.randint(0, len(time_val_coords))
            pred_sample = pred_xr.isel(time=t_idx_random)
            true_sample = true_xr.isel(time=t_idx_random)
            sample_rmse_var = np.sqrt(((pred_sample - true_sample) ** 2).weighted(area_weights_vis).mean()).item()
            plot_comparison(true_sample, pred_sample, 
                            f"U-Net: {var_name.upper()} - Validation Sample (Timestep {t_idx_random})",
                            metric_val=sample_rmse_var, metric_name="RMSE")
        else:
            print(f"No time steps available in validation predictions for {var_name} to plot a random sample.")

except FileNotFoundError:
    print("val_preds.npy or val_trues.npy not found. "
          "Ensure that the training and validation loop (trainer.fit) has been run successfully, "
          "and the on_validation_epoch_end method in ClimateEmulationModule saved these files.")
except AttributeError as e:
    print(f"AttributeError during visualization: {e}. Ensure datamodule is correctly initialized and set up.")
    print("This might happen if 'datamodule' from the training cell is not in scope or wasn't fully set up.")
except Exception as e:
    print(f"An error occurred during visualization: {e}")

