diff --git a/.gitignore b/.gitignore index 597dc5c0..7b2841d1 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,6 @@ pyrealm.egg-info # Data pyrealm_build_data/inputs_data_24.25.nc -pyrealm_build_data/eda.py # Profiling prof/ diff --git a/poetry.lock b/poetry.lock index b7876eff..a0f572a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3431,4 +3431,4 @@ xarray = [ zipp = [ {file = "zipp-3.18.1-py3-none-any.whl", hash = "sha256:206f5a15f2af3dbaee80769fb7dc6f249695e940acca08dfb2a4769fe61e538b"}, {file = "zipp-3.18.1.tar.gz", hash = "sha256:2884ed22e7d8961de1c9a05142eb69a247f120291bc0206a00a7642f09b5b715"}, -] +] \ No newline at end of file diff --git a/pyrealm_build_data/data_model.nc b/pyrealm_build_data/data_model.nc new file mode 100644 index 00000000..27e46de6 Binary files /dev/null and b/pyrealm_build_data/data_model.nc differ diff --git a/pyrealm_build_data/synth_data.py b/pyrealm_build_data/synth_data.py new file mode 100644 index 00000000..4aff6334 --- /dev/null +++ b/pyrealm_build_data/synth_data.py @@ -0,0 +1,107 @@ +"""This script uses a parametrized model to compress the input dataset. + +It fits a time series model to the input data and stores the model parameters. +The dataset can then be reconstructed from the model parameters using the `reconstruct` +function, provided with a custom time index. +""" +from typing import Sequence + +import numpy as np +import pandas as pd +import xarray as xr + + +def make_time_features(t: Sequence[float]) -> pd.DataFrame: + """Make time features for a given time index. + + The model can be written as + g(t) = a₀ + a₁ t + ∑_i b_i sin(2π f_i t) + c_i cos(2π f_i t), + where t is the time index, f_i are the frequencies, and a₀, a₁, b_i, c_i are the + model parameters. + + Args: + t: An array of datetime values. + """ + dt = pd.to_datetime(t).rename("time") + df = pd.DataFrame(index=dt).assign(const=1.0) + + df["linear"] = (dt - pd.Timestamp("2000-01-01")) / pd.Timedelta("365.25d") + + for f in [730.5, 365.25, 12, 6, 4, 3, 2, 1, 1 / 2, 1 / 3, 1 / 4, 1 / 6]: + df[f"freq_{f:.2f}_sin"] = np.sin(2 * np.pi * f * df["linear"]) + df[f"freq_{f:.2f}_cos"] = np.cos(2 * np.pi * f * df["linear"]) + + return df + + +def fit_ts_model(da: xr.DataArray, fs: pd.DataFrame) -> xr.DataArray: + """Fit a time series model to the data. + + Args: + da: A DataArray with the input data. + fs: A DataFrame with the time features. + """ + print("Fitting", da.name) + + da = da.isel(time=slice(None, None, 4)) # downsample along time + da = da.dropna("time", how="all") + da = da.fillna(da.mean("time")) + df = da.to_series().unstack("time").T + + Y = df.values # (times, locs) + X = fs.loc[df.index].values # (times, feats) + A, res, *_ = np.linalg.lstsq(X, Y, rcond=None) # (feats, locs) + + loss = np.mean(res) / len(X) / np.var(Y) + pars = pd.DataFrame(A.T, index=df.columns, columns=fs.columns) + + print("Loss:", loss) + return pars.to_xarray().to_dataarray() + + +def reconstruct( + ds: xr.Dataset, dt: Sequence[float], bounds: dict | None = None +) -> xr.Dataset: + """Reconstruct the full dataset from the model parameters. + + Args: + ds: A Dataset with the model parameters. + dt: An array of datetime values. + bounds: A dictionary with the bounds for the reconstructed variables. + """ + if bounds is None: + bounds = dict( + temp=(-25, 80), + patm=(3e4, 11e4), + vpd=(0, 1e4), + co2=(0, 1e3), + fapar=(0, 1), + ppfd=(0, 1e4), + ) + x = make_time_features(dt).to_xarray().to_dataarray() + ds = xr.Dataset({k: a @ x for k, a in ds.items()}) + ds = xr.Dataset({k: a.clip(*bounds[k]) for k, a in ds.items()}) + return ds + + +if __name__ == "__main__": + ds = xr.open_dataset("pyrealm_build_data/inputs_data_24.25.nc") + + # drop locations with all NaNs (for any variable) + mask = ~ds.isnull().all("time").to_dataarray().any("variable") + ds = ds.where(mask, drop=True) + + special_time_features = dict( + patm=["const"], + co2=["const", "linear"], + ) + + features = make_time_features(ds.time) + + model = xr.Dataset() + for k in ds.data_vars: + cols = special_time_features.get(k, features.columns) + model[k] = fit_ts_model(ds[k], features[cols]) + + model = model.fillna(0.0) + model.to_netcdf("pyrealm_build_data/data_model.nc") diff --git a/tests/regression/data/test_synth_data.py b/tests/regression/data/test_synth_data.py new file mode 100644 index 00000000..6e582689 --- /dev/null +++ b/tests/regression/data/test_synth_data.py @@ -0,0 +1,46 @@ +"""Test the quality of the synthetic data generated from the model parameters.""" + +import numpy as np +import pytest +import xarray as xr + +try: + DATASET = xr.open_dataset("pyrealm_build_data/inputs_data_24.25.nc") + VARS = DATASET.data_vars +except ValueError: + pytest.skip("Original LFS dataset not checked out.", allow_module_level=True) + + +def r2_score(y_true: xr.DataArray, y_pred: xr.DataArray) -> float: + """Compute the R2 score.""" + SSE = ((y_true - y_pred) ** 2).sum() + SST = ((y_true - y_true.mean()) ** 2).sum() + return 1 - SSE / SST + + +@pytest.fixture +def syndata(modelpath="pyrealm_build_data/data_model.nc"): + """The synthetic dataset.""" + from pyrealm_build_data.synth_data import reconstruct + + model = xr.open_dataset(modelpath) + ts = xr.date_range("2012-01-01", "2018-01-01", freq="12h") + return reconstruct(model, ts) + + +@pytest.fixture +def dataset(syndata): + """The original dataset.""" + return DATASET.sel(time=syndata.time) + + +@pytest.mark.parametrize("var", VARS) +def test_synth_data_quality(dataset, syndata, var): + """Test the quality of the synthetic data.""" + times = syndata.time[np.random.choice(syndata.time.size, 1000, replace=False)] + lats = syndata.lat[np.random.choice(syndata.lat.size, 100, replace=False)] + t = dataset[var].sel(lat=lats, time=times) + p = syndata[var].sel(lat=lats, time=times) + s = r2_score(t, p) + print(f"R2 score for {var} is {s:.2f}") + assert s > 0.85