Ref: https://github.com/NCAR/cesm-lens-aws/issues/34

In [1]:
import xarray as xr
import intake
from tqdm.auto import tqdm
import shutil 
import os
from functools import reduce
import pprint
import json
from operator import mul
import random
import yaml
from distributed.utils import format_bytes
import dask
from dask_jobqueue import SLURMCluster
from distributed import Client
dask.config.set({'distributed.dashboard.link': '/proxy/{port}/status'})
dask.config.get('distributed.dashboard')

{'link': '/proxy/{port}/status', 'export-tool': False}

In [2]:
cluster = SLURMCluster(cores=4, memory="120GB", project="NTDD0005")
cluster.adapt(maximum_jobs=200)
client = Client(cluster)
cluster

VBox(children=(HTML(value='<h2>SLURMCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    …

In [3]:
def print_ds_info(ds, var):
    dt = ds[var].dtype
    itemsize = dt.itemsize
    chunk_size = ds[var].data.chunksize
    size = format_bytes(ds.nbytes)
    _bytes = reduce(mul, chunk_size) * itemsize
    chunk_size_bytes = format_bytes(_bytes)

    print(f'Variable name: {var}')
    print(f'Dataset dimensions: {ds[var].dims}')
    print(f'Chunk shape: {chunk_size}')
    print(f'Dataset shape: {ds[var].shape}')
    print(f'Chunk size: {chunk_size_bytes}')
    print(f'Dataset size: {size}')
    

dirout = "/glade/scratch/abanihi/lens-aws"
def zarr_store(exp, cmp, frequency, var, write=False, dirout=dirout):
    """ Create zarr store name/path
    """
    path = f'{dirout}/{cmp}/{frequency}/cesmLE-{exp}-{var}.zarr'
    if write and os.path.exists(path):
        shutil.rmtree(path)
    print(path)
    return path


def save_data(ds, store):
    try:
        ds.to_zarr(store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

In [4]:
col = intake.open_esm_datastore("../catalogs/glade-campaign-cesm1-le.json", sep='_',)
col

<Intake-esm catalog with 7035 dataset(s) from 191066 asset(s)>

In [5]:
def process_variables(col, variables, experiment, verbose=False):
    query = dict(component=component, stream=stream, variable=variables, experiment=experiment)
    subset = col.search(**query)
    if verbose:
        print(subset.unique(columns=['variable', 'component', 'stream', 'experiment']))
    return subset, query

In [6]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
    
run_config = []
variables = []

for component, stream_val in config.items():
    for stream, v in stream_val.items():
        frequency = v['frequency']
        variable_categories = list(v['variable_category'].keys())
        for v_cat in variable_categories:    
            experiments = list(v['variable_category'][v_cat]['experiment'].keys())
            for exp in experiments:
                chunks = v['variable_category'][v_cat]['experiment'][exp]['chunks']
                variable = v['variable_category'][v_cat]['variable']
                variables.extend(variable)
                col_subset, query = process_variables(col, variable , exp)
                d = {'query': json.dumps(query), 'col': col_subset, 'chunks': chunks, 'frequency': frequency}
                run_config.append(d)

In [7]:
def enforce_chunking(datasets, chunks):
    dsets = datasets.copy()
    choice = random.choice(range(0, len(dsets)))
    for i, (key, ds) in enumerate(dsets.items()):
        c = chunks.copy()
        for dim in list(c):
            if dim not in ds.dims:
                del c[dim]
        ds = ds.chunk(c)
        keys_to_delete = ['intake_esm_dataset_key', 'intake_esm_varname']
        for k in keys_to_delete:
            del ds.attrs[k]
        dsets[key] = ds
        variable = key.split('_')[-1]
        print_ds_info(ds, variable)
        if i == choice:
            print(ds)
        print('\n')
    return dsets

In [8]:
def preprocess(ds):    
    vars_to_drop = [vname for vname in ds.data_vars if vname not in variables]
    coord_vars = [vname for vname in ds.data_vars if 'time' not in ds[vname].dims or 'bound' in vname]
    ds_fixed = ds.set_coords(coord_vars)
    data_vars_dims = []
    for data_var in ds_fixed.data_vars:
        data_vars_dims.extend(list(ds_fixed[data_var].dims))
    coords_to_drop = [coord for coord in ds_fixed.coords if coord not in data_vars_dims]
    grid_vars = list(set(vars_to_drop + coords_to_drop) - set(['time', 'time_bound']))
    ds_fixed = ds_fixed.drop(grid_vars)
    if 'history' in ds_fixed.attrs:
        del ds_fixed.attrs['history']
    return ds_fixed

In [9]:
for run in run_config:
    print("*"*120)
    print(f"query = {run['query']}")
    frequency = run['frequency']
    chunks = run['chunks']
    dsets = run['col'].to_dataset_dict(cdf_kwargs={'chunks': chunks, 'decode_times': False}, preprocess=preprocess, progressbar=False)
    dsets = enforce_chunking(dsets, chunks)
    for key, ds in tqdm(dsets.items(), desc='Saving zarr store'):
        key = key.split('_')
        exp, cmp, var, frequency = key[1], key[0], key[-1], frequency
        store = zarr_store(exp, cmp, frequency, var, write=True, dirout=dirout)
        save_data(ds, store)

************************************************************************************************************************
query = {"component": "ocn", "stream": "pop.h", "variable": ["SST", "SSH", "SFWF", "SHF"], "experiment": "CTRL"}
Variable name: SST
Dataset dimensions: ('member_id', 'time', 'z_t', 'nlat', 'nlon')
Chunk shape: (1, 300, 1, 384, 320)
Dataset shape: (1, 21612, 1, 384, 320)
Chunk size: 147.46 MB
Dataset size: 10.62 GB
<xarray.Dataset>
Dimensions:     (d2: 2, member_id: 1, nlat: 384, nlon: 320, time: 21612, z_t: 1)
Coordinates:
    time_bound  (time, d2) float64 dask.array<chunksize=(300, 2), meta=np.ndarray>
  * z_t         (z_t) float32 500.0
  * time        (time) float64 1.46e+05 1.461e+05 ... 8.033e+05 8.034e+05
  * member_id   (member_id) int64 1
Dimensions without coordinates: d2, nlat, nlon
Data variables:
    SST         (member_id, time, z_t, nlat, nlon) float32 dask.array<chunksize=(1, 300, 1, 384, 320), meta=np.ndarray>
Attributes:
    contents:               

HBox(children=(FloatProgress(value=0.0, description='Saving zarr store', max=4.0, style=ProgressStyle(descript…

/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-SST.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-SSH.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-SHF.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-SFWF.zarr

************************************************************************************************************************
query = {"component": "ocn", "stream": "pop.h", "variable": ["SALT", "TEMP", "UVEL", "VNS", "VNT", "VVEL", "WVEL"], "experiment": "CTRL"}
Variable name: WVEL
Dataset dimensions: ('member_id', 'time', 'z_w_top', 'nlat', 'nlon')
Chunk shape: (1, 6, 60, 384, 320)
Dataset shape: (1, 21612, 60, 384, 320)
Chunk size: 176.95 MB
Dataset size: 637.36 GB


Variable name: SALT
Dataset dimensions: ('member_id', 'time', 'z_t', 'nlat', 'nlon')
Chunk shape: (1, 6, 60, 384, 320)
Dataset shape: (1, 21612, 60, 384, 320)
Chunk size: 176.95 MB
Dataset size: 637.36 GB


Variable name: VNS
Dataset dimensions: ('member_id', 'time', 

HBox(children=(FloatProgress(value=0.0, description='Saving zarr store', max=7.0, style=ProgressStyle(descript…

/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-WVEL.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-SALT.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-VNS.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-VVEL.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-TEMP.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-VNT.zarr
/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-CTRL-UVEL.zarr



In [10]:
from pathlib import Path
p = Path(dirout) / 'ocn'
stores = list(p.rglob("*.zarr"))

# stores = ["/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-SALT.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-VVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-VNT.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-UVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-TEMP.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-WVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-VNS.zarr",
#          "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-SALT.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-VVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-VNT.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-UVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-TEMP.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-WVEL.zarr",
# "/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-RCP85-VNS.zarr"]
for store in stores:
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
        print(store)
        print(ds)
    except Exception as e:
        #print(e)
        print(store)



/glade/scratch/abanihi/lens-aws/ocn/monthly/cesmLE-20C-SSH.zarr
<xarray.Dataset>
Dimensions:     (d2: 2, member_id: 40, nlat: 384, nlon: 320, time: 1872)
Coordinates:
  * member_id   (member_id) int64 1 2 3 4 5 6 7 8 ... 34 35 101 102 103 104 105
  * time        (time) object 1850-02-01 00:00:00 ... 2006-01-01 00:00:00
    time_bound  (time, d2) object dask.array<chunksize=(300, 2), meta=np.ndarray>
Dimensions without coordinates: d2, nlat, nlon
Data variables:
    SSH         (member_id, time, nlat, nlon) float32 dask.array<chunksize=(1, 300, 384, 320), meta=np.ndarray>
Attributes:
    Conventions:               CF-1.0; http://www.cgd.ucar.edu/cms/eaton/netc...
    calendar:                  All years have exactly  365 days.
    cell_methods:              cell_methods = time: mean ==> the variable val...
    contents:                  Diagnostic and Prognostic Variables
    nco_openmp_thread_number:  1
    nsteps_total:              750
    revision:                  $Id: tavg.F90 4

In [11]:
# cluster.close()

In [14]:
# %load_ext watermark
# %watermark -d -iv -m -g -h