In [None]:
import xarray as xr
import intake
from tqdm.auto import tqdm
import shutil 
from pathlib import Path

import os
from functools import reduce
import pprint
import json

from distributed.utils import format_bytes
from operator import mul

## Run These Cells for Dask Processing (NEW)

In [None]:
import dask
from dask_jobqueue import PBSCluster

num_jobs = 15
walltime = '1:00:00'
memory='60GB' 

# This line makes the dashboard link work on JupyterHub.
dask.config.set({'distributed.dashboard.link': '/proxy/{port}/status'})

cluster = PBSCluster(cores=1, processes=1, walltime=walltime, memory=memory, queue='casper', 
                     resource_spec='select=1:ncpus=1:mem=10GB',)
cluster.scale(jobs=num_jobs)


from distributed import Client
client = Client(cluster)
cluster

In [None]:
cluster.close()

## Code for Creating Combined Zarr Metadata

In [None]:
from json import JSONDecodeError

def get_deserialized_json(field):
    '''Attempt to deserialize string, and if this fails, return the string.'''
    try:
        field = json.loads(field)
    except JSONDecodeError:
        pass
    return field


def reduce_metadata(metadata, member_ids):
    '''Limit metadata to the given member_id fields.'''
    keys = set(ds_hist.attrs.keys())
    reduced_metadata = {}
    for key in keys:
        field = get_deserialized_json(metadata[key])

        if isinstance(field, dict):
            reduced = {id: field[id] for id in member_ids if id in field.keys()}
        else:
            reduced = field

        # Filter out any empty dictionaries.
        if reduced:
            reduced_metadata[key] = reduced
            
    return reduced_metadata


def combine_metadata(ds_hist, ds_fut, scenario):
    '''Take two Xarray datasets, combine their metadata, and add Zarr-specific metadata.'''

    # Drop metadata member ids from ds_hist metadata that are not present in ds_fut
    member_ids = ds_fut.coords['member_id'].values
    hist_attrs = reduce_metadata(ds_hist.attrs, member_ids)
    
    keys = set(hist_attrs.keys())
    keys = keys.union(set(ds_fut.attrs.keys()))

    metadata = {}
    for key in keys:
        if (key in hist_attrs) and (key in ds_fut.attrs):
            hist_value = hist_attrs[key]
            fut_value = get_deserialized_json(ds_fut.attrs[key])
            
            # If both stores have identical metadata, assign the metadata unchanged.
            if hist_value == fut_value:
                metadata[key] = hist_value
            else:
                # Otherwise, place both versions in a new dictionary.
                metadata[key] = {'hist': hist_value, scenario: fut_value}

        elif key in ds_hist.attrs:
            metadata[key] = {'hist': hist_value}

        else:
            metadata[key] = {scenario: fut_value}
        
        # serialize any metadata dictionary to string.
        if isinstance(metadata[key], dict):
            metadata[key] = json.dumps(metadata[key])

    #metadata['zarr-dataset-reference'] = 'For dataset documentation, see DOI https://doi.org/10.5065/D6SJ1JCH'
    metadata['zarr-note-time'] = f'Historical data runs 1950 to 2005, future data ({scenario}) runs 2006 to 2100.'
    #metadata['zarr-version'] = '1.0'
    return metadata


## Print Dataset Diagnostic Information

In [None]:
def print_ds_info(ds, var):
    """Function for printing chunking information"""

    print(f'print_ds_info: var == {var}')
    dt = ds[var].dtype
    itemsize = dt.itemsize
    chunk_size = ds[var].data.chunksize
    size = format_bytes(ds.nbytes)
    _bytes = reduce(mul, chunk_size) * itemsize
    chunk_size_bytes = format_bytes(_bytes)

    print(f'Variable name: {var}')
    print(f'Dataset dimensions: {ds[var].dims}')
    print(f'Chunk shape: {chunk_size}')
    print(f'Dataset shape: {ds[var].shape}')
    print(f'Chunk size: {chunk_size_bytes}')
    print(f'Dataset size: {size}')


## Zarr Save Utility Functions

In [None]:
def save_data(ds, chunks, store):
    try:
        ds.to_zarr(store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

        
def zarr_check(directory):
    '''Make sure the zarr stores were properly written'''

    p = Path(directory)
    stores = list(p.rglob("*.zarr"))
    #stores = list(p.rglob("tasmax.hist-rcp85.day.NAM-22i*.zarr"))
    for store in stores:
        try:
            ds = xr.open_zarr(store.as_posix(), consolidated=True)
            print('\n')
            print(store)
            print(ds)
            #pprint.pprint(ds.attrs, width=150, compact=True)        
        except Exception as e:
            #print(e)
            print(store)

## Find and Process Zarr Stores

In [None]:
input_directory = '/glade/scratch/bonnland/na-cordex/zarr/'
output_directory = '/glade/scratch/bonnland/na-cordex/zarr-concat/'
temp_directory = '/glade/scratch/bonnland/na-cordex/zarr-temp/'

# Create concatenations for hist+rcp runs.
scenario = 'rcp45'
#scenario = 'rcp85'

p = Path(input_directory)
input_stores = list(p.glob(f'*.{scenario}.*.zarr'))

WRITE_OUTPUT = True

for store in input_stores:
    future_store = store.as_posix()
    historical_store = future_store.replace(scenario, 'hist')

    # Determine the output store name and location.
    output_store_name = future_store.replace(scenario, 'hist-' + scenario)
    output_store_name = output_store_name.split('/')[-1]
    output_store = output_directory + output_store_name

    print(f"\n\nCreating store {output_store_name}")
    if WRITE_OUTPUT:
        # Produce output store if it does not exist yet
        if not os.path.exists(output_store):
            os.makedirs(output_store)
        else:
            # Store exists; skip to the next case.
            continue

    ds_hist = xr.open_zarr(historical_store, consolidated=True)
    ds_fut = xr.open_zarr(future_store, consolidated=True)
    
    hist_vars = list(ds_hist.data_vars.keys())
    fut_vars = list(ds_fut.data_vars.keys())
 
    # Verify the data variables are the same for both datasets, and there is only one variable.
    #print(f'hist_vars = {hist_vars}')
    #print(f'fut_vars = {fut_vars}')
    assert(hist_vars == fut_vars)
    assert(len(hist_vars) == 1)
    data_var = hist_vars[0]
    
    # Determine final chunk sizes.
    chunks = dict(zip(ds_fut[data_var].dims, ds_fut[data_var].data.chunksize))
    print(chunks)

    #hist_members = ds_hist.coords['member_id'].values
    #print(f'Historical member ids: {hist_members}')
    
    # Drop member_ids from ds_hist that are not in ds_fut.
    member_ids = ds_fut.coords['member_id'].values
    #print(f'Future member ids: {member_ids}')

    if len(member_ids) != len(ds_hist.coords['member_id'].values):
        ds_hist = ds_hist.sel(member_id = member_ids)
        ds_hist = ds_hist.chunk(chunks)
        temp_store_name = historical_store.split('/')[-1]
        temp_store = temp_directory + temp_store_name
        if WRITE_OUTPUT and not os.path.exists(temp_store):
            os.makedirs(temp_store)
            print(f'\n\n  Writing temporary store: {temp_store}...')
            save_data(ds_hist, chunks, temp_store)
        ds_hist = xr.open_zarr(temp_store, consolidated=True)
            
            

    #hist_members = ds_hist.coords['member_id'].values
    #print(f'Modified Historical member ids: {hist_members}')

    # Print some diagnostic info to get that warm, fuzzy feeling.
    #print_ds_info(ds_hist, hist_vars[0])
    #print_ds_info(ds_fut, fut_vars[0])

    # Verify that the data variable chunk sizes match for both datasets
    #print(f'hist chunksize = {ds_hist[data_var].data.chunksize}')
    #print(f'fut chunksize = {ds_fut[data_var].data.chunksize}')
    #assert(ds_hist[data_var].data.chunksize == ds_fut[data_var].data.chunksize)
    
    #print(ds_hist[data_var].data.chunks)
    
    metadata = combine_metadata(ds_hist, ds_fut, scenario)
    #print(f'\n\nMetadata for {output_store}:\n')
    #pprint.pprint(metadata, width=150, compact=True)

    if WRITE_OUTPUT:
        # Combine stores
        ds_out = xr.concat([ds_hist, ds_fut], dim='time', coords='minimal').sortby('time')

        #if 'height' in ds_out[data_var].coords.keys():
        #    ds_out = ds_out.drop_dims('height')
        print(ds_out.coords)
            
        # Delete the existing encoding to avoid later errors.
        ds_out[data_var].encoding = {}
        #del ds_out.data.encoding['chunks']
        #ds_out[data_var].coords.encoding = {}

        # De-fragment chunks along the time dimension.
        chunks = dict(zip(ds_fut[data_var].dims, ds_fut[data_var].data.chunksize))
        print(chunks)
        ds_out = ds_out.chunk(chunks)

        print(ds_out[data_var].encoding)
        
        # Print diagnostic info.
        print_ds_info(ds_out, hist_vars[0])
        
        # Assign final metadata
        ds_out.attrs = metadata
        
        # Write the store.
        print(f'\n\n  Writing store: {output_store}...')
        save_data(ds_out, chunks, output_store)
    

## Publish "eval" scenario stores with added metadata

In [None]:
input_directory = '/glade/scratch/bonnland/na-cordex/zarr/'
output_directory = '/glade/scratch/bonnland/na-cordex/zarr-publish/'

scenario = 'eval'

p = Path(input_directory)
input_stores = list(p.glob(f'*.{scenario}.*.zarr'))

WRITE_OUTPUT = True

for store in input_stores:
    store = store.as_posix()

    # Determine the output store name and location.
    output_store_name = store.split('/')[-1]
    output_store = output_directory + output_store_name

    print(f"\n\nCreating store {output_store_name}")
    if WRITE_OUTPUT:
        # Produce output store if it does not exist yet
        if not os.path.exists(output_store):
            os.makedirs(output_store)
        else:
            # Store exists; skip to the next case.
            continue

    ds = xr.open_zarr(store, consolidated=True)
    ds.attrs['zarr-dataset-reference'] = 'For dataset documentation, see DOI https://doi.org/10.5065/D6SJ1JCH'
    ds.attrs['zarr-note-time'] = f'ERA-Interim data runs from 1980 to 2014.'
    ds.attrs['zarr-version'] = '2.0'
    save_data(ds, None, output_store)

In [None]:
output_directory = '/glade/scratch/bonnland/na-cordex/zarr-publish'
zarr_check(output_directory)

### Explore Zarr Stores for Learning Purposes

In [None]:
#p = Path(input_directory)
#input_stores = list(p.glob(f'*.{scenario}.*.zarr'))
p = Path(output_directory)
input_stores = list(p.glob(f'prec.hist-rcp85.day.NAM-22i.*.zarr'))
    
ds = xr.open_zarr(input_stores[0].as_posix(), consolidated=True)


In [None]:
ds

In [None]:
ds.data_vars

In [None]:
ds['uas'].coords.keys()

In [None]:
del ds['uas'].coords['height']

In [None]:
ds['uas'].coords

In [None]:
ds['uas'].data.chunksize

In [None]:
ds.coords.variables

In [None]:
pprint.pprint(ds.attrs)