In [1]:
# dataset_era5_lai_whole.py
import os
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

In [3]:
# dataset_era5_lai_whole.py
import os
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

class ERA5LAIWholeWorld(Dataset):
    """
    One YEAR, 24 items (one per 15-day sample). Each __getitem__ returns:
        X: (3, H, W)  -> [ssrd, t2m, tp]  (raw/anom/z as chosen)
        y: (1, H, W)  -> LAI (raw or anomaly)
        mask: (1, H, W) boolean (True where y is valid)
        meta: dict with 'year' and 'sample'
    """
    def __init__(
        self,
        year: int,
        era5_mode="anom",   # "raw" | "anom" | "z"
        lai_mode="raw",     # "raw" | "anom"
        paths=None,
        engine="netcdf4",
        robust_nan=True,
        sample_indices=None,  # default: all 0..23
    ):
        assert era5_mode in ("raw","anom","z")
        assert lai_mode in ("raw","anom")
        self.year = int(year)
        self.era5_mode = era5_mode
        self.lai_mode = lai_mode
        self.engine = engine
        self.robust_nan = robust_nan
        self.sample_indices = list(range(24)) if sample_indices is None else list(sample_indices)

        # default paths (override with 'paths' dict)
        default_paths = {
            "era5_root": "/ptmp/mp002/ellis/lai",
            "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",   # your anomalies/z-scores
            "lai_root": "/ptmp/mp002/ellis/lai/lai",
            "lai_tmpl": "LAI.1440.720.{year}.nc",
            "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",         # change if different
        }
        self.paths = default_paths if paths is None else {**default_paths, **paths}

        # open all arrays for that year
        self._open_year()

    def _open_da(self, path, varname):
        ds = xr.open_dataset(path, engine=self.engine)
        da = ds[varname]
        if self.robust_nan:
            fv = da.attrs.get("_FillValue", None)
            if fv is not None:
                da = da.where(da != fv)
        # north-up
        lat_name = "lat" if "lat" in da.coords else "latitude"
        da = da.sortby(lat_name)
        # attach sample coord
        if "time" in da.dims and da.sizes["time"] == 24:
            da = da.assign_coords(sample=("time", np.arange(24)))
        return da

    def _open_year(self):
        y = self.year

        # ERA5 inputs
        if self.era5_mode == "raw":
            f_ssrd = os.path.join(self.paths["era5_root"], "ssrd", f"ssrd.15daily.fc.era5.1440.720.{y}.nc")
            f_t2m  = os.path.join(self.paths["era5_root"], "t2m",  f"t2m.15daily.an.era5.1440.720.{y}.nc")
            f_tp   = os.path.join(self.paths["era5_root"], "tp",   f"tp.15daily.fc.era5.1440.720.{y}.nc")
            self.ssrd = self._open_da(f_ssrd, "ssrd")
            self.t2m  = self._open_da(f_t2m,  "t2m")
            self.tp   = self._open_da(f_tp,   "tp")
        else:
            suffix = "anom" if self.era5_mode == "anom" else "z"
            base = self.paths["era5_anom_dir"]
            self.ssrd = self._open_da(os.path.join(base, f"ssrd_{suffix}_{y}.nc"), f"ssrd_{suffix}")
            self.t2m  = self._open_da(os.path.join(base, f"t2m_{suffix}_{y}.nc"),  f"t2m_{suffix}")
            self.tp   = self._open_da(os.path.join(base, f"tp_{suffix}_{y}.nc"),   f"tp_{suffix}")

        # LAI target
        lai_file = os.path.join(self.paths["lai_root"], self.paths["lai_tmpl"].format(year=y))
        lai_var = self._infer_var(lai_file)
        self.lai = self._open_da(lai_file, lai_var)

        if self.lai_mode == "anom":
            # use your saved LAI anomalies per year if available
            lai_anom_dir = self.paths.get("lai_anom_dir", self.paths["lai_root"])
            lai_anom_file = os.path.join(lai_anom_dir, f"LAI_anom_{y}.nc")
            if os.path.exists(lai_anom_file):
                self.lai = self._open_da(lai_anom_file, "LAI_anom")
            else:
                raise FileNotFoundError(f"LAI anomaly file not found: {lai_anom_file}")

        # dims/coords
        self.lat_name = "lat" if "lat" in self.lai.coords else "latitude"
        self.lon_name = "lon" if "lon" in self.lai.coords else "longitude"

        # sanity checks
        for da, name in [(self.ssrd,"ssrd"),(self.t2m,"t2m"),(self.tp,"tp")]:
            assert "time" in da.dims and da.sizes["time"] == 24, f"{name} expects 24 samples"
            assert da.sizes[self.lat_name] == self.lai.sizes[self.lat_name]
            assert da.sizes[self.lon_name] == self.lai.sizes[self.lon_name]
        assert "time" in self.lai.dims and self.lai.sizes["time"] == 24

    def _infer_var(self, nc_path):
        with xr.open_dataset(nc_path, engine=self.engine) as ds:
            for v in ds.data_vars:
                if ds[v].ndim >= 2:
                    return v
        raise RuntimeError(f"No data variable in {nc_path}")

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, i):
        s = self.sample_indices[i]
        # features (H,W) each
        x_ssrd = self.ssrd.isel(time=s).values
        x_t2m  = self.t2m .isel(time=s).values
        x_tp   = self.tp  .isel(time=s).values
        # target
        y_lai  = self.lai .isel(time=s).values  # may contain NaNs over ocean

        # stack to tensors, channel-first
        X = np.stack([x_ssrd, x_t2m, x_tp], axis=0)              # (3,H,W)
        y = np.expand_dims(y_lai, axis=0)                       # (1,H,W)
        mask = ~np.isnan(y)                                     # valid target pixels

        # features: fill NaNs with 0 (or choose another fill)
        X = np.nan_to_num(X, nan=0.0)

        X = torch.from_numpy(X.astype(np.float32))
        y = torch.from_numpy(np.nan_to_num(y, nan=0.0).astype(np.float32))
        mask = torch.from_numpy(mask.astype(np.bool_))
        meta = {"year": self.year, "sample": int(s)}
        return X, y, mask, meta

# Loss helper (same as before)
def masked_mse_loss(pred, target, mask):
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask.float()
    denom = mask.float().sum().clamp_min(1.0)
    return diff2.sum() / denom

def collate_keep_meta(batch):
    """Custom collate so that 'meta' stays a list of dicts (not collated into tensors)."""
    Xs, ys, masks, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(masks), list(metas)

In [5]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

ds = ERA5LAIWholeWorld(
    year=1990,
    era5_mode="raw",   # or "raw"/"z"
    lai_mode="raw",     # or "anom"
    paths={
        "era5_root": "/ptmp/mp002/ellis/lai",
        "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
        "lai_root": "/ptmp/mp002/ellis/lai/lai",
        "lai_tmpl": "LAI.1440.720.{year}.nc",
        # "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",  # if using LAI anomalies
    },
)

loader = DataLoader(ds, batch_size=1, shuffle=False)  # batch_size=1 since each item is full globe

model = nn.Sequential(
    nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
    nn.Conv2d(16, 1, 3, padding=1)
)

for X, y, mask, meta in loader:
    pred = model(X)
    loss = masked_mse_loss(pred, y, mask)
    print(f"year={meta['year'][0].item()}, sample={meta['sample'][0].item()}, loss={float(loss):.4f}")
    break

year=1990, sample=0, loss=166.4673


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print(f"year={meta['year'][0].item()}, sample={meta['sample'][0].item()}, loss={float(loss):.4f}")


In [1]:
# dataset_era5_lai_sequence.py
import os, bisect
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

def _open_da(path, varname, engine="netcdf4", robust_nan=True):
    ds = xr.open_dataset(path, engine=engine)
    da = ds[varname]
    if robust_nan:
        fv = da.attrs.get("_FillValue", None)
        if fv is not None:
            da = da.where(da != fv)
    # north-up
    lat_name = "lat" if "lat" in da.coords else "latitude"
    da = da.sortby(lat_name)
    # attach sample coord
    if "time" in da.dims and da.sizes["time"] == 24:
        da = da.assign_coords(sample=("time", np.arange(24)))
    return da

def _infer_var(nc_path, engine="netcdf4"):
    with xr.open_dataset(nc_path, engine=engine) as ds:
        for v in ds.data_vars:
            if ds[v].ndim >= 2:
                return v
    raise RuntimeError(f"No data variable in {nc_path}")

class ERA5LAISequenceWorld(Dataset):
    """
    Multi-year, whole-world, sliding-window dataset.

    Each item:
      X: shape depends on feature_layout:
         - "time_channels": (T*C, H, W)  e.g., (48*3, H, W) for 2 years @ 24/yr
         - "time_first":    (T, C, H, W) e.g., (48, 3, H, W)
      y:
         - if target_mode="last": (1, H, W)  (LAI at last timestep in window)
         - if target_mode="all":  (T, 1, H, W) (LAI for each timestep)
      mask:
         - same shape as y; True where target is valid (non-NaN)
      meta: dict with {"start_index": int, "years": [list years spanned], "t_indices": [global t indices]}

    Parameters
    ----------
    years: list[int]         years to include (e.g., [1982, 1983, ..., 2010])
    seq_len: int             window length in samples (e.g., 48 for 2 years)
    seq_stride: int          step between window starts (e.g., 12 for quarterly steps)
    era5_mode: "raw"|"anom"|"z"
    lai_mode:  "raw"|"anom"
    feature_layout: "time_channels"|"time_first"
    target_mode: "last"|"all"
    paths: dict              era5_root, era5_anom_dir, lai_root, lai_tmpl, (optional) lai_anom_dir
    engine: str
    robust_nan: bool
    """
    def __init__(
        self,
        years,
        seq_len=48,
        seq_stride=24,
        era5_mode="anom",
        lai_mode="raw",
        feature_layout="time_channels",
        target_mode="last",
        paths=None,
        engine="netcdf4",
        robust_nan=True,
    ):
        super().__init__()
        assert era5_mode in ("raw","anom","z")
        assert lai_mode in ("raw","anom")
        assert feature_layout in ("time_channels","time_first")
        assert target_mode in ("last","all")
        self.years = list(years)
        self.seq_len = int(seq_len)
        self.seq_stride = int(seq_stride)
        self.era5_mode = era5_mode
        self.lai_mode = lai_mode
        self.feature_layout = feature_layout
        self.target_mode = target_mode
        self.engine = engine
        self.robust_nan = robust_nan

        defaults = {
            "era5_root": "/ptmp/mp002/ellis/lai",
            "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
            "lai_root": "/ptmp/mp002/ellis/lai/lai",
            "lai_tmpl": "LAI.1440.720.{year}.nc",
            "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",
        }
        self.paths = defaults if paths is None else {**defaults, **paths}

        # Open per-year arrays; keep per-year handles to avoid loading everything at once
        self._open_years()

        # Build global time index over concatenated years: 24 samples per year
        self.samples_per_year = 24
        self.year_offsets = [i*self.samples_per_year for i in range(len(self.years))]
        self.total_samples = self.samples_per_year * len(self.years)

        # Build list of window start indices
        self.starts = list(range(0, self.total_samples - self.seq_len + 1, self.seq_stride))
        if len(self.starts) == 0:
            raise ValueError("seq_len is longer than total samples. Reduce seq_len or add years.")

    def _open_years(self):
        self.ssrd_y = []
        self.t2m_y  = []
        self.tp_y   = []
        self.lai_y  = []
        # open one year to fix dims
        example_y = self.years[0]

        for y in self.years:
            # ERA5 features
            if self.era5_mode == "raw":
                f_ssrd = os.path.join(self.paths["era5_root"], "ssrd", f"ssrd.15daily.fc.era5.1440.720.{y}.nc")
                f_t2m  = os.path.join(self.paths["era5_root"], "t2m",  f"t2m.15daily.an.era5.1440.720.{y}.nc")
                f_tp   = os.path.join(self.paths["era5_root"], "tp",   f"tp.15daily.fc.era5.1440.720.{y}.nc")
                ssrd = _open_da(f_ssrd, "ssrd", self.engine, self.robust_nan)
                t2m  = _open_da(f_t2m,  "t2m",  self.engine, self.robust_nan)
                tp   = _open_da(f_tp,   "tp",   self.engine, self.robust_nan)
            else:
                suffix = "anom" if self.era5_mode == "anom" else "z"
                base = self.paths["era5_anom_dir"]
                ssrd = _open_da(os.path.join(base, f"ssrd_{suffix}_{y}.nc"), f"ssrd_{suffix}", self.engine, self.robust_nan)
                t2m  = _open_da(os.path.join(base, f"t2m_{suffix}_{y}.nc"),  f"t2m_{suffix}",  self.engine, self.robust_nan)
                tp   = _open_da(os.path.join(base, f"tp_{suffix}_{y}.nc"),   f"tp_{suffix}",   self.engine, self.robust_nan)

            # LAI
            lai_file = os.path.join(self.paths["lai_root"], self.paths["lai_tmpl"].format(year=y))
            lai_var  = _infer_var(lai_file, self.engine)
            lai = _open_da(lai_file, lai_var, self.engine, self.robust_nan)
            if self.lai_mode == "anom":
                lai_anom_dir = self.paths.get("lai_anom_dir", self.paths["lai_root"])
                lai_anom_file = os.path.join(lai_anom_dir, f"LAI_anom_{y}.nc")
                if os.path.exists(lai_anom_file):
                    lai = _open_da(lai_anom_file, "LAI_anom", self.engine, self.robust_nan)
                else:
                    raise FileNotFoundError(f"LAI anomaly file not found: {lai_anom_file}")

            # basic checks
            for da, name in [(ssrd,"ssrd"),(t2m,"t2m"),(tp,"tp")]:
                assert "time" in da.dims and da.sizes["time"] == 24, f"{name} {y} expects 24 samples"

            self.ssrd_y.append(ssrd)
            self.t2m_y.append(t2m)
            self.tp_y.append(tp)
            self.lai_y.append(lai)

        # dims
        self.lat_name = "lat" if "lat" in self.lai_y[0].coords else "latitude"
        self.lon_name = "lon" if "lon" in self.lai_y[0].coords else "longitude"
        H = int(self.lai_y[0].sizes[self.lat_name]); W = int(self.lai_y[0].sizes[self.lon_name])
        for j in range(len(self.years)):
            for da, name in [(self.ssrd_y[j],"ssrd"),(self.t2m_y[j],"t2m"),(self.tp_y[j],"tp")]:
                assert da.sizes[self.lat_name] == H and da.sizes[self.lon_name] == W, f"{name} grid mismatch in year {self.years[j]}"

    def __len__(self):
        return len(self.starts)

    def _slice_from_global_t(self, global_t):
        """Map global sample index to (year_idx, local_time_index)."""
        year_idx = min(len(self.year_offsets)-1, bisect.bisect_right(self.year_offsets, global_t) - 1)
        local_t  = global_t - self.year_offsets[year_idx]
        return year_idx, local_t

    def __getitem__(self, idx):
        start = self.starts[idx]
        Ts = self.seq_len

        # Collect sequence across year boundaries
        X_list = []   # each entry shape (3, H, W)
        y_list = []   # each entry shape (1, H, W) for target if target_mode="all"

        for t in range(Ts):
            g = start + t
            yi, lt = self._slice_from_global_t(g)

            # features (H,W)
            x_ssrd = self.ssrd_y[yi].isel(time=lt).values
            x_t2m  = self.t2m_y [yi].isel(time=lt).values
            x_tp   = self.tp_y  [yi].isel(time=lt).values
            X_list.append(np.stack([x_ssrd, x_t2m, x_tp], axis=0))  # (3,H,W)

            if self.target_mode == "all":
                y_map = self.lai_y[yi].isel(time=lt).values
                y_list.append(np.expand_dims(y_map, axis=0))         # (1,H,W)

        # Stack over time
        X_seq = np.stack(X_list, axis=0)  # (T, 3, H, W)

        # Target
        if self.target_mode == "last":
            glast = start + Ts - 1
            yi, lt = self._slice_from_global_t(glast)
            y_map = self.lai_y[yi].isel(time=lt).values
            y = np.expand_dims(y_map, axis=0)              # (1,H,W)
            mask = ~np.isnan(y)
        else:
            y = np.stack(y_list, axis=0)                   # (T,1,H,W)
            mask = ~np.isnan(y)

        # Fill NaNs in features, keep y NaNs masked
        X_seq = np.nan_to_num(X_seq, nan=0.0)

        # Layout
        if self.feature_layout == "time_channels":
            X = X_seq.transpose(1,0,2,3).reshape(-1, X_seq.shape[2], X_seq.shape[3])  # (T*C, H, W)
        else:
            X = X_seq  # (T, C, H, W)

        X_t = torch.from_numpy(X.astype(np.float32))
        y_t = torch.from_numpy(np.nan_to_num(y, nan=0.0).astype(np.float32))
        m_t = torch.from_numpy(mask.astype(np.bool_))

        # meta
        # list of global time indices and years covered
        t_indices = list(range(start, start+Ts))
        years_cov = sorted(set(self.years[self._slice_from_global_t(t)[0]] for t in t_indices))
        meta = {"start_index": int(start), "t_indices": t_indices, "years": years_cov}
        return X_t, y_t, m_t, meta

# Loss (same idea; supports last or all target)
def masked_mse_loss(pred, target, mask):
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask.float()
    denom = mask.float().sum().clamp_min(1.0)
    return diff2.sum() / denom

In [2]:
from torch.utils.data import DataLoader

ds_seq = ERA5LAISequenceWorld(
    years=range(1988, 1990),     # training years
    seq_len=48,                  # 2 years
    seq_stride=24,               # step one year
    era5_mode="raw",
    lai_mode="raw",
    feature_layout="time_channels",  # model input shape (48*3, H, W)
    target_mode="last",
    paths={
        "era5_root": "/ptmp/mp002/ellis/lai",
        "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
        "lai_root": "/ptmp/mp002/ellis/lai/lai",
        "lai_tmpl": "LAI.1440.720.{year}.nc",
    },
    engine="netcdf4",
)

def collate_keep_meta(batch):
    Xs, ys, ms, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(ms), list(metas)

loader = DataLoader(ds_seq, batch_size=1, shuffle=True, collate_fn=collate_keep_meta)

In [3]:
import os
import json
import torch
from torch.utils.data import DataLoader

# Resolve relative to the script location
script_dir = os.getcwd()
config_path = os.path.join(script_dir, "..", "inputs", "training.json")

with open(config_path, "r") as f:
    cfg = json.load(f)

train_ds = ERA5LAISequenceWorld(
    years=cfg["splits"]["train_years"],
    seq_len=cfg["training"]["seq_len_samples"],
    seq_stride=cfg["training"]["seq_stride"],
    era5_mode=cfg["data"]["era5_mode"],
    lai_mode=cfg["data"]["lai_mode"],
    feature_layout=cfg["training"]["feature_layout"],
    target_mode=cfg["training"]["target_mode"],
    paths=cfg["data"]["paths"],
    engine=cfg["data"]["engine"],
)

val_ds = ERA5LAISequenceWorld(
    years=cfg["splits"]["val_years"],
    seq_len=cfg["training"]["seq_len_samples"],
    seq_stride=cfg["training"]["seq_stride"],
    era5_mode=cfg["data"]["era5_mode"],
    lai_mode=cfg["data"]["lai_mode"],
    feature_layout=cfg["training"]["feature_layout"],
    target_mode=cfg["training"]["target_mode"],
    paths=cfg["data"]["paths"],
    engine=cfg["data"]["engine"],
)

def collate_keep_meta(batch):
    Xs, ys, ms, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(ms), list(metas)

train_loader = DataLoader(
    train_ds,
    batch_size=cfg["training"]["batch_size"],
    shuffle=cfg["training"]["shuffle"],
    collate_fn=collate_keep_meta,
)

val_loader = DataLoader(
    val_ds,
    batch_size=cfg["training"]["batch_size"],
    shuffle=False,
    collate_fn=collate_keep_meta,
)

In [None]:
# train_optimized.py
import os, bisect, json, time
import numpy as np
import xarray as xr
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler

# ---------------- Dataset Helper Functions (No changes here) ----------------
def _open_da(path, varname, engine="netcdf4", robust_nan=True):
    ds = xr.open_dataset(path, engine=engine)
    da = ds[varname]
    if robust_nan:
        fv = da.attrs.get("_FillValue", None)
        if fv is not None:
            da = da.where(da != fv)
    lat_name = "lat" if "lat" in da.coords else "latitude"
    da = da.sortby(lat_name)
    if "time" in da.dims and da.sizes["time"] == 24:
        da = da.assign_coords(sample=("time", np.arange(24)))
    return da

def _infer_var(nc_path, engine="netcdf4"):
    with xr.open_dataset(nc_path, engine=engine) as ds:
        for v in ds.data_vars:
            if ds[v].ndim >= 2:
                return v
    raise RuntimeError(f"No suitable data variable found in {nc_path}")

# ---------------- PyTorch Dataset Class (MODIFIED FOR PERFORMANCE) ----------------
class ERA5LAISequenceWorld(Dataset):
    def __init__(
        self, years, seq_len=48, seq_stride=24,
        era5_mode="anom", lai_mode="raw",
        feature_layout="time_channels", target_mode="last",
        paths=None, engine="netcdf4", robust_nan=True,
    ):
        super().__init__()
        # ... (initial parameters are the same) ...
        self.years = list(years)
        self.seq_len = int(seq_len)
        self.seq_stride = int(seq_stride)
        self.era5_mode = era5_mode
        self.lai_mode = lai_mode
        self.feature_layout = feature_layout
        self.target_mode = target_mode
        self.engine = engine
        self.robust_nan = robust_nan
        defaults = {
            "era5_root": "/ptmp/mp002/ellis/lai",
            "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
            "lai_root": "/ptmp/mp002/ellis/lai/lai",
            "lai_tmpl": "LAI.1440.720.{year}.nc",
            "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",
        }
        self.paths = defaults if paths is None else {**defaults, **paths}
        
        # This method now loads all data into RAM for speed
        self._open_and_load_years() # <-- MODIFIED

        self.samples_per_year = 24
        self.total_samples = self.lai_data.shape[0] # <-- MODIFIED to use the loaded data shape

        self.starts = list(range(0, self.total_samples - self.seq_len + 1, self.seq_stride))
        if not self.starts:
            raise ValueError("seq_len is longer than total samples. Reduce seq_len or add years.")

    def _open_and_load_years(self): # <-- MODIFIED: New optimized loading function
        """
        Loads all necessary data for the given years, concatenates them,
        and then loads everything into memory as NumPy arrays to prevent I/O bottlenecks.
        """
        ssrd_y, t2m_y, tp_y, lai_y = [], [], [], []
        print(f"    - Loading data for years: {self.years}...") # <-- NEW progress indicator
        for y in self.years:
            # ... (the logic to find and open individual files is the same) ...
            if self.era5_mode == "raw":
                f_ssrd = os.path.join(self.paths["era5_root"], "ssrd", f"ssrd.15daily.fc.era5.1440.720.{y}.nc")
                f_t2m  = os.path.join(self.paths["era5_root"], "t2m",  f"t2m.15daily.an.era5.1440.720.{y}.nc")
                f_tp   = os.path.join(self.paths["era5_root"], "tp",   f"tp.15daily.fc.era5.1440.720.{y}.nc")
                ssrd = _open_da(f_ssrd, "ssrd", self.engine, self.robust_nan)
                t2m  = _open_da(f_t2m,  "t2m",  self.engine, self.robust_nan)
                tp   = _open_da(f_tp,   "tp",   self.engine, self.robust_nan)
            else:
                suffix = "anom" if self.era5_mode == "anom" else "z"
                base = self.paths["era5_anom_dir"]
                ssrd = _open_da(os.path.join(base, f"ssrd_{suffix}_{y}.nc"), f"ssrd_{suffix}", self.engine, self.robust_nan)
                t2m  = _open_da(os.path.join(base, f"t2m_{suffix}_{y}.nc"),  f"t2m_{suffix}",  self.engine, self.robust_nan)
                tp   = _open_da(os.path.join(base, f"tp_{suffix}_{y}.nc"),   f"tp_{suffix}",   self.engine, self.robust_nan)
            lai_file = os.path.join(self.paths["lai_root"], self.paths["lai_tmpl"].format(year=y))
            lai_var  = _infer_var(lai_file, self.engine)
            lai = _open_da(lai_file, lai_var, self.engine, self.robust_nan)
            if self.lai_mode == "anom":
                lai_anom_file = os.path.join(self.paths.get("lai_anom_dir", self.paths["lai_root"]), f"LAI_anom_{y}.nc")
                lai = _open_da(lai_anom_file, "LAI_anom", self.engine, self.robust_nan)
            
            ssrd_y.append(ssrd); t2m_y.append(t2m); tp_y.append(tp); lai_y.append(lai)
        
        # --- NEW OPTIMIZATION STEP ---
        # Concatenate all years into single xarray DataArrays
        print("    - Concatenating yearly data...")
        full_ssrd = xr.concat(ssrd_y, dim="time")
        full_t2m  = xr.concat(t2m_y, dim="time")
        full_tp   = xr.concat(tp_y, dim="time")
        full_lai  = xr.concat(lai_y, dim="time")

        # Now, load all the data from disk into RAM. This is the key performance gain.
        # This will be slow once, but makes training much faster.
        print("    - Loading all data into memory (this may take a moment)...")
        self.ssrd_data = full_ssrd.load().values
        self.t2m_data  = full_t2m.load().values
        self.tp_data   = full_tp.load().values
        self.lai_data  = full_lai.load().values
        print("    - ✅ Data loaded.")

    def __len__(self):
        return len(self.starts)

    def __getitem__(self, idx): # <-- MODIFIED: New optimized __getitem__
        """Fetches a single training sequence by slicing the pre-loaded NumPy arrays."""
        start = self.starts[idx]
        end = start + self.seq_len

        # --- SLICE pre-loaded NumPy arrays (extremely fast) ---
        # Shape of each is (T, H, W) where T=seq_len
        x_ssrd_seq = self.ssrd_data[start:end, :, :]
        x_t2m_seq  = self.t2m_data[start:end, :, :]
        x_tp_seq   = self.tp_data[start:end, :, :]

        # Stack the features to get (T, 3, H, W)
        X_seq = np.stack([x_ssrd_seq, x_t2m_seq, x_tp_seq], axis=1)

        # --- Prepare the target (y) and mask from pre-loaded data ---
        if self.target_mode == "last":
            # Target is the LAI at the final time step of the sequence
            y = self.lai_data[end - 1, :, :] # Get the last slice
            y = np.expand_dims(y, axis=0)    # Add channel dim -> (1, H, W)
            mask = ~np.isnan(y)
        else: # target_mode == "all"
            y = self.lai_data[start:end, :, :]
            y = np.expand_dims(y, axis=1) # Add channel dim -> (T, 1, H, W)
            mask = ~np.isnan(y)
        
        X_seq = np.nan_to_num(X_seq, nan=0.0)

        # --- Reshape features based on layout (no changes here) ---
        if self.feature_layout == "time_channels":
            X = X_seq.transpose(0, 2, 3, 1).reshape(X_seq.shape[2], X_seq.shape[3], -1)
            X = X.transpose(2, 0, 1) # to get C,H,W -> (3T, H, W)
        else:
            X = X_seq

        # --- Convert to PyTorch tensors (no changes here) ---
        X_t = torch.from_numpy(X.astype(np.float32))
        y_t = torch.from_numpy(np.nan_to_num(y, nan=0.0).astype(np.float32))
        m_t = torch.from_numpy(mask.astype(np.bool_))

        # We can simplify the metadata as we don't need to track years anymore
        meta = {"start_index": int(start)}
        return X_t, y_t, m_t, meta

# ... (The rest of the script, including TinyCNN, collate_fn, and the main training loop,
#      remains exactly the same as the previous version with all the print statements.) ...

def masked_mse_loss(pred, target, mask):
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask.float()
    denom = mask.float().sum().clamp_min(1.0)
    return diff2.sum() / denom

def collate_keep_meta(batch):
    Xs, ys, ms, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(ms), list(metas)

class TinyCNN(nn.Module):
    def __init__(self, in_ch, mid=16, out_ch=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, mid, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(mid, mid, kernel_size=3, padding=1),   nn.ReLU(inplace=True),
            nn.Conv2d(mid, out_ch, kernel_size=1)
        )
    def forward(self, x):
        return self.net(x)

if __name__ == "__main__":
    try:
        script_dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        script_dir = os.getcwd()
    config_path = os.path.join(script_dir, "..", "inputs", "training.json")
    with open(config_path, "r") as f:
        cfg = json.load(f)
    print("--- Configuration Loaded ---")
    print(json.dumps(cfg, indent=2))
    print("-" * 28)
    print("\n--- Initializing Datasets ---")
    print("Initializing training dataset...")
    train_ds = ERA5LAISequenceWorld(
        years=cfg["splits"]["train_years"],
        seq_len=cfg["training"]["seq_len_samples"],
        seq_stride=cfg["training"]["seq_stride"],
        era5_mode=cfg["data"]["era5_mode"],
        lai_mode=cfg["data"]["lai_mode"],
        feature_layout=cfg["training"]["feature_layout"],
        target_mode=cfg["training"]["target_mode"],
        paths=cfg["data"]["paths"],
        engine=cfg["data"]["engine"],
    )
    print("Initializing validation dataset...")
    val_ds = ERA5LAISequenceWorld(
        years=cfg["splits"]["val_years"],
        seq_len=cfg["training"]["seq_len_samples"],
        seq_stride=cfg["training"]["seq_stride"],
        era5_mode=cfg["data"]["era5_mode"],
        lai_mode=cfg["data"]["lai_mode"],
        feature_layout=cfg["training"]["feature_layout"],
        target_mode=cfg["training"]["target_mode"],
        paths=cfg["data"]["paths"],
        engine=cfg["data"]["engine"],
    )
    print(f"✅ Training set has {len(train_ds)} samples.")
    print(f"✅ Validation set has {len(val_ds)} samples.")
    print("-" * 29)
    train_loader = DataLoader(
        train_ds, batch_size=cfg["training"]["batch_size"],
        shuffle=cfg["training"]["shuffle"], num_workers=0,
        collate_fn=collate_keep_meta, pin_memory=True # Can set pin_memory=True now
    )
    val_loader = DataLoader(
        val_ds, batch_size=cfg["training"]["batch_size"],
        shuffle=False, num_workers=0,
        collate_fn=collate_keep_meta, pin_memory=True # Can set pin_memory=True now
    )
    print("\n--- Initializing Model & Device ---")
    DEVICE_ID = 1
    if torch.cuda.is_available():
        if DEVICE_ID >= torch.cuda.device_count():
            print(f"⚠️ WARNING: Device ID {DEVICE_ID} is not available. Found {torch.cuda.device_count()} devices.")
            DEVICE_ID = 0
        device = torch.device(f"cuda:{DEVICE_ID}")
        torch.backends.cudnn.benchmark = True
        gpu_name = torch.cuda.get_device_name(device)
        print(f"✅ Using GPU: {gpu_name} (Device {DEVICE_ID})")
    else:
        device = torch.device("cpu")
        print("⚠️ WARNING: CUDA not available. Using CPU.")
    X0, _, _, _ = next(iter(train_loader))
    in_channels = X0.shape[1]
    model = TinyCNN(in_ch=in_channels)
    model.to(device)
    print(f"✅ Model initialized with {in_channels} input channels.")
    print("-" * 33)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler() if device.type == "cuda" else None
    EPOCHS = 2
    ckpt_path = os.path.join(script_dir, "tinycnn_ckpt.pt")
    loss_history = {"train": [], "val": []}
    def epoch_loop(loader, train=True, epoch_num=0, total_epochs=0):
        model.train(train)
        total_loss, n_batches = 0.0, 0
        mode = "Train" if train else "Val"
        loader_len = len(loader)
        for batch_idx, (X, y, mask, metas) in enumerate(loader):
            non_blocking = (device.type == "cuda")
            X = X.to(device, non_blocking=non_blocking)
            y = y.to(device, non_blocking=non_blocking)
            mask = mask.to(device, non_blocking=non_blocking)
            with torch.set_grad_enabled(train):
                with autocast(device_type=device.type, enabled=(scaler is not None)):
                    pred = model(X)
                    loss = masked_mse_loss(pred, y, mask)
                if train:
                    opt.zero_grad(set_to_none=True)
                    if scaler is not None:
                        scaler.scale(loss).backward()
                        scaler.step(opt)
                        scaler.update()
                    else:
                        loss.backward()
                        opt.step()
            total_loss += float(loss.detach())
            n_batches += 1
            if (batch_idx + 1) % 20 == 0 or (batch_idx + 1) == loader_len:
                progress = (batch_idx + 1) / loader_len
                print(f"\rEpoch {epoch_num}/{total_epochs} [{mode}] | Batch {batch_idx+1}/{loader_len} ({progress:.0%})", end="")
        print()
        return total_loss / max(n_batches, 1)
    print("\n--- Starting Training ---")
    start_time = time.time()
    for epoch in range(1, EPOCHS + 1):
        train_loss = epoch_loop(train_loader, train=True, epoch_num=epoch, total_epochs=EPOCHS)
        loss_history["train"].append(train_loss)
        val_loss = epoch_loop(val_loader, train=False, epoch_num=epoch, total_epochs=EPOCHS)
        loss_history["val"].append(val_loss)
        print(f"Epoch {epoch}/{EPOCHS} Summary | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")
        print("-" * 25)
    total_time = time.time() - start_time
    print(f"\n--- ✅ Training Complete in {total_time:.2f}s ---")
    torch.save({"model": model.state_dict(), "in_channels": in_channels}, ckpt_path)
    print(f"💾 Saved final model checkpoint to: {ckpt_path}")
    print("\n--- Loss Progression ---")
    print("Epoch | Train Loss | Val Loss")
    print("----------------------------")
    for i in range(EPOCHS):
        print(f"{i+1:^5} | {loss_history['train'][i]:<10.6f} | {loss_history['val'][i]:<10.6f}")
    print("----------------------------")


--- Configuration Loaded ---
{
  "splits": {
    "train_years": [
      1985,
      1986,
      1987,
      1988,
      1989,
      1990,
      1991,
      1992,
      1993,
      1994,
      1995,
      1996,
      1997,
      1998,
      1999,
      2000,
      2001,
      2002,
      2003,
      2004,
      2005,
      2006,
      2007,
      2008,
      2009
    ],
    "val_years": [
      2010,
      2011,
      2012,
      2013,
      2014,
      2015,
      2016,
      2017
    ]
  },
  "data": {
    "era5_mode": "raw",
    "lai_mode": "raw",
    "paths": {
      "era5_root": "/ptmp/mp002/ellis/lai",
      "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
      "lai_root": "/ptmp/mp002/ellis/lai/lai",
      "lai_tmpl": "LAI.1440.720.{year}.nc"
    },
    "engine": "netcdf4"
  },
  "training": {
    "seq_len_samples": 48,
    "seq_stride": 24,
    "feature_layout": "time_channels",
    "target_mode": "last",
    "batch_size": 1,
    "shuffle": true
  }
}
-----------------------