In [None]:
from pysr import PySRRegressor

model = PySRRegressor(
    maxsize=20,
    niterations=40,  # < Increase me for better results
    binary_operators=["+", "*"],
    unary_operators=[
        "cos",
        "exp",
        "sin",
        "inv(x) = 1/x",
        # ^ Custom operator (julia syntax)
    ],
    extra_sympy_mappings={"inv": lambda x: 1 / x},
    # ^ Define operator for SymPy as well
    elementwise_loss="loss(prediction, target) = (prediction - target)^2",
    # ^ Custom loss function (julia syntax)
)

Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython


In [3]:
import numpy as np

X = 2 * np.random.randn(100, 5)
y = 2.5382 * np.cos(X[:, 3]) + X[:, 0] ** 2 - 0.5

In [4]:
model.fit(X, y)

Compiling Julia backend...


[ Info: Started!



Expressions evaluated per second: 1.020e+02
Progress: 1 / 1240 total iterations (0.081%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
2           9.276e+01  0.000e+00  y = inv(-0.19246)
3           3.045e+01  1.114e+00  y = inv(cos(-0.19246))
4           2.526e+01  1.869e-01  y = cos(x₃) + 2.2271
7           2.485e+01  5.436e-03  y = cos(x₃) + (2.2271 + cos(0.97859))
8           2.361e+01  5.141e-02  y = sin((x₀ + -5.0124) * -0.88503) + 2.0254
12          1.924e+01  5.117e-02  y = (cos(inv(x₀)) + 2.9442) + sin((x₀ + 10.257) * -0.75884...
                                      )
───────────────────────────────────────────────────────────────────────────────────────────────────
════════════════════════════════════════════════════════════════════════════════════════════════════
Press 'q' and t

[ Info: Final population:
[ Info: Results saved to:


───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           2.682e+01  0.000e+00  y = 2.9243
3           3.098e+00  1.079e+00  y = x₀ * x₀
5           2.963e+00  2.234e-02  y = (x₀ * x₀) + -0.36806
6           1.265e+00  8.515e-01  y = (x₀ * x₀) + cos(x₃)
8           2.485e-01  8.135e-01  y = (x₀ * x₀) + (cos(x₃) * 2.482)
10          9.876e-14  1.428e+01  y = (cos(x₃) * 2.5382) + ((x₀ * x₀) + -0.5)
12          7.076e-14  1.667e-01  y = ((cos(x₃) * 2.5382) + -0.15571) + ((x₀ * x₀) + -0.3442...
                                      9)
15          6.818e-14  1.238e-02  y = (((x₀ * x₀) + 0.23184) + ((cos(x₃) + -0.47577) * 1.538...
                                      2)) + cos(x₃)
17          5.726e-14  8.730e-02  y = ((cos(x₃) + -0.47577) * 1.5382) + ((cos(x₃) + 0.13098)...
                                       + ((x₀ * x₀) + 0.10085))
──────────────────────────────────────────────────────────

0,1,2
,model_selection,'best'
,binary_operators,"['+', '*']"
,unary_operators,"['cos', 'exp', ...]"
,expression_spec,
,niterations,40
,populations,31
,population_size,27
,max_evals,
,maxsize,20
,maxdepth,


  - outputs/20250902_175247_HikCt6/hall_of_fame.csv


In [5]:
model.predict(X)

array([ 4.47729852,  2.78642053,  1.74733907, -1.9312365 , -1.57530461,
       -1.24139381, 13.0107179 ,  3.08724434, -2.72198741,  2.0981466 ,
       -0.21824326,  1.29639098,  1.29659926,  2.19508346,  4.98519675,
       -2.19480198, 13.12935384,  4.95843753,  0.50747679, -1.82470013,
        3.44137442,  0.65295236,  4.0651504 ,  1.90214316, -1.6480471 ,
        0.95748693,  1.53222094,  4.85537858,  3.79724522, -2.75796795,
        2.06282375,  1.06084474, -1.59026745,  5.86382525,  1.32769494,
        2.40792131,  0.198897  ,  7.98189258,  2.02076533, -2.07578854,
       -0.58318002, 29.0451265 , 11.085375  , -0.58234741,  1.85960329,
       -2.10544314, 13.80732442,  2.24209418, 12.71833664,  1.88399864,
       -1.69675951,  3.37415439,  9.87943115, -2.60027233,  5.48618573,
        1.83716347,  4.26441318,  2.03607359,  0.55870556, -2.7350578 ,
       -0.05600902,  0.22756228,  3.84701081, -0.65990045,  2.2479534 ,
       14.63906183,  0.37489241,  0.28753309, 10.59552612, 13.94

In [6]:
print(model)

PySRRegressor.equations_ = [
	   pick      score                                           equation  \
	0         0.000000                                          2.9243407   
	1         1.079113                                            x0 * x0   
	2         0.022341                            (x0 * x0) + -0.36806437   
	3         0.851465                                (x0 * x0) + cos(x3)   
	4         0.813458                  (x0 * x0) + (cos(x3) * 2.4819562)   
	5        14.276945         (cos(x3) * 2.5382001) + ((x0 * x0) + -0.5)   
	6  >>>>   0.166723  ((cos(x3) * 2.5382001) + -0.15570687) + ((x0 *...   
	7         0.012382  (((x0 * x0) + 0.23183653) + ((cos(x3) + -0.475...   
	8         0.087297  ((cos(x3) + -0.47577462) * 1.5382) + ((cos(x3)...   
	
	           loss  complexity  
	0  2.681955e+01           1  
	1  3.098457e+00           3  
	2  2.963057e+00           5  
	3  1.264601e+00           6  
	4  2.485377e-01           8  
	5  9.876336e-14          10  
	6  7.075909

In [None]:
#!/usr/bin/env python3
"""
pysr_lai.py

Small PySR symbolic regression workflow to predict LAI from ssrd, t2m, tp.
Usage:
    python pysr_lai.py --lai /path/to/lai.nc --ssrd /path/to/ssrd.nc --t2m /path/to/t2m.nc --tp /path/to/tp.nc
"""

import argparse
import os
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from pysr import PySRRegressor  # make sure pysr is installed

def open_primary_var(path):
    ds = xr.open_dataset(path)
    # If dataset contains exactly 1 data variable, return it; otherwise pick the first
    data_vars = list(ds.data_vars)
    if len(data_vars) == 0:
        raise ValueError(f"No data variables found in {path}")
    if len(data_vars) > 1:
        print(f"Warning: {path} contains multiple data variables. Using '{data_vars[0]}' by default.")
    var = data_vars[0]
    return ds[var].load()  # load into memory (careful with size) - we coarsen later if needed

def coarsen_all(arrays_dict, lat_step=10, lon_step=10, time_slice=None):
    """Coarsen all arrays consistently by slicing every lat_step/lon_step and optionally restricting time."""
    out = {}
    for k, da in arrays_dict.items():
        # ensure coordinates named 'latitude' / 'longitude' or accept variants
        # We'll use index-based slicing to be robust
        lat_slice = slice(None, None, lat_step)
        lon_slice = slice(None, None, lon_step)
        if time_slice is not None:
            dsliced = da.isel(time=time_slice, latitude=lat_slice, longitude=lon_slice)
        else:
            dsliced = da.isel(latitude=lat_slice, longitude=lon_slice)
        out[k] = dsliced
    return out

def stack_to_samples(da):
    """Stack (time, latitude, longitude) into one axis named 'sample' and return values and coords."""
    stacked = da.stack(sample=("time", da.dims[-2], da.dims[-1]))
    return stacked.values, stacked

def main(args):
    # 1. Open datasets (auto variable selection)
    print("Opening datasets...")
    lai_da = open_primary_var(args.lai)
    ssrd_da = open_primary_var(args.ssrd)
    t2m_da = open_primary_var(args.t2m)
    tp_da = open_primary_var(args.tp)

    print("LAI variable dims:", lai_da.dims, "shape:", lai_da.shape)
    # 2. Coarsen / subsample to keep memory reasonable
    print("Coarsening with lat_step=", args.lat_step, " lon_step=", args.lon_step)
    arrays = {"lai": lai_da, "ssrd": ssrd_da, "t2m": t2m_da, "tp": tp_da}
    arrays = coarsen_all(arrays, lat_step=args.lat_step, lon_step=args.lon_step, time_slice=args.time_slice)

    # 3. Optionally further restrict time range
    # (time_slice taken in coarsen_all above)

    # 4. Stack to (samples,)
    print("Stacking arrays to samples...")
    y_vals, y_stack = stack_to_samples(arrays["lai"])
    X_list = []
    names = []
    for name in ["ssrd", "t2m", "tp"]:
        vals, _ = stack_to_samples(arrays[name])
        X_list.append(vals)
        names.append(name)
    X = np.vstack(X_list).T  # shape (n_samples, 3)
    print("X shape:", X.shape, "y shape:", y_vals.shape)

    # 5. Clean NaNs
    print("Cleaning NaNs...")
    mask = ~np.isnan(X).any(axis=1) & ~np.isnan(y_vals)
    print(f"Samples before: {X.shape[0]}, after removing NaNs: {mask.sum()}")
    X_clean = X[mask]
    y_clean = y_vals[mask]

    # 6. Optionally subsample for speed
    if args.max_samples and mask.sum() > args.max_samples:
        rng = np.random.default_rng(args.random_seed)
        idx = rng.choice(np.arange(mask.sum()), size=args.max_samples, replace=False)
        X_clean = X_clean[idx]
        y_clean = y_clean[idx]
        print(f"Subsampled to {args.max_samples} samples for PySR speed.")

    # 7. Standardize features (helps symbolic regression numeric stability)
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_clean)

    # 8. Train/test split
    X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_clean, test_size=args.test_size, random_state=args.random_seed)
    print("Train/test sizes:", X_train.shape[0], X_test.shape[0])

    # 9. Run PySR (small config by default; tune niterations for better results)
    print("Running PySR symbolic regression...")
    model = PySRRegressor(
        niterations=args.niterations,
        binary_operators=["+", "-", "*", "/"],
        unary_operators=["sin", "cos", "exp", "log", "sqrt"],
        populationsize=args.population_size,
        maxsize=args.maxsize,
        ncyclesperiteration=60,
        loss="loss01",  # default: mean absolute error-like; you can change
        model_selection="best",
        timeout=args.timeout,  # seconds (None -> no timeout)
        tempdir=args.output_dir,
        multithreading=args.n_jobs,
        verbosity=1,
        progress=args.show_progress
    )

    # Provide feature names for clearer equations
    feature_names = names
    model.fit(X_train, y_train, feature_names=feature_names)

    print("PySR finished. Equations:")
    print(model)

    # 10. Evaluate on test set
    y_pred_test = model.predict(X_test)
    r2_test = np.corrcoef(y_test, y_pred_test)[0,1]**2  # approx R^2
    print(f"Approx R^2 on test set (corr-based): {r2_test:.4f}")

    # 11. Save best equations table
    eqs_df = model.equations_
    eqs_path = os.path.join(args.output_dir, "pysr_equations.csv")
    eqs_df.to_csv(eqs_path, index=False)
    print("Saved equations to:", eqs_path)

    # 12. Create predictions for the full stacked sample set (where mask true)
    print("Predicting full grid (where data available)...")
    # Scale all non-NaN X to feed into model:
    X_all = X[mask]  # corresponds to y_clean rows
    X_all_scaled = scaler.transform(X_all if len(X_all.shape) == 2 else X_all.reshape(-1, X_all.shape[-1]))
    y_all_pred = model.predict(X_all_scaled)

    # Reconstruct y_pred into full sample-shaped array (fill with NaNs where mask was False)
    y_full = np.full(y_vals.shape, np.nan, dtype=float)
    y_full[mask] = y_all_pred

    # Put back into DataArray with original stacked coords, then unstack back to (time, lat, lon)
    pred_da = xr.DataArray(y_full, coords=[y_stack.sample], dims=["sample"])
    pred_da = pred_da.unstack("sample")
    pred_da.name = "lai_pred"
    pred_da.attrs["note"] = "Predicted LAI from PySR symbolic regression"

    # Save predicted netcdf
    pred_nc_path = os.path.join(args.output_dir, "lai_pred_pysr.nc")
    pred_da.to_dataset().to_netcdf(pred_nc_path)
    print("Saved prediction netCDF to:", pred_nc_path)

    # 13. Quick diagnostic scatter plot test vs observed (random subset)
    plt.figure(figsize=(6,6))
    rng = np.random.default_rng(args.random_seed)
    nplot = min(5000, len(y_test))
    sel = rng.choice(len(y_test), size=nplot, replace=False)
    plt.scatter(y_test[sel], y_pred_test[sel], s=2, alpha=0.6)
    plt.xlabel("observed LAI (test)")
    plt.ylabel("predicted LAI")
    plt.title("PySR LAI: observed vs predicted (test set)")
    plt.grid(True)
    scatter_path = os.path.join(args.output_dir, "obs_vs_pred_test.png")
    plt.savefig(scatter_path, dpi=150)
    plt.close()
    print("Saved scatter plot to:", scatter_path)

    print("Done. Check output folder:", args.output_dir)


if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Small PySR workflow to predict LAI from ssrd, t2m, tp")
    p.add_argument("--lai", required=True, help="Path to LAI netcdf")
    p.add_argument("--ssrd", required=True, help="Path to SSRD netcdf")
    p.add_argument("--t2m", required=True, help="Path to T2M netcdf")
    p.add_argument("--tp", required=True, help="Path to TP netcdf")
    p.add_argument("--lat-step", type=int, default=10, help="Spatial subsampling step for latitude (default 10)")
    p.add_argument("--lon-step", type=int, default=10, help="Spatial subsampling step for longitude (default 10)")
    p.add_argument("--time-slice", type=int, nargs='+', default=None,
                   help="Optional time slice indices for .isel(time=...) e.g. --time-slice 0 23 (use first 24 timesteps).")
    p.add_argument("--max-samples", type=int, default=200000, help="Max number of samples to keep for PySR (subsamples randomly).")
    p.add_argument("--niterations", type=int, default=40, help="PySR niterations (increase for better results).")
    p.add_argument("--population-size", type=int, dest="population_size", default=100, help="PySR population size.")
    p.add_argument("--maxsize", type=int, default=20, help="Max expression size for PySR.")
    p.add_argument("--timeout", type=int, default=None, help="Timeout in seconds for PySR (optional).")
    p.add_argument("--n-jobs", type=int, default=4, dest="n_jobs", help="Number of threads for PySR.")
    p.add_argument("--test-size", type=float, default=0.2, help="Test set proportion.")
    p.add_argument("--random-seed", type=int, default=0, help="Random seed.")
    p.add_argument("--output-dir", default="pysr_output", help="Directory to store outputs.")
    p.add_argument("--show-progress", action="store_true", help="Show PySR progress if supported.")
    args = p.parse_args()

    # normalize time_slice format for isel usage
    if args.time_slice is not None:
        # if user provided two ints like "0 23", we make slice(0,24)
        if len(args.time_slice) == 2:
            start, stop = args.time_slice
            args.time_slice = slice(start, stop + 1)
        else:
            # if list of indices, keep as list
            args.time_slice = args.time_slice

    os.makedirs(args.output_dir, exist_ok=True)
    main(args)
