In [1]:
# dataset_era5_lai_whole.py
import os
import numpy as np
import xarray as xr
import torch
from torch.utils.data import Dataset

In [2]:
class ERA5LAIWholeWorld(Dataset):
    """
    One YEAR, 24 items (one per 15-day sample). Each __getitem__ returns:
        X: (3, H, W)  -> [ssrd, t2m, tp]  (raw/anom/z as chosen)
        y: (1, H, W)  -> LAI (raw or anomaly)
        mask: (1, H, W) boolean (True where y is valid)
        meta: dict with 'year' and 'sample'
    """
    def __init__(
        self,
        year: int,
        era5_mode="anom",   # "raw" | "anom" | "z"
        lai_mode="raw",     # "raw" | "anom"
        paths=None,
        engine="netcdf4",
        robust_nan=True,
        sample_indices=None,  # default: all 0..23
    ):
        assert era5_mode in ("raw","anom","z")
        assert lai_mode in ("raw","anom")
        self.year = int(year)
        self.era5_mode = era5_mode
        self.lai_mode = lai_mode
        self.engine = engine
        self.robust_nan = robust_nan
        self.sample_indices = list(range(24)) if sample_indices is None else list(sample_indices)

        # default paths (override with 'paths' dict)
        default_paths = {
            "era5_root": "/ptmp/mp002/ellis/lai",
            "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",   # your anomalies/z-scores
            "lai_root": "/ptmp/mp002/ellis/lai/lai",
            "lai_tmpl": "LAI.1440.720.{year}.nc",
            "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",         # change if different
        }
        self.paths = default_paths if paths is None else {**default_paths, **paths}

        # open all arrays for that year
        self._open_year()

    def _open_da(self, path, varname):
        ds = xr.open_dataset(path, engine=self.engine)
        da = ds[varname]
        if self.robust_nan:
            fv = da.attrs.get("_FillValue", None)
            if fv is not None:
                da = da.where(da != fv)
        # north-up
        lat_name = "lat" if "lat" in da.coords else "latitude"
        da = da.sortby(lat_name)
        # attach sample coord
        if "time" in da.dims and da.sizes["time"] == 24:
            da = da.assign_coords(sample=("time", np.arange(24)))
        return da

    def _open_year(self):
        y = self.year

        # ERA5 inputs
        if self.era5_mode == "raw":
            f_ssrd = os.path.join(self.paths["era5_root"], "ssrd", f"ssrd.15daily.fc.era5.1440.720.{y}.nc")
            f_t2m  = os.path.join(self.paths["era5_root"], "t2m",  f"t2m.15daily.an.era5.1440.720.{y}.nc")
            f_tp   = os.path.join(self.paths["era5_root"], "tp",   f"tp.15daily.fc.era5.1440.720.{y}.nc")
            self.ssrd = self._open_da(f_ssrd, "ssrd")
            self.t2m  = self._open_da(f_t2m,  "t2m")
            self.tp   = self._open_da(f_tp,   "tp")
        else:
            suffix = "anom" if self.era5_mode == "anom" else "z"
            base = self.paths["era5_anom_dir"]
            self.ssrd = self._open_da(os.path.join(base, f"ssrd_{suffix}_{y}.nc"), f"ssrd_{suffix}")
            self.t2m  = self._open_da(os.path.join(base, f"t2m_{suffix}_{y}.nc"),  f"t2m_{suffix}")
            self.tp   = self._open_da(os.path.join(base, f"tp_{suffix}_{y}.nc"),   f"tp_{suffix}")

        # LAI target
        lai_file = os.path.join(self.paths["lai_root"], self.paths["lai_tmpl"].format(year=y))
        lai_var = self._infer_var(lai_file)
        self.lai = self._open_da(lai_file, lai_var)

        if self.lai_mode == "anom":
            # use your saved LAI anomalies per year if available
            lai_anom_dir = self.paths.get("lai_anom_dir", self.paths["lai_root"])
            lai_anom_file = os.path.join(lai_anom_dir, f"LAI_anom_{y}.nc")
            if os.path.exists(lai_anom_file):
                self.lai = self._open_da(lai_anom_file, "LAI_anom")
            else:
                raise FileNotFoundError(f"LAI anomaly file not found: {lai_anom_file}")

        # dims/coords
        self.lat_name = "lat" if "lat" in self.lai.coords else "latitude"
        self.lon_name = "lon" if "lon" in self.lai.coords else "longitude"

        # sanity checks
        for da, name in [(self.ssrd,"ssrd"),(self.t2m,"t2m"),(self.tp,"tp")]:
            assert "time" in da.dims and da.sizes["time"] == 24, f"{name} expects 24 samples"
            assert da.sizes[self.lat_name] == self.lai.sizes[self.lat_name]
            assert da.sizes[self.lon_name] == self.lai.sizes[self.lon_name]
        assert "time" in self.lai.dims and self.lai.sizes["time"] == 24

    def _infer_var(self, nc_path):
        with xr.open_dataset(nc_path, engine=self.engine) as ds:
            for v in ds.data_vars:
                if ds[v].ndim >= 2:
                    return v
        raise RuntimeError(f"No data variable in {nc_path}")

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, i):
        s = self.sample_indices[i]
        # features (H,W) each
        x_ssrd = self.ssrd.isel(time=s).values
        x_t2m  = self.t2m .isel(time=s).values
        x_tp   = self.tp  .isel(time=s).values
        # target
        y_lai  = self.lai .isel(time=s).values  # may contain NaNs over ocean

        # stack to tensors, channel-first
        X = np.stack([x_ssrd, x_t2m, x_tp], axis=0)              # (3,H,W)
        y = np.expand_dims(y_lai, axis=0)                       # (1,H,W)
        mask = ~np.isnan(y)                                     # valid target pixels

        # features: fill NaNs with 0 (or choose another fill)
        X = np.nan_to_num(X, nan=0.0)

        X = torch.from_numpy(X.astype(np.float32))
        y = torch.from_numpy(np.nan_to_num(y, nan=0.0).astype(np.float32))
        mask = torch.from_numpy(mask.astype(np.bool_))
        meta = {"year": self.year, "sample": int(s)}
        return X, y, mask, meta

# Loss helper (same as before)
def masked_mse_loss(pred, target, mask):
    diff2 = (pred - target) ** 2
    diff2 = diff2 * mask.float()
    denom = mask.float().sum().clamp_min(1.0)
    return diff2.sum() / denom

def collate_keep_meta(batch):
    """Custom collate so that 'meta' stays a list of dicts (not collated into tensors)."""
    Xs, ys, masks, metas = zip(*batch)
    return torch.stack(Xs), torch.stack(ys), torch.stack(masks), list(metas)

In [3]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch

ds = ERA5LAIWholeWorld(
    year=1990,
    era5_mode="raw",   # or "raw"/"z"
    lai_mode="raw",     # or "anom"
    paths={
        "era5_root": "/ptmp/mp002/ellis/lai",
        "era5_anom_dir": "/ptmp/mp040/outputdir/era5/anom",
        "lai_root": "/ptmp/mp002/ellis/lai/lai",
        "lai_tmpl": "LAI.1440.720.{year}.nc",
        # "lai_anom_dir": "/ptmp/mp002/ellis/lai/anom",  # if using LAI anomalies
    },
)

loader = DataLoader(ds, batch_size=1, shuffle=False)  # batch_size=1 since each item is full globe

In [4]:
from pysr import PySRRegressor

model = PySRRegressor(
    maxsize=20,
    niterations=40,  # < Increase me for better results
    binary_operators=["+", "*"],
    unary_operators=[
        "cos",
        "exp",
        "sin",
        "inv(x) = 1/x",
        # ^ Custom operator (julia syntax)
    ],
    extra_sympy_mappings={"inv": lambda x: 1 / x},
    # ^ Define operator for SymPy as well
    elementwise_loss="loss(prediction, target) = (prediction - target)^2",
    # ^ Custom loss function (julia syntax)
)



Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [None]:
for X, y, mask, meta in loader:
    X = X.squeeze(0).numpy()        # (3, H, W)
    y = y.squeeze(0).numpy()        # (1, H, W)
    mask = mask.squeeze(0).numpy()  # (1, H, W)

    # Flatten
    X_flat = X.reshape(3, -1).T      # (n_pixels, 3)
    y_flat = y.reshape(-1)           # (n_pixels,)
    mask_flat = mask.reshape(-1)     # (n_pixels,)

    # Keep only valid pixels
    X_flat = X_flat[mask_flat]
    y_flat = y_flat[mask_flat]

    # 🔹 Random subsample (e.g. 20,000 samples max)
    n_samples = min(20000, len(y_flat))
    idx = np.random.choice(len(y_flat), n_samples, replace=False)
    X_flat = X_flat[idx]
    y_flat = y_flat[idx]

    # Train/val split
    split = int(0.8 * n_samples)
    X_train, X_val = X_flat[:split], X_flat[split:]
    y_train, y_val = y_flat[:split], y_flat[split:]

    # Fit PySR
    model.fit(X_train, y_train)

    print("PySR finished. Discovered equations:")
    print(model)

    y_pred_val = model.predict(X_val)
    
    loss = masked_mse_loss(y_pred_val, y_val, mask_flat)
    print(f"year={meta['year'][0].item()}, sample={meta['sample'][0].item()}, loss={float(loss):.4f}")
    break

Compiling Julia backend...
[ Info: Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form.
[ Info: Started!



Expressions evaluated per second: 0.000e+00
Progress: 0 / 1240 total iterations (0.000%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and then <enter> to stop execution early.

Expressions evaluated per second: 2.310e+00
Progress: 1 / 1240 total iterations (0.081%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           2.170e+00  0.000e+00  y = 1.1143
──────────────────────────────────

In [None]:
import xarray as xr
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
from pysr import PySRRegressor

# ----------------------
# 1. Load data
# ----------------------
variables = {
    "lai": ds_lai["LAI"],
    "ssrd": ds_ssrd["SSRD"],
    "t2m": ds_t2m["T2M"],
    "tp": ds_tp["TP"],
}

# Stack spatial/temporal dimensions into samples
X = np.column_stack([
    variables["ssrd"].values.ravel(),
    variables["t2m"].values.ravel(),
    variables["tp"].values.ravel(),
])
y = variables["lai"].values.ravel()

# Remove NaNs
mask = ~np.isnan(X).any(axis=1) & ~np.isnan(y)
X, y = X[mask], y[mask]

# ----------------------
# 2. Downsample (optional)
# ----------------------
# Randomly select a subset (say 10k samples) for speed
n_samples = min(10000, len(y))
idx = np.random.choice(len(y), size=n_samples, replace=False)
X, y = X[idx], y[idx]

# ----------------------
# 3. Train/Validation split
# ----------------------
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# ----------------------
# 4. Define and fit PySR model
# ----------------------
model = PySRRegressor(
    niterations=100,  # increase later
    binary_operators=["+", "-", "*", "/"],
    unary_operators=["log", "exp", "sin", "cos"],
    model_selection="best",  # pick best equation
    progress=True,
    verbosity=1,
)

model.fit(X_train, y_train)

# ----------------------
# 5. Evaluate
# ----------------------
y_val_pred = model.predict(X_val)

r2 = r2_score(y_val, y_val_pred)
rmse = np.sqrt(mean_squared_error(y_val, y_val_pred))

print("\nBest equation found:")
print(model.get_best())

print(f"\nValidation R²: {r2:.3f}, RMSE: {rmse:.3f}")


In [None]:
# #!/usr/bin/env python3
# """
# pysr_lai.py

# Small PySR symbolic regression workflow to predict LAI from ssrd, t2m, tp.
# Usage:
#     python pysr_lai.py --lai /path/to/lai.nc --ssrd /path/to/ssrd.nc --t2m /path/to/t2m.nc --tp /path/to/tp.nc
# """

# import argparse
# import os
# import numpy as np
# import xarray as xr
# import pandas as pd
# import matplotlib.pyplot as plt
# from sklearn.preprocessing import StandardScaler
# from sklearn.model_selection import train_test_split
# from pysr import PySRRegressor  # make sure pysr is installed

# def open_primary_var(path):
#     ds = xr.open_dataset(path)
#     # If dataset contains exactly 1 data variable, return it; otherwise pick the first
#     data_vars = list(ds.data_vars)
#     if len(data_vars) == 0:
#         raise ValueError(f"No data variables found in {path}")
#     if len(data_vars) > 1:
#         print(f"Warning: {path} contains multiple data variables. Using '{data_vars[0]}' by default.")
#     var = data_vars[0]
#     return ds[var].load()  # load into memory (careful with size) - we coarsen later if needed

# def coarsen_all(arrays_dict, lat_step=10, lon_step=10, time_slice=None):
#     """Coarsen all arrays consistently by slicing every lat_step/lon_step and optionally restricting time."""
#     out = {}
#     for k, da in arrays_dict.items():
#         # ensure coordinates named 'latitude' / 'longitude' or accept variants
#         # We'll use index-based slicing to be robust
#         lat_slice = slice(None, None, lat_step)
#         lon_slice = slice(None, None, lon_step)
#         if time_slice is not None:
#             dsliced = da.isel(time=time_slice, latitude=lat_slice, longitude=lon_slice)
#         else:
#             dsliced = da.isel(latitude=lat_slice, longitude=lon_slice)
#         out[k] = dsliced
#     return out

# def stack_to_samples(da):
#     """Stack (time, latitude, longitude) into one axis named 'sample' and return values and coords."""
#     stacked = da.stack(sample=("time", da.dims[-2], da.dims[-1]))
#     return stacked.values, stacked

# def main(args):
#     # 1. Open datasets (auto variable selection)
#     print("Opening datasets...")
#     lai_da = open_primary_var(args.lai)
#     ssrd_da = open_primary_var(args.ssrd)
#     t2m_da = open_primary_var(args.t2m)
#     tp_da = open_primary_var(args.tp)

#     print("LAI variable dims:", lai_da.dims, "shape:", lai_da.shape)
#     # 2. Coarsen / subsample to keep memory reasonable
#     print("Coarsening with lat_step=", args.lat_step, " lon_step=", args.lon_step)
#     arrays = {"lai": lai_da, "ssrd": ssrd_da, "t2m": t2m_da, "tp": tp_da}
#     arrays = coarsen_all(arrays, lat_step=args.lat_step, lon_step=args.lon_step, time_slice=args.time_slice)

#     # 3. Optionally further restrict time range
#     # (time_slice taken in coarsen_all above)

#     # 4. Stack to (samples,)
#     print("Stacking arrays to samples...")
#     y_vals, y_stack = stack_to_samples(arrays["lai"])
#     X_list = []
#     names = []
#     for name in ["ssrd", "t2m", "tp"]:
#         vals, _ = stack_to_samples(arrays[name])
#         X_list.append(vals)
#         names.append(name)
#     X = np.vstack(X_list).T  # shape (n_samples, 3)
#     print("X shape:", X.shape, "y shape:", y_vals.shape)

#     # 5. Clean NaNs
#     print("Cleaning NaNs...")
#     mask = ~np.isnan(X).any(axis=1) & ~np.isnan(y_vals)
#     print(f"Samples before: {X.shape[0]}, after removing NaNs: {mask.sum()}")
#     X_clean = X[mask]
#     y_clean = y_vals[mask]

#     # 6. Optionally subsample for speed
#     if args.max_samples and mask.sum() > args.max_samples:
#         rng = np.random.default_rng(args.random_seed)
#         idx = rng.choice(np.arange(mask.sum()), size=args.max_samples, replace=False)
#         X_clean = X_clean[idx]
#         y_clean = y_clean[idx]
#         print(f"Subsampled to {args.max_samples} samples for PySR speed.")

#     # 7. Standardize features (helps symbolic regression numeric stability)
#     scaler = StandardScaler()
#     X_scaled = scaler.fit_transform(X_clean)

#     # 8. Train/test split
#     X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_clean, test_size=args.test_size, random_state=args.random_seed)
#     print("Train/test sizes:", X_train.shape[0], X_test.shape[0])

#     # 9. Run PySR (small config by default; tune niterations for better results)
#     print("Running PySR symbolic regression...")
#     model = PySRRegressor(
#         niterations=args.niterations,
#         binary_operators=["+", "-", "*", "/"],
#         unary_operators=["sin", "cos", "exp", "log", "sqrt"],
#         populationsize=args.population_size,
#         maxsize=args.maxsize,
#         ncyclesperiteration=60,
#         loss="loss01",  # default: mean absolute error-like; you can change
#         model_selection="best",
#         timeout=args.timeout,  # seconds (None -> no timeout)
#         tempdir=args.output_dir,
#         multithreading=args.n_jobs,
#         verbosity=1,
#         progress=args.show_progress
#     )

#     # Provide feature names for clearer equations
#     feature_names = names
#     model.fit(X_train, y_train, feature_names=feature_names)

#     print("PySR finished. Equations:")
#     print(model)

#     # 10. Evaluate on test set
#     y_pred_test = model.predict(X_test)
#     r2_test = np.corrcoef(y_test, y_pred_test)[0,1]**2  # approx R^2
#     print(f"Approx R^2 on test set (corr-based): {r2_test:.4f}")

#     # 11. Save best equations table
#     eqs_df = model.equations_
#     eqs_path = os.path.join(args.output_dir, "pysr_equations.csv")
#     eqs_df.to_csv(eqs_path, index=False)
#     print("Saved equations to:", eqs_path)

#     # 12. Create predictions for the full stacked sample set (where mask true)
#     print("Predicting full grid (where data available)...")
#     # Scale all non-NaN X to feed into model:
#     X_all = X[mask]  # corresponds to y_clean rows
#     X_all_scaled = scaler.transform(X_all if len(X_all.shape) == 2 else X_all.reshape(-1, X_all.shape[-1]))
#     y_all_pred = model.predict(X_all_scaled)

#     # Reconstruct y_pred into full sample-shaped array (fill with NaNs where mask was False)
#     y_full = np.full(y_vals.shape, np.nan, dtype=float)
#     y_full[mask] = y_all_pred

#     # Put back into DataArray with original stacked coords, then unstack back to (time, lat, lon)
#     pred_da = xr.DataArray(y_full, coords=[y_stack.sample], dims=["sample"])
#     pred_da = pred_da.unstack("sample")
#     pred_da.name = "lai_pred"
#     pred_da.attrs["note"] = "Predicted LAI from PySR symbolic regression"

#     # Save predicted netcdf
#     pred_nc_path = os.path.join(args.output_dir, "lai_pred_pysr.nc")
#     pred_da.to_dataset().to_netcdf(pred_nc_path)
#     print("Saved prediction netCDF to:", pred_nc_path)

#     # 13. Quick diagnostic scatter plot test vs observed (random subset)
#     plt.figure(figsize=(6,6))
#     rng = np.random.default_rng(args.random_seed)
#     nplot = min(5000, len(y_test))
#     sel = rng.choice(len(y_test), size=nplot, replace=False)
#     plt.scatter(y_test[sel], y_pred_test[sel], s=2, alpha=0.6)
#     plt.xlabel("observed LAI (test)")
#     plt.ylabel("predicted LAI")
#     plt.title("PySR LAI: observed vs predicted (test set)")
#     plt.grid(True)
#     scatter_path = os.path.join(args.output_dir, "obs_vs_pred_test.png")
#     plt.savefig(scatter_path, dpi=150)
#     plt.close()
#     print("Saved scatter plot to:", scatter_path)

#     print("Done. Check output folder:", args.output_dir)


# if __name__ == "__main__":
#     p = argparse.ArgumentParser(description="Small PySR workflow to predict LAI from ssrd, t2m, tp")
#     p.add_argument("--lai", required=True, help="Path to LAI netcdf")
#     p.add_argument("--ssrd", required=True, help="Path to SSRD netcdf")
#     p.add_argument("--t2m", required=True, help="Path to T2M netcdf")
#     p.add_argument("--tp", required=True, help="Path to TP netcdf")
#     p.add_argument("--lat-step", type=int, default=10, help="Spatial subsampling step for latitude (default 10)")
#     p.add_argument("--lon-step", type=int, default=10, help="Spatial subsampling step for longitude (default 10)")
#     p.add_argument("--time-slice", type=int, nargs='+', default=None,
#                    help="Optional time slice indices for .isel(time=...) e.g. --time-slice 0 23 (use first 24 timesteps).")
#     p.add_argument("--max-samples", type=int, default=200000, help="Max number of samples to keep for PySR (subsamples randomly).")
#     p.add_argument("--niterations", type=int, default=40, help="PySR niterations (increase for better results).")
#     p.add_argument("--population-size", type=int, dest="population_size", default=100, help="PySR population size.")
#     p.add_argument("--maxsize", type=int, default=20, help="Max expression size for PySR.")
#     p.add_argument("--timeout", type=int, default=None, help="Timeout in seconds for PySR (optional).")
#     p.add_argument("--n-jobs", type=int, default=4, dest="n_jobs", help="Number of threads for PySR.")
#     p.add_argument("--test-size", type=float, default=0.2, help="Test set proportion.")
#     p.add_argument("--random-seed", type=int, default=0, help="Random seed.")
#     p.add_argument("--output-dir", default="pysr_output", help="Directory to store outputs.")
#     p.add_argument("--show-progress", action="store_true", help="Show PySR progress if supported.")
#     args = p.parse_args()

#     # normalize time_slice format for isel usage
#     if args.time_slice is not None:
#         # if user provided two ints like "0 23", we make slice(0,24)
#         if len(args.time_slice) == 2:
#             start, stop = args.time_slice
#             args.time_slice = slice(start, stop + 1)
#         else:
#             # if list of indices, keep as list
#             args.time_slice = args.time_slice

#     os.makedirs(args.output_dir, exist_ok=True)
#     main(args)
