Ref: https://github.com/NCAR/cesm-lens-aws/issues/34

In [None]:
%load_ext autoreload
%autoreload 2

import json
import os
import pprint
import random
import shutil
from functools import reduce
from operator import mul

import xarray as xr
import yaml
from distributed import Client
from distributed.utils import format_bytes
from tqdm.auto import tqdm

import dask
import intake
from dask_jobqueue import SLURMCluster
from helpers import (create_grid_dataset, enforce_chunking, get_grid_vars,
                     print_ds_info, process_variables, save_data, zarr_store)

dask.config.set({"distributed.dashboard.link": "/proxy/{port}/status"})

In [None]:
cluster = SLURMCluster(cores=4, memory="200GB")
cluster.adapt(minimum_jobs=0, maximum_jobs=35)
# cluster.scale(jobs=3)
client = Client(cluster)
cluster

In [None]:
# Set to True if saving large Zarr files is resulting in KilledWorker or Dask crashes.
BIG_SAVE = False
if BIG_SAVE:
    min_workers = 10
    client.wait_for_workers(min_workers)

In [None]:
# It's safer to use a dash '-' to separate fields, not underscores, because CESM variables have underscores.
field_separator = "-"
col = intake.open_esm_datastore(
    "../catalogs/glade-campaign-cesm1-le.json", sep=field_separator,
)
col

In [None]:
def _preprocess(ds):
    """Drop all unnecessary variables and coordinates"""
    grid_vars = get_grid_vars(ds, variables)
    ds_fixed = ds.drop(grid_vars)
    if "history" in ds_fixed.attrs:
        del ds_fixed.attrs["history"]
    return ds_fixed


def _preprocess_ice_nh(ds):
    member_31 = "/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/ice/proc/tseries/daily/hi_d/b.e11.B20TRC5CNBDRD.f09_g16.031.cice.h1.hi_d_nh.19200101-20051231.nc"
    m_31 = xr.open_dataset(member_31, chunks={"time": 2}, decode_times=False)

    # Fix some issues in member_35_ice_nh
    if ds.time.min() == 0:
        ds = ds.assign(time=m_31.time)

    return _preprocess(ds)


def _preprocess_ice_sh(ds):
    member_31 = "/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/ice/proc/tseries/daily/hi_d/b.e11.B20TRC5CNBDRD.f09_g16.031.cice.h1.hi_d_sh.19200101-20051231.nc"
    m_31 = xr.open_dataset(member_31, chunks={"time": 2}, decode_times=False)

    # Fix some issues in member_35_ice_sh
    if ds.time.min() == 0:
        ds = ds.assign(time=m_31.time)

    return _preprocess(ds)


def _preprocess_lnd(ds):
    path = "/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/lnd/proc/tseries/daily/FSNO/b.e11.BRCP85C5CNBDRD.f09_g16.001.clm2.h1.FSNO.20060101-20801231.nc"
    lnd_grid = create_grid_dataset(path, variables=["FSNO"])
    ds = ds.assign(lat=lnd_grid.lat)
    return _preprocess(ds)

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

run_config = []
variables = []

for component, stream_val in config.items():
    for stream, v in stream_val.items():
        frequency = v["frequency"]
        variable_categories = list(v["variable_category"].keys())
        for v_cat in variable_categories:
            experiments = list(
                v["variable_category"][v_cat]["experiment"].keys()
            )
            for exp in experiments:
                chunks = v["variable_category"][v_cat]["experiment"][exp][
                    "chunks"
                ]
                variable = v["variable_category"][v_cat]["variable"]
                variables.extend(variable)
                col_subset, query = process_variables(
                    col, variable, component, stream, exp
                )
                if not col_subset.df.empty:
                    d = {
                        "query": query,
                        "col": col_subset,
                        "chunks": chunks,
                        "frequency": frequency,
                    }
                    run_config.append(d)

In [None]:
dirout = "/glade/scratch/abanihi/lens-aws"

for run in run_config:
    print("*" * 120)
    query = run["query"]
    print(f"query = {query}")
    frequency = run["frequency"]
    chunks = run["chunks"]
    preprocess = _preprocess
    if query["experiment"] == "20C" and query["stream"] == "cice.h1":
        if query["component"] == "ice_sh":
            preprocess = _preprocess_ice_sh
        elif query["component"] == "ice_nh":
            preprocess = _preprocess_ice_sh

    if query["experiment"] == "RCP85" and query["stream"] == "clm2.h1":
        preprocess = _preprocess_lnd

    print(f"using preprocess={preprocess.__name__}")
    dsets = run["col"].to_dataset_dict(
        cdf_kwargs={"chunks": chunks, "decode_times": False},
        preprocess=preprocess,
        progressbar=False,
    )
    dsets = enforce_chunking(dsets, chunks, field_separator)
    for key, ds in tqdm(dsets.items(), desc="Saving zarr store"):
        key = key.split(field_separator)
        exp, cmp, var, frequency = key[1], key[0], key[-1], frequency
        store = zarr_store(exp, cmp, frequency, var, write=True, dirout=dirout)
        save_data(ds, store)

In [None]:
# Make sure the zarr stores were properly written

from pathlib import Path

p = Path(dirout) / "lnd"
stores = list(p.rglob("*.zarr"))
for store in stores:
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print("\n")
        print(store)
        print(ds)
    except Exception as e:
        # print(e)
        print(store)

In [None]:
cluster.close()

In [None]:
# %load_ext watermark
# %watermark -d -iv -m -g -h