In [1]:
import os
import glob
import xarray as xr
import numpy as np
from datetime import datetime
from dask import compute


In [2]:

# Define the root directory and experiment name
expt_name = 'LS_OLv8_M36'
# root_directory = f'/discover/nobackup/amfox/Experiments/snow_M21C_test/{expt_name}/output/SMAP_EASEv2_M36_GLOBAL/cat/ens0000'
root_directory = f'/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/{expt_name}/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg'

# Define the start and end dates
start_date = datetime(2000, 10, 1)
end_date = datetime(2023, 10, 1)


In [3]:
from dask_jobqueue import SLURMCluster
from dask.distributed import Client

# Create a SLURM cluster
cluster = SLURMCluster(
    cores=4,  # Number of cores per worker
    memory="16GB",  # Memory per worker
    processes=1,  # Number of processes per worker
    walltime="01:00:00",  # Maximum runtime
    job_extra=["--export=ALL"],  # Export environment variables
    env_extra=[
        "module load anconda",  # Load necessary modules
        "conda activate diag",  # Activate the Conda environment
    ],
)

# Scale the cluster
cluster.scale(jobs=10)  # Request 10 workers

# Connect the client
client = Client(cluster)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44241 instead


In [4]:
%%time
# --- Find all matching files using glob ---
file_pattern = os.path.join(
    root_directory,
    'Y*',
    'M*',
    f'{expt_name}.tavg24_1d_lnd_Nt.2*.nc4'
)

all_files = sorted(glob.glob(file_pattern))

# print the first 5 files, one per line
for file in all_files[:5]:
    print(file)

# --- Parse date from filenames like:
# snow_LS_OLv8_M36.tavg24_1d_lnd_Nt.20030101_1200z.nc4
selected_files = []
for file in all_files:
    basename = os.path.basename(file)
    try:
        date_str = basename.split('.')[-2].split('_')[0]  # '20030101' just before the '_1200z.nc4'
        file_date = datetime.strptime(date_str, '%Y%m%d')
        if start_date <= file_date <= end_date:
            selected_files.append(file)
    except Exception as e:
        continue

# --- Load all selected datasets using nested combine with explicit concat_dim ---
print(f"Loading {len(selected_files)} files")
combined_ds = xr.open_mfdataset(
    selected_files,
    combine='nested',
    concat_dim='time',
    parallel=True,  # Enable parallel processing with Dask
    engine='netcdf4',
    chunks={}
)

print('Done loading files.')


/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg/Y2000/M06/LS_OLv8_M36.tavg24_1d_lnd_Nt.20000601_1200z.nc4
/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg/Y2000/M06/LS_OLv8_M36.tavg24_1d_lnd_Nt.20000602_1200z.nc4
/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg/Y2000/M06/LS_OLv8_M36.tavg24_1d_lnd_Nt.20000603_1200z.nc4
/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg/Y2000/M06/LS_OLv8_M36.tavg24_1d_lnd_Nt.20000604_1200z.nc4
/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_OLv8_M36/LS_OLv8_M36/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg/Y2000/M06/LS_OLv8_M36.tavg24_1d_lnd_Nt.20000605_1200z.nc4
Loading 8401 files
Done loadin

In [5]:
%%time
# Define the variables to be extracted
variables = {
    'precipitation_total_surface_flux': 'PRECTOTCORRLAND',
    'snowfall_land': 'PRECSNOCORRLAND',
}

# Initialize dictionaries to store results
results = {var: {} for var in variables}

# Perform calculations for each variable
for var, ds_var in variables.items():
    results[var]['mean'] = combined_ds[ds_var].mean(dim='time', skipna=True)
    results[var]['std'] = combined_ds[ds_var].std(dim='time', skipna=True)

CPU times: user 409 ms, sys: 3.48 ms, total: 413 ms
Wall time: 411 ms


In [6]:
%%time
# Compute all results in parallel
computed_results = compute(*[results[var]['mean'] for var in variables] + 
                           [results[var]['std'] for var in variables])

# Organize results back into dictionaries
for i, var in enumerate(variables):
    results[var]['mean'] = computed_results[i]
    results[var]['std'] = computed_results[i + len(variables)]

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 1min 25s, sys: 2.7 s, total: 1min 27s
Wall time: 3min 35s


In [12]:
%%time
# Save the statistics to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_stats.npz',
         **{f'mean_{var}': results[var]['mean'].values for var in variables},
         **{f'std_{var}': results[var]['std'].values for var in variables})

CPU times: user 2.99 ms, sys: 0 ns, total: 2.99 ms
Wall time: 84.1 ms


In [None]:
%%time
# Define the variables to be extracted
variables = {
    'sm_surface': 'SFMC',
    'sm_rootzone': 'RZMC',
    'sm_profile': 'PRMC',
    'precipitation_total_surface_flux': 'PRECTOTCORRLAND',
    'vegetation_greenness_fraction': 'GRN',
    'leaf_area_index': 'LAI',
    'snow_mass': 'SNOMASLAND',
    'surface_temperature_of_land_incl_snow': 'TSURFLAND',
    'soil_temperature_layer_1': 'TSOIL1',
    'snowfall_land': 'PRECSNOCORRLAND',
    'snow_depth_within_snow_covered_area_fraction_on_land': 'SNODPLAND',
    'snowpack_evaporation_latent_heat_flux_on_land': 'LHLANDSBLN',
    'overland_runoff_including_throughflow': 'RUNSURFLAND',
    'baseflow_flux_land': 'BASEFLOWLAND',
    'snowmelt_flux_land': 'SMLAND',
    'total_evaporation_land': 'EVLAND',
    'net_shortwave_flux_land': 'SWLAND',
    'total_water_storage_land': 'TWLAND',
    'fractional_area_of_snow_on_land': 'FRLANDSNO'  # New variable added
}

# Initialize dictionaries to store results
results = {var: {} for var in variables}

# Perform calculations for each variable
for var, ds_var in variables.items():
    results[var]['concat'] = combined_ds[ds_var]
    results[var]['mean'] = combined_ds[ds_var].mean(dim='time', skipna=True)
    results[var]['std'] = combined_ds[ds_var].std(dim='time', skipna=True)

# Compute all results in parallel
computed_results = compute(*[results[var]['mean'] for var in variables] + 
                           [results[var]['std'] for var in variables])

# Organize results back into dictionaries
for i, var in enumerate(variables):
    results[var]['mean'] = computed_results[i]
    results[var]['std'] = computed_results[i + len(variables)]


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


In [None]:

# Save the concatenated variables to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_concat.npz',
         **{f'{var}_concat': results[var]['concat'].values for var in variables})


In [None]:

# Save the statistics to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_stats.npz',
         **{f'mean_{var}': results[var]['mean'].values for var in variables},
         **{f'std_{var}': results[var]['std'].values for var in variables})


In [None]:
%%time
# Calculate the mean, etc. for each time step along the tile dimension
ts_results = {var: {} for var in variables}
for var in variables:
    ts_results[var]['mean'] = combined_ds[variables[var]].mean(dim='tile', skipna=True)
    ts_results[var]['std'] = combined_ds[variables[var]].std(dim='tile', skipna=True)

# Compute all time series results in parallel
ts_computed_results = compute(*[ts_results[var]['mean'] for var in variables] + 
                              [ts_results[var]['std'] for var in variables])

# Organize time series results back into dictionaries
for i, var in enumerate(variables):
    ts_results[var]['mean'] = ts_computed_results[i]
    ts_results[var]['std'] = ts_computed_results[i + len(variables)]

# Save the time series to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_timeseries.npz',
         **{f'ts_mean_{var}': ts_results[var]['mean'].values for var in variables},
         **{f'ts_std_{var}': ts_results[var]['std'].values for var in variables})