In [None]:
import xarray as xr
import numpy as np
import sys
from pathlib import Path
import matplotlib.pyplot as plt

PROJECT_ROOT = Path("/Net/Groups/BGI/people/ecathain/TRENDY_Emulator_Scripts/NewModel")
sys.path.append(str(PROJECT_ROOT))
from src.utils.tools import slurm_shard


### Find all NC Files and Build a list of tasks

In [3]:
dir = Path("/Net/Groups/BGI/people/ecathain/TRENDY_Emulator_Scripts/NewModel/data/predictions/stabilised_32_year/32_year_scenarios/netcdf/S3")

nc_files = list(dir.rglob("*.nc"))

tasks = slurm_shard(nc_files)

out_dir = Path("/Net/Groups/BGI/people/ecathain/TRENDY_Emulator_Scripts/NewModel/data/analysis/metrics/Stable-Emulator/S3")

[INFO] No SLURM array vars; processing all items.


In [4]:
data = {}
for f in tasks:
    ds = xr.open_dataset(f)
    var = list(ds.data_vars)[0]
    print(var)
    data[var] = ds[var]
    
annual_vars = ['cVeg', 'cSoil', 'cLitter']

mrro
npp
cTotal
fLuc
cVeg
nbp
ra
evapotrans
rh
fFire
cSoil
cLitter
gpp
lai
mrso


### Means

In [28]:
def save_netcdf(da, path, overwrite=True):
    if overwrite and path.exists():
        path.unlink()
    path.parent.mkdir(parents=True, exist_ok=True)
    da.to_netcdf(path)

In [6]:
time_means = {}
for var, da in data.items():
    mean = da.mean(dim="time")
    time_means[var] = mean
    print(mean.shape) 
    out_path = out_dir / "mean" / "time_mean" / f"{var}_time_mean.nc"
    save_netcdf(mean, out_path)

(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)
(360, 720)


In [9]:
space_mean = {}
for var, da in data.items():
    if var in annual_vars:
        da = da.resample(time="YE").mean()

    mean = da.mean(dim=["lat", "lon"])
    space_mean[var] = mean
    print(mean.shape)
    out_path = out_dir / "mean" / "space_mean" / f"{var}_space_mean.nc"
    save_netcdf(mean, out_path)

(1476,)
(1476,)
(1476,)
(1476,)
(123,)
(1476,)
(1476,)
(1476,)
(1476,)
(1476,)
(123,)
(123,)
(1476,)
(1476,)
(1476,)


### Seasonality

In [29]:
seasonality_3D = {}

for var, da in data.items():
    if var not in ['cVeg','cSoil','cLitter']:  # Skip non-monthly variables
        
        monthly_mean = da.groupby("time.month").mean(dim="time")  # shape: month=12,...

        # index of month for each timestep
        month_index = xr.DataArray(
            da["time.month"].values,
            dims=("time",),
            coords={"time": da.time},
            name="month",
        )

        # select per-time-step monthly mean
        monthly_mean_repeated = monthly_mean.sel(month=month_index)

        seasonality_3D[var] = monthly_mean_repeated
        print(var, monthly_mean_repeated.shape)
        out_path = out_dir / "seasonality_3D" / f"{var}_seasonality_3D.nc"
        save_netcdf(monthly_mean_repeated, out_path)

mrro (1476, 360, 720)
npp (1476, 360, 720)
cTotal (1476, 360, 720)
fLuc (1476, 360, 720)
nbp (1476, 360, 720)
ra (1476, 360, 720)
evapotrans (1476, 360, 720)
rh (1476, 360, 720)
fFire (1476, 360, 720)
gpp (1476, 360, 720)
lai (1476, 360, 720)
mrso (1476, 360, 720)


In [30]:
seasonality_1D = {}

for var, da in data.items():
    if var not in annual_vars:  # Skip non-monthly variables
        
        monthly_mean = da.groupby("time.month").mean(dim=["time", "lat", "lon"])  # shape: month=12,...

        seasonality_1D[var] = monthly_mean
        print(var, monthly_mean.shape)
        out_path = out_dir / "seasonality_1D" / f"{var}_seasonality_1D.nc"
        save_netcdf(monthly_mean, out_path)

mrro (12,)
npp (12,)
cTotal (12,)
fLuc (12,)
nbp (12,)
ra (12,)
evapotrans (12,)
rh (12,)
fFire (12,)
gpp (12,)
lai (12,)
mrso (12,)


### Trend

In [40]:
trend_2D = {}

for var, da in data.items():
    # 1. Aggregate to annual
    annual_da = da.resample(time="YE").mean()

    # 2. Use 0..N-1 as "years since start"
    n_years = annual_da.sizes["time"]
    annual_da = annual_da.assign_coords(time=np.arange(n_years))

    # 3. Fit linear trend → coeffs (degree, lat, lon)
    pf = annual_da.polyfit(dim="time", deg=1)
    coeff = pf["polyfit_coefficients"]

    # 4. Simple attrs
    orig_units = da.attrs.get("units", "")

    # intercept uses orig units; slope uses orig_units yr-1
    if orig_units:
        coeff.attrs["units"] = orig_units              # applies to intercept
        coeff.attrs["slope_units"] = f"{orig_units} yr-1"
    else:
        coeff.attrs["units"] = ""
        coeff.attrs["slope_units"] = "yr-1"

    # time units info for the fit
    coeff.attrs["time_units"] = "years since 1901-01-01"

    trend_2D[var] = coeff
    print(var, coeff.shape)

    out_path = out_dir / "trend_2D" / f"{var}_trend_2D.nc"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    save_netcdf(coeff, out_path)

mrro (2, 360, 720)
npp (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


cTotal (2, 360, 720)
fLuc (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


cVeg (2, 360, 720)
nbp (2, 360, 720)
ra (2, 360, 720)
evapotrans (2, 360, 720)
rh (2, 360, 720)
fFire (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


cSoil (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


cLitter (2, 360, 720)
gpp (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


lai (2, 360, 720)


  warn_on_deficient_rank(rank, x.shape[1])


mrso (2, 360, 720)


### 1D Trend

In [37]:
trend_1d = {}

for var, trend in trend_2D.items():
    mean_trend = trend.mean(dim=["lat", "lon"])
    trend_1d[var] = mean_trend
    print(var, mean_trend.shape)
    out_path = out_dir / "trend_1D" / f"{var}_trend_1D.nc"
    save_netcdf(mean_trend, out_path)

mrro (2,)
npp (2,)
cTotal (2,)
fLuc (2,)
cVeg (2,)
nbp (2,)
ra (2,)
evapotrans (2,)
rh (2,)
fFire (2,)
cSoil (2,)
cLitter (2,)
gpp (2,)
lai (2,)
mrso (2,)


### Inter-Annual Variability

In [None]:
iav_3D = {}

for var, da in data.items():
    # 1. Annual mean
    annual_da = da.resample(time="YE").mean()   # (time_years, lat, lon)

    # 2. Use same annual time coord as trend
    n_years = annual_da.sizes["time"]
    annual_da = annual_da.assign_coords(time=np.arange(n_years))

    # 3. Retrieve slope & intercept for this variable
    trend = trend_2D[var]             
    intercept = trend.sel(degree=0)
    slope = trend.sel(degree=1)

    # 4. Build the fitted annual trend line (per year)
    t = annual_da["time"]           
    fitted = intercept + slope * t    

    # 5. Detrend annual series → IAV (annual anomalies)
    iav_da = annual_da - fitted    
    
    # Add some useful metadata
    iav_da.attrs["long_name"] = f"Interannual variability of {var}"
    iav_da.attrs["trend_time_units"] = "years since 1901-01-01"

    iav_3D[var] = iav_da

    print(var, iav_da.shape)
    
    # Save
    out_path = out_dir / "iav_3D" / f"{var}_iav_3D.nc"
    save_netcdf(iav_da, out_path)

mrro (123, 360, 720)
npp (123, 360, 720)
cTotal (123, 360, 720)
fLuc (123, 360, 720)
cVeg (123, 360, 720)
nbp (123, 360, 720)
ra (123, 360, 720)
evapotrans (123, 360, 720)
rh (123, 360, 720)
fFire (123, 360, 720)
cSoil (123, 360, 720)
cLitter (123, 360, 720)
gpp (123, 360, 720)
lai (123, 360, 720)
mrso (123, 360, 720)
