In [None]:
import os
import glob
import xarray as xr
import numpy as np
from datetime import datetime
from dask import compute


In [None]:

# Define the root directory and experiment name
expt_name = 'snow_LS_OLv8_M36'
root_directory = f'/discover/nobackup/amfox/Experiments/snow_M21C_test/{expt_name}/output/SMAP_EASEv2_M36_GLOBAL/cat/ens0000'

# Define the start and end dates
start_date = datetime(2003, 1, 1)
end_date = datetime(2004, 1, 1)


In [None]:

# --- Find all matching files using glob ---
file_pattern = os.path.join(
    root_directory,
    'Y*',
    'M*',
    f'{expt_name}.tavg24_1d_lnd_Nt.*.nc4'
)

all_files = sorted(glob.glob(file_pattern))

# print the first 5 files, one per line
for file in all_files[:5]:
    print(file)

# --- Parse date from filenames like:
# snow_LS_OLv8_M36.tavg24_1d_lnd_Nt.20030101_1200z.nc4
selected_files = []
for file in all_files:
    basename = os.path.basename(file)
    try:
        date_str = basename.split('.')[-2].split('_')[0]  # '20030101' just before the '_1200z.nc4'
        file_date = datetime.strptime(date_str, '%Y%m%d')
        if start_date <= file_date <= end_date:
            selected_files.append(file)
    except Exception as e:
        continue

# --- Load all selected datasets using nested combine with explicit concat_dim ---
print(f"Loading {len(selected_files)} files")
combined_ds = xr.open_mfdataset(
    selected_files,
    combine='nested',
    concat_dim='time',
    parallel=True,  # Enable parallel processing with Dask
    engine='netcdf4',
    chunks={}
)

print('Done loading files.')


In [None]:

# Define the variables to be extracted
variables = {
    'sm_surface': 'SFMC',
    'sm_rootzone': 'RZMC',
    'sm_profile': 'PRMC',
    'precipitation_total_surface_flux': 'PRECTOTCORRLAND',
    'vegetation_greenness_fraction': 'GRN',
    'leaf_area_index': 'LAI',
    'snow_mass': 'SNOMASLAND',
    'surface_temperature_of_land_incl_snow': 'TSURFLAND',
    'soil_temperature_layer_1': 'TSOIL1',
    'snowfall_land': 'PRECSNOCORRLAND',
    'snow_depth_within_snow_covered_area_fraction_on_land': 'SNODPLAND',
    'snowpack_evaporation_latent_heat_flux_on_land': 'LHLANDSBLN',
    'overland_runoff_including_throughflow': 'RUNSURFLAND',
    'baseflow_flux_land': 'BASEFLOWLAND',
    'snowmelt_flux_land': 'SMLAND',
    'total_evaporation_land': 'EVLAND',
    'net_shortwave_flux_land': 'SWLAND',
    'total_water_storage_land': 'TWLAND',
    'fractional_area_of_snow_on_land': 'FRLANDSNO'  # New variable added
}

# Initialize dictionaries to store results
results = {var: {} for var in variables}

# Perform calculations for each variable
for var, ds_var in variables.items():
    results[var]['concat'] = combined_ds[ds_var]
    results[var]['mean'] = combined_ds[ds_var].mean(dim='time', skipna=True)
    results[var]['std'] = combined_ds[ds_var].std(dim='time', skipna=True)

# Compute all results in parallel
computed_results = compute(*[results[var]['mean'] for var in variables] + 
                           [results[var]['std'] for var in variables])

# Organize results back into dictionaries
for i, var in enumerate(variables):
    results[var]['mean'] = computed_results[i]
    results[var]['std'] = computed_results[i + len(variables)]


In [None]:

# Save the concatenated variables to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_concat.npz',
         **{f'{var}_concat': results[var]['concat'].values for var in variables})


In [None]:

# Save the statistics to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_stats.npz',
         **{f'mean_{var}': results[var]['mean'].values for var in variables},
         **{f'std_{var}': results[var]['std'].values for var in variables})


In [None]:

# Calculate the mean, etc. for each time step along the tile dimension
ts_results = {var: {} for var in variables}
for var in variables:
    ts_results[var]['mean'] = combined_ds[variables[var]].mean(dim='tile', skipna=True)
    ts_results[var]['std'] = combined_ds[variables[var]].std(dim='tile', skipna=True)

# Compute all time series results in parallel
ts_computed_results = compute(*[ts_results[var]['mean'] for var in variables] + 
                              [ts_results[var]['std'] for var in variables])

# Organize time series results back into dictionaries
for i, var in enumerate(variables):
    ts_results[var]['mean'] = ts_computed_results[i]
    ts_results[var]['std'] = ts_computed_results[i + len(variables)]

# Save the time series to a new .npz file
np.savez(f'{expt_name}_{start_date.strftime("%Y%m%d")}_{end_date.strftime("%Y%m%d")}_tavg24_1d_lnd_Nt_timeseries.npz',
         **{f'ts_mean_{var}': ts_results[var]['mean'].values for var in variables},
         **{f'ts_std_{var}': ts_results[var]['std'].values for var in variables})