# Large Esembles (LEs)

In [6]:
def remove_duplicates_in_dim(ds, dim="member"):
    """
    Removes duplicate values in the specified coordinate (e.g., 'member') 
    while preserving only the first occurrence.

    Parameters:
    ds (xarray.Dataset): The input dataset with potentially duplicated coordinate values.
    dim (str): The coordinate to check for duplicates (default: "member").

    Returns:
    xarray.Dataset: A dataset with unique values in the specified coordinate.
    """
    if dim not in ds.coords:
        raise ValueError(f"Coordinate '{dim}' not found in dataset.")

    # Get coordinate values
    member_values = ds[dim].values

    # Find unique indices while keeping the first occurrence
    _, unique_indices = np.unique(member_values, return_index=True)

    # Select only unique members
    ds_unique = ds.isel({dim: np.sort(unique_indices)})

    return ds_unique

In [7]:
# model = 'canesm5_lens'
# experiment = 'ssp585'

model = 'cesm1_lens'
experiment = 'rcp85'

In [8]:
# LE_PATH = '/g/data/v45/nxm561/cesm2_lens/Amon/tas'
LE_PATH = f'/g/data/v45/nxm561/{model}/Amon/tas'

#'tas_Amon_CESM1-CAM5_historical_rcp45_r10i1p1_192001-208012_g025.nc')

In [9]:
os.listdir(LE_PATH)

['tas_Amon_CESM1-CAM5_historical_rcp85_r21i1p1_192001-210012_g025.nc',
 'tas_CESM1-CAM5_hist_rcp85_r26i1p1_192001-210012.nc',
 'tas_CESM1-CAM5_hist_rcp85_r32i1p1_192001-210012.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r15i1p1_192001-210012_g025.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r11i1p1_192001-210012_g025.nc',
 'tas_CESM1-CAM5_hist_rcp85_r31i1p1_192001-210012.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r13i1p1_192001-210012_g025.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r8i1p1_192001-210012_g025.nc',
 'tas_CESM1-CAM5_hist_rcp85_r23i1p1_192001-210012.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r34i1p1_192001-210012_g025.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp85_r39i1p1_192001-210012_g025.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp45_r12i1p1_192001-208012_g025.nc',
 'tas_CESM1-CAM5_hist_rcp85_r37i1p1_192001-210012.nc',
 'tas_CESM1-CAM5_hist_rcp85_r1i1p1_192001-210012.nc',
 'tas_Amon_CESM1-CAM5_historical_rcp45_r13i1p1_192001-208012_g025.nc',
 'tas_CESM1-CAM5_hist_rcp85_r33i1

In [11]:
# file_list = glob(os.path.join(LE_PATH, '*rcp85*.nc'))
file_list = glob(os.path.join(LE_PATH, f'*{experiment}*.nc'))

In [12]:
len(file_list)

81

In [13]:
file_list[:-3]

['/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp85_r21i1p1_192001-210012_g025.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_CESM1-CAM5_hist_rcp85_r26i1p1_192001-210012.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_CESM1-CAM5_hist_rcp85_r32i1p1_192001-210012.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp85_r15i1p1_192001-210012_g025.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp85_r11i1p1_192001-210012_g025.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_CESM1-CAM5_hist_rcp85_r31i1p1_192001-210012.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp85_r13i1p1_192001-210012_g025.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp85_r8i1p1_192001-210012_g025.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_CESM1-CAM5_hist_rcp85_r23i1p1_192001-210012.nc',
 '/g/data/v45/nxm561/cesm1_lens/Amon/tas/tas_Amon_CESM1-CAM5_historical_rcp

In [14]:
# Full paths
file_paths = [os.path.join(LE_PATH, f) for f in file_list]

# Regex to extract ensemble member (e.g., r21i1p1)
def extract_member(filename):
    match = re.search(r'r\d+i\d+p\d+', filename)
    return match.group(0) if match else None

# Open datasets individually, add member coordinate, and store them in a list
ds_list = []
for file in file_paths:
    member = extract_member(file)
    ds = xr.open_dataset(file) #, chunks='auto')#={'time':-1, 'lon':144//12, 'lat':72//12})
    ds = ds.assign_coords(member=member).expand_dims('member')
    ds_list.append(ds['tas'])

# Concatenate along the 'member' dimension
ds_combined = xr.concat(ds_list, dim='member').chunk(
    {'member': -1, 'time':-1, 'lon':144//12, 'lat':72//12}) 

In [15]:
# ds_combined = ds_combined['tas']a

In [16]:
ds_combined

Unnamed: 0,Array,Chunk
Bytes,18.85 GiB,134.02 MiB
Shape,"(81, 3012, 72, 144)","(81, 3012, 6, 12)"
Dask graph,144 chunks in 1 graph layer,144 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 18.85 GiB 134.02 MiB Shape (81, 3012, 72, 144) (81, 3012, 6, 12) Dask graph 144 chunks in 1 graph layer Data type float64 numpy.ndarray",81  1  144  72  3012,

Unnamed: 0,Array,Chunk
Bytes,18.85 GiB,134.02 MiB
Shape,"(81, 3012, 72, 144)","(81, 3012, 6, 12)"
Dask graph,144 chunks in 1 graph layer,144 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [17]:
ds_combined = ds_combined.drop_vars('height', errors='ignore')

In [18]:
ds_combined

Unnamed: 0,Array,Chunk
Bytes,18.85 GiB,134.02 MiB
Shape,"(81, 3012, 72, 144)","(81, 3012, 6, 12)"
Dask graph,144 chunks in 1 graph layer,144 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 18.85 GiB 134.02 MiB Shape (81, 3012, 72, 144) (81, 3012, 6, 12) Dask graph 144 chunks in 1 graph layer Data type float64 numpy.ndarray",81  1  144  72  3012,

Unnamed: 0,Array,Chunk
Bytes,18.85 GiB,134.02 MiB
Shape,"(81, 3012, 72, 144)","(81, 3012, 6, 12)"
Dask graph,144 chunks in 1 graph layer,144 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [19]:
ds_combined = remove_duplicates_in_dim(ds_combined, 'member')

In [20]:
ds_combined

Unnamed: 0,Array,Chunk
Bytes,9.31 GiB,66.18 MiB
Shape,"(40, 3012, 72, 144)","(40, 3012, 6, 12)"
Dask graph,144 chunks in 2 graph layers,144 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 9.31 GiB 66.18 MiB Shape (40, 3012, 72, 144) (40, 3012, 6, 12) Dask graph 144 chunks in 2 graph layers Data type float64 numpy.ndarray",40  1  144  72  3012,

Unnamed: 0,Array,Chunk
Bytes,9.31 GiB,66.18 MiB
Shape,"(40, 3012, 72, 144)","(40, 3012, 6, 12)"
Dask graph,144 chunks in 2 graph layers,144 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [21]:
ds_combined = client.scatter(ds_combined).result().persist()

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


In [22]:
# ds_combined = ds_combined.chunk({'member': -1, 'time':-1, 'lon':144//12, 'lat':72//12}) 

In [23]:
# ds_combined = ds_combined.persist()
# wait(ds_combined)

In [24]:
ds_resamp_combined = ds_combined.resample(time='YE').mean()#.persist()
# wait(ds_resamp_combined)

In [25]:
ds_resamp_combined = client.scatter(ds_resamp_combined).result().persist()
wait(ds_resamp_combined)

DoneAndNotDoneFutures(done={<Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 3, 1)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 10, 7)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 4, 0)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 8, 10)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 4, 9)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 8, 0)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 4, 10)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 8, 9)>, <Future: finished, type: numpy.ndarray, key: ('transpose-b329562b90dfa79b7e20bfbfd9b937b8', 0, 0, 2, 10)>, <Future: finis

In [26]:
OUTPUT_FNAME = f'/g/data/w40/ab2313/time_of_emergence/{model}_{experiment}.zarr'

In [27]:
ds_resamp_combined.to_zarr(OUTPUT_FNAME, mode='w')

<xarray.backends.zarr.ZarrStore at 0x14c3e73fedc0>

# ERA5

In [25]:
# np.array(os.listdir('/g/data/rt52/era5/single-levels'))
# np.array(os.listdir('/g/data/rt52/era5/single-levels/monthly-averaged'))

In [26]:
# client.cluster.scale(100)
# sleep(6)

In [27]:
# # #'2t': 2m tempeature
# # # 'tp': total precipitation
# # var = '2t'
# # rsn: snow density
# # sd: snow depth
# variable = 'sd'
# original_name = variable#'2t'
# base_path = f'/g/data/rt52/era5/single-levels/reanalysis'
# path = os.path.join(base_path, original_name)
# path

In [28]:
# MY_ERA5_PATH = os.path.join(paths.DATA_DIR, 'era5')
# output_file = Path(os.path.join(MY_ERA5_PATH, f"{variable}.zarr"))
# output_file

In [29]:
# wild_tag = '*/*.nc'
# years = np.sort(os.listdir(path))
# files_to_open = np.array(glob(os.path.join(path, wild_tag), recursive=True))
# files_to_open.shape

In [30]:
# files_to_open[:2]

In [31]:
# td0 = xr.open_dataset(files_to_open[0])
# td = xr.open_dataset(files_to_open[-1])

In [32]:
# td[variable].attrs

In [33]:
# td0.sum(dim='time')[variable].plot()

In [34]:
# td.sum(dim='time')[variable].plot()

In [35]:
# climatology_files = np.array([
#     f for f in files_to_open 
#     if (
#         (year := int(f.split('/')[-2])) > base_period.start 
#         and year < base_period.end
#     )
# ])

In [36]:
# def __preprocess(ds):
#     return ds.to_array().rename({'latitude': 'lat', 'longitude': 'lon'})

In [37]:
# output_file_tmp = str(output_file).replace('.zarr', '_tmp.zarr')
# output_file_tmp

In [38]:
# clim_list = []
# for i, year in enumerate(np.arange(base_period.start, base_period.end)):
#     print(f'{year}, ', end='')
#     year_file = np.sort([f for f in files_to_open if str(year) in f.split('/')[-2]])
#     da_raw = xr.open_mfdataset(
#         year_file,
#         use_cftime=True,
#         chunks = CHUNKS_FOR_ERA5,
#         preprocess = __preprocess
#     )

#     # Annual averagte tempature
#     da = da_raw.resample(time="1D").mean()
    
#     da = da.squeeze("variable", drop=True)
#     da.name = original_name
#     clim_list.append(da)

In [39]:
# %%time
# base_period_ds = xr.concat(clim_list, dim='time').chunk(CHUNKS_FOR_ERA5)
# base_period_percentile_ds = base_period_ds.reduce(my_stats.dask_percentile, q=99.9, dim='time')
# base_period_percentile_ds = base_period_percentile_ds.persist()
# wait(base_period_percentile_ds)

In [40]:
# years_to_use = years #[years.astype(int) > 1981].astype(str)
# years_to_use

In [41]:
# for year in years_to_use: #years
#     print(f'{year}, ', end='')
#     year_file = np.sort([f for f in files_to_open if year in f.split('/')[-2]])
#     da_raw = xr.open_mfdataset(
#         year_file,
#         use_cftime=True,
#         chunks = chunks,
#         preprocess = __preprocess
#     )

#     # Annual averagte tempature OR mean snow density
#     # da = da_raw.resample(time='Y').mean().compute()

#     # Rx1d
#     # daily_resample = da_raw.resample(time="1D").sum()
#     # da = daily_resample.resample(time="1Y").max()

#     # Yearly total preci
#     # da = da_raw.resample(time='YE').sum()

#     # TX99Count
#     # daily_mean_da = da_raw.resample(time="1D").mean() # Daily mean tempeature
#     # da = (daily_mean_da > base_period_percentile_ds).resample(time='YE').sum()
#     #.groupby("time.year").sum()

#     da = da.chunk(chunks)
#     da = da.squeeze("variable", drop=True)
#     da.name = original_name
    
#     # # Save to Zarr
#     if year == years[0]:
#         # First year: Create the Zarr file
#         da.to_zarr(output_file_tmp, mode="w", consolidated=True)
#         print(f"Saved {year} to {output_file}")
#     else:
#         # Subsequent years: Append to the Zarr file
#         da.to_zarr(output_file_tmp, mode="a", append_dim="time")
#         print(f"Appended {year} to {output_file}")

In [42]:
# da.attrs

In [43]:
# (daily_mean_da > base_period_99p_ds).resample(time='YE').sum().compute().plot()

In [44]:
# output_file_tmp

In [45]:
# # Open all the files in the tmp file
# data_ds = (xr.open_zarr(output_file_tmp, use_cftime=True)
#            .to_array()
#            .compute()
#            .chunk(chunks=CHUNKS_FOR_ERA5))

In [46]:
# data_ds.time.dt.year

In [47]:
# data_ds = data_ds.squeeze('variable', drop=True)
# data_ds.name = variable

In [48]:
# data_ds = data_ds * 1000 # For precipitation

In [49]:
# data_ds.isel(time=16).plot(robust=True)

In [50]:
# data_ds.attrs = {
#     **data_ds.attrs,
#     "dataset_name": 'era5',
#     'variable': original_name,
#     'save_chunks' : CHUNKS_FOR_ERA5,
#     'info':("tp (total precipitation) reampled to daily"
#             "sum then resample to yearly max")
# }

In [51]:
# data_ds

In [52]:
# data_ds.to_zarr(output_file, mode='w')