Ref: https://github.com/NCAR/cesm-lens-aws/issues/34


In [None]:
%load_ext autoreload
%autoreload 2

import json
import os
import pprint
import random
import shutil
from functools import reduce, partial
from operator import mul

import xarray as xr
import yaml
from distributed import Client
from distributed.utils import format_bytes
from tqdm.auto import tqdm
import pandas as pd
from collections import Counter

import dask
import intake
from dask_jobqueue import SLURMCluster
from helpers import (create_grid_dataset, enforce_chunking, get_grid_vars,
                     print_ds_info, process_variables, save_data, zarr_store, fix_time, inspect_written_stores)

#dask.config.set({"distributed.dashboard.link": "/proxy/{port}/status"})
xr.set_options(keep_attrs=True)
import numpy as np

In [None]:
cluster = SLURMCluster(cores=8, memory="200GB", processes=4)
cluster.adapt(minimum_jobs=1, maximum_jobs=35)
# cluster.scale(jobs=3)
client = Client(cluster)
cluster

In [None]:
# # Set to True if saving large Zarr files is resulting in KilledWorker or Dask crashes.
# BIG_SAVE = False
# if BIG_SAVE:
#     min_workers = 10
#     client.wait_for_workers(min_workers)

In [None]:
# It's safer to use a dash '-' to separate fields, not underscores, because CESM variables have underscores.
field_separator = "-"
col = intake.open_esm_datastore(
    "../catalogs/glade-campaign-cesm1-le.json", sep=field_separator,
)
col

In [None]:
dirout = "/glade/scratch/abanihi/data/lens-aws"

In [None]:
def _preprocess(ds):
    """Drop all unnecessary variables and coordinates"""

    vars_to_drop = [vname for vname in ds.data_vars if vname not in variables]
    coord_vars = [
        vname
        for vname in ds.data_vars
        if "time" not in ds[vname].dims or "bound" in vname or "bnds" in vname
    ]
    ds_fixed = ds.set_coords(coord_vars)
    data_vars_dims = []
    for data_var in ds_fixed.data_vars:
        data_vars_dims.extend(list(ds_fixed[data_var].dims))
    coords_to_drop = [
        coord for coord in ds_fixed.coords if coord not in data_vars_dims
    ]
    grid_vars = list(
        set(vars_to_drop + coords_to_drop)
        - set(["time", "time_bound", "time_bnds", "time_bounds"])
    )
    ds_fixed = ds_fixed.drop(grid_vars).reset_coords()
    if "history" in ds_fixed.attrs:
        del ds_fixed.attrs["history"]
    return ds_fixed


member_31_nh = "/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/ice/proc/tseries/daily/hi_d/b.e11.B20TRC5CNBDRD.f09_g16.031.cice.h1.hi_d_nh.19200101-20051231.nc"
m_31_nh = xr.open_dataset(member_31_nh, chunks={})


def _preprocess_ice_nh(ds):
    # Fix time duplication issues in member_35_ice_nh
    if ds.attrs["title"] == "b.e11.B20TRC5CNBDRD.f09_g16.035":
        attrs = ds.time.attrs
        encoding = ds.time.encoding
        bounds_attrs = ds.time_bounds.attrs
        bounds_encoding = ds.time_bounds.attrs

        ds = ds.assign_coords(time=m_31_nh.time)
        ds.time_bounds.data = m_31_nh.time_bounds.data
        ds.time.attrs, ds.time.encoding = attrs, encoding
        ds.time_bounds.attrs, ds.time_bounds.encoding = (
            bounds_attrs,
            bounds_encoding,
        )
    return _preprocess(ds)


member_31_sh = "/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/ice/proc/tseries/daily/hi_d/b.e11.B20TRC5CNBDRD.f09_g16.031.cice.h1.hi_d_sh.19200101-20051231.nc"
m_31_sh = xr.open_dataset(member_31_sh, chunks={})


def _preprocess_ice_sh(ds):
    # Fix time duplication issues in member_35_ice_sh
    if ds.attrs["title"] == "b.e11.B20TRC5CNBDRD.f09_g16.035":
        attrs = ds.time.attrs
        encoding = ds.time.encoding
        bounds_attrs = ds.time_bounds.attrs
        bounds_encoding = ds.time_bounds.attrs

        ds = ds.assign_coords(time=m_31_sh.time)
        ds.time_bounds.data = m_31_sh.time_bounds.data
        ds.time.attrs, ds.time.encoding = attrs, encoding
        ds.time_bounds.attrs, ds.time_bounds.encoding = (
            bounds_attrs,
            bounds_encoding,
        )
    return _preprocess(ds)


def _preprocess_lnd(ds):
    grid = xr.open_zarr(
        "/glade/scratch/abanihi/data/lens-aws/lnd/static/grid.zarr"
    )
    ds = ds.assign_coords(lat=grid["lat"])
    return _preprocess(ds)


def _preprocess_atm(ds):
    grid = xr.open_zarr(
        "/glade/scratch/abanihi/data/lens-aws/atm/static/grid.zarr"
    )
    ds = ds.assign_coords(lat=grid["lat"])
    return _preprocess(ds)

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)

run_config = []
variables = []

for component, stream_val in config.items():
    for stream, v in stream_val.items():
        frequency = v["frequency"]
        freq = v["freq"]
        time_bounds_dim = v["time_bounds_dim"]
        variable_categories = list(v["variable_category"].keys())
        for v_cat in variable_categories:
            experiments = list(
                v["variable_category"][v_cat]["experiment"].keys()
            )
            for exp in experiments:
                if frequency == "daily" and exp == "20C" and component == "atm":
                    # date_ranges = ['1990010100Z-2005123118Z', '2026010100Z-2035123118Z', '2071010100Z-2080123118Z']
                    chunks = v["variable_category"][v_cat]["experiment"][exp][
                        "chunks"
                    ]
                    variable = v["variable_category"][v_cat]["variable"]
                    variables.extend(variable)
                    col_subset, query = process_variables(
                        col, variable, component, stream, exp
                    )
                    if not col_subset.df.empty:
                        d = {
                            "query": query,
                            "col": col_subset,
                            "chunks": chunks,
                            "frequency": frequency,
                            "freq": freq,
                            "time_bounds_dim": time_bounds_dim,
                        }
                        run_config.append(d)
run_config

In [None]:
for run in tqdm(run_config, desc="runs"):
    print("*" * 120)
    query = run["query"]
    print(f"query = {query}")
    frequency = run["frequency"]
    chunks = run["chunks"]
    freq = run["freq"]
    time_bounds_dim = run["time_bounds_dim"]
    preprocess = _preprocess
    if query["experiment"] == "20C" and query["stream"] == "cice.h1":
        if query["component"] == "ice_sh":
            preprocess = _preprocess_ice_sh
        elif query["component"] == "ice_nh":
            preprocess = _preprocess_ice_nh
    elif query["component"] == "lnd":
        preprocess = _preprocess_lnd
    elif query["component"] == "atm":
        preprocess = _preprocess_atm

    print(preprocess.__name__)

    dsets = run["col"].to_dataset_dict(
        cdf_kwargs={"chunks": chunks, "decode_times": True, "use_cftime": True},
        preprocess=preprocess,
        progressbar=True,
    )
    dsets = enforce_chunking(dsets, chunks, field_separator)
    for key, ds in tqdm(dsets.items(), desc="Saving zarr store"):
        print(ds.get_index("time").is_monotonic_increasing)
        key = key.split(field_separator)
        exp, cmp, var, frequency = key[1], key[0], key[-1], frequency
        if frequency != "hourly6":
            if exp == "20C":
                start = "1850-01"
                end = "2006-01"
                if frequency == "monthly":
                    ds = fix_time(
                        ds,
                        start=start,
                        end=end,
                        freq=freq,
                        time_bounds_dim=time_bounds_dim,
                    )
                else:
                    ds = fix_time(
                        ds,
                        start=None,
                        end=None,
                        freq=None,
                        time_bounds_dim=time_bounds_dim,
                        generate_bounds=False,
                    )
                ds_20c = ds.sel(time=slice("1920", None)).chunk(chunks)
                store = zarr_store(
                    exp, cmp, frequency, var, write=False, dirout=dirout
                )
                save_data(ds_20c, store)
                ds_hist = ds.sel(time=slice(None, "1919"), member_id=1).chunk(
                    chunks
                )
                exp = "HIST"
                store = zarr_store(
                    exp, cmp, frequency, var, write=False, dirout=dirout
                )
                save_data(ds_hist, store)

            elif exp == "RCP85":
                start = "2006-01"
                end = "2101-01"
                if frequency == "monthly":
                    ds = fix_time(
                        ds,
                        start=start,
                        end=end,
                        freq=freq,
                        time_bounds_dim=time_bounds_dim,
                    )
                else:
                    ds = fix_time(
                        ds,
                        start=None,
                        end=None,
                        freq=None,
                        time_bounds_dim=time_bounds_dim,
                        generate_bounds=False,
                    )
                store = zarr_store(
                    exp, cmp, frequency, var, write=False, dirout=dirout
                )
                save_data(ds, store)

            elif exp == "CTRL":
                ds = fix_time(
                    ds,
                    start=None,
                    end=None,
                    freq=None,
                    time_bounds_dim=time_bounds_dim,
                    generate_bounds=False,
                )
                store = zarr_store(
                    exp, cmp, frequency, var, write=False, dirout=dirout
                )
                save_data(ds, store)
        else:
            if exp == "20C":
                start = "1990"
                end = "2006-01-01T06:00"
                ds = fix_time(
                    ds,
                    start=start,
                    end=end,
                    freq=freq,
                    time_bounds_dim=time_bounds_dim,
                    instantaneous=True,
                )
                frequency_x = f"{frequency}-1990-2005"
                store = zarr_store(
                    exp, cmp, frequency_x, var, write=False, dirout=dirout
                )
                save_data(ds, store)

            elif exp == "RCP85":
                frequency_x = f"{frequency}-2026-2035"
                ds_1 = ds.sel(time=slice(None, "2036"))
                start = "2026"
                end = "2036-01-01T06:00"
                ds_1 = fix_time(
                    ds_1,
                    start=start,
                    end=end,
                    freq=freq,
                    time_bounds_dim=time_bounds_dim,
                    instantaneous=True,
                ).chunk(chunks)
                store_1 = zarr_store(
                    exp, cmp, frequency_x, var, write=False, dirout=dirout
                )

                frequency_x = f"{frequency}-2071-2080"
                start = "2071"
                end = "2081-01-01T06:00"
                ds_2 = ds.sel(time=slice("2071", None))
                ds_2 = fix_time(
                    ds_2,
                    start=start,
                    end=end,
                    freq=freq,
                    time_bounds_dim=time_bounds_dim,
                    instantaneous=True,
                ).chunk(chunks)
                store_2 = zarr_store(
                    exp, cmp, frequency_x, var, write=False, dirout=dirout
                )

                assert ds.time.size == (ds_1.time.size + ds_2.time.size)
                save_data(ds_1, store_1)
                save_data(ds_2, store_2)

In [None]:
ds_hist.time, ds_20c.time

In [None]:
# Make sure the zarr stores were properly written

inspect_written_stores(dirout, random_sample_size=10)

In [None]:
# %load_ext watermark
# %watermark -d -iv -m -g -h