In [1]:
# dataset_era5_lai_whole.py
import os
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

In [3]:
# dataset_era5_lai_whole.py
import os
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

class ERA5LAIWholeWorld(Dataset):
    """
    One YEAR, 24 items (one per 15-day sample). Each __getitem__ returns:
        X: (3, H, W)  -> [ssrd, t2m, tp]  (raw/anom/z as chosen)
        y: (1, H, W)  -> LAI (raw or anomaly)
        mask: (1, H, W) boolean (True where y is valid)
        meta: dict with 'year' and 'sample'
    """
    def __init__(
        self,
        year: int,
        era5_mode="anom",   # "raw" | "anom" | "z"
        lai_mode="raw",     # "raw" | "anom"
        paths=None,
        engine="netcdf4",
        robust_nan=True,
        sample_indices=None,  # default: all 0..23
    ):
        assert era5_mode in ("raw","anom","z")
        assert lai_mode in ("raw","anom")
        self.year = int(year)
        self.era5_mode = era5_mode
        self.lai_mode = lai_mode
        self.engine = engine
        self.robust_nan = robust_nan
        self.sample_indices = list(range(24)) if sample_indices is None else list(sample_indices)

        # default paths (override with 'paths' dict)
        default_paths = {
            "era5_root": "/ptmp/mp002/ellis/lai",
            "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",   # your anomalies/z-scores
            "lai_root": "/ptmp/mp002/ellis/lai/lai",
            "lai_tmpl": "LAI.1440.720.{year}.nc",
            "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",         # change if different
        }
        self.paths = default_paths if paths is None else {**default_paths, **paths}

        # open all arrays for that year
        self._open_year()

    def _open_da(self, path, varname):
        ds = xr.open_dataset(path, engine=self.engine)
        da = ds[varname]
        if self.robust_nan:
            fv = da.attrs.get("_FillValue", None)
            if fv is not None:
                da = da.where(da != fv)
        # north-up
        lat_name = "lat" if "lat" in da.coords else "latitude"
        da = da.sortby(lat_name)
        # attach sample coord
        if "time" in da.dims and da.sizes["time"] == 24:
            da = da.assign_coords(sample=("time", np.arange(24)))
        return da

    def _open_year(self):
        y = self.year

        # ERA5 inputs
        if self.era5_mode == "raw":
            f_ssrd = os.path.join(self.paths["era5_root"], "ssrd", f"ssrd.15daily.fc.era5.1440.720.{y}.nc")
            f_t2m  = os.path.join(self.paths["era5_root"], "t2m",  f"t2m.15daily.an.era5.1440.720.{y}.nc")
            f_tp   = os.path.join(self.paths["era5_root"], "tp",   f"tp.15daily.fc.era5.1440.720.{y}.nc")
            self.ssrd = self._open_da(f_ssrd, "ssrd")
            self.t2m  = self._open_da(f_t2m,  "t2m")
            self.tp   = self._open_da(f_tp,   "tp")
        else:
            suffix = "anom" if self.era5_mode == "anom" else "z"
            base = self.paths["era5_anom_dir"]
            self.ssrd = self._open_da(os.path.join(base, f"ssrd_{suffix}_{y}.nc"), f"ssrd_{suffix}")
            self.t2m  = self._open_da(os.path.join(base, f"t2m_{suffix}_{y}.nc"),  f"t2m_{suffix}")
            self.tp   = self._open_da(os.path.join(base, f"tp_{suffix}_{y}.nc"),   f"tp_{suffix}")

        # LAI target
        lai_file = os.path.join(self.paths["lai_root"], self.paths["lai_tmpl"].format(year=y))
        lai_var = self._infer_var(lai_file)
        self.lai = self._open_da(lai_file, lai_var)

        if self.lai_mode == "anom":
            # use your saved LAI anomalies per year if available
            lai_anom_dir = self.paths.get("lai_anom_dir", self.paths["lai_root"])
            lai_anom_file = os.path.join(lai_anom_dir, f"LAI_anom_{y}.nc")
            if os.path.exists(lai_anom_file):
                self.lai = self._open_da(lai_anom_file, "LAI_anom")
            else:
                raise FileNotFoundError(f"LAI anomaly file not found: {lai_anom_file}")

        # dims/coords
        self.lat_name = "lat" if "lat" in self.lai.coords else "latitude"
        self.lon_name = "lon" if "lon" in self.lai.coords else "longitude"

        # sanity checks
        for da, name in [(self.ssrd,"ssrd"),(self.t2m,"t2m"),(self.tp,"tp")]:
            assert "time" in da.dims and da.sizes["time"] == 24, f"{name} expects 24 samples"
            assert da.sizes[self.lat_name] == self.lai.sizes[self.lat_name]
            assert da.sizes[self.lon_name] == self.lai.sizes[self.lon_name]
        assert "time" in self.lai.dims and self.lai.sizes["time"] == 24

    def _infer_var(self, nc_path):
        with xr.open_dataset(nc_path, engine=self.engine) as ds:
            for v in ds.data_vars:
                if ds[v].ndim >= 2:
                    return v
        raise RuntimeError(f"No data variable in {nc_path}")

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, i):
        s = self.sample_indices[i]
        # features (H,W) each
        x_ssrd = self.ssrd.isel(time=s).values
        x_t2m  = self.t2m .isel(time=s).values
        x_tp   = self.tp  .isel(time=s).values
        # target
        y_lai  = self.lai .isel(time=s).values  # may contain NaNs over ocean

        # stack to tensors, channel-first
        X = np.stack([x_ssrd, x_t2m, x_tp], axis=0)              # (3,H,W)
        y = np.expand_dims(y_lai, axis=0)                       # (1,H,W)
        mask = ~np.isnan(y)                                     # valid target pixels

        # features: fill NaNs with 0 (or choose another fill)
        X = np.nan_to_num(X, nan=0.0)

        X = torch.from_numpy(X.astype(np.float32))
        y = torch.from_numpy(np.nan_to_num(y, nan=0.0).astype(np.float32))
        mask = torch.from_numpy(mask.astype(np.bool_))
        meta = {"year": self.year, "sample": int(s)}
        return X, y, mask, meta

# Loss helper (same as before)
def masked_mse_loss(pred, target, mask):
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask.float()
    denom = mask.float().sum().clamp_min(1.0)
    return diff2.sum() / denom

def collate_keep_meta(batch):
    """Custom collate so that 'meta' stays a list of dicts (not collated into tensors)."""
    Xs, ys, masks, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(masks), list(metas)

In [5]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

ds = ERA5LAIWholeWorld(
    year=1990,
    era5_mode="raw",   # or "raw"/"z"
    lai_mode="raw",     # or "anom"
    paths={
        "era5_root": "/ptmp/mp002/ellis/lai",
        "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
        "lai_root": "/ptmp/mp002/ellis/lai/lai",
        "lai_tmpl": "LAI.1440.720.{year}.nc",
        # "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",  # if using LAI anomalies
    },
)

loader = DataLoader(ds, batch_size=1, shuffle=False)  # batch_size=1 since each item is full globe

model = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
    nn.Conv2d(16, 1, 3, padding=1)
)

for X, y, mask, meta in loader:
    pred = model(X)
    loss = masked_mse_loss(pred, y, mask)
    print(f"year={meta['year'][0].item()}, sample={meta['sample'][0].item()}, loss={float(loss):.4f}")
    break

year=1990, sample=0, loss=166.4673


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print(f"year={meta['year'][0].item()}, sample={meta['sample'][0].item()}, loss={float(loss):.4f}")
