# Using Kerchunk to improve NetCDF processing efficiency
This notebook contains some example steps to build a `kerchunk` index file for a set of NetCDF files in the ECMWF ERA5 reanalysis data available as part of the AWS Public Dataset Program (https://registry.opendata.aws/ecmwf-era5/).

## Python imports

In [None]:
%matplotlib inline
import boto3
import botocore
import fsspec
import matplotlib.pyplot as plt
import matplotlib
import xarray as xr
import numpy as np
#import hvplot.xarray
import ujson
import os
import dask
from dask.distributed import performance_report, Client, progress
from pathlib import Path

font = {'family' : 'sans-serif',
        'weight' : 'normal',
        'size'   : 18}
matplotlib.rc('font', **font)

In [None]:
import kerchunk
from kerchunk.hdf import SingleHdf5ToZarr
from kerchunk.combine import MultiZarrToZarr

## ECS Cluster Initialisation
This notebook expects Dask to be running in an ECS cluster.  There is an example AWS CloudFormation template available at https://github.com/awslabs/amazon-asdi/tree/main/examples/dask for quickly creating this environment in your own AWS account to run this notebook.

**Update the variables below to identify the name of the ECS cluster in your environment.**

In [None]:
stackname="dask-environment"

Identify the Dask scheduler and worker ECS services

In [None]:
# Retrieve stack outputs
cfn = boto3.client('cloudformation')
resp = cfn.describe_stacks(StackName=stackname)
outputs = {}
for output in resp['Stacks'][0]['Outputs']:
    outputs[output['OutputKey']] = output['OutputValue']
cluster = outputs['DaskECSClusterName']
schedulerservice = outputs['DaskSchedulerServiceName']
workerservice = outputs['DaskWorkerServiceName']
outputs

Start the Dask scheduler service

In [None]:
ecs = boto3.client('ecs')
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=1)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])

The following will identify the public IP address of the Dask-Scheduler task (based on security group membership) and output the dashboard URL:

In [None]:
ec2 = boto3.client('ec2')
resp = ec2.describe_network_interfaces(
  Filters=[{
      'Name': 'group-id',
      'Values': [outputs['DaskSchedulerSecurityGroup']]
  }])
schedulerurl = 'http://' + resp['NetworkInterfaces'][0]['Association']['PublicDnsName'] + '/status'
from IPython.display import display,HTML
display(HTML('Dask scheduler URL: <a href=\'' + schedulerurl + '\'>' + schedulerurl + '</a>'))

### Scale out Dask workers and connect

In [None]:
numWorkers=12
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=numWorkers)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

In [None]:
client = Client('Dask-Scheduler.local-dask:8786')
client

Enable `fsspec` debugging if desired (this will increase the log output)

In [None]:
#client.run(fsspec.utils.setup_logging, logger_name="fsspec", level="DEBUG")

## Build the Kerchunk Index

We are now going to open a dataset locally and extract metadata into a JSON file using Kerchunk.  This step only needs to be done once!  After the index is created you can re-use it whenever processing the same dataset.  In this example we're going to build an index for a full year of the `air_temperature_at_2_metres` variable.

First, create a list of files to target in our S3 bucket.

In [None]:
start_year = 2020
end_year = 2020
years = list(np.arange(start_year, end_year+1, 1))
months = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]
file_pattern = 'era5-pds/{year}/{month}/data/air_temperature_at_2_metres.nc'
flist = [file_pattern.format(year=year, month=month) for year in years for month in months]
flist

Create a local temporary folder to hold the index data for each NetCDF file

In [None]:
json_dir = 'jsons/'
localfs = fsspec.filesystem('file')
!rm -r jsons
!mkdir jsons

This code loops through each file and extract the metadata using the Kerchunk module, then write it to a local JSON file.  To start we're just defining the function, it is executed in the next step.

In [None]:
fs = fsspec.filesystem('s3', anon=True)
so = dict(mode='rb', default_fill_cache=False, default_cache_type='first')
def gen_json(u):
    with fs.open(u, **so) as infile:
        h5chunks = SingleHdf5ToZarr(infile, u, inline_threshold=300, error="pdb")
        tchunks = h5chunks.translate()
        # Also write to a file
        parts = u.split('/')
        year = parts[1]
        month = parts[2]
        fstem = Path(u).stem 
        outf = f'{json_dir}{year}{month}{fstem}.json'
        print(outf)
        with localfs.open(outf, 'wb') as f:
            f.write(ujson.dumps(tchunks).encode());
        return tchunks

The next step will build the index files - it will take some time so please be patient!  The code prints out each file as it is written.

In [None]:
%%time
singles = []
for f in flist:
    singles.append(gen_json(f))

In [None]:
indexfilelist = sorted(localfs.glob(f'{json_dir}*.json'))
indexfilelist

Now that we have the index data, we can combine it into a single JSON file for the whole dataset using `MultiZarrToZarr`.  The below preprocessing step is necessary to add a default fill value otherwise Zarr will give us NaN co-ordinates.

In [None]:
import zarr
def modify_fill_value(out):
    out_ = zarr.open(out)
    out_.lon.fill_value = -999
    out_.lat.fill_value = -999
    return out

def postprocess(out):
    out = modify_fill_value(out)
    return out

In [None]:
mzz = MultiZarrToZarr(
    indexfilelist,
    remote_protocol='s3',
    remote_options={'anon':True},
    concat_dims=['time0'],
    postprocess = postprocess
)

In [None]:
%%time
out = mzz.translate()

We've got the combined index data in memory, now write it out to a JSON file

In [None]:
singleindexfile = f'era5-{start_year}-{end_year}.json'
with localfs.open(singleindexfile, 'wb') as f:
        f.write(ujson.dumps(out).encode());

## Processing ERA5 Data Using a Kerchunk Index

In the previous step we created a Kerchunk index on the ERA5 dataset for the year 2020 with a single variable.  Now we can use that index to open and process the dataset using Dask.

In [None]:
# These should match your index file name created above
start_year = 2020
end_year = 2020
index_file = f'era5-{start_year}-{end_year}.json'
print(f'Loading index from {index_file}')

with open(index_file) as f:
    idx = ujson.load(f)

Create the S3 connection based on the Kerchunk index.  This is done using a `reference` type file system from the `fsspec` module, which is a specially created implementation for Kerchunk indexing: https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.implementations.reference.ReferenceFileSystem

In [None]:
s_opts = {'skip_instance_cache':True}
r_opts = {'anon':True}
fs = fsspec.filesystem("reference", fo=idx, ref_storage_args=s_opts,
                       remote_protocol='s3', remote_options=r_opts)
zarrmap = fs.get_mapper("")
list(zarrmap.keys())[0:10]

Now open the dataset in xarray with the Zarr engine

In [None]:
%%time
ds = xr.open_dataset(zarrmap, engine="zarr", backend_kwargs={'consolidated':False}, 
                     chunks={'time0':384})
ds

Let's check that the air temperature dataset looks like what we expect!

In [None]:
ds.air_temperature_at_2_metres

## Optionally specify a region
This reduces the amount of data we are working with by slicing to a specific region by lat/lon

In [None]:
dssubset = ds['air_temperature_at_2_metres'].sel(lat=slice(-10,-50),lon=slice(110,180)) - 273.15
dssubset.attrs['units'] = 'C'
dssubset

In [None]:
subset_mean = dssubset.mean(dim='time0')
subset_mean = client.persist(subset_mean)
progress(subset_mean)

In [None]:
subset_mean.compute()
subset_mean.plot(figsize=(12,6), cmap='magma')
plt.title(f'Mean 2-m Air Temperature {start_year} - {end_year}')

## Calculations on the global dataset
The calculations below take us back to the global dataset which is held in the `ds` reference.

### Convert units from K to C
This performs a simple subtraction operation, to convert the temperature unit into Celcius.

In [None]:
ds['air_temperature_at_2_metres'] = (ds.air_temperature_at_2_metres - 273.15)
ds.air_temperature_at_2_metres.attrs['units'] = 'C'
ds.air_temperature_at_2_metres

Perform this calculation immediately using the dataset that is already loaded in worker memory

In [None]:
ds = client.persist(ds)
progress(ds)

### Calculate the mean 2-m air temperature for all times

In [None]:
# calculates the mean along the time dimension
temp_mean = ds['air_temperature_at_2_metres'].mean(dim='time0')

In [None]:
temp_mean

The expressions above didn’t actually compute anything. They just build the dask task graph. To do the computations, we call the `persist` method below.

In [None]:
temp_mean = temp_mean.persist()
progress(temp_mean)

In [None]:
temp_mean.compute()
xpl = temp_mean.sortby('lon')
xpl.plot(figsize=(30, 15))
plt.title(f'{start_year} - {end_year} Mean 2-m Air Temperature')

## Dask Memory management

Executing code in these cells can help you recover memory in the worker processes if things are getting tight.

First, clear up all known datasets.

In [None]:
client.cancel(ds)
client.cancel(temp_mean)
client.cancel(dssubset)
client.cancel(subset_mean)

This snippet of code reduces the workers memory footprint, which can be useful in debugging memory use.  It should get rid of most of the "unmanaged" memory reported in the dask dashboard.

In [None]:
import ctypes

def trim_memory() -> int:
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

client.run(trim_memory)

If memory still isn't coming down, this is a last resort. It will terminate all workers and restart them fresh.

In [None]:
client.restart()

## Cluster Scale Down

When we are temporarily done with the cluster we can scale it down to save on costs

In [None]:
# Shut down workers
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=0)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

In [None]:
client.close()

In [None]:
# Shut down scheduler
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=0)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])