In [2]:
# Cell 1: Imports
# Basic OS, date/time, and numerical libraries
import os
from datetime import datetime
import numpy as np

# Libraries for data handling and scientific computing
import xarray as xr
import dask.array as da # For handling large arrays that don't fit in memory

# PyTorch libraries for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F # For functions like F.pad
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# PyTorch Lightning for streamlining training
import lightning.pytorch as pl

# Plotting library
import matplotlib.pyplot as plt

# Pandas for data manipulation (e.g., creating submission CSV)
import pandas as pd


In [3]:
# Cell 2: Configuration
# Main configuration dictionary for the experiment

# NOTE: You MUST change the 'path' in config['data'] to the correct location 
# of your 'processed_data_cse151b_v2_corrupted_ssp245.zarr' file.
config = {
    "data": {
        "path": "processed_data_cse151b_v2_corrupted_ssp245/processed_data_cse151b_v2_corrupted_ssp245.zarr",
        "input_vars": ["CO2", "SO2", "CH4", "BC", "rsdt"], # Input climate forcing variables
        "output_vars": ["tas", "pr"], # Target variables: surface air temperature and precipitation
        "target_member_id": 0, # Ensemble member to use for target variables
        "train_ssps": ["ssp126", "ssp370", "ssp585"], # SSP scenarios for training
        "test_ssp": "ssp245",  # SSP scenario for testing (held-out)
        "test_months": 360,   # Number of months for the test split (last 10 years)
        "batch_size": 64,     # Batch size for training and evaluation
        "num_workers": 4,     # Number of workers for DataLoader
    },
    "model_unet": { # Configuration specific to U-Net
        "type": "unet",
        "init_features": 64, # Number of features in the first convolutional layer of U-Net
        "bilinear": True,    # Whether to use bilinear upsampling in U-Net's decoder
    },
    "training": {
        "lr": 1e-3, # Learning rate
        # Add other training parameters like weight_decay if needed
    },
    "trainer": {
        "max_epochs": 20,         # Maximum number of training epochs
        "accelerator": "auto",    # Auto-detect accelerator (CPU, GPU, TPU)
        "devices": "auto",        # Auto-detect number of devices
        "precision": 32,          # Training precision (e.g., 16 for mixed-precision)
        "deterministic": True,    # For reproducibility
        "num_sanity_val_steps": 0,# Number of sanity check validation steps before training
        # "logger": True, # Example: WandbLogger(...) or TensorBoardLogger(...)
        # "callbacks": [] # Example: [ModelCheckpoint(...), EarlyStopping(...)]
    },
    "seed": 42, # Seed for reproducibility
}

# Set seed for PyTorch Lightning, PyTorch, NumPy, and Python's random module
pl.seed_everything(config["seed"], workers=True) 

# Suggestion from PyTorch for Tensor Core utilization on compatible GPUs
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7: # Check for Volta architecture or newer
    torch.set_float32_matmul_precision('medium') # or 'high'
    print("Set torch.set_float32_matmul_precision('medium') for Tensor Core utilization.")


Seed set to 42


In [4]:
# Cell 3: Latitude Weights Utility

def get_lat_weights(latitude_values):
    """
    Computes cosine-based area weights for each latitude.
    This accounts for the Earth's curvature, giving more weight to
    grid cells near the equator for global metrics.

    Args:
        latitude_values (np.array): Array of latitude values in degrees.

    Returns:
        np.array: Normalized latitude weights.
    """
    lat_rad = np.deg2rad(latitude_values) # Convert degrees to radians
    weights = np.cos(lat_rad)             # Cosine of latitude
    return weights / np.mean(weights)     # Normalize by the mean weight


In [5]:
# Cell 4: Normalizer Class

class Normalizer:
    """
    Handles Z-score normalization for input and output data.
    (data - mean) / std
    """
    def __init__(self):
        self.mean_in, self.std_in = None, None   # Statistics for input data
        self.mean_out, self.std_out = None, None # Statistics for output data

    def set_input_statistics(self, mean, std):
        """Sets the mean and standard deviation for input data."""
        self.mean_in = mean
        self.std_in = std

    def set_output_statistics(self, mean, std):
        """Sets the mean and standard deviation for output data."""
        self.mean_out = mean
        self.std_out = std

    def normalize(self, data, data_type):
        """
        Normalizes the data using pre-computed statistics.

        Args:
            data (np.array or dask.array): Data to normalize.
            data_type (str): "input" or "output", to use appropriate statistics.

        Returns:
            Normalized data.
        
        Raises:
            ValueError: If statistics for the specified data_type are not set.
        """
        if data_type == "input":
            if self.mean_in is None or self.std_in is None:
                raise ValueError("Input statistics not set in Normalizer.")
            # Add a small epsilon to std to prevent division by zero if std is very small or zero
            return (data - self.mean_in) / (self.std_in + 1e-8) 
        elif data_type == "output":
            if self.mean_out is None or self.std_out is None:
                raise ValueError("Output statistics not set in Normalizer.")
            return (data - self.mean_out) / (self.std_out + 1e-8)
        else:
            raise ValueError(f"Invalid data_type '{data_type}'. Must be 'input' or 'output'.")

    def inverse_transform_output(self, data):
        """
        Applies inverse normalization to output data (predictions).

        Args:
            data (torch.Tensor or np.array): Normalized output data.

        Returns:
            Data in original physical units.

        Raises:
            ValueError: If output statistics are not set.
        """
        if self.mean_out is None or self.std_out is None:
            raise ValueError("Output statistics not set in Normalizer for inverse transform.")
        return data * (self.std_out + 1e-8) + self.mean_out


In [6]:
# Cell 5: U-Net Model Architecture
# (DoubleConv, Down, Up, OutConv, UNet classes)

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels_x1, in_channels_x2, out_channels, bilinear=True):
        super().__init__()
        # in_channels_x1: channels of the feature map from the upsampling path (lower layer in decoder)
        # in_channels_x2: channels of the feature map from the skip connection (encoder)
        # out_channels: channels produced by this Up block

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            # After upsampling x1, its channel count (in_channels_x1) remains the same.
            # It's then concatenated with x2 (in_channels_x2).
            # So, the input to DoubleConv is (in_channels_x1 + in_channels_x2).
            self.conv = DoubleConv(in_channels_x1 + in_channels_x2, out_channels)
        else:
            # ConvTranspose2d halves the channels of x1 (in_channels_x1 -> in_channels_x1 // 2)
            self.up = nn.ConvTranspose2d(in_channels_x1, in_channels_x1 // 2, kernel_size=2, stride=2)
            # Input to DoubleConv is (in_channels_x1 // 2 + in_channels_x2)
            self.conv = DoubleConv(in_channels_x1 // 2 + in_channels_x2, out_channels)


    def forward(self, x1, x2):
        # x1: feature map from upsampling path (e.g., from x_bottleneck or previous Up layer)
        # x2: feature map from skip connection (encoder path, e.g., x4_skip)
        x1 = self.up(x1)
        
        # Pad x1 to match x2's spatial dimensions if they differ
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        x = torch.cat([x2, x1], dim=1) # Concatenate along channel dimension
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_input_channels, n_output_channels, bilinear=True, init_features=64):
        super(UNet, self).__init__()
        self.n_input_channels = n_input_channels
        self.n_output_channels = n_output_channels
        self.bilinear = bilinear
        self.init_features = init_features 
        
        f = init_features # Short alias

        # Encoder
        self.inc = DoubleConv(n_input_channels, f)
        self.down1 = Down(f, f * 2)
        self.down2 = Down(f * 2, f * 4)
        self.down3 = Down(f * 4, f * 8)
        self.down4 = Down(f * 8, f * 16) # Bottleneck input features

        # Decoder
        # Arguments for Up: (channels_from_lower_upsampled_layer, channels_from_skip_connection, output_channels_for_this_Up_block)
        self.up1 = Up(f * 16, f * 8,  f * 8, bilinear)
        self.up2 = Up(f * 8,  f * 4,  f * 4, bilinear)
        self.up3 = Up(f * 4,  f * 2,  f * 2, bilinear)
        self.up4 = Up(f * 2,  f,      f,     bilinear)
        
        self.outc = OutConv(f, n_output_channels)

    def forward(self, x):
        x1_skip = self.inc(x)           # -> f
        x2_skip = self.down1(x1_skip)   # -> f * 2
        x3_skip = self.down2(x2_skip)   # -> f * 4
        x4_skip = self.down3(x3_skip)   # -> f * 8
        x_bottleneck = self.down4(x4_skip) # -> f * 16

        # Decoder
        x = self.up1(x_bottleneck, x4_skip) # x_bottleneck (f*16), x4_skip (f*8). Up outputs f*8
        x = self.up2(x, x3_skip)            # x (f*8), x3_skip (f*4). Up outputs f*4
        x = self.up3(x, x2_skip)            # x (f*4), x2_skip (f*2). Up outputs f*2
        x = self.up4(x, x1_skip)            # x (f*2), x1_skip (f). Up outputs f
        
        logits = self.outc(x)
        return logits


In [7]:
# Cell 6: ClimateDataset and ClimateDataModule

class ClimateDataset(Dataset):
    def __init__(self, inputs_dask, outputs_dask, output_is_normalized=True):
        """
        PyTorch Dataset for climate data.

        Args:
            inputs_dask (dask.array): Dask array of input features.
            outputs_dask (dask.array): Dask array of output targets.
            output_is_normalized (bool): Flag indicating if outputs_dask is already normalized.
                                         Used for the test set where targets are not pre-normalized.
        """
        self.size = inputs_dask.shape[0]
        print(f"Creating dataset with {self.size} samples...")

        inputs_np = inputs_dask.compute()
        outputs_np = outputs_dask.compute()

        self.inputs = torch.from_numpy(inputs_np).float()
        self.outputs = torch.from_numpy(outputs_np).float()

        if torch.isnan(self.inputs).any():
            raise ValueError("NaNs found in input dataset after converting to tensor.")
        if torch.isnan(self.outputs).any():
            raise ValueError("NaNs found in output dataset after converting to tensor.")
        
        print(f"Dataset created. Input shape: {self.inputs.shape}, Output shape: {self.outputs.shape}")


    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.inputs[idx], self.outputs[idx]


class ClimateDataModule(pl.LightningDataModule):
    def __init__(
        self,
        path,
        input_vars,
        output_vars,
        train_ssps,
        test_ssp,
        target_member_id,
        test_months=120,
        batch_size=32,
        num_workers=0,
        seed=42, 
    ):
        super().__init__()
        self.save_hyperparameters() 
        self.normalizer = Normalizer()

    def prepare_data(self):
        if not os.path.exists(self.hparams.path):
            raise FileNotFoundError(f"Data path not found: {self.hparams.path}. Please check config['data']['path'].")

    def setup(self, stage=None):
        ds = xr.open_zarr(self.hparams.path, consolidated=False, chunks={"time": 24})
        
        # --- FIX for spatial_template ---
        # The 'rsdt' variable might not have 'member_id'. Handle this conditionally.
        rsdt_var_for_template = ds["rsdt"]
        if "member_id" in rsdt_var_for_template.dims:
            spatial_template = rsdt_var_for_template.isel(time=0, ssp=0, member_id=0, drop=True)
        else:
            # If 'member_id' is not present, select without it.
            # This assumes 'rsdt' is consistent across members or doesn't have that dimension.
            spatial_template = rsdt_var_for_template.isel(time=0, ssp=0, drop=True)
        # --- END FIX ---

        def load_ssp(ssp_name):
            input_dask_list, output_dask_list = [], []
            
            for var_name in self.hparams.input_vars:
                da_var = ds[var_name].sel(ssp=ssp_name)
                if "latitude" in da_var.dims: 
                    da_var = da_var.rename({"latitude": "y", "longitude": "x"})
                # For input variables, if member_id exists, select the target_member_id.
                # If it doesn't exist (e.g. for some forcing data), this sel will be a no-op if strict=False,
                # or we can check existence. Xarray's sel is usually robust if the dim doesn't exist.
                # However, to be safe, let's check.
                if "member_id" in da_var.dims:
                    da_var = da_var.sel(member_id=self.hparams.target_member_id)
                
                if set(da_var.dims) == {"time"}: 
                    da_var = da_var.broadcast_like(spatial_template).transpose("time", "y", "x")
                input_dask_list.append(da_var.data)
            
            for var_name in self.hparams.output_vars:
                # Output variables are always selected by target_member_id
                da_out = ds[var_name].sel(ssp=ssp_name, member_id=self.hparams.target_member_id)
                if "latitude" in da_out.dims: 
                    da_out = da_out.rename({"latitude": "y", "longitude": "x"})
                output_dask_list.append(da_out.data)

            return da.stack(input_dask_list, axis=1), da.stack(output_dask_list, axis=1)

        train_input_list, train_output_list = [], []
        val_input_ssp370, val_output_ssp370 = None, None

        for ssp in self.hparams.train_ssps:
            x_ssp, y_ssp = load_ssp(ssp)
            if ssp == "ssp370": 
                val_input_ssp370 = x_ssp[-self.hparams.test_months:]
                val_output_ssp370 = y_ssp[-self.hparams.test_months:]
                train_input_list.append(x_ssp[:-self.hparams.test_months])
                train_output_list.append(y_ssp[:-self.hparams.test_months])
            else:
                train_input_list.append(x_ssp)
                train_output_list.append(y_ssp)
        
        train_input_all_ssp = da.concatenate(train_input_list, axis=0)
        train_output_all_ssp = da.concatenate(train_output_list, axis=0)

        input_mean = da.nanmean(train_input_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        input_std = da.nanstd(train_input_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        self.normalizer.set_input_statistics(mean=input_mean, std=input_std)

        output_mean = da.nanmean(train_output_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        output_std = da.nanstd(train_output_all_ssp, axis=(0, 2, 3), keepdims=True).compute()
        self.normalizer.set_output_statistics(mean=output_mean, std=output_std)

        train_input_norm = self.normalizer.normalize(train_input_all_ssp, "input")
        train_output_norm = self.normalizer.normalize(train_output_all_ssp, "output")
        
        val_input_norm = self.normalizer.normalize(val_input_ssp370, "input")
        val_output_norm = self.normalizer.normalize(val_output_ssp370, "output")

        test_input_ssp, test_output_ssp = load_ssp(self.hparams.test_ssp)
        test_input_ssp = test_input_ssp[-self.hparams.test_months:] 
        test_output_ssp = test_output_ssp[-self.hparams.test_months:]
        test_input_norm = self.normalizer.normalize(test_input_ssp, "input")

        if stage == "fit" or stage is None:
            self.train_dataset = ClimateDataset(train_input_norm, train_output_norm)
            self.val_dataset = ClimateDataset(val_input_norm, val_output_norm)
        if stage == "test" or stage is None:
            self.test_dataset = ClimateDataset(test_input_norm, test_output_ssp, output_is_normalized=False)
        
        self.lat = spatial_template.y.values
        self.lon = spatial_template.x.values
        self.area_weights = xr.DataArray(get_lat_weights(self.lat), dims=["y"], coords={"y": self.lat})

        ds.close()

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, shuffle=False,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, shuffle=False,
                          num_workers=self.hparams.num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=self.hparams.num_workers > 0)

    def get_lat_weights(self):
        return self.area_weights

    def get_coords(self):
        return self.lat, self.lon


In [8]:
# Cell 7: ClimateEmulationModule (PyTorch Lightning)

class ClimateEmulationModule(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model 
        self.save_hyperparameters(ignore=['model']) 
        
        self.criterion = nn.MSELoss() 
        self.normalizer = None 
        
        self.val_preds, self.val_targets = [], []
        self.test_preds, self.test_targets = [], []

    def forward(self, x):
        return self.model(x)

    def _get_normalizer_from_datamodule(self):
        """Helper to safely get normalizer from datamodule."""
        if hasattr(self.trainer, 'datamodule') and self.trainer.datamodule is not None and hasattr(self.trainer.datamodule, 'normalizer'):
            return self.trainer.datamodule.normalizer
        else:
            # Fallback if trainer.datamodule is not set up (e.g. direct call to test without fit)
            # This requires 'config' to be globally accessible or passed differently.
            print("Warning: Normalizer not found via self.trainer.datamodule. Attempting fallback initialization.")
            temp_dm = ClimateDataModule(**config["data"]) 
            temp_dm.prepare_data()
            temp_dm.setup(stage="test") # Or appropriate stage to ensure normalizer stats are computed
            return temp_dm.normalizer


    def on_fit_start(self):
        """Called at the beginning of training."""
        self.normalizer = self._get_normalizer_from_datamodule()

    def on_test_start(self):
        """Called at the beginning of testing."""
        if self.normalizer is None: # Ensure normalizer is available
            self.normalizer = self._get_normalizer_from_datamodule()


    def training_step(self, batch, batch_idx):
        x, y_norm = batch 
        y_hat_norm = self(x)   
        loss = self.criterion(y_hat_norm, y_norm)
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_norm = batch
        y_hat_norm = self(x)
        loss = self.criterion(y_hat_norm, y_norm)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()

        y_hat_denorm = self.normalizer.inverse_transform_output(y_hat_norm.detach().cpu().numpy())
        y_denorm = self.normalizer.inverse_transform_output(y_norm.detach().cpu().numpy())
        
        self.val_preds.append(y_hat_denorm)
        self.val_targets.append(y_denorm)
        return loss 

    def on_validation_epoch_end(self):
        if not self.trainer.sanity_checking: # Skip during sanity check
            if not self.val_preds or not self.val_targets: 
                return

            preds_epoch = np.concatenate(self.val_preds, axis=0)
            trues_epoch = np.concatenate(self.val_targets, axis=0)
            
            if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()
            
            self._evaluate(preds_epoch, trues_epoch, phase="val")
            
            np.save("val_preds.npy", preds_epoch)
            np.save("val_trues.npy", trues_epoch)
            
            self.val_preds.clear() 
            self.val_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y_true_denorm = batch 
        y_hat_norm = self(x)    

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()
        
        y_hat_denorm = self.normalizer.inverse_transform_output(y_hat_norm.detach().cpu().numpy())
        
        self.test_preds.append(y_hat_denorm)
        self.test_targets.append(y_true_denorm.detach().cpu().numpy()) 

    def on_test_epoch_end(self):
        if not self.test_preds or not self.test_targets: 
            return

        preds_epoch = np.concatenate(self.test_preds, axis=0)
        trues_epoch = np.concatenate(self.test_targets, axis=0)

        if self.normalizer is None: self.normalizer = self._get_normalizer_from_datamodule()

        self._evaluate(preds_epoch, trues_epoch, phase="test")
        self._save_submission(preds_epoch) 
        
        self.test_preds.clear()
        self.test_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

    def _evaluate(self, preds_np, trues_np, phase="val"):
        """Calculates and logs evaluation metrics."""
        # This check is important for when _evaluate might be called outside trainer.fit/test context
        # or if datamodule is not correctly propagated.
        if self.trainer.datamodule is None or not hasattr(self.trainer.datamodule, 'get_lat_weights'):
            print("Warning: self.trainer.datamodule not fully available in _evaluate. Using fallback for coords/weights.")
            dm_eval = ClimateDataModule(**config["data"]) # Re-init for coords, assumes global config
            dm_eval.prepare_data()
            dm_eval.setup(stage=phase) # Setup for the correct stage
            area_weights = dm_eval.get_lat_weights()
            lat, lon = dm_eval.get_coords()
            output_vars = dm_eval.hparams.output_vars
        else:
            area_weights = self.trainer.datamodule.get_lat_weights()
            lat, lon = self.trainer.datamodule.get_coords()
            output_vars = self.trainer.datamodule.hparams.output_vars


        time_coords = np.arange(preds_np.shape[0])
        metrics_summary = {}

        for i, var_name in enumerate(output_vars):
            p_var = preds_np[:, i] 
            t_var = trues_np[:, i] 
            
            p_xr = xr.DataArray(p_var, dims=["time", "y", "x"], coords={"time": time_coords, "y": lat, "x": lon})
            t_xr = xr.DataArray(t_var, dims=["time", "y", "x"], coords={"time": time_coords, "y": lat, "x": lon})

            rmse = np.sqrt(((p_xr - t_xr) ** 2).weighted(area_weights).mean()).item()
            mean_rmse = np.sqrt(((p_xr.mean("time") - t_xr.mean("time")) ** 2).weighted(area_weights).mean()).item()
            std_mae = np.abs(p_xr.std("time") - t_xr.std("time")).weighted(area_weights).mean().item()

            print(f"[{phase.upper()}] {var_name}: RMSE={rmse:.4f}, Time-Mean RMSE={mean_rmse:.4f}, Time-Stddev MAE={std_mae:.4f}")
            
            metrics_summary[f"{phase}/{var_name}/rmse"] = rmse
            metrics_summary[f"{phase}/{var_name}/time_mean_rmse"] = mean_rmse
            metrics_summary[f"{phase}/{var_name}/time_std_mae"] = std_mae
        
        self.log_dict(metrics_summary, logger=True)

    def _save_submission(self, predictions_np):
        """Saves model predictions to a CSV file in Kaggle submission format."""
        if self.trainer.datamodule is None or not hasattr(self.trainer.datamodule, 'get_coords'):
            print("Warning: self.trainer.datamodule not fully available in _save_submission. Using fallback.")
            dm_submission = ClimateDataModule(**config["data"])
            dm_submission.prepare_data()
            dm_submission.setup(stage="test") # Ensure coords are loaded
            lat, lon = dm_submission.get_coords()
            output_vars = dm_submission.hparams.output_vars
        else:
            lat, lon = self.trainer.datamodule.get_coords()
            output_vars = self.trainer.datamodule.hparams.output_vars
            
        time_coords_submission = np.arange(predictions_np.shape[0])

        rows = []
        for t_idx, t_val in enumerate(time_coords_submission):
            for var_idx, var_name in enumerate(output_vars):
                for y_idx, y_val in enumerate(lat):
                    for x_idx, x_val in enumerate(lon):
                        row_id = f"t{t_idx:03d}_{var_name}_{y_val:.2f}_{x_val:.2f}"
                        pred_value = predictions_np[t_idx, var_idx, y_idx, x_idx]
                        rows.append({"ID": row_id, "Prediction": pred_value})

        submission_df = pd.DataFrame(rows)
        submission_dir = "submissions"
        os.makedirs(submission_dir, exist_ok=True)
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        filepath = os.path.join(submission_dir, f"kaggle_submission_unet_{timestamp}.csv")
        submission_df.to_csv(filepath, index=False)
        print(f"✅ Submission saved to: {filepath}")


In [9]:
# Cell 8: Training and Evaluation Script

# --- Instantiate DataModule ---
datamodule = ClimateDataModule(**config["data"])
# datamodule.prepare_data() # Called by Trainer when .fit() or .test() is called
# datamodule.setup()      # Called by Trainer when .fit() or .test() is called

# --- Instantiate U-Net Model ---
n_inputs = len(config["data"]["input_vars"])
n_outputs = len(config["data"]["output_vars"])

unet_config_params = config.get("model_unet", {}) 
init_features = unet_config_params.get("init_features", 64)
bilinear_upsampling = unet_config_params.get("bilinear", True)

unet_model = UNet(n_input_channels=n_inputs, 
                  n_output_channels=n_outputs, 
                  init_features=init_features,
                  bilinear=bilinear_upsampling)

# --- Instantiate Lightning Module ---
learning_rate = config["training"]["lr"]
lightning_module = ClimateEmulationModule(unet_model, learning_rate=learning_rate)

# --- Setup Trainer ---
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss", 
    mode="min",         
    filename="unet-best-{epoch:02d}-{val/loss:.2f}", 
    save_top_k=1,       
    verbose=True
)

early_stop_callback = EarlyStopping(
    monitor="val/loss",
    patience=5, 
    verbose=True,
    mode="min"
)

trainer_params = {**config["trainer"]} 
trainer_params["callbacks"] = [checkpoint_callback, early_stop_callback]
# Optional: Add logger
# from lightning.pytorch.loggers import TensorBoardLogger
# logger = TensorBoardLogger("tb_logs", name="unet_climate_emulation")
# trainer_params["logger"] = logger

trainer = pl.Trainer(**trainer_params)

# # --- Train the Model ---
# print("Starting U-Net model training...")
# trainer.fit(lightning_module, datamodule=datamodule)
# print("Training finished.")

# # --- Test the Model ---
# print("Starting U-Net model testing using the best checkpoint...")
# # trainer.test will use the checkpoint_callback's best_model_path by default if available
# # or you can specify ckpt_path="best"
# test_results = trainer.test(lightning_module, datamodule=datamodule, ckpt_path="best") 
# print("Testing finished.")
# print("Test Results:", test_results)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
# # Cell 9: Plotting Utilities (Optional)
# # Ensure matplotlib, numpy, and xarray are imported (usually in Cell 1)

# def plot_comparison(true_xr, pred_xr, title, cmap='viridis', diff_cmap='RdBu_r', metric_val=None, metric_name="Metric"):
#     """
#     Plots a comparison between ground truth, prediction, and their difference.
#     Includes calculation and display of a spatial mean metric (e.g., RMSE).
#     """
#     fig, axs = plt.subplots(1, 3, figsize=(18, 5)) 
#     fig.suptitle(title, fontsize=16) 

#     common_min = min(true_xr.min().item(), pred_xr.min().item())
#     common_max = max(true_xr.max().item(), pred_xr.max().item())

#     true_xr.plot(ax=axs[0], cmap=cmap, vmin=common_min, vmax=common_max, add_colorbar=True, cbar_kwargs={'label': true_xr.name or 'Value'})
#     axs[0].set_title(f"Ground Truth")

#     pred_xr.plot(ax=axs[1], cmap=cmap, vmin=common_min, vmax=common_max, add_colorbar=True, cbar_kwargs={'label': pred_xr.name or 'Value'})
#     axs[1].set_title(f"Prediction")

#     diff = pred_xr - true_xr
#     abs_max_diff = np.max(np.abs(diff.data)) if diff.size > 0 else 0.1 
    
#     diff_plot_params = {'cmap': diff_cmap, 'add_colorbar': True, 'cbar_kwargs': {'label': 'Difference'}}
#     if abs_max_diff > 0: 
#         diff_plot_params['vmin'] = -abs_max_diff
#         diff_plot_params['vmax'] = abs_max_diff
        
#     diff.plot(ax=axs[2], **diff_plot_params)
    
#     title_suffix = ""
#     if metric_val is not None:
#         title_suffix = f" ({metric_name}: {metric_val:.4f})"
#     axs[2].set_title(f"Difference (Pred - Truth){title_suffix}")

#     plt.tight_layout(rect=[0, 0, 1, 0.96]) 
#     plt.show()


In [11]:
# # Cell 10: Visualization Script (Optional)

# try:
#     val_preds_loaded = np.load("val_preds.npy")
#     val_trues_loaded = np.load("val_trues.npy")

#     if not hasattr(datamodule, 'lat') or datamodule.lat is None:
#         print("Datamodule not fully set up for visualization. Setting it up...")
#         # datamodule.prepare_data() # Should have been called by trainer
#         datamodule.setup(stage="fit") # Ensure lat, lon, etc. are available

#     lat, lon = datamodule.get_coords()
#     output_vars = config["data"]["output_vars"] 
#     area_weights_vis = datamodule.get_lat_weights() 
    
#     time_val_coords = np.arange(val_preds_loaded.shape[0])

#     print(f"\n--- Visualizing Validation Predictions for U-Net ---")
#     for i, var_name in enumerate(output_vars):
#         pred_xr = xr.DataArray(val_preds_loaded[:, i], dims=["time", "y", "x"], 
#                                coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)
#         true_xr = xr.DataArray(val_trues_loaded[:, i], dims=["time", "y", "x"], 
#                                coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)

#         pred_mean = pred_xr.mean("time")
#         true_mean = true_xr.mean("time")
#         mean_rmse_var = np.sqrt(((pred_mean - true_mean) ** 2).weighted(area_weights_vis).mean()).item()
#         plot_comparison(true_mean, pred_mean, 
#                         f"U-Net: {var_name.upper()} - Validation Time-Mean",
#                         metric_val=mean_rmse_var, metric_name="Time-Mean RMSE")

#         pred_std = pred_xr.std("time")
#         true_std = true_xr.std("time")
#         std_mae_var = np.abs(pred_std - true_std).weighted(area_weights_vis).mean().item()
#         plot_comparison(true_std, pred_std, 
#                         f"U-Net: {var_name.upper()} - Validation Time-StdDev", cmap="plasma",
#                         metric_val=std_mae_var, metric_name="Time-StdDev MAE")

#         if len(time_val_coords) > 0:
#             t_idx_random = np.random.randint(0, len(time_val_coords))
#             pred_sample = pred_xr.isel(time=t_idx_random)
#             true_sample = true_xr.isel(time=t_idx_random)
#             sample_rmse_var = np.sqrt(((pred_sample - true_sample) ** 2).weighted(area_weights_vis).mean()).item()
#             plot_comparison(true_sample, pred_sample, 
#                             f"U-Net: {var_name.upper()} - Validation Sample (Timestep {t_idx_random})",
#                             metric_val=sample_rmse_var, metric_name="RMSE")
#         else:
#             print(f"No time steps available in validation predictions for {var_name} to plot a random sample.")

# except FileNotFoundError:
#     print("val_preds.npy or val_trues.npy not found. "
#           "Ensure that the training and validation loop (trainer.fit) has been run successfully, "
#           "and the on_validation_epoch_end method in ClimateEmulationModule saved these files.")
# except AttributeError as e:
#     print(f"AttributeError during visualization: {e}. Ensure datamodule is correctly initialized and set up.")
#     print("This might happen if 'datamodule' from the training cell is not in scope or wasn't fully set up.")
# except Exception as e:
#     print(f"An error occurred during visualization: {e}")



In [16]:
# Cell 11: U-Net Fine-Tuning Loop

from copy import deepcopy
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, TQDMProgressBar

# hyperparameter_sets = [
#     {"unet_init_features": 32, "lr": 1e-4, "optimizer_type": "AdamW", "scheduler_type": "CosineAnnealingLR", "batch_size": 32, "max_epochs_ft": 5},
#     {"unet_init_features": 64, "lr": 1e-4, "optimizer_type": "AdamW", "scheduler_type": "CosineAnnealingLR", "batch_size": 32, "max_epochs_ft": 5},
#     {"unet_init_features": 32, "lr": 5e-4, "optimizer_type": "Adam",  "scheduler_type": None,              "batch_size": 64, "max_epochs_ft": 5},
#     {"unet_init_features": 64, "lr": 1e-3, "optimizer_type": "AdamW", "scheduler_type": "CosineAnnealingLR", "batch_size": 64, "max_epochs_ft": 8},
# ]

# best: {"unet_init_features": 32, "lr": 1e-4, "optimizer_type": "AdamW", "scheduler_type": "CosineAnnealingLR", "batch_size": 32, "max_epochs_ft": 5}

hyperparameter_sets = [
    # original four (gets 1)
    {"unet_init_features": 32,  "lr": 1e-4, "optimizer_type": "AdamW",   "scheduler_type": "CosineAnnealingLR", "batch_size": 32, "max_epochs_ft": 20},
    {"unet_init_features": 128, "lr": 1e-4, "optimizer_type": "AdamW",   "scheduler_type": "StepLR",             "batch_size": 32, "max_epochs_ft": 20},
    {"unet_init_features": 32,  "lr": 1e-3, "optimizer_type": "Adam",    "scheduler_type": "CosineAnnealingLR", "batch_size": 32, "max_epochs_ft": 20},
    {"unet_init_features": 64,  "lr": 5e-4, "optimizer_type": "Adam",    "scheduler_type": "StepLR",             "batch_size": 16, "max_epochs_ft": 20},
    {"unet_init_features": 128, "lr": 5e-4, "optimizer_type": "AdamW",   "scheduler_type": None,                "batch_size": 64, "max_epochs_ft": 20},
    {"unet_init_features": 32,  "lr": 2e-4, "optimizer_type": "AdamW",   "scheduler_type": "CosineAnnealingLR", "batch_size": 16, "max_epochs_ft": 20},
    {"unet_init_features": 128, "lr": 1e-3, "optimizer_type": "Adam",    "scheduler_type": None,                "batch_size": 32, "max_epochs_ft": 20},

    # four new combos with different optimizers
    {"unet_init_features": 32,  "lr": 1e-2, "optimizer_type": "SGD",     "scheduler_type": "StepLR",             "batch_size": 32, "max_epochs_ft": 20},
    {"unet_init_features": 64,  "lr": 1e-2, "optimizer_type": "SGD",     "scheduler_type": None,                "batch_size": 64, "max_epochs_ft": 20},
    {"unet_init_features": 32,  "lr": 5e-4, "optimizer_type": "RMSprop", "scheduler_type": "CosineAnnealingLR", "batch_size": 16, "max_epochs_ft": 20},
    {"unet_init_features": 64,  "lr": 1e-3, "optimizer_type": "Adagrad", "scheduler_type": "StepLR",             "batch_size": 32, "max_epochs_ft": 20},
]


fine_tuning_results = []

# Original config for data path and other fixed settings
base_config_data_path = config["data"]["path"] 
base_config_trainer_fixed = {k: v for k, v in config["trainer"].items() if k not in ["max_epochs", "callbacks", "logger", "default_root_dir"]}


for i, params in enumerate(hyperparameter_sets):
    print(f"\n--- Fine-Tuning Run {i+1}/{len(hyperparameter_sets)} ---")
    print(f"Parameters: {params}")

    # 1. Create a deep copy of the base config and update it
    current_config = deepcopy(config) # Start with the global config
    current_config["data"]["batch_size"] = params["batch_size"]
    current_config["model_unet"]["init_features"] = params["unet_init_features"]
    # We'll handle lr, optimizer, scheduler directly in LightningModule or Trainer setup
    
    # 2. Re-instantiate DataModule (important if batch_size changes)
    # The datamodule setup (normalization stats) should ideally be based on the full training set once.
    # For fine-tuning, if only batch_size changes, re-instantiating is okay.
    # If other data aspects change (like SSPs), ensure setup() is appropriate.
    datamodule_ft = ClimateDataModule(**current_config["data"])
    # datamodule_ft.prepare_data() # Called by Trainer
    # datamodule_ft.setup(stage="fit") # Called by Trainer

    # 3. Re-instantiate U-Net model
    n_inputs_ft = len(current_config["data"]["input_vars"])
    n_outputs_ft = len(current_config["data"]["output_vars"])
    
    unet_model_ft = UNet(
        n_input_channels=n_inputs_ft, 
        n_output_channels=n_outputs_ft, 
        init_features=current_config["model_unet"]["init_features"],
        bilinear=current_config["model_unet"].get("bilinear", True) # Use bilinear from config or default
    )

    # 4. Re-instantiate LightningModule
    # We will override configure_optimizers for this run
    lightning_module_ft = ClimateEmulationModule(
        unet_model_ft, 
        learning_rate=params["lr"] # Pass the current learning rate
    )

    # Override configure_optimizers for the current fine-tuning run
    def custom_configure_optimizers(self_lm):
        if params["optimizer_type"] == "AdamW":
            optimizer = optim.AdamW(self_lm.parameters(), lr=self_lm.hparams.learning_rate, weight_decay=0.01)
        else: # Default to Adam
            optimizer = optim.Adam(self_lm.parameters(), lr=self_lm.hparams.learning_rate)
        
        if params["scheduler_type"] == "CosineAnnealingLR":
            # T_max could be params["max_epochs_ft"] or total training steps
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params["max_epochs_ft"], eta_min=1e-6)
            return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}}
        return optimizer

    # Bind the custom method to this instance of the LightningModule
    import types
    lightning_module_ft.configure_optimizers = types.MethodType(custom_configure_optimizers, lightning_module_ft)



    # 5. Re-instantiate Trainer
    # Callbacks for this fine-tuning run
    # Using a unique directory for each run's checkpoints/logs can be helpful
    run_specific_dir = f"ft_unet_run_{i+1}"
    
    ft_checkpoint_callback = ModelCheckpoint(
        monitor="val/loss",
        mode="min",
        dirpath=os.path.join("lightning_logs", run_specific_dir, "checkpoints"), # Save checkpoints in run-specific dir
        filename="best-unet-{epoch:02d}-{val/loss:.3f}",
        save_top_k=1,
        verbose=False # Less verbose for multiple runs
    )
    ft_early_stop_callback = EarlyStopping(
        monitor="val/loss",
        patience=3, # Shorter patience for fine-tuning runs
        verbose=False,
        mode="min"
    )

    ft_progress_bar = TQDMProgressBar(refresh_rate=10)

    trainer_ft_config = {
        **base_config_trainer_fixed, # Use fixed parts of trainer config
        "max_epochs": params["max_epochs_ft"],
        "callbacks": [ft_checkpoint_callback, ft_early_stop_callback, ft_progress_bar],
        "logger": pl.loggers.TensorBoardLogger("tb_logs", name=f"unet_ft_run_{i+1}"), # Log each run separately
        "default_root_dir": os.path.join("lightning_logs", run_specific_dir) # Root dir for this run
    }
    trainer_ft = pl.Trainer(**trainer_ft_config)

    # 6. Run training
    print(f"Fitting model for run {i+1} with params: {params}")
    trainer_ft.fit(lightning_module_ft, datamodule=datamodule_ft)
    
    # 7. Evaluate on the validation set using the best checkpoint from this run
    print(f"Evaluating model from run {i+1} on validation set...")
    # The test method here is used for validation set evaluation for fine-tuning
    # Ensure your ClimateEmulationModule's test_step and on_test_epoch_end
    # are suitable for this (e.g., they log "test/..." metrics which you'd interpret as "finetune_val/...")
    # Or, better, use trainer.validate if you only need validation metrics.
    # For simplicity, let's assume trainer.test() is okay and we look at its output.
    # The checkpoint_callback saves the best model based on "val/loss".
    
    val_results = trainer_ft.validate(lightning_module_ft, datamodule=datamodule_ft, ckpt_path="best")
    best_val_loss = val_results[0].get('val/loss', float('inf')) # Get the final validation loss

    # You might also want to run .test() on the actual test set if you want to see
    # how each hyperparameter set performs on the final test data, but typically
    # hyperparameter tuning is done based on validation set performance.
    # For now, we'll store validation results.
    
    current_run_results = {
        "params": params,
        "best_val_loss": best_val_loss,
        "val_metrics": val_results[0] # Store all metrics from validation
        # "test_metrics_on_val_split": test_on_val_results[0] # If you used .test on val_dataloader
    }
    fine_tuning_results.append(current_run_results)
    print(f"Run {i+1} Validation Loss: {best_val_loss:.4f}")


# 8. Analyze fine-tuning results
print("\n--- Fine-Tuning Complete ---")
best_run = None
best_loss = float('inf')

for result in fine_tuning_results:
    print(f"Params: {result['params']}, Best Val Loss: {result['best_val_loss']:.4f}")
    if result['best_val_loss'] < best_loss:
        best_loss = result['best_val_loss']
        best_run = result

if best_run:
    print(f"\nBest performing set of parameters:")
    print(f"Params: {best_run['params']}")
    print(f"Validation Loss: {best_run['best_val_loss']:.4f}")
    print(f"Full Validation Metrics: {best_run['val_metrics']}")
else:
    print("No fine-tuning runs completed or no results to analyze.")

# After identifying the best hyperparameters, you would typically retrain your model
# on the full training data (including the part of ssp370 used for validation here)
# for a larger number of epochs, and then evaluate on the final held-out test set (ssp245).



--- Fine-Tuning Run 1/11 ---
Parameters: {'unet_init_features': 32, 'lr': 0.0001, 'optimizer_type': 'AdamW', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 20}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Fitting model for run 1 with params: {'unet_init_features': 32, 'lr': 0.0001, 'optimizer_type': 'AdamW', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 20}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 7.9 M  | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
7.9 M     Trainable params
0         Non-trainable params
7.9 M     Total params
31.401    Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=6.3136, Time-Mean RMSE=4.8641, Time-Stddev MAE=1.6958
[VAL] pr: RMSE=2.8983, Time-Mean RMSE=1.1522, Time-Stddev MAE=1.6099


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.8153, Time-Mean RMSE=3.2095, Time-Stddev MAE=1.1445
[VAL] pr: RMSE=2.6841, Time-Mean RMSE=0.8863, Time-Stddev MAE=1.3645


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.8066, Time-Mean RMSE=2.7326, Time-Stddev MAE=0.9541
[VAL] pr: RMSE=2.4140, Time-Mean RMSE=0.7449, Time-Stddev MAE=1.1659


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.1148, Time-Mean RMSE=1.9308, Time-Stddev MAE=0.7901
[VAL] pr: RMSE=2.3256, Time-Mean RMSE=0.6328, Time-Stddev MAE=1.0233


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.8506, Time-Mean RMSE=1.7266, Time-Stddev MAE=0.7011
[VAL] pr: RMSE=2.2382, Time-Mean RMSE=0.6019, Time-Stddev MAE=0.8468


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.6291, Time-Mean RMSE=1.6324, Time-Stddev MAE=0.6565
[VAL] pr: RMSE=2.1203, Time-Mean RMSE=0.5392, Time-Stddev MAE=0.8495


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.4801, Time-Mean RMSE=1.5376, Time-Stddev MAE=0.6083
[VAL] pr: RMSE=2.1160, Time-Mean RMSE=0.5443, Time-Stddev MAE=0.9160


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.5046, Time-Mean RMSE=1.5660, Time-Stddev MAE=0.5711
[VAL] pr: RMSE=2.1241, Time-Mean RMSE=0.5000, Time-Stddev MAE=0.8379


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.4789, Time-Mean RMSE=1.3203, Time-Stddev MAE=0.6064
[VAL] pr: RMSE=2.1529, Time-Mean RMSE=0.4955, Time-Stddev MAE=0.8584


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2729, Time-Mean RMSE=1.3327, Time-Stddev MAE=0.5677
[VAL] pr: RMSE=2.0773, Time-Mean RMSE=0.4738, Time-Stddev MAE=0.8068


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2438, Time-Mean RMSE=1.3015, Time-Stddev MAE=0.5563
[VAL] pr: RMSE=2.0671, Time-Mean RMSE=0.4502, Time-Stddev MAE=0.8474


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2877, Time-Mean RMSE=1.2881, Time-Stddev MAE=0.5278
[VAL] pr: RMSE=2.0856, Time-Mean RMSE=0.4728, Time-Stddev MAE=0.8490


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2204, Time-Mean RMSE=1.2438, Time-Stddev MAE=0.5436
[VAL] pr: RMSE=2.0807, Time-Mean RMSE=0.4590, Time-Stddev MAE=0.8541


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1734, Time-Mean RMSE=1.2142, Time-Stddev MAE=0.5333
[VAL] pr: RMSE=2.0558, Time-Mean RMSE=0.4388, Time-Stddev MAE=0.8303


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1551, Time-Mean RMSE=1.1987, Time-Stddev MAE=0.5101
[VAL] pr: RMSE=2.0582, Time-Mean RMSE=0.4420, Time-Stddev MAE=0.8277


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1361, Time-Mean RMSE=1.1745, Time-Stddev MAE=0.5173
[VAL] pr: RMSE=2.0502, Time-Mean RMSE=0.4171, Time-Stddev MAE=0.8274


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1116, Time-Mean RMSE=1.1567, Time-Stddev MAE=0.5073
[VAL] pr: RMSE=2.0466, Time-Mean RMSE=0.4230, Time-Stddev MAE=0.8029


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1185, Time-Mean RMSE=1.1686, Time-Stddev MAE=0.4939
[VAL] pr: RMSE=2.0478, Time-Mean RMSE=0.4211, Time-Stddev MAE=0.8258


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0737, Time-Mean RMSE=1.1592, Time-Stddev MAE=0.4876
[VAL] pr: RMSE=2.0324, Time-Mean RMSE=0.4089, Time-Stddev MAE=0.8104


Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


[VAL] tas: RMSE=2.1229, Time-Mean RMSE=1.1734, Time-Stddev MAE=0.5044
[VAL] pr: RMSE=2.0435, Time-Mean RMSE=0.4139, Time-Stddev MAE=0.8151
Evaluating model from run 1 on validation set...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_1/checkpoints/best-unet-epoch=18-val/loss=0.170.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_1/checkpoints/best-unet-epoch=18-val/loss=0.170.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0737, Time-Mean RMSE=1.1592, Time-Stddev MAE=0.4876
[VAL] pr: RMSE=2.0324, Time-Mean RMSE=0.4089, Time-Stddev MAE=0.8104
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val/loss            0.17025740444660187
       val/pr/rmse          2.0323846340179443
  val/pr/time_mean_rmse     0.4089001715183258
   val/pr/time_std_mae      0.8104011416435242
      val/tas/rmse          2.0737040042877197
 val/tas/time_mean_rmse     1.1592297554016113
  val/tas/time_std_mae      0.4876391291618347
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Run 1 Validation Loss: 0.1703

--- Fine-Tuning Run 2/11 ---
Parameters: {'unet_init_features': 128, 'lr': 0.0001, 'optimizer

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Fitting model for run 2 with params: {'unet_init_features': 128, 'lr': 0.0001, 'optimizer_type': 'AdamW', 'scheduler_type': 'StepLR', 'batch_size': 32, 'max_epochs_ft': 20}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])



  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 125 M  | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
125 M     Trainable params
0         Non-trainable params
125 M     Total params
502.059   Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=5.2094, Time-Mean RMSE=3.7463, Time-Stddev MAE=1.2224
[VAL] pr: RMSE=2.7432, Time-Mean RMSE=0.8626, Time-Stddev MAE=1.5619


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.2537, Time-Mean RMSE=1.9809, Time-Stddev MAE=0.9290
[VAL] pr: RMSE=2.3885, Time-Mean RMSE=0.7687, Time-Stddev MAE=1.1608


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.9165, Time-Mean RMSE=1.8312, Time-Stddev MAE=0.8352
[VAL] pr: RMSE=2.1631, Time-Mean RMSE=0.6031, Time-Stddev MAE=0.9031


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2012, Time-Mean RMSE=1.3924, Time-Stddev MAE=0.5370
[VAL] pr: RMSE=2.0777, Time-Mean RMSE=0.5831, Time-Stddev MAE=0.8052


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0698, Time-Mean RMSE=1.2886, Time-Stddev MAE=0.5041
[VAL] pr: RMSE=2.0480, Time-Mean RMSE=0.5123, Time-Stddev MAE=0.8150


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2661, Time-Mean RMSE=1.5868, Time-Stddev MAE=0.4636
[VAL] pr: RMSE=2.0397, Time-Mean RMSE=0.4809, Time-Stddev MAE=0.8414


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2115, Time-Mean RMSE=1.5222, Time-Stddev MAE=0.5168
[VAL] pr: RMSE=2.0741, Time-Mean RMSE=0.5343, Time-Stddev MAE=0.8088


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0471, Time-Mean RMSE=1.3239, Time-Stddev MAE=0.5105
[VAL] pr: RMSE=2.0313, Time-Mean RMSE=0.4758, Time-Stddev MAE=0.8267


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7457, Time-Mean RMSE=0.9119, Time-Stddev MAE=0.3843
[VAL] pr: RMSE=1.9907, Time-Mean RMSE=0.3733, Time-Stddev MAE=0.7521


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8037, Time-Mean RMSE=1.0523, Time-Stddev MAE=0.4274
[VAL] pr: RMSE=2.0016, Time-Mean RMSE=0.4186, Time-Stddev MAE=0.7758


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8387, Time-Mean RMSE=1.0859, Time-Stddev MAE=0.4428
[VAL] pr: RMSE=1.9865, Time-Mean RMSE=0.3833, Time-Stddev MAE=0.7358


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6943, Time-Mean RMSE=0.8980, Time-Stddev MAE=0.3924
[VAL] pr: RMSE=2.0007, Time-Mean RMSE=0.4369, Time-Stddev MAE=0.7555
Evaluating model from run 2 on validation set...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_2/checkpoints/best-unet-epoch=08-val/loss=0.161.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_2/checkpoints/best-unet-epoch=08-val/loss=0.161.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7457, Time-Mean RMSE=0.9119, Time-Stddev MAE=0.3843
[VAL] pr: RMSE=1.9907, Time-Mean RMSE=0.3733, Time-Stddev MAE=0.7521
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val/loss            0.16142010688781738
       val/pr/rmse           1.990682601928711
  val/pr/time_mean_rmse     0.3733293414115906
   val/pr/time_std_mae      0.7521287798881531
      val/tas/rmse          1.7456622123718262
 val/tas/time_mean_rmse      0.911907434463501
  val/tas/time_std_mae      0.38432568311691284
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Run 2 Validation Loss: 0.1614

--- Fine-Tuning Run 3/11 ---
Parameters: {'unet_init_features': 32, 'lr': 0.001, 'optimizer_type': 'Adam', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 20}
Fitting model for run 3 with params: {'unet_init_features': 32, 'lr': 0.001, 'optimizer_type': 'Adam', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 20}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 7.9 M  | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
7.9 M     Trainable params
0         Non-trainable params
7.9 M     Total params
31.401    Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.4118, Time-Mean RMSE=2.7856, Time-Stddev MAE=1.4107
[VAL] pr: RMSE=2.7762, Time-Mean RMSE=0.9568, Time-Stddev MAE=1.5287


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.8275, Time-Mean RMSE=2.8249, Time-Stddev MAE=1.1180
[VAL] pr: RMSE=2.5182, Time-Mean RMSE=1.0303, Time-Stddev MAE=1.0598


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.6856, Time-Mean RMSE=1.8764, Time-Stddev MAE=0.7557
[VAL] pr: RMSE=2.2337, Time-Mean RMSE=0.7645, Time-Stddev MAE=0.8388


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.3852, Time-Mean RMSE=1.6536, Time-Stddev MAE=0.6358
[VAL] pr: RMSE=2.0996, Time-Mean RMSE=0.5976, Time-Stddev MAE=0.8060


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.4516, Time-Mean RMSE=1.6640, Time-Stddev MAE=0.6460
[VAL] pr: RMSE=2.1259, Time-Mean RMSE=0.6085, Time-Stddev MAE=0.8332


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1873, Time-Mean RMSE=1.4732, Time-Stddev MAE=0.5840
[VAL] pr: RMSE=2.0542, Time-Mean RMSE=0.4995, Time-Stddev MAE=0.7998


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1139, Time-Mean RMSE=1.4122, Time-Stddev MAE=0.5294
[VAL] pr: RMSE=2.0751, Time-Mean RMSE=0.5808, Time-Stddev MAE=0.8006


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0692, Time-Mean RMSE=1.3672, Time-Stddev MAE=0.5404
[VAL] pr: RMSE=2.0413, Time-Mean RMSE=0.5132, Time-Stddev MAE=0.7840


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9657, Time-Mean RMSE=1.2290, Time-Stddev MAE=0.4867
[VAL] pr: RMSE=2.0211, Time-Mean RMSE=0.4669, Time-Stddev MAE=0.7402


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9109, Time-Mean RMSE=1.1875, Time-Stddev MAE=0.4600
[VAL] pr: RMSE=2.0094, Time-Mean RMSE=0.4316, Time-Stddev MAE=0.7726


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9786, Time-Mean RMSE=1.2560, Time-Stddev MAE=0.4993
[VAL] pr: RMSE=2.0314, Time-Mean RMSE=0.4880, Time-Stddev MAE=0.7706


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8528, Time-Mean RMSE=1.1237, Time-Stddev MAE=0.4485
[VAL] pr: RMSE=1.9957, Time-Mean RMSE=0.3896, Time-Stddev MAE=0.7308


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8240, Time-Mean RMSE=1.1053, Time-Stddev MAE=0.4238
[VAL] pr: RMSE=1.9812, Time-Mean RMSE=0.3663, Time-Stddev MAE=0.7559


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8249, Time-Mean RMSE=1.1094, Time-Stddev MAE=0.3973
[VAL] pr: RMSE=1.9959, Time-Mean RMSE=0.4091, Time-Stddev MAE=0.7439


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7819, Time-Mean RMSE=1.0602, Time-Stddev MAE=0.3890
[VAL] pr: RMSE=1.9752, Time-Mean RMSE=0.3428, Time-Stddev MAE=0.7472


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7793, Time-Mean RMSE=1.0531, Time-Stddev MAE=0.3960
[VAL] pr: RMSE=1.9828, Time-Mean RMSE=0.3748, Time-Stddev MAE=0.7603


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7193, Time-Mean RMSE=0.9899, Time-Stddev MAE=0.3851
[VAL] pr: RMSE=1.9698, Time-Mean RMSE=0.3470, Time-Stddev MAE=0.7629


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7226, Time-Mean RMSE=0.9821, Time-Stddev MAE=0.3720
[VAL] pr: RMSE=1.9693, Time-Mean RMSE=0.3399, Time-Stddev MAE=0.7453


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7206, Time-Mean RMSE=0.9953, Time-Stddev MAE=0.3673
[VAL] pr: RMSE=1.9638, Time-Mean RMSE=0.3261, Time-Stddev MAE=0.7559


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7206, Time-Mean RMSE=0.9930, Time-Stddev MAE=0.3697
[VAL] pr: RMSE=1.9630, Time-Mean RMSE=0.3227, Time-Stddev MAE=0.7485


`Trainer.fit` stopped: `max_epochs=20` reached.


Evaluating model from run 3 on validation set...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_3/checkpoints/best-unet-epoch=19-val/loss=0.157.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_3/checkpoints/best-unet-epoch=19-val/loss=0.157.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7206, Time-Mean RMSE=0.9930, Time-Stddev MAE=0.3697
[VAL] pr: RMSE=1.9630, Time-Mean RMSE=0.3227, Time-Stddev MAE=0.7485
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val/loss            0.15701662003993988
       val/pr/rmse          1.9630497694015503
  val/pr/time_mean_rmse     0.32268384099006653
   val/pr/time_std_mae       0.748530387878418
      val/tas/rmse          1.7206227779388428
 val/tas/time_mean_rmse     0.9930409789085388
  val/tas/time_std_mae      0.3696863055229187
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Run 3 Validation Loss: 0.1570

--- Fine-Tuning Run 4/11 ---
Parameters: {'unet_init_features': 64, 'lr': 0.0005, 'optimizer

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Fitting model for run 4 with params: {'unet_init_features': 64, 'lr': 0.0005, 'optimizer_type': 'Adam', 'scheduler_type': 'StepLR', 'batch_size': 16, 'max_epochs_ft': 20}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 31.4 M | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
31.4 M    Trainable params
0         Non-trainable params
31.4 M    Total params
125.544   Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.0178, Time-Mean RMSE=2.4745, Time-Stddev MAE=1.1351
[VAL] pr: RMSE=2.7169, Time-Mean RMSE=0.8944, Time-Stddev MAE=1.5494


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=3.7058, Time-Mean RMSE=2.2512, Time-Stddev MAE=1.5496
[VAL] pr: RMSE=2.4221, Time-Mean RMSE=0.9656, Time-Stddev MAE=1.1389


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2521, Time-Mean RMSE=1.4402, Time-Stddev MAE=0.5554
[VAL] pr: RMSE=2.1207, Time-Mean RMSE=0.5698, Time-Stddev MAE=0.8596


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.1615, Time-Mean RMSE=1.3547, Time-Stddev MAE=0.5437
[VAL] pr: RMSE=2.1026, Time-Mean RMSE=0.6106, Time-Stddev MAE=0.8403


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.0050, Time-Mean RMSE=1.2896, Time-Stddev MAE=0.4955
[VAL] pr: RMSE=2.0421, Time-Mean RMSE=0.4866, Time-Stddev MAE=0.8378


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.9077, Time-Mean RMSE=1.1355, Time-Stddev MAE=0.4668
[VAL] pr: RMSE=2.0328, Time-Mean RMSE=0.4873, Time-Stddev MAE=0.8052


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=2.2074, Time-Mean RMSE=1.4563, Time-Stddev MAE=0.6229
[VAL] pr: RMSE=2.0871, Time-Mean RMSE=0.5825, Time-Stddev MAE=0.8820


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7871, Time-Mean RMSE=1.0294, Time-Stddev MAE=0.3932
[VAL] pr: RMSE=1.9972, Time-Mean RMSE=0.4105, Time-Stddev MAE=0.7811


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8052, Time-Mean RMSE=1.0993, Time-Stddev MAE=0.4175
[VAL] pr: RMSE=2.0421, Time-Mean RMSE=0.5560, Time-Stddev MAE=0.8047


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.7744, Time-Mean RMSE=1.0581, Time-Stddev MAE=0.4013
[VAL] pr: RMSE=2.0562, Time-Mean RMSE=0.5648, Time-Stddev MAE=0.8048


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.8645, Time-Mean RMSE=1.1786, Time-Stddev MAE=0.3935
[VAL] pr: RMSE=1.9866, Time-Mean RMSE=0.3879, Time-Stddev MAE=0.7411


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6912, Time-Mean RMSE=0.9216, Time-Stddev MAE=0.3657
[VAL] pr: RMSE=1.9771, Time-Mean RMSE=0.3519, Time-Stddev MAE=0.7823


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6686, Time-Mean RMSE=0.9317, Time-Stddev MAE=0.3799
[VAL] pr: RMSE=1.9949, Time-Mean RMSE=0.4091, Time-Stddev MAE=0.8194


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6860, Time-Mean RMSE=0.9505, Time-Stddev MAE=0.3908
[VAL] pr: RMSE=2.0142, Time-Mean RMSE=0.4759, Time-Stddev MAE=0.7501


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6520, Time-Mean RMSE=0.9145, Time-Stddev MAE=0.3565
[VAL] pr: RMSE=1.9801, Time-Mean RMSE=0.3687, Time-Stddev MAE=0.7549
Evaluating model from run 4 on validation set...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_4/checkpoints/best-unet-epoch=11-val/loss=0.159.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/ft_unet_run_4/checkpoints/best-unet-epoch=11-val/loss=0.159.ckpt


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=1.6912, Time-Mean RMSE=0.9216, Time-Stddev MAE=0.3657
[VAL] pr: RMSE=1.9771, Time-Mean RMSE=0.3519, Time-Stddev MAE=0.7823
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        val/loss            0.15948902070522308
       val/pr/rmse          1.9770545959472656
  val/pr/time_mean_rmse     0.35193729400634766
   val/pr/time_std_mae       0.782333254814148
      val/tas/rmse          1.6911522150039673
 val/tas/time_mean_rmse     0.9215661287307739
  val/tas/time_std_mae      0.3656538128852844
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Run 4 Validation Loss: 0.1595

--- Fine-Tuning Run 5/11 ---
Parameters: {'unet_init_features': 128, 'lr': 0.0005, 'optimize

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Fitting model for run 5 with params: {'unet_init_features': 128, 'lr': 0.0005, 'optimizer_type': 'AdamW', 'scheduler_type': None, 'batch_size': 64, 'max_epochs_ft': 20}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 125 M  | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
125 M     Trainable params
0         Non-trainable params
125 M     Total params
502.059   Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


/home/ruz039/.local/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (43) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.7638, Time-Mean RMSE=3.1174, Time-Stddev MAE=1.5548
[VAL] pr: RMSE=3.0080, Time-Mean RMSE=1.3159, Time-Stddev MAE=1.7295


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.4820, Time-Mean RMSE=2.9405, Time-Stddev MAE=1.2795
[VAL] pr: RMSE=3.0581, Time-Mean RMSE=1.5464, Time-Stddev MAE=1.6282


Validation: |          | 0/? [00:00<?, ?it/s]

[VAL] tas: RMSE=4.1962, Time-Mean RMSE=2.4267, Time-Stddev MAE=1.0823
[VAL] pr: RMSE=2.8195, Time-Mean RMSE=1.0269, Time-Stddev MAE=1.5137


OSError: [Errno 122] Disk quota exceeded

best_params_unet: {'unet_init_features': 64,
 'lr': 0.0001,
 'optimizer_type': 'AdamW',
 'scheduler_type': 'CosineAnnealingLR',
 'batch_size': 32,
 'max_epochs_ft': 5}

{'unet_init_features': 64,
 'lr': 0.0001,
 'optimizer_type': 'AdamW',
 'scheduler_type': 'CosineAnnealingLR',
 'batch_size': 32,
 'max_epochs_ft': 5}


--- Starting U-Net Final Training with Best Hyperparameters ---
Using best U-Net parameters: {'unet_init_features': 64, 'lr': 0.0001, 'optimizer_type': 'AdamW', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 5}


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Fitting final U-Net model with best params: {'unet_init_features': 64, 'lr': 0.0001, 'optimizer_type': 'AdamW', 'scheduler_type': 'CosineAnnealingLR', 'batch_size': 32, 'max_epochs_ft': 5}
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params | Mode 
----------------------------------------------
0 | model     | UNet    | 31.4 M | train
1 | criterion | MSELoss | 0      | train
----------------------------------------------
31.4 M    Trainable params
0         Non-trainable params
31.4 M    Total params
125.544   Total estimated model params size (MB)
96        Modules in train mode
0         Modules in eval mode


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved. New best score: 0.354
Epoch 0, global step 85: 'val/loss' reached 0.35412 (best 0.35412), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=00-val/loss=0.354.ckpt' as top 1


[VAL] tas: RMSE=4.6756, Time-Mean RMSE=2.9997, Time-Stddev MAE=1.4171
[VAL] pr: RMSE=2.8718, Time-Mean RMSE=1.1280, Time-Stddev MAE=1.6364


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.104 >= min_delta = 0.0. New best score: 0.250
Epoch 1, global step 170: 'val/loss' reached 0.25040 (best 0.25040), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=01-val/loss=0.250.ckpt' as top 1


[VAL] tas: RMSE=3.3871, Time-Mean RMSE=2.0922, Time-Stddev MAE=0.8684
[VAL] pr: RMSE=2.4418, Time-Mean RMSE=0.7451, Time-Stddev MAE=1.1813


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.032 >= min_delta = 0.0. New best score: 0.218
Epoch 2, global step 255: 'val/loss' reached 0.21838 (best 0.21838), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=02-val/loss=0.218.ckpt' as top 1


[VAL] tas: RMSE=3.1998, Time-Mean RMSE=2.1231, Time-Stddev MAE=0.8184
[VAL] pr: RMSE=2.2566, Time-Mean RMSE=0.6638, Time-Stddev MAE=0.9735


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.039 >= min_delta = 0.0. New best score: 0.179
Epoch 3, global step 340: 'val/loss' reached 0.17946 (best 0.17946), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=03-val/loss=0.179.ckpt' as top 1


[VAL] tas: RMSE=2.3045, Time-Mean RMSE=1.5245, Time-Stddev MAE=0.5885
[VAL] pr: RMSE=2.0808, Time-Mean RMSE=0.5461, Time-Stddev MAE=0.8473


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.004 >= min_delta = 0.0. New best score: 0.175
Epoch 4, global step 425: 'val/loss' reached 0.17514 (best 0.17514), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=04-val/loss=0.175.ckpt' as top 1


[VAL] tas: RMSE=2.2324, Time-Mean RMSE=1.4714, Time-Stddev MAE=0.5525
[VAL] pr: RMSE=2.0589, Time-Mean RMSE=0.5192, Time-Stddev MAE=0.8460


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.175
Epoch 5, global step 510: 'val/loss' reached 0.17507 (best 0.17507), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=05-val/loss=0.175.ckpt' as top 1


[VAL] tas: RMSE=2.1329, Time-Mean RMSE=1.3450, Time-Stddev MAE=0.5243
[VAL] pr: RMSE=2.0633, Time-Mean RMSE=0.5476, Time-Stddev MAE=0.8276


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.003 >= min_delta = 0.0. New best score: 0.172
Epoch 6, global step 595: 'val/loss' reached 0.17226 (best 0.17226), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=06-val/loss=0.172.ckpt' as top 1


[VAL] tas: RMSE=2.1397, Time-Mean RMSE=1.3708, Time-Stddev MAE=0.5376
[VAL] pr: RMSE=2.0468, Time-Mean RMSE=0.4955, Time-Stddev MAE=0.8550


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.005 >= min_delta = 0.0. New best score: 0.167
Epoch 7, global step 680: 'val/loss' reached 0.16720 (best 0.16720), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=07-val/loss=0.167.ckpt' as top 1


[VAL] tas: RMSE=1.9983, Time-Mean RMSE=1.2085, Time-Stddev MAE=0.4942
[VAL] pr: RMSE=2.0186, Time-Mean RMSE=0.4482, Time-Stddev MAE=0.8197


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.001 >= min_delta = 0.0. New best score: 0.166
Epoch 8, global step 765: 'val/loss' reached 0.16615 (best 0.16615), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=08-val/loss=0.166.ckpt' as top 1


[VAL] tas: RMSE=1.9510, Time-Mean RMSE=1.1619, Time-Stddev MAE=0.4505
[VAL] pr: RMSE=2.0146, Time-Mean RMSE=0.4161, Time-Stddev MAE=0.8049


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.002 >= min_delta = 0.0. New best score: 0.164
Epoch 9, global step 850: 'val/loss' reached 0.16437 (best 0.16437), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=09-val/loss=0.164.ckpt' as top 1


[VAL] tas: RMSE=1.9372, Time-Mean RMSE=1.1764, Time-Stddev MAE=0.4260
[VAL] pr: RMSE=2.0040, Time-Mean RMSE=0.4110, Time-Stddev MAE=0.8258


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 935: 'val/loss' was not in top 1


[VAL] tas: RMSE=1.8836, Time-Mean RMSE=1.0863, Time-Stddev MAE=0.4396
[VAL] pr: RMSE=2.0131, Time-Mean RMSE=0.4628, Time-Stddev MAE=0.8137


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.002 >= min_delta = 0.0. New best score: 0.163
Epoch 11, global step 1020: 'val/loss' reached 0.16282 (best 0.16282), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=11-val/loss=0.163.ckpt' as top 1


[VAL] tas: RMSE=1.8246, Time-Mean RMSE=1.0323, Time-Stddev MAE=0.4154
[VAL] pr: RMSE=1.9970, Time-Mean RMSE=0.4027, Time-Stddev MAE=0.7839


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.001 >= min_delta = 0.0. New best score: 0.162
Epoch 12, global step 1105: 'val/loss' reached 0.16156 (best 0.16156), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=12-val/loss=0.162.ckpt' as top 1


[VAL] tas: RMSE=1.8301, Time-Mean RMSE=1.0540, Time-Stddev MAE=0.4153
[VAL] pr: RMSE=1.9882, Time-Mean RMSE=0.3885, Time-Stddev MAE=0.7937


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.161
Epoch 13, global step 1190: 'val/loss' reached 0.16123 (best 0.16123), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=13-val/loss=0.161.ckpt' as top 1


[VAL] tas: RMSE=1.7929, Time-Mean RMSE=1.0121, Time-Stddev MAE=0.4022
[VAL] pr: RMSE=1.9881, Time-Mean RMSE=0.3910, Time-Stddev MAE=0.7936


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.161
Epoch 14, global step 1275: 'val/loss' reached 0.16096 (best 0.16096), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=14-val/loss=0.161.ckpt' as top 1


[VAL] tas: RMSE=1.7799, Time-Mean RMSE=0.9767, Time-Stddev MAE=0.4033
[VAL] pr: RMSE=1.9862, Time-Mean RMSE=0.3856, Time-Stddev MAE=0.8032


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.001 >= min_delta = 0.0. New best score: 0.160
Epoch 15, global step 1360: 'val/loss' reached 0.16038 (best 0.16038), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=15-val/loss=0.160.ckpt' as top 1


[VAL] tas: RMSE=1.7839, Time-Mean RMSE=0.9797, Time-Stddev MAE=0.4022
[VAL] pr: RMSE=1.9827, Time-Mean RMSE=0.3641, Time-Stddev MAE=0.8098


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.001 >= min_delta = 0.0. New best score: 0.159
Epoch 16, global step 1445: 'val/loss' reached 0.15911 (best 0.15911), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=16-val/loss=0.159.ckpt' as top 1


[VAL] tas: RMSE=1.7470, Time-Mean RMSE=0.9589, Time-Stddev MAE=0.3870
[VAL] pr: RMSE=1.9752, Time-Mean RMSE=0.3516, Time-Stddev MAE=0.7673


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 1530: 'val/loss' was not in top 1


[VAL] tas: RMSE=1.7566, Time-Mean RMSE=0.9641, Time-Stddev MAE=0.3989
[VAL] pr: RMSE=1.9796, Time-Mean RMSE=0.3682, Time-Stddev MAE=0.7846


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.159
Epoch 18, global step 1615: 'val/loss' reached 0.15865 (best 0.15865), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=18-val/loss=0.159.ckpt' as top 1


[VAL] tas: RMSE=1.7433, Time-Mean RMSE=0.9430, Time-Stddev MAE=0.3912
[VAL] pr: RMSE=1.9724, Time-Mean RMSE=0.3459, Time-Stddev MAE=0.7824


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val/loss improved by 0.000 >= min_delta = 0.0. New best score: 0.158
Epoch 19, global step 1700: 'val/loss' reached 0.15841 (best 0.15841), saving model to '/home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=19-val/loss=0.158.ckpt' as top 1


[VAL] tas: RMSE=1.7502, Time-Mean RMSE=0.9523, Time-Stddev MAE=0.3925
[VAL] pr: RMSE=1.9702, Time-Mean RMSE=0.3376, Time-Stddev MAE=0.7781


`Trainer.fit` stopped: `max_epochs=20` reached.


Testing final U-Net model on test set (ssp245)...
Creating dataset with 360 samples...


Restoring states from the checkpoint path at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=19-val/loss=0.158.ckpt


Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ruz039/private/cse151b/DL_for_Climate_Emulation/lightning_logs/final_unet_run/checkpoints/final-best-unet-epoch=19-val/loss=0.158.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[TEST] tas: RMSE=290.5800, Time-Mean RMSE=290.5367, Time-Stddev MAE=3.7093
[TEST] pr: RMSE=4.3301, Time-Mean RMSE=3.8674, Time-Stddev MAE=1.3886
✅ Submission saved to: submissions/kaggle_submission_unet_20250522_193549.csv
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test/pr/rmse           4.330072402954102
 test/pr/time_mean_rmse     3.8674283027648926
  test/pr/time_std_mae      1.3885520696640015
      test/tas/rmse          290.5799865722656
 test/tas/time_mean_rmse     290.5366516113281
  test/tas/time_std_mae     3.7092957496643066
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Final U-Net Test Results (on ssp245): [{'test/tas/rmse': 290.5799865722656, 'test/tas/ti

In [17]:
# Cell 13: U-Net - Final Visualization
# This cell loads the validation predictions saved during the FINAL U-Net training run and plots them.

print(f"\n--- Visualizing Validation Predictions for Final U-Net ---")
try:
    # Load the saved predictions and true values from the final U-Net run
    val_preds_loaded_final_unet = np.load("val_preds.npy") 
    val_trues_loaded_final_unet = np.load("val_trues.npy")

    # Ensure datamodule_final_unet is available from the training cell (Cell 18)
    if 'datamodule_final_unet' not in globals() or not hasattr(datamodule_final_unet, 'lat') or datamodule_final_unet.lat is None:
        print("datamodule_final_unet not fully set up for visualization. Setting it up...")
        datamodule_final_unet_viz = ClimateDataModule(**config["data"]) # Use global config for safety
        datamodule_final_unet_viz.setup(stage="fit") 
        lat, lon = datamodule_final_unet_viz.get_coords()
        output_vars = config["data"]["output_vars"] 
        area_weights_vis = datamodule_final_unet_viz.get_lat_weights()
    else:
        lat, lon = datamodule_final_unet.get_coords()
        output_vars = config["data"]["output_vars"] 
        area_weights_vis = datamodule_final_unet.get_lat_weights() 
    
    time_val_coords = np.arange(val_preds_loaded_final_unet.shape[0])

    for i, var_name in enumerate(output_vars):
        pred_xr_final_unet = xr.DataArray(val_preds_loaded_final_unet[:, i], dims=["time", "y", "x"], 
                                   coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)
        true_xr_final_unet = xr.DataArray(val_trues_loaded_final_unet[:, i], dims=["time", "y", "x"], 
                                   coords={"time": time_val_coords, "y": lat, "x": lon}, name=var_name)

        pred_mean_final_unet = pred_xr_final_unet.mean("time")
        true_mean_final_unet = true_xr_final_unet.mean("time")
        mean_rmse_var_final_unet = np.sqrt(((pred_mean_final_unet - true_mean_final_unet) ** 2).weighted(area_weights_vis).mean()).item()
        plot_comparison(true_mean_final_unet, pred_mean_final_unet, 
                        f"Final U-Net: {var_name.upper()} - Validation Time-Mean",
                        metric_val=mean_rmse_var_final_unet, metric_name="Time-Mean RMSE")

        pred_std_final_unet = pred_xr_final_unet.std("time")
        true_std_final_unet = true_xr_final_unet.std("time")
        std_mae_var_final_unet = np.abs(pred_std_final_unet - true_std_final_unet).weighted(area_weights_vis).mean().item()
        plot_comparison(true_std_final_unet, pred_std_final_unet, 
                        f"Final U-Net: {var_name.upper()} - Validation Time-StdDev", cmap="plasma",
                        metric_val=std_mae_var_final_unet, metric_name="Time-StdDev MAE")

        if len(time_val_coords) > 0:
            t_idx_random_final_unet = np.random.randint(0, len(time_val_coords))
            pred_sample_final_unet = pred_xr_final_unet.isel(time=t_idx_random_final_unet)
            true_sample_final_unet = true_xr_final_unet.isel(time=t_idx_random_final_unet)
            sample_rmse_var_final_unet = np.sqrt(((pred_sample_final_unet - true_sample_final_unet) ** 2).weighted(area_weights_vis).mean()).item()
            plot_comparison(true_sample_final_unet, pred_sample_final_unet, 
                            f"Final U-Net: {var_name.upper()} - Validation Sample (Timestep {t_idx_random_final_unet})",
                            metric_val=sample_rmse_var_final_unet, metric_name="RMSE")
        else:
            print(f"No time steps available in validation predictions for {var_name} to plot a random sample.")

except NameError as e:
    print(f"NameError during U-Net final visualization: {e}. Ensure 'best_run_unet' was defined and the final training cell (Cell 18) ran successfully.")
except FileNotFoundError:
    print("val_preds.npy or val_trues.npy not found for final U-Net run. "
          "Ensure that the final U-Net training and validation (Cell 18) has run successfully.")
except AttributeError as e:
    print(f"AttributeError during final U-Net visualization: {e}. Ensure datamodule_final_unet is correctly initialized.")
except Exception as e:
    print(f"An error occurred during final U-Net visualization: {e}")



--- Visualizing Validation Predictions for Final U-Net ---
datamodule_final_unet not fully set up for visualization. Setting it up...
Creating dataset with 2703 samples...
Dataset created. Input shape: torch.Size([2703, 5, 48, 72]), Output shape: torch.Size([2703, 2, 48, 72])
Creating dataset with 360 samples...
Dataset created. Input shape: torch.Size([360, 5, 48, 72]), Output shape: torch.Size([360, 2, 48, 72])
NameError during U-Net final visualization: name 'plot_comparison' is not defined. Ensure 'best_run_unet' was defined and the final training cell (Cell 18) ran successfully.
