Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,6 @@ def __init__(
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

# Store the stats of the daily data before filling NaNs
self.daily_mean, self.daily_std = calc_stats(self.daily_np)

if land_mask is not None:
lm = land_mask.to_numpy().copy()
if lm.ndim == 3:
Expand All @@ -69,8 +66,11 @@ def __init__(
# daily_mask: True where NaN (i.e. missing ocean data, not land)
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)

# Fill NaNs with 0 in-place
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
# Stats will be set later via set_stats() and NaNs will be filled with 0 in-place
self.daily_mean = None
self.daily_std = None
self._nans_filled = False
self._warned = False

# Precompute padded_days_mask as a tensor (same for all patches)
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
Expand Down Expand Up @@ -111,6 +111,13 @@ def __len__(self):

def __getitem__(self, idx):
"""Get a spatiotemporal patch sample based on the index."""
if not self._nans_filled and not self._warned:
warnings.warn(
"NaNs have not been replaced. Call fill_nans_with_zero() before using the dataset.",
UserWarning
)
self._warned = True

if idx < 0 or idx >= len(self.patch_indices):
raise IndexError("Index out of range")

Expand Down Expand Up @@ -159,3 +166,41 @@ def __getitem__(self, idx):
"lat_patch": lat_patch, # (H,)
"lon_patch": lon_patch, # (W,)
}


def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
"""Compute mean and std from specified indices (or all data if None).

Args:
indices: List of patch indices to compute stats from. If None, use all.

Returns:
Tuple of (mean, std) arrays
"""
if indices is None:
data = self.daily_np # (M, T, H, W)
else:
# Stack selected spatial patches
ph, pw = self.patch_size
patches = []
for idx in indices:
i, j = self.patch_indices[idx]
patch = self.daily_np[:, :, i:i+ph, j:j+pw] # (M, T, ph, pw)
patches.append(patch)
data = np.concatenate(patches, axis=-1) # (M, T, H, W_total)

mean, std = calc_stats(data) # (M,)

self.daily_mean = mean
self.daily_std = std

# Fill NaNs with 0 in-place after stats are computed
self.fill_nans_with_zero()

return mean, std

def fill_nans_with_zero(self):
"""Fill NaN values in daily_np with zero in-place."""
if not self._nans_filled:
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
self._nans_filled = True
38 changes: 30 additions & 8 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from torch.utils.data import Dataset
from climanet.st_encoder_decoder import SpatioTemporalModel
from climanet.train import _compute_masked_loss
import xarray as xr
import torch
from torch.utils.data import DataLoader
Expand All @@ -19,12 +20,16 @@ def _save_netcdf(predictions: np.ndarray, dataset: Dataset, save_dir: str):
"""Helper function to convert predictions to xarray and save as netCDF."""
B, M, H, W = predictions.shape

lats = dataset.monthly_da.coords["lat"].values
lons = dataset.monthly_da.coords["lon"].values
times = dataset.monthly_da.coords["time"].values
base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset
indices = dataset.indices if hasattr(dataset, 'indices') else range(len(dataset))

full_predictions = np.empty((M, len(lats), len(lons)), dtype=predictions.dtype)
for i, (lat_start, lon_start) in enumerate(dataset.patch_indices):
lats = base_dataset.monthly_da.coords["lat"].values
lons = base_dataset.monthly_da.coords["lon"].values
times = base_dataset.monthly_da.coords["time"].values

full_predictions = np.full((M, len(lats), len(lons)), np.nan, dtype=predictions.dtype)
for i, patch_idx in enumerate(indices):
lat_start, lon_start = base_dataset.patch_indices[patch_idx]
full_predictions[:, lat_start : lat_start + H, lon_start : lon_start + W] = (
predictions[i]
)
Expand Down Expand Up @@ -95,15 +100,19 @@ def predict_monthly_var(
)

# Initialize an empty list to store predictions
M = dataset.monthly_np.shape[0]
H, W = dataset.patch_size
base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset
base_dataset.fill_nans_with_zero() # Ensure NaNs are filled before prediction

M = base_dataset.monthly_np.shape[0]
H, W = base_dataset.patch_size
all_predictions = torch.empty(len(dataset), M, H, W)

# Set up logging
writer = _setup_logging(run_dir)

with torch.no_grad():
idx = 0
average_loss = 0.0
for i, batch in enumerate(dataloader):
# Move batch to the appropriate device
predictions = model(
Expand All @@ -112,14 +121,27 @@ def predict_monthly_var(
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda),
)

# Compute masked loss
loss = _compute_masked_loss(
predictions, batch["monthly_patch"], batch["land_mask_patch"]
)
average_loss += loss.item()

all_predictions[idx : idx + predictions.size(0)] = predictions.cpu()
idx += predictions.size(0)

if verbose:
print(f"Processed batch {i + 1}/{len(dataloader)}")
print(f"Processed batch {i + 1}/{len(dataloader)}, with loss: {loss.item():.4f}")

writer.add_scalar("Progress/Batch", i + 1, idx)

average_loss = average_loss / len(dataloader)

if verbose:
print(f"Average loss over all batches: {average_loss:.4f}")
writer.add_scalar("Loss/Average", average_loss)

if return_numpy:
all_predictions = all_predictions.numpy()

Expand Down
9 changes: 7 additions & 2 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,17 @@ def train_monthly_model(
verbose: whether to print training progress
"""

# check if dataset has indices attribute for stats calculation
base_dataset = dataset.dataset if hasattr(dataset, 'dataset') else dataset
indices = dataset.indices if hasattr(dataset, 'indices') else None
mean, std = base_dataset.compute_stats(indices)

# Initialize the model
model = model.to(device)
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(dataset.daily_mean))
decoder.scale.copy_(torch.from_numpy(dataset.daily_std) + 1e-6)
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6)

# Create data loader
dataloader = DataLoader(
Expand Down