# Zarrify DART Reanalysis History Files

In [None]:
import numpy as np
import xarray as xr
import intake
import ast

import dask.distributed
from dask.distributed import Client
from ncar_jobqueue import NCARCluster

import fsspec
from pathlib import Path
import shutil 
import os
from functools import reduce
from operator import mul

### Configuration Parameters

In [None]:
catalog_path = './dart-zarr-input.json'

output_folder = '/glade/scratch/bonnland/DART/ds345.0/zarr-publish'

#output_variables = ['T', 'PS', 'Q', 'US', 'VS', 'CLDICE', 'CLDLIQ']

output_variables = ['CLDICE', 'CLDLIQ', 'VS']

# Number of elements per chunk in the target stores.
# A negative value means "don't chunk this dimension".
target_chunks = {'lat': 32, 
                 'slat': 32, 
                 'lon': 32, 
                 'slon': 32, 
                 'lev': -1,
                 'time': 30, 
                 'member_id': 10}

# target_chunks = {'lat': 32, 
#                  'lon': 32, 
#                  'lev': -1,
#                  'time': 30, 
#                  'member_id': 10}

## Run These Cells for Dask Processing

In [None]:
import dask
from ncar_jobqueue import NCARCluster

# For Cheyenne

# These are "per node", and then .scale() selects the number of nodes.
#walltime = "1:00:00"
#walltime = "00:30:00"
walltime = "00:45:00"

#  This results in about 20% maximum memory usage.
#cluster = NCARCluster(cores=1, processes=1, memory='109GB', walltime=walltime)
#num_nodes = 16


# Run 16 workers on 4 nodes, giving each worker around 25GB RAM.  
#cluster = NCARCluster(cores=4, processes=4, memory='109GB', walltime=walltime)

# # Run 4 workers on each node, giving each worker around 25GB RAM.  
# cluster = NCARCluster(cores=16, processes=4, memory='109GB', walltime=walltime)
# num_nodes = 2

# Run <= 4 workers on each node to avoid crashes.
cluster = NCARCluster(cores=10, processes=4, memory='109GB', walltime=walltime)
num_nodes = 8

cluster.scale(jobs=num_nodes)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

### Zarr-related Helper Functions

In [None]:
### Preprocessing Steps for each input dataset before merge
def preprocess(ds):
    """Pare down each input dataset to a single variable.  
       The subsequent merge will eliminate unused coordinates automatically. 
        
        This function does not allow additional arguments, so the target 
        output variable needs to be defined globally in TARGET_VAR.
    """
    drop_vars = [var for var in ds.data_vars 
                 if var != TARGET_VAR]

    ds_fixed = ds.drop_vars(drop_vars)
    
    return ds_fixed

In [None]:
def print_ds_info(ds, var):
    """Function for printing chunking information"""

    print(f'print_ds_info: var == {var}')
    dt = ds[var].dtype
    itemsize = dt.itemsize
    chunk_size = ds[var].data.chunksize
    size = format_bytes(ds.nbytes)
    _bytes = reduce(mul, chunk_size) * itemsize
    chunk_size_bytes = format_bytes(_bytes)

    print(f'Variable name: {var}')
    print(f'Dataset dimensions: {ds[var].dims}')
    print(f'Chunk shape: {chunk_size}')
    print(f'Dataset shape: {ds[var].shape}')
    print(f'Chunk size: {chunk_size_bytes}')
    print(f'Dataset size: {size}')

    
def zarr_store(var, dirout, write=False):
    """ Create zarr store name/path
    """
    path = f'{dirout}/{var}.zarr'
    if write and os.path.exists(path):
        shutil.rmtree(path)
    print(path)
    return path


def save_data(ds, store):
    try:
        ds.to_zarr(store=store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

        
def zarr_check():
    '''Make sure the zarr stores were properly written'''

    from pathlib import Path
    p = Path(output_folder)
    stores = list(p.rglob("*.zarr"))
    #stores = list(p.rglob("*.rcp45.day.NAM-22i.raw.zarr"))
    for store in stores:
        try:
            ds = xr.open_zarr(store.as_posix(), consolidated=True)
            print('\n')
            print(store)
            print(ds)
        except Exception as e:
            #print(e)
            print(store)

In [None]:
# Open catalog with single-valued "variable" column

# Have the catalog interpret the "variable" column as a list of values.
col = intake.open_esm_datastore(catalog_path)
col


In [None]:
# Show the eventual output store base names.
print("Eventual store base names:")
print(col.keys())

In [None]:
REALLY_SAVE = True

for variable in output_variables:
    # This variable gets used in the "preprocess" function and must be defined now in the global scope.
    TARGET_VAR = variable

    col_subset = col.search(variable = variable)
    # Produce var-based stores.  The catalog will determine how many stores and their base names.
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        dsets = col_subset.to_dataset_dict(zarr_kwargs={'consolidated': True}, preprocess=preprocess)

    ds_out = dsets[variable]
    
    # Specify final chunking.
    ds_out = ds_out.chunk(target_chunks)
    
    # Confirm output contents.
    print_ds_info(ds_out, variable)
    
    store = zarr_store(variable, dirout = output_folder, write=REALLY_SAVE)
    if REALLY_SAVE:
        save_data(ds_out, store=store)
        print("     ... Done.")
    else:
        print("     ... (Skipping)")
        del ds_out


In [None]:
# Open each output dataset to confirm it was created properly.

zarr_check()

In [None]:
cluster.close()