Ref: https://github.com/NCAR/cesm-lens-aws/issues/34

In [None]:
import xarray as xr
import intake
from tqdm.auto import tqdm
import shutil 
import os
from functools import reduce
import pprint
import json
from operator import mul
import random
import yaml
from distributed.utils import format_bytes

import numpy as np
import pandas as pd

import cftime
from datetime import date

## Calendar Conversion functions

In [None]:
# Functions for converting single date objects from one type to another.

def convert_to_noleap(cftime360_obj, datemap):
    ''' Convert Date from 360 Day to NoLeap'''
    newdate = datemap[cftime360_obj.dayofyr - 1]
    converted = cftime.DatetimeNoLeap(year=cftime360_obj.year, month=newdate.month, day=newdate.day)
    return converted

def convert_to_gregorian(cftime_noleap_obj):
    ''' Convert Date from NoLeap to Gregorian '''
    converted = cftime.DatetimeGregorian(year=cftime_noleap_obj.year, month=cftime_noleap_obj.month, day=cftime_noleap_obj.day)
    return converted

def convert_hour(time_obj, hour_of_day):
    ''' Convert date object to Gregorian and explicitly set the hour of day.'''
    time_obj = cftime.DatetimeGregorian(year=time_obj.year, month=time_obj.month, day=time_obj.day, hour=hour_of_day, minute=0, second=0)
    return time_obj

In [None]:
def get_datemap_360_to_noleap():
    ''' Return an array of dates mapping days from the 360-Day calendar to the No-Leap calendar. '''

    # Choose any year with 365 days. 
    dummy_year = 1999

    # These are the days of the year that will be missing on the time axis for each year.
    # The goal is to spread missing dates out evenly over each year.
    #
    # Modify specific dates as desired. 
    missing_dates = [date(dummy_year, 1, 31),
                     date(dummy_year, 3, 31),
                     date(dummy_year, 5, 31),
                     date(dummy_year, 8, 31),
                     date(dummy_year, 10, 31),]
    
    day_one = date(dummy_year, 1, 1)
    missing_dates_indexes = [(day - day_one).days + 1 for day in missing_dates] 
    missing_dates_indexes

    datemap_indexes = np.setdiff1d(np.arange(365), missing_dates_indexes)
    datemap_indexes

    dates = pd.date_range(f'1/1/{dummy_year}', f'12/31/{dummy_year}')
    assert(len(dates) == 365)
    
    date_map = dates[datemap_indexes]
    assert(len(date_map) == 360)
    
    # Check to make sure February 29 is not a date in the resulting map.
    #is_leap_day = [(d.month == 2) and (d.day == 29) for d in date_map]
    #print(is_leap_day)
    #assert(not any(is_leap_day))
    return date_map


# Create a global map for moving days of the year to other days of the year.
datemap_global = get_datemap_360_to_noleap()

In [None]:
# This code "pads out" data variables with missing values for Leap Days.  
# It's possible that xarray will do this automatically as long as one calendar being merged has Leap Days in it.

def convert_dataset_noleap_to_gregorian(ds):
    '''Converts an xarray dataset from the NoLeap calendar to the Gregorian calendar.  
       Data for Leap Days are filled with missing values (np.nan).
    '''
    # Convert dates in the original dataset from the NoLeap to Gregorian calendar
    ds['time'] = [convert_to_gregorian(t) for t in ds.time.values]
    
    # Create an equivalent date range on the Gregorian calendar
    start_date = ds.time.values[0]
    end_date = ds.time.values[-1]
    times = xr.DataArray(xr.cftime_range(start=start_date, end=end_date, freq='D', calendar='gregorian', normalize=True), dims='time')
    
    # Find the leap days in this date range.
    is_leap_day = (times.time.dt.month == 2) & (times.time.dt.day == 29)
    leap_days = times.where(is_leap_day, drop=True)
    
    # Create fill values for these days.
    one_time_step = ds.isel(time=slice(0, 1))
    fill_values = []
    for leap_day in leap_days:
        d = xr.full_like(one_time_step,fill_value=np.nan)
        d = d.assign_coords(time=[leap_day.data])
        fill_values.append(d)
    
    # Append the fill values to the dataset and then sort values by time.
    fill_values.append(ds)

    ds_fixed=xr.concat(fill_values, dim='time').sortby('time')
    return ds_fixed 

## Run These Cells for Dask Processing

In [None]:
import dask
from dask_jobqueue import SLURMCluster
from distributed import Client
dask.config.set({'distributed.dashboard.link': '/proxy/{port}/status'})
dask.config.get('distributed.dashboard')

In [None]:
#min_jobs = 20
#cluster = SLURMCluster(cores=4, memory="50GB", project="STDD0003")
min_jobs = 4
cluster = SLURMCluster(cores=4, memory="50GB")
cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=35)
#cluster.scale(jobs=3)
client = Client(cluster)
cluster

In [None]:
# Set to True if saving large Zarr files is resulting in KilledWorker or Dask crashes.
BIG_SAVE = True
if BIG_SAVE:
    min_workers = min_jobs
    print('Waiting for ' + str(min_jobs) + ' workers.')
    client.wait_for_workers(min_workers)

## Main Notebook Code

#### Prepare individual dataset for merge

In [None]:
def preprocess(ds):
    """This function gets called on each original dataset before concatenation.
       Convert all dataset calendars to Gregorian.  
       For now, also drop other data variables, like time bounds, until we get things looking good.
    """

    # Print dataset title for debug purposes
    #print(ds.attrs['title'])
    #print(f'ds.time.attrs = {ds.time.attrs}')
    #print(f'ds.time.encoding = {ds.time.encoding}')

    attrs = ds.time.attrs
    encoding = ds.time.encoding
    bounds_name = ds.time.attrs['bounds']
    
    ds_fixed = ds
    #"""Drop all unneeded variables and coordinates"""
    #vars_to_drop = [vname for vname in ds.data_vars if vname not in variables]
    #coord_vars = [vname for vname in ds.data_vars if 'time' not in ds[vname].dims or 'bnd' in vname]
    #ds_fixed = ds.set_coords(coord_vars)
    #data_vars_dims = []
    #for data_var in ds_fixed.data_vars:
    #    data_vars_dims.extend(list(ds_fixed[data_var].dims))
    #coords_to_drop = [coord for coord in ds_fixed.coords if coord not in data_vars_dims]
    #grid_vars = list(set(vars_to_drop + coords_to_drop) - set(['time', 'time_bound']))
    #ds_fixed = ds_fixed.drop(grid_vars)
    #if 'history' in ds_fixed.attrs:
    #    del ds_fixed.attrs['history']
    
    # Print some diagnostic information on the dataset.
    #print_ds_info(ds, 'tasmax')
    
    # Test for calendar type xarray found when it loaded the dataset.
    time_type = f'{type(ds.time.values[0])}'
    has_360_day_calendar = "Datetime360Day" in time_type
    has_noleap_calendar = "DatetimeNoLeap" in time_type
    
    # Extract the time_bnds variable for conversion
    bnds = ds_fixed[bounds_name].values

    if has_360_day_calendar:
        print(f'Found 360 day calendar; converting dates to NoLeap, then date types to Gregorian.\n')
        ds_fixed['time'] = [convert_to_noleap(t, datemap_global) for t in ds_fixed.time.values]
        ds_fixed['time'] = [convert_to_gregorian(t) for t in ds_fixed.time.values]

        bnds = [[convert_to_noleap(col, datemap_global) for col in row] for row in bnds]
        bnds = [[convert_to_gregorian(col) for col in row] for row in bnds]
        #ds_fixed = convert_dataset_noleap_to_gregorian(ds_fixed)

    # Convert any NoLeap calendar to the Gregorian calendar.
    elif has_noleap_calendar:
        ds_fixed['time'] = [convert_to_gregorian(t) for t in ds_fixed.time.values]
        bnds = [[convert_to_gregorian(col) for col in row] for row in bnds]
        #ds_fixed = convert_dataset_noleap_to_gregorian(ds_fixed)

    # Change time of day to noon for all time axis points.
    #print(ds_fixed.time.values.shape)
    ds_fixed['time'] = [convert_hour(t, 12) for t in ds_fixed.time.values]
    bnds = [[convert_hour(col, 0) for col in row] for row in bnds]
    ds_fixed[bounds_name] = (('time', 'bnds'), bnds)
    
    # Convert CFTimeIndex to Pandas DateTimeIndex
    if type(ds_fixed.time.indexes['time'] == 'Index'):
        print('found Index object; converting to CFTimeIndex object.\n')
        datetimeindex = xr.CFTimeIndex(ds_fixed.time.indexes['time']).to_datetimeindex()
        ds.assign_coords(time = datetimeindex)
        
    ds.time.attrs = attrs
    ds.time.encoding = encoding
    ds = ds.set_coords([bounds_name])

    return ds_fixed

#### Merged dataset processing functions

In [None]:
def fix_time(
    ds,
    start,
    end,
    freq,
    time_bounds_dim,
    calendar='standard',
    generate_bounds=True,
    instantaneous=False,
):
    '''Regenerate time axis to be consistent with time bounds variable'''
    
    ds = ds.sortby('time').copy()
    attrs = ds.time.attrs
    encoding = ds.time.encoding
    bounds_name = ds.time.attrs['bounds']
    ds[bounds_name].load()
    if generate_bounds:
        times = xr.cftime_range(
            start=start, end=end, freq=freq, calendar=calendar
        )
        bounds = np.vstack([times[:-1], times[1:]]).T
        ds[bounds_name].data = bounds

    if instantaneous:
        ds = ds.assign_coords(time=ds[bounds_name].min(time_bounds_dim))
    else:
        ds = ds.assign_coords(time=ds[bounds_name].mean(time_bounds_dim))
    ds.time.attrs = attrs
    ds.time.encoding = encoding
    ds = ds.set_coords([bounds_name])
    return ds

In [None]:
def enforce_chunking(datasets, chunks, data_var):
    """Enforce uniform chunking in the Zarr Store.
    """
    dsets = datasets.copy()
    choice = random.choice(range(0, len(dsets)))
    for i, (key, ds) in enumerate(dsets.items()):
        print(f'key == {key}')
        c = chunks.copy()
        for dim in list(c):
            if dim not in ds.dims:
                del c[dim]
        ds = ds.chunk(c)
        keys_to_delete = ['intake_esm_dataset_key', 'intake_esm_varname']
        for k in keys_to_delete:
            del ds.attrs[k]
        dsets[key] = ds
        #variable = key.split(field_separator)[-1]
        #print_ds_info(ds, variable)
        print_ds_info(ds, data_var)
        if i == choice:
            print(ds)
        print('\n')
    return dsets

In [None]:
def print_ds_info(ds, var):
    """Function for printing chunking information"""

    print(f'print_ds_info: var == {var}')
    dt = ds[var].dtype
    itemsize = dt.itemsize
    chunk_size = ds[var].data.chunksize
    size = format_bytes(ds.nbytes)
    _bytes = reduce(mul, chunk_size) * itemsize
    chunk_size_bytes = format_bytes(_bytes)

    print(f'Variable name: {var}')
    print(f'Dataset dimensions: {ds[var].dims}')
    print(f'Chunk shape: {chunk_size}')
    print(f'Dataset shape: {ds[var].shape}')
    print(f'Chunk size: {chunk_size_bytes}')
    print(f'Dataset size: {size}')

# For now, make the Zarr output directory a global variable.
dirout = './zarr-stores'

def zarr_store(var, exp, frequency, grid, biascorrection, write=False, dirout=dirout):
    """ Create zarr store name/path
    """
    path = f'{dirout}/{var}.{exp}.{frequency}.{grid}.{biascorrection}.zarr'
    if write and os.path.exists(path):
        shutil.rmtree(path)
    print(path)
    return path


def save_data(ds, store):
    try:
        ds.to_zarr(store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

#### Metadata preparation functions

In [None]:
def get_dataset_metadata(ds, member_id):
    '''Convert dataset metadata to dictionary form.
    '''
    m_dict = {}
    for key, value in ds.attrs.items():
        m_dict[key] = {member_id: value}
    return m_dict

def get_metadata_from_catalog_entries(catalog_entries):
    '''Take a catalog subset and combine all global dataset metadata into one dictionary.
    '''
    metadata = {}

    # Loop over catalog rows
    dataframe = catalog_entries.df
    for path, member_id in zip(dataframe['path'], dataframe['member_id']):

        ds = xr.open_dataset(path)
        ds_metadata = get_dataset_metadata(ds, member_id)

        # Loop over metadata entries in dataset
        for key, value in ds_metadata.items():
            if key in metadata:
                metadata[key].update(value)
            else:
                metadata[key] = value
    return metadata


def save_metadata_to_csv(metadata_dict, variable_name):
    '''Save metadata in dictionary form to a csv file. '''
    dataframe = pd.DataFrame.from_dict(metadata_dict)
    dataframe.to_csv(f'{variable_name}.csv')

In [None]:
m = get_metadata(ds, 'test')
m

In [None]:
global_metadata_dict = get_metadata_from_catalog_entries(col)
save_metadata_to_csv(global_metadata_dict, "tasmax")

## Create Zarr Stores Using the Catalog and Main Notebook Code

In [None]:
# It's safer to use a underscore separator, because NA-CORDEX grids have dashes.
field_separator = '_'
col = intake.open_esm_datastore("./toy-na-cordex.json", sep=field_separator)
col

In [None]:
# Example of isolating one entry from the catalog.
#ds = col['/Users/bonnland/GitRepos/cesm-lens-zarrification/notebooks/na-cordex/data-subsets/subset_tasmax.rcp85.CanESM2.CRCM5-UQAM.day.NAM-22i.raw.nc_tasmax_rcp85_CanESM2_CRCM5-UQAM_day_NAM-22i_raw_common_CanESM2.CRCM5-UQAM'].to_dask()
#dict(ds.dims)

In [None]:
# Hard-code the variable name in a global variable for now.
variables = ['tasmax']

In [None]:
# Consolidate datasets according to the catalog JSON metadata.
chunks = {'time': 200, 'lat': -1, 'lon': -1}
dsets = col.to_dataset_dict(cdf_kwargs={'chunks': chunks, 'use_cftime': True}, preprocess=preprocess, progressbar=False)
dset = dsets['rcp85_day_NAM-22i_raw']
dset

In [None]:
# The following line will place all ensemble members in the same chunk.   
# Comment out to have each ensemble member in its own chunk.
chunks['member_id'] = 1

# Take care of ragged edges in original datasets, to optimize chunking strategy.
dsets = enforce_chunking(dsets, chunks, variables[0])
dsets

In [None]:
# Create/Overwrite the Zarr Stores.


In [None]:
for key, ds in tqdm(dsets.items(), desc='Saving zarr store'):
    print('key: ' + key)
    key = key.split(field_separator)
    exp, frequency, grid, biascorrection = key[0], key[1], key[2], key[3]
    
    # Regenerate the time bounds variable to be consistent across all ensemble members.
    #
    # start:  Move the starting bound backward from noon to midnight of the first day.
    # end:    Create an extra day for the ending time bound of the last day, and set hour to midnight.
    start = convert_hour((dset.time.values[0]), 0)
    end = convert_hour(pd.to_datetime(dset.time.values[-1].strftime()) + pd.DateOffset(1), 0)
    time_bounds_dim='time'
    ds_fixed = fix_time(dset, start=start, end=end, freq='D', time_bounds_dim=time_bounds_dim).chunk(chunks)
    
    var = variables[0]
    store = zarr_store(var, exp, frequency, grid, biascorrection, write=True, dirout=dirout)
    print(store)
    save_data(ds_fixed, store)

In [None]:
# Make sure the zarr stores were properly written

from pathlib import Path
p = Path(dirout)
stores = list(p.rglob("*.zarr"))
for store in stores:
    try:
        ds = xr.open_zarr(store.as_posix(), consolidated=True)
        print('\n')
        print(store)
        print(ds)
    except Exception as e:
        #print(e)
        print(store)

### If Using Dask on HPC, release the workers.

In [None]:
!date

In [None]:
cluster.close()

In [None]:
# Use this to print out details about the conda environment.
# %load_ext watermark
# %watermark -d -iv -m -g -h

## Alternative to Using the Catalog for Preprocessing:  Load Datasets Directly

In [None]:
subset_folder = './data-subsets'
fileList = os.listdir(subset_folder)
fileList

In [None]:
datasets = []
for f in fileList:
    # Create xarray dataset from file.
    filePath = f'{subset_folder}/{f}'
    ds = xr.open_dataset(filePath, use_cftime=True)
    print(filePath)
    print(ds)
    break
    #preprocess(ds)
        
    datasets.append(ds)

In [None]:
datasets

## Test preprocessing for 360-day calendars

In [None]:
# Test conditions for 360 calendars
filePath = './data-subsets/subset_tasmax.rcp85.HadGEM2-ES.RegCM4.day.NAM-22i.raw.nc'
ds = xr.open_dataset(filePath, use_cftime=True)
ds

In [None]:
ds

In [None]:
ds_processed = preprocess(ds)
ds_processed

## Batch Processing Code Using the Configuration File "config.yaml"
###  This is Not Yet Tested and Working.

In [None]:
def process_variables(col, variable, scenario, frequency, grid, biascorrection, verbose=True):
    query = dict(variable=variable, scenario=scenario, frequency=frequency, grid=grid, biascorrection=biascorrection)
    subset = col.search(**query)
    if verbose:
        print(subset.unique(columns=['variable', 'scenario', 'frequency,', 'grid', 'biascorrection']))
    return subset, query

In [None]:
with open("config.yaml") as f:
    config = yaml.safe_load(f)
        
variables = config['variables']
frequencies = config['frequencies']
scenarios = config['scenarios']
biascorrections = config['biascorrections']
grid_categories = config['grid_categories']

In [None]:
run_config = []

for key, value in grid_categories.items():
    grid = value['grid']
    chunks = value['chunks']
    for scenario in scenarios:
        for frequency in frequencies:
            for biascorrection in biascorrections:
                for variable in variables:
                    col_subset, query = process_variables(col, variable, scenario, frequency, grid, biascorrection)
                    d = {'query': json.dumps(query), 'col': col_subset, 'chunks': chunks, 'frequency': frequency}
                    run_config.append(d)
                    
run_config

In [None]:
variables = []


variable_categories = list(config['variable_category'].keys())
grid = config['grid']
biascorrection = config['biascorrection']
frequency = config['frequency']

In [None]:
for v_cat in variable_categories:
    scenarios = list(config['variable_category'][v_cat]['scenario'].keys())
    for scenario in scenarios:
        print(scenario)
        chunks = config['variable_category'][v_cat]['scenario'][scenario]['chunks']
        variable = config['variable_category'][v_cat]['variable']
        variables.extend(variable)
        col_subset, query = process_variables(col, variable, scenario, grid, biascorrection)
        d = {'query': json.dumps(query), 'col': col_subset, 'chunks': chunks, 'frequency': frequency}
        run_config.append(d)
                
#print(variables)
#print(run_config)

In [None]:
run_config

In [None]:
for run in run_config:
    print("*"*120)
    print(f"query = {run['query']}")
    frequency = run['frequency']
    chunks = run['chunks']
    # Try preprocessing, including calendar conversion.
    dsets = run['col'].to_dataset_dict(cdf_kwargs={'chunks': chunks, 'decode_times': False}, preprocess=preprocess, progressbar=False)
    # Turn off preprocessing for now.
    #dsets = run['col'].to_dataset_dict(cdf_kwargs={'chunks': chunks, 'decode_times': False}, preprocess=None, progressbar=False)

    # The following line will place all ensemble members in the same chunk.
    #chunks['member_id'] = 1
    #dsets = enforce_chunking(dsets, chunks)
    
    for key, ds in tqdm(dsets.items(), desc='Saving zarr store'):
        print('key: ' + key)
        key = key.split(field_separator)
        exp, cmp, var, frequency = key[1], key[0], key[-1], frequency
        store = zarr_store(exp, cmp, frequency, var, write=True, dirout=dirout)
        save_data(ds, store)

### SANDBOX: Code Testing Area

In [None]:
ds = xr.open_dataset('/Users/bonnland/GitRepos/cesm-lens-zarrification/notebooks/na-cordex/data-subsets/subset_tasmax.rcp85.MPI-ESM-LR.WRF.day.NAM-22i.raw.nc')
ds

In [None]:
print(ds.data_vars)
print(ds.attrs['title'])

In [None]:
# Convert dates in the original dataset from the NoLeap to Gregorian calendar
ds['time'] = [convert_to_gregorian(t) for t in ds.time.values]

In [None]:
# Create a date range on the Gregorian calendar
start_date = ds.time.values[0]
end_date = ds.time.values[-1]

times = xr.DataArray(xr.cftime_range(start=start_date, end=end_date, freq='D', calendar='gregorian'), dims='time')
times

In [None]:
# Find the leap days in this date range.
is_leap_day = (times.time.dt.month == 2) & (times.time.dt.day == 29)
leap_days = times.where(is_leap_day, drop=True)
leap_days

In [None]:
# Create fill values for these days.
one_time_step = ds['tasmax'].isel(time=slice(0, 1))
fill_values = []
for leap_day in leap_days:
    d = xr.full_like(one_time_step,fill_value=np.nan)
    d = d.assign_coords(time=[leap_day.data])
    fill_values.append(d)

In [None]:
# Append the fill values to the dataset and then sort values by time.
fill_values.append(ds['tasmax'])

ds_fixed=xr.concat(fill_values, dim='time').sortby('time')
ds_fixed

In [None]:
col

In [None]:
#[dsets[key].get_index('time') for key in dsets][2][0]
[dsets[key].get_index('time') for key in dsets]

In [None]:
list(dsets.values())[:2]

In [None]:
list(dsets.values())[0].time.values[0]

In [None]:
xr.concat(list(dsets.values())[:3], dim='member_id', combine_attrs='drop', data_vars=['tasmax'])

In [None]:
ds.time.values[0].replace(hour=23)

In [None]:
preprocess(ds)

In [None]:
type(ds.time.indexes["time"].to_datetimeindex())


In [None]:
# Use the following query to gather all data for one variable.
#subset = col.search(variable='tasmax', scenario=['hist','rcp85'], grid='NAM-22i', frequency='day')
subset = col.search(variable='tasmax', scenario=['rcp85'], grid='NAM-22i', frequency='day')

# Use this to load some 360-day data for conversion to the Gregorian calendar.
#subset = col.search(variable='tasmax', scenario=['hist'], grid='NAM-22i', frequency='day', driver='HadGEM2-ES')


subset.unique(columns=['rcm', 'driver', 'biascorrection', 'common'])

In [None]:
subset.keys()

In [None]:
subset.df

In [None]:
for key in subset.keys():
    print(type(subset[key]))

In [None]:
# Look for strange metadata
for key in tqdm(subset.keys()):
    try:
        subset[key](cdf_kwargs={'chunks': {}, 'decode_times': False}).to_dask()
    except Exception as e:
        print(f'\tFile:{subset[key].df.path.tolist()} --- Exception: {e}', end="")