In [1]:
import glob
import re
import matplotlib as plt
import numpy as np
import scipy as sp
import xarray as xr

In [2]:
import dask
from dask.distributed import Client, performance_report
from dask_jobqueue import PBSCluster

In [3]:
# File paths
rda_scratch = "/gpfs/csfs1/collections/rda/scratch/harshah"
rda_data    = "/gpfs/csfs1/collections/rda/data/"
era5_path   = rda_data + "ds633.0/e5.oper.an.sfc/"
zarr_path   = rda_scratch + "/tas_zarr/"

In [4]:
## Find NetCDF files with tas (Surface air temperature at 2m) using glob and a search pattern
tas_pattern = era5_path + "**/e5.oper.an.sfc.128_167_2t.*.nc"
tas_ncfiles = glob.glob(tas_pattern, recursive=True)

In [5]:
len(tas_ncfiles)

1009

In [6]:
tas_ncfiles[0]

'/gpfs/csfs1/collections/rda/data/ds633.0/e5.oper.an.sfc/202207/e5.oper.an.sfc.128_167_2t.ll025sc.2022070100_2022073123.nc'

In [7]:
# Create a PBS cluster object
cluster = PBSCluster(
    job_name="dask-wk23-hpc",
    cores=1,
    memory="32GiB",
    processes=1,
    local_directory= rda_scratch+ "/dask/spill",
    resource_spec="select=1:ncpus=1:mem=32GB",
    queue="casper",
    walltime="5:00:00",
    # interface = 'ib0'
    interface="ext",
)

In [8]:
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/8787/status,

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.119:46095,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [9]:
# Define a key function that extracts the YYYYMM part
def extract_date(filename):
    # Using regular expression to find the first occurrence of a pattern resembling 'YYYYMM'
    match = re.search(r"/(\d{6})/", filename)
    if match:
        return match.group(1)
    else:
        return filename  # Return the original filename if pattern is not found


# Sort the list using the key function
sorted_ncfiles = sorted(tas_ncfiles, key=extract_date)
###
sorted_ncfiles[1]

'/gpfs/csfs1/collections/rda/data/ds633.0/e5.oper.an.sfc/194002/e5.oper.an.sfc.128_167_2t.ll025sc.1940020100_1940022923.nc'

In [10]:
#
lat_chunksize = 139
lon_chuksize = 277
time_chunksize = 240

##########


def process_batch(start_index, end_index, ncfile_list, zfile_path):
    # Generate a Zarr store name based on the batch indices
    # ncfile_list = List containing filenames of NetCDF files, # zfile_path = zarr_file_path
    #
    start_date = extract_date(ncfile_list[start_index])
    end_date = extract_date(ncfile_list[end_index - 1])
    #
    zarr_store_name = zfile_path + f"tas2m_{start_date}_{end_date}.zarr"
    print(zarr_store_name)
    # Read the files in the current batch into a single xarray dataset
    datasets = xr.open_mfdataset(
        ncfile_list[start_index:end_index], combine="nested", concat_dim="time"
    ).VAR_2T

    # Rechunk the dataset to holed 10 days-worth of data in a chunk
    rechunked_dataset = datasets.chunk(
        {"time": time_chunksize, "latitude": lat_chunksize, "longitude": lon_chuksize}
    )

    # Save the combined dataset to a Zarr file
    rechunked_dataset.to_zarr(zarr_store_name, mode="w")

## Let us rewrite these files into a zarr store

In [11]:
# Scale the cluster to n workers
cluster.scale(10)

In [12]:
cluster

0,1
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://128.117.208.119:46095,Workers: 0
Dashboard: https://jupyterhub.hpc.ucar.edu/stable/user/harshah/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [13]:
# %%time
# # Specify batch size
# batch_size = 120
# for start_index in np.arange(0, len(sorted_ncfiles), batch_size):
#     # for start_index in np.arange(0,249,batch_size):
#     end_index = min(start_index + batch_size, len(sorted_ncfiles))
#     print(end_index)
#     # Process the current batch
#     process_batch(start_index, end_index, sorted_ncfiles, zarr_path)
#     print(f"Processed files {start_index} to {end_index}")

# print("All files have been processed.")

In [14]:
%%time
tas = xr.open_mfdataset(zarr_path + "tas2m*.zarr", combine="nested", concat_dim="time", engine="zarr").VAR_2T
tas

CPU times: user 298 ms, sys: 89.7 ms, total: 387 ms
Wall time: 930 ms


Unnamed: 0,Array,Chunk
Bytes,2.78 TiB,35.25 MiB
Shape,"(737088, 721, 1440)","(240, 139, 277)"
Dask graph,110808 chunks in 19 graph layers,110808 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.78 TiB 35.25 MiB Shape (737088, 721, 1440) (240, 139, 277) Dask graph 110808 chunks in 19 graph layers Data type float32 numpy.ndarray",1440  721  737088,

Unnamed: 0,Array,Chunk
Bytes,2.78 TiB,35.25 MiB
Shape,"(737088, 721, 1440)","(240, 139, 277)"
Dask graph,110808 chunks in 19 graph layers,110808 chunks in 19 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
%%time
tas_daily = tas.resample(time="1D").mean()
tas_daily

In [None]:
%%time
tas_daily = tas_daily.chunk({"latitude": 139, "longitude": 544, "time": 1000})
tas_daily

### Let us now compare the writing speed for NetCDF vs zarr formats

In [None]:
# %%time
# ## Generate performance report
# with performance_report(filename="e5_zarr_report.html"):
#     tas_daily.to_dataset().to_zarr(zarr_path + "e5_tas2m_daily_1940_2023.zarr", mode="w" )

In [None]:
%%time
## Generate performance report
with performance_report(filename="e5_nc_report.html"):
    tas_daily.to_dataset().to_netcdf(zarr_path + "e5_tas2m_daily_1940_2023.nc", mode="w")