In [None]:
import dask.array as da
import dask
import xarray as xr
import numpy as np
import os
from dateutil.relativedelta import relativedelta
from datetime import datetime

expt_name = 'LS_DAv8_M36'

start_date = datetime(2005, 1, 1)
end_date = datetime(2006, 1, 1)

start_date_str = start_date.strftime('%Y%m%d')
end_date_str = end_date.strftime('%Y%m%d')

root_directory = f'/discover/nobackup/projects/land_da/Experiment_archive/M21C_land_sweeper_DAv8_M36/{expt_name}/output/SMAP_EASEv2_M36_GLOBAL/cat/ens_avg'

output_file = f'{expt_name}_{start_date_str}_{end_date_str}_catch_progn_snow_incr.nc4'

snow_list = []
current_date = start_date
files_found = 0

while current_date <= end_date:
    year_month_directory = os.path.join(root_directory, 
                                        f"Y{current_date.year}", 
                                        f"M{current_date.month:02d}")
    for filename in sorted(os.listdir(year_month_directory)):
        if filename.endswith('.nc4') and not filename.endswith('z.nc4') and filename.startswith(f'{expt_name}.catch_progn_incr.2'):
            file_path = os.path.join(year_month_directory, filename)

            ds = xr.open_dataset(file_path, chunks={})  # Dask-aware open

            # Convert or standardize time coordinate
            if np.issubdtype(ds['time_stamp'].dtype, np.datetime64):
                time_coord = ds['time_stamp']
            else:
                try:
                    decoded_times = [t.decode('utf-8') for t in ds['time_stamp'].values]
                    parsed_times = np.array([np.datetime64(datetime.strptime(t[:13], "%Y%m%d_%H%M"), 'ns') for t in decoded_times])
                    time_coord = xr.DataArray(parsed_times, dims='time', name='time')
                except Exception as e:
                    continue

            wesnn1_incr = ds['WESNN1_INCR']
            wesnn2_incr = ds['WESNN2_INCR']
            wesnn3_incr = ds['WESNN3_INCR']
            snow_incr = wesnn1_incr + wesnn2_incr + wesnn3_incr

            snow_incr = snow_incr.assign_coords(time=time_coord)
            snow_list.append(snow_incr)
            files_found += 1

    current_date += relativedelta(months=1)

if files_found == 0:
    raise RuntimeError("No valid input files found. Aborting.")

# Concatenate all time slices and save once
snow_concat = xr.concat(snow_list, dim='time')
out_ds = xr.Dataset({'SNOW_INCR': snow_concat})
out_ds.to_netcdf(output_file)

print(f'Finished writing merged dataset to {output_file}')