In [1]:
import os
import glob
import xarray as xr
import numpy as np
import importlib.util
from datetime import datetime

In [2]:
# --- User-defined experiment name and root directory ---
expt_name = 'DAv7_M36_MULTI_type_13_comb_fp_scaled'
root_directory = f'/discover/nobackup/projects/land_da/Experiment_archive/{expt_name}/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg'
# root_directory = f'/Users/amfox/Desktop/GEOSldas_diagnostics/test_data/fp_scaled/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg'

# --- Define fixed start and end dates ---
start_date = datetime(2015, 4, 1)
end_date = datetime(2021, 4, 1)  # 6 years later

start_date_str = start_date.strftime('%Y%m%d')
end_date_str = end_date.strftime('%Y%m%d')

In [3]:

# --- Check if dask is available and show debug info ---
try:
    import dask
    dask_available = True
    print(f"Dask version: {dask.__version__}")
except ImportError:
    dask_available = False
    print("Dask not available.")

# --- List available engines and chunk managers ---
print("\nAvailable xarray engines:")
print(xr.backends.list_engines())

print("\nChecking if 'dask.array' is importable:")
if importlib.util.find_spec("dask.array") is not None:
    print("dask.array is available.")
else:
    print("dask.array is NOT available.")


Dask version: 2025.2.0

Available xarray engines:
{'netcdf4': <NetCDF4BackendEntrypoint>
  Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray
  Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html, 'h5netcdf': <H5netcdfBackendEntrypoint>
  Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray
  Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html, 'store': <StoreBackendEntrypoint>
  Open AbstractDataStore instances in Xarray
  Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html}

Checking if 'dask.array' is importable:
dask.array is available.


In [None]:
%%time

# --- Find all matching files using glob ---
file_pattern = os.path.join(
    root_directory,
    'Y*',
    'M*',
    f'{expt_name}.catch_progn_incr.*.nc4'
)

all_files = sorted(glob.glob(file_pattern))

# --- Parse date from filenames like:
# DAv7_M36_MULTI_type_13_comb_fp_scaled.catch_progn_incr.20160227.nc4
selected_files = []
for file in all_files:
    basename = os.path.basename(file)
    try:
        date_str = basename.split('.')[-2]  # '20160227' just before the '.nc4'
        file_date = datetime.strptime(date_str, '%Y%m%d')
        if start_date <= file_date <= end_date:
            selected_files.append(file)
    except Exception as e:
        continue

# --- Load all selected datasets using nested combine with explicit concat_dim ---
print(f"Loading {len(selected_files)} files")
combined_ds = xr.open_mfdataset(
    selected_files,
    combine='nested',
    concat_dim='time',
    parallel=dask_available,
    engine='netcdf4',
    chunks={}
)

print('Done loading files.')


Loading 2202 files


In [None]:
%%time

# Rechunk after loading to ensure larger chunk sizes are used
desired_chunks = {'time': 800, 'tile': 112573}
print(f"Rechunking to desired chunks: {desired_chunks}")
combined_ds = combined_ds.chunk(desired_chunks)

In [None]:
%%time

from dask import compute

thresholds = [0.0, 10.0e-7, 0.00005, 0.0001, 0.00015, 0.0002, 0.00025, 0.0003, 0.00035, 0.0004, 0.00045, 0.0005]
computations = {}

for threshold in thresholds:
    key = f'{threshold:.5f}'.split('.')[1].rstrip('0')
    incremented_values = (combined_ds['SRFEXC_INCR'] < -threshold) | (combined_ds['SRFEXC_INCR'] > threshold)
    computations[f'cnt_{key}'] = incremented_values.sum(dim='time')
    computations[f'mean_{key}'] = combined_ds['SRFEXC_INCR'].where(incremented_values).mean(dim='time', skipna=True)
    computations[f'std_{key}'] = combined_ds['SRFEXC_INCR'].where(incremented_values).std(dim='time', skipna=True)

# Compute all results in parallel
results_computed = compute(*computations.values())
results_keys = list(computations.keys())

# Organize results back into dictionaries
cnt_srfexc_increment = {k.replace('cnt_', ''): v for k, v in zip(results_keys, results_computed) if k.startswith('cnt_')}
mean_srfexc_increment = {k.replace('mean_', ''): v for k, v in zip(results_keys, results_computed) if k.startswith('mean_')}
std_srfexc_increment = {k.replace('std_', ''): v for k, v in zip(results_keys, results_computed) if k.startswith('std_')}

# %% [markdown]
# ## Save results to file

In [None]:
%%time
output_file_srfexc = f'{expt_name}_{start_date_str}_{end_date_str}_catch_progn_incr_stats_optimized.npz'

if os.path.exists(output_file_srfexc):
    os.remove(output_file_srfexc)

np.savez(output_file_srfexc,
         **{f'cnt_srfexc_increment_{key}': cnt_srfexc_increment[key] for key in cnt_srfexc_increment},
         **{f'mean_srfexc_increment_{key}': mean_srfexc_increment[key] for key in mean_srfexc_increment},
         **{f'std_srfexc_increment_{key}': std_srfexc_increment[key] for key in std_srfexc_increment})

print(f"Data successfully saved to {output_file_srfexc}")

In [None]:
# Define the thresholds and their corresponding labels
thresholds = [0.0, 10.0e-7, 0.00005, 0.0001, 0.00015, 0.0002, 0.00025, 0.0003, 0.00035, 0.0004, 0.00045, 0.0005]
labels = [f'Threshold: {threshold}\n Number of srfexc increments' for threshold in thresholds]

# Define the output file name
output_file = f'{expt_name}_{start_date_str}_{end_date_str}_catch_progn_incr_stats_test_dask.npz'

# Load the data from the .npz file
data = np.load(output_file)

test = data['cnt_srfexc_increment_']

print(test.shape)