From 2a205fc83912cfb75b289f949f9f0a36cb737a87 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 13 Apr 2026 12:39:19 +0200 Subject: [PATCH 1/5] calculate stats on train/test split --- climanet/dataset.py | 50 ++++++++++++++++++++++++++++++++++++++++----- climanet/train.py | 9 ++++++-- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/climanet/dataset.py b/climanet/dataset.py index 7976297..f14b8d7 100644 --- a/climanet/dataset.py +++ b/climanet/dataset.py @@ -54,9 +54,6 @@ def __init__( self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy() self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy() - # Store the stats of the daily data before filling NaNs - self.daily_mean, self.daily_std = calc_stats(self.daily_np) - if land_mask is not None: lm = land_mask.to_numpy().copy() if lm.ndim == 3: @@ -69,8 +66,11 @@ def __init__( # daily_mask: True where NaN (i.e. missing ocean data, not land) self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W) - # Fill NaNs with 0 in-place - np.nan_to_num(self.daily_np, copy=False, nan=0.0) + # Stats will be set later via set_stats() and NaNs will be filled with 0 in-place + self.daily_mean = None + self.daily_std = None + self._nans_filled = False + self._warned = False # Precompute padded_days_mask as a tensor (same for all patches) self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool() @@ -111,6 +111,13 @@ def __len__(self): def __getitem__(self, idx): """Get a spatiotemporal patch sample based on the index.""" + if not self._nans_filled and not self._warned: + warnings.warn( + "NaNs have not been replaced. Call compute_stats() before using the dataset.", + UserWarning + ) + self._warned = True + if idx < 0 or idx >= len(self.patch_indices): raise IndexError("Index out of range") @@ -159,3 +166,36 @@ def __getitem__(self, idx): "lat_patch": lat_patch, # (H,) "lon_patch": lon_patch, # (W,) } + + + def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]: + """Compute mean and std from specified indices (or all data if None). + + Args: + indices: List of patch indices to compute stats from. If None, use all. + + Returns: + Tuple of (mean, std) arrays + """ + if indices is None: + data = self.daily_np # (M, T, H, W) + else: + # Stack selected spatial patches + ph, pw = self.patch_size + patches = [] + for idx in indices: + i, j = self.patch_indices[idx] + patch = self.daily_np[:, :, i:i+ph, j:j+pw] # (M, T, ph, pw) + patches.append(patch) + data = np.concatenate(patches, axis=-1) # (M, T, H, W_total) + + mean, std = calc_stats(data) # (M,) + + self.daily_mean = mean + self.daily_std = std + + if not self._nans_filled: + np.nan_to_num(self.daily_np, copy=False, nan=0.0) + self._nans_filled = True + + return mean, std diff --git a/climanet/train.py b/climanet/train.py index 4344d12..b656a8b 100644 --- a/climanet/train.py +++ b/climanet/train.py @@ -67,12 +67,17 @@ def train_monthly_model( verbose: whether to print training progress """ + # check if dataset has indices attribute for stats calculation + base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset + indices = dataset.indices if hasattr(dataset, 'indices') else None + mean, std = base_dataset.compute_stats(indices) + # Initialize the model model = model.to(device) decoder = model.decoder with torch.no_grad(): - decoder.bias.copy_(torch.from_numpy(dataset.daily_mean)) - decoder.scale.copy_(torch.from_numpy(dataset.daily_std) + 1e-6) + decoder.bias.copy_(torch.from_numpy(mean)) + decoder.scale.copy_(torch.from_numpy(std) + 1e-6) # Create data loader dataloader = DataLoader( From ceda34b7c15954681b36fa9928d35db58beb21f9 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 13 Apr 2026 14:00:49 +0200 Subject: [PATCH 2/5] fix predict for train/test split --- climanet/dataset.py | 13 +++++++++---- climanet/predict.py | 21 ++++++++++++++------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/climanet/dataset.py b/climanet/dataset.py index f14b8d7..f845c8c 100644 --- a/climanet/dataset.py +++ b/climanet/dataset.py @@ -113,7 +113,7 @@ def __getitem__(self, idx): """Get a spatiotemporal patch sample based on the index.""" if not self._nans_filled and not self._warned: warnings.warn( - "NaNs have not been replaced. Call compute_stats() before using the dataset.", + "NaNs have not been replaced. Call fill_nans_with_zero() before using the dataset.", UserWarning ) self._warned = True @@ -194,8 +194,13 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]: self.daily_mean = mean self.daily_std = std - if not self._nans_filled: - np.nan_to_num(self.daily_np, copy=False, nan=0.0) - self._nans_filled = True + # Fill NaNs with 0 in-place after stats are computed + self.fill_nans_with_zero() return mean, std + + def fill_nans_with_zero(self): + """Fill NaN values in daily_np with zero in-place.""" + if not self._nans_filled: + np.nan_to_num(self.daily_np, copy=False, nan=0.0) + self._nans_filled = True \ No newline at end of file diff --git a/climanet/predict.py b/climanet/predict.py index bb716d9..ea3b62e 100644 --- a/climanet/predict.py +++ b/climanet/predict.py @@ -19,12 +19,16 @@ def _save_netcdf(predictions: np.ndarray, dataset: Dataset, save_dir: str): """Helper function to convert predictions to xarray and save as netCDF.""" B, M, H, W = predictions.shape - lats = dataset.monthly_da.coords["lat"].values - lons = dataset.monthly_da.coords["lon"].values - times = dataset.monthly_da.coords["time"].values + base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset + indices = dataset.indices if hasattr(dataset, 'indices') else range(len(dataset)) - full_predictions = np.empty((M, len(lats), len(lons)), dtype=predictions.dtype) - for i, (lat_start, lon_start) in enumerate(dataset.patch_indices): + lats = base_dataset.monthly_da.coords["lat"].values + lons = base_dataset.monthly_da.coords["lon"].values + times = base_dataset.monthly_da.coords["time"].values + + full_predictions = np.full((M, len(lats), len(lons)), np.nan, dtype=predictions.dtype) + for i, patch_idx in enumerate(indices): + lat_start, lon_start = base_dataset.patch_indices[patch_idx] full_predictions[:, lat_start : lat_start + H, lon_start : lon_start + W] = ( predictions[i] ) @@ -95,8 +99,11 @@ def predict_monthly_var( ) # Initialize an empty list to store predictions - M = dataset.monthly_np.shape[0] - H, W = dataset.patch_size + base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset + base_dataset.fill_nans_with_zero() # Ensure NaNs are filled before prediction + + M = base_dataset.monthly_np.shape[0] + H, W = base_dataset.patch_size all_predictions = torch.empty(len(dataset), M, H, W) # Set up logging From f4e2c6914f61d31ae777a12407e5d5a84506bb0e Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 13 Apr 2026 14:18:34 +0200 Subject: [PATCH 3/5] return loss in predict --- climanet/predict.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/climanet/predict.py b/climanet/predict.py index ea3b62e..452f7b4 100644 --- a/climanet/predict.py +++ b/climanet/predict.py @@ -3,6 +3,7 @@ import numpy as np from torch.utils.data import Dataset from climanet.st_encoder_decoder import SpatioTemporalModel +from climanet.train import _compute_masked_loss import xarray as xr import torch from torch.utils.data import DataLoader @@ -111,6 +112,7 @@ def predict_monthly_var( with torch.no_grad(): idx = 0 + average_loss = 0.0 for i, batch in enumerate(dataloader): # Move batch to the appropriate device predictions = model( @@ -119,14 +121,27 @@ def predict_monthly_var( batch["land_mask_patch"].to(device, non_blocking=use_cuda), batch["padded_days_mask"].to(device, non_blocking=use_cuda), ) + + # Compute masked loss + loss = _compute_masked_loss( + predictions, batch["monthly_patch"], batch["land_mask_patch"] + ) + average_loss += loss.item() + all_predictions[idx : idx + predictions.size(0)] = predictions.cpu() idx += predictions.size(0) if verbose: - print(f"Processed batch {i + 1}/{len(dataloader)}") + print(f"Processed batch {i + 1}/{len(dataloader)}, with loss: {loss.item():.4f}") writer.add_scalar("Progress/Batch", i + 1, idx) + average_loss = average_loss / len(dataloader) + + if verbose: + print(f"Average loss over all batches: {average_loss:.4f}") + writer.add_scalar("Loss/Average", average_loss) + if return_numpy: all_predictions = all_predictions.numpy() From dd1470263ba22f1052a64ea04c8bc3802b93b8b8 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Tue, 14 Apr 2026 11:32:55 +0200 Subject: [PATCH 4/5] fix minor things --- climanet/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/climanet/dataset.py b/climanet/dataset.py index f845c8c..2c23592 100644 --- a/climanet/dataset.py +++ b/climanet/dataset.py @@ -185,9 +185,9 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]: patches = [] for idx in indices: i, j = self.patch_indices[idx] - patch = self.daily_np[:, :, i:i+ph, j:j+pw] # (M, T, ph, pw) + patch = self.daily_np[:, :, i:i+ph, j:j+pw] patches.append(patch) - data = np.concatenate(patches, axis=-1) # (M, T, H, W_total) + data = np.concatenate(patches, axis=-1) mean, std = calc_stats(data) # (M,) From 8da363f3ed0798a0508e3938b6cd172711b3bb83 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Tue, 14 Apr 2026 11:33:51 +0200 Subject: [PATCH 5/5] add set_seed to utils --- climanet/utils.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/climanet/utils.py b/climanet/utils.py index 0d5688b..6279f0c 100644 --- a/climanet/utils.py +++ b/climanet/utils.py @@ -1,4 +1,4 @@ - +import random from typing import Tuple import numpy as np import xarray as xr @@ -153,9 +153,28 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None): def calc_stats(arr: np.ndarray, mean_axis: int = 0) -> Tuple[np.ndarray, np.ndarray]: - """Calculate mean and std along the specified axis, ignoring NaNs.""" + """Calculate mean and std along the specified axis, ignoring NaNs. + + Args: + arr: Input array containing NaNs to ignore. shape is (M, T, H, W) + mean_axis: Axis along which to compute mean and std (default is 0 for month) + Returns: + mean: Mean values along the specified axis, shape (M,) + std: Standard deviation along the specified axis, shape (M,) + """ axes_to_reduce = tuple(i for i in range(arr.ndim) if i != mean_axis) mean = np.nanmean(arr, axis=axes_to_reduce) # shape: (M,) std = np.nanstd(arr, axis=axes_to_reduce) # shape: (M,) return mean, std + + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # https://docs.pytorch.org/docs/stable/notes/randomness.html + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False