Ref: https://github.com/NCAR/cesm-lens-aws/issues/34

In [None]:
import json
import os
import pprint
import random
import shutil
from functools import reduce
from operator import mul

import xarray as xr
import yaml
from tqdm.auto import tqdm

import dask
import intake
from dask_jobqueue import SLURMCluster
from distributed import Client
from distributed.utils import format_bytes

dask.config.set({"distributed.dashboard.link": "/proxy/{port}/status"})

In [None]:
cluster = SLURMCluster(cores=4, memory="200GB", project="STDD0003")
cluster.adapt(minimum_jobs=2, maximum_jobs=4)
# cluster.scale(jobs=3)
client = Client(cluster)
cluster

In [None]:
# Set to True if saving large Zarr files is resulting in KilledWorker or Dask crashes.
BIG_SAVE = False
if BIG_SAVE:
    min_workers = 10
    client.wait_for_workers(min_workers)

In [None]:
def print_ds_info(ds, var):
    """Function for printing chunking information"""
    dt = ds[var].dtype
    itemsize = dt.itemsize
    chunk_size = ds[var].data.chunksize
    size = format_bytes(ds.nbytes)
    _bytes = reduce(mul, chunk_size) * itemsize
    chunk_size_bytes = format_bytes(_bytes)

    print(f"Variable name: {var}")
    print(f"Dataset dimensions: {ds[var].dims}")
    print(f"Chunk shape: {chunk_size}")
    print(f"Dataset shape: {ds[var].shape}")
    print(f"Chunk size: {chunk_size_bytes}")
    print(f"Dataset size: {size}")


dirout = "/glade/work/bonnland/lens-aws"


def zarr_store(exp, cmp, frequency, var, write=False, dirout=dirout):
    """ Create zarr store name/path
    """
    path = f"{dirout}/{cmp}/{frequency}/cesmLE-{exp}-{var}.zarr"
    if write and os.path.exists(path):
        shutil.rmtree(path)
    print(path)
    return path


def save_data(ds, store):
    try:
        ds.to_zarr(store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

In [None]:
# It's safer to use a dash '-' to separate fields, not underscores, because CESM variables have underscores.
field_separator = "-"
col = intake.open_esm_datastore(
    "../catalogs/glade-campaign-cesm1-le.json", sep=field_separator,
)
col

In [None]:
def process_variables(col, variables, experiment, verbose=False):
    query = dict(
        component=component,
        stream=stream,
        variable=variables,
        experiment=experiment,
    )
    subset = col.search(**query)
    if verbose:
        print(
            subset.unique(
                columns=["variable", "component", "stream", "experiment"]
            )
        )
    return subset, query

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

run_config = []
variables = []

for component, stream_val in config.items():
    for stream, v in stream_val.items():
        frequency = v["frequency"]
        variable_categories = list(v["variable_category"].keys())
        for v_cat in variable_categories:
            experiments = list(
                v["variable_category"][v_cat]["experiment"].keys()
            )
            for exp in experiments:
                chunks = v["variable_category"][v_cat]["experiment"][exp][
                    "chunks"
                ]
                variable = v["variable_category"][v_cat]["variable"]
                variables.extend(variable)
                col_subset, query = process_variables(col, variable, exp)
                d = {
                    "query": json.dumps(query),
                    "col": col_subset,
                    "chunks": chunks,
                    "frequency": frequency,
                }
                run_config.append(d)

In [None]:
def enforce_chunking(datasets, chunks):
    """Enforce uniform chunking"""
    dsets = datasets.copy()
    choice = random.choice(range(0, len(dsets)))
    for i, (key, ds) in enumerate(dsets.items()):
        c = chunks.copy()
        for dim in list(c):
            if dim not in ds.dims:
                del c[dim]
        ds = ds.chunk(c)
        keys_to_delete = ["intake_esm_dataset_key", "intake_esm_varname"]
        for k in keys_to_delete:
            del ds.attrs[k]
        dsets[key] = ds
        variable = key.split(field_separator)[-1]
        print_ds_info(ds, variable)
        if i == choice:
            print(ds)
        print("\n")
    return dsets

In [None]:
def preprocess(ds):
    """Drop all unnecessary variables and coordinates"""
    vars_to_drop = [vname for vname in ds.data_vars if vname not in variables]
    coord_vars = [
        vname
        for vname in ds.data_vars
        if "time" not in ds[vname].dims or "bound" in vname
    ]
    ds_fixed = ds.set_coords(coord_vars)
    data_vars_dims = []
    for data_var in ds_fixed.data_vars:
        data_vars_dims.extend(list(ds_fixed[data_var].dims))
    coords_to_drop = [
        coord for coord in ds_fixed.coords if coord not in data_vars_dims
    ]
    grid_vars = list(
        set(vars_to_drop + coords_to_drop) - set(["time", "time_bound"])
    )
    ds_fixed = ds_fixed.drop(grid_vars)
    if "history" in ds_fixed.attrs:
        del ds_fixed.attrs["history"]
    return ds_fixed

In [None]:
for run in run_config:
    print("*" * 120)
    print(f"query = {run['query']}")
    frequency = run["frequency"]
    chunks = run["chunks"]
    dsets = run["col"].to_dataset_dict(
        cdf_kwargs={"chunks": chunks, "decode_times": False},
        preprocess=preprocess,
        progressbar=False,
    )
    dsets = enforce_chunking(dsets, chunks)
    for key, ds in tqdm(dsets.items(), desc="Saving zarr store"):
        key = key.split(field_separator)
        exp, cmp, var, frequency = key[1], key[0], key[-1], frequency
        store = zarr_store(exp, cmp, frequency, var, write=True, dirout=dirout)
        save_data(ds, store)

In [None]:
# Make sure the zarr stores were properly written

from pathlib import Path

p = Path(dirout) / "ocn"
stores = list(p.rglob("*.zarr"))
for store in stores:
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print("\n")
        print(store)
        print(ds)
    except Exception as e:
        # print(e)
        print(store)

In [None]:
cluster.close()

In [None]:
# %load_ext watermark
# %watermark -d -iv -m -g -h