In [1]:
import xarray as xr
import numpy as np
import os
from datetime import datetime
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [2]:
def build_lsm_dataset(root_dir, file_prefix, varnames, start_year=2000, end_year=2024):
    all_files = []
    all_dates = []

    for year in range(start_year, end_year + 1):
        for month in range(1, 13):
            filename = f"{file_prefix}.tavg24_1d_lnd_Nt.monthly.{year:04d}{month:02d}.nc4"
            fpath = os.path.join(
                root_dir,
                f"Y{year:04d}",
                f"M{month:02d}",
                filename
            )
            if os.path.exists(fpath):
                all_files.append(fpath)
                all_dates.append(np.datetime64(f"{year:04d}-{month:02d}-01"))

    print(f"Found {len(all_files)} files.")

    # Load static lat/lon from first file
    with xr.open_dataset(all_files[0]) as ds_static:
        lat = ds_static["lat"]
        lon = ds_static["lon"]

    # Define preprocess to extract just the desired variables
    def _preprocess(ds):
        return ds[[v for v in varnames if v in ds.variables]]

    # Load all time-varying data
    ds = xr.open_mfdataset(
        all_files,
        combine="nested",
        concat_dim="time",
        parallel=True,
        engine="netcdf4",
        preprocess=_preprocess
    )

    # Assign fixed lat/lon and real time
    ds = ds.assign_coords({
        "time": ("time", np.array(all_dates, dtype="datetime64[ns]")),
        "lat": ("tile", lat.data),
        "lon": ("tile", lon.data)
    })

    # Mask invalid fill values
    for var in varnames:
        if var in ds:
            ds[var] = ds[var].where(ds[var] < 1e10)

    return ds

In [3]:
ds_ol = build_lsm_dataset(
    root_dir="/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg",
    file_prefix="LS_OLv8_M36",
    varnames=["SFMC", "RZMC", "PRECTOTCORRLAND", "FRLANDSNO", "TSOIL1"]
)

Found 288 files.


In [4]:
ds_da = build_lsm_dataset(
    root_dir="/discover/nobackup/projects/land_da/M21C_land_sweeper/LS_DAv8_M36_v2/LS_DAv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg",
    file_prefix="LS_DAv8_M36",
    varnames=["SFMC", "RZMC", "PRECTOTCORRLAND", "FRLANDSNO", "TSOIL1"]
)

Found 288 files.


In [5]:
# Preview
print(ds_ol.RZMC)

<xarray.DataArray 'RZMC' (time: 288, tile: 112573)>
dask.array<where, shape=(288, 112573), dtype=float32, chunksize=(1, 112573), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 2000-06-01 2000-07-01 ... 2024-05-01
    lat      (tile) float32 71.81 70.94 69.29 68.52 ... -18.53 -17.05 70.1
    lon      (tile) float32 -179.8 -179.8 -179.8 -179.8 ... 179.8 179.8 179.8
Dimensions without coordinates: tile
Attributes:
    long_name:     soil_moisture_rootzone
    units:         m3 m-3
    cell_methods:  time: mean


In [None]:
def plot_tile_map(lon, lat, data, title="Tile-based variable", cmap="Blues", vmin=None, vmax=None):
    fig = plt.figure(figsize=(10, 5))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_global()
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.gridlines(draw_labels=True, linewidth=0.2)

    sc = ax.scatter(
        lon, lat, c=data,
        cmap=cmap,
        s=1,
        transform=ccrs.PlateCarree(),
        vmin=vmin, vmax=vmax
    )
    plt.colorbar(sc, ax=ax, orientation="vertical", label=title)
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

In [None]:
plot_tile_map(
    lon=ds_ol.lon.isel(time=0).values,
    lat=ds_ol.lat.isel(time=0).values,
    data=(ds_ol.RZMC.isel(time=0).values),
    title="RZMC (m3/m3)"
)

In [None]:
print(ds_ol)
print(ds_an)

In [None]:
print(ds_ol.time)

In [None]:
clim_period = slice("2001-01-01", "2020-12-31")

clim_ol = ds_ol.RZMC.sel(time=clim_period).groupby("time.month").mean("time")
clim_an = ds_an.RZMC.sel(time=clim_period).groupby("time.month").mean("time")


In [None]:
clim_period = slice("2001-01-01", "2020-12-31")

clim_ol = ds_ol.RZMC.sel(time=clim_period).groupby("time.month").mean("time")
clim_an = ds_an.RZMC.sel(time=clim_period).groupby("time.month").mean("time")

# Monthly climatology std dev (2001–2020)
std_ol = ds_ol.RZMC.sel(time=clim_period).groupby("time.month").std("time")
std_an = ds_an.RZMC.sel(time=clim_period).groupby("time.month").std("time")

months = np.arange(1, 13)
z_anom_ol_list = []
z_anom_an_list = []

for m in months:
    sel_ol = ds_ol.RZMC.sel(time=ds_ol.time.dt.month == m)
    sel_an = ds_an.RZMC.sel(time=ds_an.time.dt.month == m)

    z_ol = (sel_ol - clim_ol.sel(month=m)) / std_ol.sel(month=m)
    z_an = (sel_an - clim_an.sel(month=m)) / std_an.sel(month=m)

    z_anom_ol_list.append(z_ol)
    z_anom_an_list.append(z_an)

z_anom_ol = xr.concat(z_anom_ol_list, dim="time").sortby("time")
z_anom_an = xr.concat(z_anom_an_list, dim="time").sortby("time")

# Resample to yearly means of z-anomalies
z_annual_ol = z_anom_ol.resample(time="1Y").mean("time")
z_annual_an = z_anom_an.resample(time="1Y").mean("time")

# Interannual variance of standardized anomalies
z_var_ol = z_annual_ol.var("time")
z_var_an = z_annual_an.var("time")
z_var_diff = z_var_an - z_var_ol


In [None]:
print(z_var_diff.shape)
print(np.max(z_var_diff.values))
print(np.min(z_var_diff.values))
print(np.mean(z_var_diff.values))

In [None]:
plot_tile_map(
    lon=ds_ol.lon.isel(time=0).values,
    lat=ds_ol.lat.isel(time=0).values,
    data=z_var_ol,
    title="Monthly anomaly variance (standardized) (OL) SFMC",
    vmin=0,
    vmax=5    
)

plot_tile_map(
    lon=ds_ol.lon.isel(time=0).values,
    lat=ds_ol.lat.isel(time=0).values,
    data=z_var_an,
    title="Monthly anomaly variance (standardized) (DA) SFMC",
    vmin=0,
    vmax=5
)

plot_tile_map(
    lon=ds_ol.lon.isel(time=0).values,
    lat=ds_ol.lat.isel(time=0).values,
    data=z_var_diff,
    title="Δ Monthly anomaly variance (standardized) (DA - OL) SFMC",
    vmin=-2,
    vmax=2,
    cmap="RdBu"
)

In [None]:
# Convert to NumPy array and drop NaNs
data = z_var_diff.values
data = data[np.isfinite(data)]

# Plot histogram as PDF
plt.figure(figsize=(8, 4))
plt.hist(data, bins=100, density=True, color='skyblue', edgecolor='black')
plt.axvline(0, color='red', linestyle='--', label='No Change')

plt.title("PDF of Δ Standardized Interannual Variance (AN - OL)")
plt.xlabel("Δ Variance (Z-score units²)")
plt.ylabel("Probability Density")
plt.legend()
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.show()

In [6]:
# Define compression settings for all variables
comp = dict(zlib=True, complevel=4)  # 4 = good balance

encoding = {var: comp for var in ds_ol.data_vars}
# Save with compression
ds_ol.to_netcdf("OLv8_land_variables_2000_2024_compressed.nc", encoding=encoding)

encoding = {var: comp for var in ds_da.data_vars}
# Save with compression
ds_da.to_netcdf("DAv8_land_variables_2000_2024_compressed.nc", encoding=encoding)
