# Combine History Files along the Time Dimension

We have to perform the sparse reindexing with Zarr files at a later step, since the source NetCDF files are very large and difficult to fit into memory.

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import fsspec

import dask.distributed
from dask.distributed import Client
from ncar_jobqueue import NCARCluster

### Configuration/Tuning Options

In [None]:
# Input and final target folders
INPUT_FOLDER = '/glade/scratch/bonnland/DART/ds345.0/lnd'
TARGET_FOLDER = '/glade/scratch/bonnland/DART/ds345.0/lnd_zarr/'

# Target folder for performance tuning
#TARGET_FOLDER = '/glade/scratch/bonnland/DART/ds345.0/ZARR-SCRATCH/'

#### For Land h1 variables.
# 
#VARS = ['TSA','GPP', 'GRAINC_TO_FOOD', 'GSSHALN', 'GSSUNLN', 
#            'NPP', 'NPP_NUPTAKE', 'PLANT_NDEMAND', 'QVEGT', 'TLAI']
#TARGET_CHUNKS = {'time': 100, 'pft': 100000}

#### For Land h0 variables 
#   Chunks are small, so increase time chunk to prevent fragmentation.
VARS = ['TSA', 'ER', 'EFLX_LH_TOT', 'HR']
TARGET_CHUNKS = {'time': 1000, 'lat': 32, 'lon': 32} 


In [None]:
# Try to keep metadata during Xarray operations.
xr.set_options(keep_attrs=True)

## Run These Cells for Dask CASPER

In [None]:
# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs neacr the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

# For Casper
num_cores = 1 #2 #1
num_jobs = 50 #25 #4
walltime = "3:00:00"
memory = '10GB'  #'100GB'

cluster = NCARCluster(cores=num_cores, processes=1, memory=memory, project='STDD0003', walltime=walltime)
cluster.scale(jobs=num_jobs)

client = Client(cluster)
cluster

## Run These Cells for Dask CHEYENNE

In [None]:
import dask
from ncar_jobqueue import NCARCluster

# Processes is processes PER CORE.
# This one works fine.
#cluster = NCARCluster(cores=15, processes=1, memory='100GB', project='STDD0003')
# This one also works, but occasionally hangs near the end.
#cluster = NCARCluster(cores=10, processes=1, memory='50GB', project='STDD0003')

# For Cheyenne

# Run small set of workers on each node to avoid RAM shortages and Dask crashes.  (tried 8,4)
num_nodes = 20 #40
num_workers_per_node = 3 #2 #1
num_cores_per_node = 4
walltime = "6:00:00" #"8:00:00"

cluster = NCARCluster(cores=num_cores_per_node, 
                      processes=num_workers_per_node, 
                      memory='109GB', 
                      walltime=walltime)

cluster.scale(jobs=num_nodes)

from distributed import Client
from distributed.utils import format_bytes
client = Client(cluster)
cluster

In [None]:
cluster.close()

In [None]:
def compute_index(ds):
    """Compute the transform from 1D to sparse 6D
    """
    lats = list(ds.pfts1d_lat.astype('float32').data)
    lons = list(ds.pfts1d_lon.astype('float32').data)
    vegtype = list(ds.pfts1d_itype_veg.data)
    coltype = list(ds.pfts1d_itype_col.data)
    lunittype = list(ds.pfts1d_itype_lunit.data)
    active = list(ds.pfts1d_active.data)
    
    # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions.
    index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active], 
                                  names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active'))

    return index

In [None]:
def sparsify(chunk):
    chunk = chunk.unstack(sparse=True)
    return chunk

In [None]:
def preprocess_sparse(ds):
    """This function gets called on each original dataset before concatenation.
       Convert the time value from index to datetime64.  
    """    
    index = compute_index(ds)
    ds['pft'] = index

    # Drop unneeded variables as soon as possible.
    drop_vars = [var for var in ds.data_vars 
                 if var not in PFT_VARS]
    ds = ds.drop_vars(drop_vars)

    ds = ds.load()
    
    ds = ds.chunk(chunks=TARGET_CHUNKS)
    
    for var in PFT_VARS:
        # Try limiting to one timestep.
        ds[var] = ds[var].isel(time=0)
        ds[var] = xr.map_blocks(sparsify, ds[var])
        
    return ds

In [None]:
def preprocess(ds):
    """This function gets called on each original dataset before concatenation.
       Convert the time value from index to datetime64.  
    """    
    # Drop unneeded variables as soon as possible.
    drop_vars = [var for var in ds.data_vars 
                 if var not in VARS]
    ds = ds.drop_vars(drop_vars)
        
    return ds

## Create a Zarr Store for each of 80 ensemble members.

In [None]:
def get_file_list(member_id, is_instantaneous):
    """Returns a list of NetCDF files for an ensemble member.
    """
    padded_id = str(member_id).zfill(4)
    if is_instantaneous:
        data_filter = f'{INPUT_FOLDER}/{padded_id}/*.clm2_{padded_id}.h0.*.nc'
    else:
        data_filter = f'{INPUT_FOLDER}/{padded_id}/*.clm2_{padded_id}.h1.*.nc'
        #data_filter = f'{INPUT_FOLDER}/{padded_id}/*.clm2_{padded_id}.h1.2015.nc'

    file_list = fs.glob(data_filter)
    
    # For now, remove 2011 files. 
    file_list = [file for file in file_list if '2011' not in file]
    return file_list

In [None]:
def get_dataset(member_id):
    """Given an integer id for some ensemble member, return a Xarray dataset
       created from its history files.
    """
    
    #is_instantaneous = False
    is_instantaneous = True
    file_list = get_file_list(member_id, is_instantaneous)
    #print(file_list)

    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        ds = xr.open_mfdataset(file_list, concat_dim='time', parallel=True,
                               preprocess=preprocess, decode_cf=False, combine="nested",
                               data_vars='minimal', coords='minimal', compat='override',
                              )

    #  engine='h5netcdf'
    
    # Rechunk after combining time steps, so we can chunk time.
    # Note that "chunks" specifies the number of elements *in* each chunk,
    # not the number of chunks.
    ds = ds.chunk(chunks=TARGET_CHUNKS)
    
    return ds

In [None]:
def save_data(ds, member_id):
    save_folder = TARGET_FOLDER
    store = f'{save_folder}/member_{member_id}.zarr'
    try:
        ds.to_zarr(store, consolidated=True)
        del ds
    except Exception as e:
        print(f"Failed to write {store}: {e}")

### Loop over ensemble members and create a Zarr store for each.

In [None]:
%%time

fs = fsspec.filesystem(None)

#for i in range(1):
#for i in np.arange(1, 80):
for i in range(80):
    member_id = i+1
    print(f'  Creating store for member {member_id} ...')
    ds = get_dataset(member_id)
    save_data(ds, member_id)
    

In [None]:
!date

In [None]:
cluster.close()

### Verify details from one of the created stores.

In [None]:
store = '/glade/scratch/bonnland/DART/ds345.0/lnd_zarr/member_1.zarr'
ds = xr.open_zarr(store, consolidated=True)
ds

In [None]:
print(ds)

In [None]:
arr = xr.DataArray(
    data=np.ones((2, 3)),
    dims=["time", "pft"],coords={"time": range(2), "pft": range(3), "a": ("x", [3, 4])},)

In [None]:
arr

## Unused/Non-working Code

In [None]:
# This naive approach to reshaping results in out-of-memory errors.

def reshape_pft_var(ds, var, pft):
    """Given a dataset and specific variable name, return a version of the variable's 
       data with added dimensions for pft, lat, and lon.
    """
    ixy            = ds.pfts1d_ixy
    jxy            = ds.pfts1d_jxy
    vegtype        = ds.pfts1d_itype_veg    

    gridded = np.empty([len(ds.time), len(pft_names), len(ds.lat), len(ds.lon)], dtype=np.float32)
    gridded[:, vegtype.values.astype(int)-1, jxy.values.astype(int) - 1, ixy.values.astype(int) - 1] = ds[var].values
    
    return gridded

In [None]:
def preprocess_nonsparse(ds):
    """This function gets called on each original dataset before concatenation.
       Convert the time value from index to datetime64.  
    """
    # Drop unneeded variables as soon as possible.
    drop_vars = [var for var in ds.data_vars 
                 if var not in PFT_VARS]

    ds_fixed = ds.drop_vars(drop_vars)
    
    # Reshape the dataset to use a grid using Pandas
    keep_vars = [var for var in ds.data_vars 
                 if var in PFT_VARS]

    for var in keep_vars:
        for pft in pft_names: 
            new_var = f'{var}__{pft}'
            print(new_var)
            gridded = reshape_pft_var(ds, var, pft)
            ds_fixed[new_var] = xr.DataArray(gridded, dims=['time', 'lat', 'lon'],
                                  coords=[ds.time.values, ds.lat.values, ds.lon.values])
        
    return ds_fixed

In [None]:
# This approach keeps the pft-related data sparse, but latitudes without land are not represented in the 
# new dimension lat_pft.

# It may be possible to expand the latitude dimension later to include these missing values:
# [-62.67016, -61.72775, -60.78534, -59.842934, -58.900524, -57.958115, -57.015705, 
#  84.34555, 85.28796, 86.23037, 87.172775, 88.11518, 89.057594, 90.0]

def expand_pft_dataset_MYSTERY_FAIL(ds):
    lats = ds.pfts1d_lat.astype('float32').data
    lons = ds.pfts1d_lon.astype('float32').data
    vegtype = ds.pfts1d_itype_veg.data
    coltype = ds.pfts1d_itype_col.data
    lunittype = ds.pfts1d_itype_lunit.data
    active = ds.pfts1d_active.data
    
    #print(f'{len(lats)}  {len(lons)}  {len(vegtype)}  {len(coltype)}  {len(lunittype)}  {len(active)}')

    # Redefine the 'pft' dimension as a multi-index, which will increase the number of dimensions.
    #arrays = [list(lats.data), list(lons.data), list(vegtype.data), list(coltype.data), list(lunittype.data, active.data]
    #dim_names = ('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active')
    index = pd.MultiIndex.from_arrays([lats, lons, vegtype, coltype, lunittype, active], 
                                  names=('pftlat', 'pftlon', 'vegtype', 'coltype', 'lunittype', 'active'))
    ds['pft'] = index

    # Keep the data sparse if possible to avoid memory shortages.
    ds_new = ds.unstack(sparse=True)

    return ds_new