# Processing ERA5 data in NetCDF Format

This notebook demonstrates how to work with the ECMWF ERA5 reanalysis available as part of the AWS Public Dataset Program (https://registry.opendata.aws/ecmwf-era5/).

This notebook utilizes Amazon SageMaker & AWS Fargate for providing an environment with a Jupyter notebook and Dask cluster. There is an example AWS CloudFormation template available at https://github.com/aws-samples/aws-opendata-samples/blob/main/projects/aws-era5-dask/dask-environment.yaml for quickly creating this environment in your own AWS account to run this notebook.

## Python Imports

In [None]:
%matplotlib inline
import boto3
import botocore
import datetime
import matplotlib.pyplot as plt
import matplotlib
import xarray as xr
import numpy as np
import s3fs
import fsspec
import dask
from dask.distributed import performance_report, Client, progress

font = {'family' : 'sans-serif',
        'weight' : 'normal',
        'size'   : 18}
matplotlib.rc('font', **font)

Install extra software here, if necessary

In [None]:
#import sys
#!{sys.executable} -m pip install graphviz
#import graphviz

## Set up the Dask Client to talk to our Fargate Dask Distributed Cluster

This notebook expects Dask to be running in an ECS cluster.  There is an example AWS CloudFormation template available at https://github.com/aws-samples/aws-opendata-samples/blob/main/projects/aws-era5-dask/dask-environment.yaml for quickly creating this environment in your own AWS account to run this notebook.  The code in this notebook assumes you are running in this environment and will need adjusting if you are using a different Dask setup.

**Update the stackname variable below to identify the name of your CloudFormation stack**

In [None]:
stackname="dask-environment"

In [None]:
# Retrieve stack outputs
cfn = boto3.client('cloudformation')
resp = cfn.describe_stacks(StackName=stackname)
outputs = {}
for output in resp['Stacks'][0]['Outputs']:
    outputs[output['OutputKey']] = output['OutputValue']
cluster = outputs['DaskECSClusterName']
schedulerservice = outputs['DaskSchedulerServiceName']
workerservice = outputs['DaskWorkerServiceName']
outputs

Start the dask scheduler service through ECS

In [None]:
ecs = boto3.client('ecs')
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=1)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])

The following will identify the public IP address of the Dask-Scheduler task (based on security group membership) and output the dashboard URL:

In [None]:
ec2 = boto3.client('ec2')
resp = ec2.describe_network_interfaces(
  Filters=[{
      'Name': 'group-id',
      'Values': [outputs['DaskSchedulerSecurityGroup']]
  }])
schedulerurl = 'http://' + resp['NetworkInterfaces'][0]['Association']['PublicDnsName'] + '/status'
from IPython.display import display,HTML
display(HTML('Dask scheduler URL - click to open in new tab: <a href=\'' + schedulerurl + '\'>' + schedulerurl + '</a>'))

### Scale out Dask Workers and connect
Start the dask worker tasks and connect to the scheduler.  This will take a minute or so.

In [None]:
numWorkers=12
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=numWorkers)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

In [None]:
client = Client('dask-scheduler.local-dask:8786')
client

Enable fsspec debugging if desired.  This will increase the work log output but can be helpful to identify inefficient S3 access.  Worker log output can be found in CloudWatch Logs.

In [None]:
#client.run(fsspec.utils.setup_logging, logger_name="fsspec", level="DEBUG")

## Open an Example File and Check the Native Chunking

Before we start processing lets explore the dataset to discover its structure and chunking layout.  We want to chunk in an aligned way for maximum performance.

First list the NetCDF files for a single month in the ERA5 public S3 bucket, using `s3fs`.

In [None]:
bucketname="era5-workshop-data"

In [None]:
fs = s3fs.S3FileSystem()
fs.ls(f'{bucketname}/2021/01/data')

Now lets open one of the files as a dataset using `xarray`, then explore its size, shape and chunk layout.

In [None]:
url = f's3://{bucketname}/2021/05/data/air_temperature_at_2_metres.nc'
ncfile = fsspec.open(url)
ds = xr.open_dataset(ncfile.open())
ds.air_temperature_at_2_metres.encoding

In [None]:
ds.info()

In [None]:
# Note this causes the file to be read into memory
# print('file size in GB {:0.2f}\n'.format(ds.nbytes / 1e9))

### Explore the underlying HDF5 structure
NetCDF version 4 uses HDF5 as the underlying file structure, so you can also use h5py directly to view information about the NetCDF data (assuming it is in netcdf4 / HDF5 format)

In [None]:
import h5py
h5f = h5py.File(fs.open(url))
list(h5f.keys())

Lets take a look at the chunk layout and size information

In [None]:
ds = h5f['air_temperature_at_2_metres']
print("Number of chunks in file dataset:", ds.id.get_num_chunks())
print("Dataset shape:", ds.shape)
print("Chunk shape:", ds.chunks)
print("Compression:", ds.compression)
print("Dataset storage size: {:0.2f} MB".format(ds.id.get_storage_size() / 1e6))
print("Dataset full size: {:0.2f} MB".format(ds.nbytes / 1e6))

Now lets have a look at the first few chunks to see how big they are on disk, and get a rough idea of the file layout

In [None]:
for i in range(0,20):
   print(ds.id.get_chunk_info(i))
print(':')
print(ds.id.get_chunk_info(ds.id.get_num_chunks()-1))

## Open 2-m air temperature as a single dataset
This is where the real work begins.  We start by defining the set of S3 objects that we are going to process, using a file pattern.  We'll start with a full year of data - 12 files.

In [None]:
start_year = 2021
end_year = 2021
years = list(np.arange(start_year, end_year+1, 1))
months = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]

file_pattern = 's3://{bucket}/{year}/{month}/data/air_temperature_at_2_metres.nc'
files_list = [file_pattern.format(bucket=bucketname,year=year,month=month) for year in years for month in months]
files_list

### Caching

We now open each object as a "file like object" using s3fs, so we use them with xarray and dask.

Depending on how the data is chunked, there are some parameters we can use to control `fsspec` caching that can improve performance.  By default `s3fs` uses `'bytes'` as the cache type, which is a read-ahead cache mechanism that will request and cache an additional 5MB for each chunk read from S3.  This turns out to be inefficient for the ERA5 dataset, because the NetCDF chunks are small (~200k) and not always stored sequentially.  

The `default_block_size` parameter controls how much additional data will be requested from S3.  We've found values of between 256k and 512k to be optimal for the ERA5 dataset, but feel free to experiment.

You can also experiment with other cache types, as implemented here: https://github.com/fsspec/filesystem_spec/blob/master/fsspec/caching.py

In [None]:
# We'll use our cluster to open files for the dataset in parallel across the workers
@dask.delayed
def s3open(path):
    # Note the s3fs block size - this is the amount of data that will be read from s3 with each GetObject request.
    # By default it is 5MB but for ERA5 we have found a 512k block yields more efficient S3 requests and faster performance.
    fs = s3fs.S3FileSystem(default_fill_cache=False, default_block_size=512*1024)
    return fs.open(path)

files_mapper = [s3open(path) for path in files_list]

### Chunk sizing
Now initialise the xarray dataset, specifying a chunk size that is a multiple of the underlying NetCDF chunk of 100/100/24. 
We are aiming for ~100MB chunk size as per dask recommendations.  Uncompressed chunks are 960KB so we set our chunks parameter to give us 128x that.

In [None]:
%%time
ds = xr.open_mfdataset(
    files_mapper, 
    engine='h5netcdf', 
    chunks={'lon':400,'lat':200,'time0':384}, # 128x larger than underlying NetCDF chunk
    concat_dim='time0', 
    combine='nested', 
    coords='minimal', 
    compat='override',
    parallel=True
)

In [None]:
print('ds size in GB {:0.2f}\n'.format(ds.nbytes / 1e9))
ds.info

The `ds.info` output above shows us that there are four dimensions to the data: lat, lon, and time0; and two data variables: air_temperature_at_2_metres, and air_pressure_at_mean_sea_level.

Let's check the chunking...

In [None]:
ds.air_temperature_at_2_metres

## Convert units to C from K
This performs a simple subtraction operation, to convert the temperature unit into Celsius.  The operation will not actually be performed at this stage - not until we try to access the result or make the explicit call to `persist`, below.

In [None]:
ds['air_temperature_at_2_metres'] = (ds.air_temperature_at_2_metres - 273.15)
ds.air_temperature_at_2_metres.attrs['units'] = 'C'
ds.air_temperature_at_2_metres

## Read all data into dask worker memory
The following line reads the entire data set into worker memory.  This step is unnecessary at this stage, but it will make all subsequent calculations much faster (at the expense of memory usage!) and is a useful illustration of how dask works.  Otherwise, calculations are done without reading all data into worker memory at once, and data will need to be read back in for each calculation (taking much longer, but using less memory). 

The subtraction calculation we queued up above will also be executed during this step.

In [None]:
ds = client.persist(ds)
progress(ds)

Sometimes data isn't evenly distributed, depending on the dataset and chunk size that we selected.  Here we rebalance the data across workers so that future tasks will make best use of cluster resources.

In [None]:
client.rebalance()

## Calculate the mean 2-m air temperature for the entire dataset
Now let's do some calculations across the entire data set, starting with calculation of the mean for every grid point.

In [None]:
# calculates the mean along the time dimension
temp_mean = ds['air_temperature_at_2_metres'].mean(dim='time0')

The expressions above didn’t actually compute anything. They just build the dask task graph. To do the computations, we call the `persist` method below.

In [None]:
temp_mean

In [None]:
temp_mean = temp_mean.persist()
progress(temp_mean)

## Plot Average Surface Temperature
To plot data, we need to read it back into the local notebook python environment.  This is done using the "compute" function.  Once the data is back in local memory, we can use matplotlib to display it visually.  For more information refer to: https://distributed.dask.org/en/latest/manage-computation.html

In [None]:
temp_mean.compute()
temp_mean.plot(figsize=(30, 15))
plt.title(f'Mean 2-m Air Temperature {start_year} - {end_year}')

Thats the mean of the hourly sample in the source dataset.  Let's down-sample the data by taking the daily maximum and re-calculating the mean based on that.  This is one line of code...

In [None]:
daily_max_mean = ds['air_temperature_at_2_metres'].resample(indexer={"time0":'D'}).max().mean(dim='time0')
daily_max_mean

We don't necessarily need to call `persist` here, the `compute` call below will trigger this for us - but this lets us see the progress in the notebook.

In [None]:
daily_max_mean = daily_max_mean.persist()
progress(daily_max_mean)

In [None]:
daily_max_mean.compute()
daily_max_mean.plot(figsize=(30, 15))
plt.title(f'Average daily maximum temperature {start_year} - {end_year}')

## Repeat for standard deviation
The data is in memory so let's do another calculation - this time standard deviation!

In [None]:
temp_std = ds['air_temperature_at_2_metres'].std(dim='time0')
temp_std

In [None]:
temp_std = temp_std.persist()
progress(temp_std)

In [None]:
temp_std.compute()
temp_std.plot(figsize=(30, 15),cmap='inferno')
plt.title(f'Standard Deviation 2-m Air Temperature {start_year} - {end_year}')

## Slicing a specific region
This reduces the amount of data we are working with by slicing to a specific region by lat/lon

To do this we use the `xarray.Dataset.sel` function and provide lat/lon slice co-ordinates.  The slice in the code below is for the ANZ region.  Feel free to adjust!

In [None]:
dssubset = ds['air_temperature_at_2_metres'].sel(lat=slice(0,-50),lon=slice(110,180))
dssubset

Calculate the mean of the region - this should be very fast because the data is in worker memory already

In [None]:
subset_mean = dssubset.resample(indexer={'time0':'D'}).max().mean(dim='time0')
subset_mean = client.persist(subset_mean)
progress(subset_mean)

In [None]:
subset_mean.compute()
subset_mean.plot(figsize=(12,8), cmap='inferno')
plt.title(f'ANZ Mean 2-m Air Temperature {start_year} - {end_year}')

## Plot temperature time series for points
This example creates a dataframe table of data for some specific locations defined in the array below - extracting time-series data for only the specific points we are interested in.

Feel free to change the cities and locations!

In [None]:
# location coordinates
locs = [
    {'name': 'Wellington', 'lon': 172.78, 'lat': -41.28},
    {'name': 'Honolulu', 'lon': -157.84, 'lat': 21.29},
    {'name': 'Seattle', 'lon': -122.33, 'lat': 47.61},
    {'name': 'Melbourne', 'lon': 144.95, 'lat': -37.84}
]

# convert westward longitudes to degrees east
for l in locs:
    if l['lon'] < 0:
        l['lon'] = 360 + l['lon']
locs

In [None]:
ds_locs = xr.Dataset()
air_temp_ds = ds

# interate through the locations and create a dataset
# containing the temperature values for each location
for l in locs:
    name = l['name']
    lon = l['lon']
    lat = l['lat']
    var_name = name

    ds2 = air_temp_ds.sel(lon=lon, lat=lat, method='nearest')

    lon_attr = '%s_lon' % name
    lat_attr = '%s_lat' % name

    ds2.attrs[lon_attr] = ds2.lon.values.tolist()
    ds2.attrs[lat_attr] = ds2.lat.values.tolist()
    ds2 = ds2.rename({'air_temperature_at_2_metres' : var_name}).drop(('lat', 'lon'))

    ds_locs = xr.merge([ds_locs, ds2])

ds_locs.data_vars

In [None]:
ds_locs

Now lets extract the data - this should be fast because everything is in worker memory

In [None]:
ds_locs = client.persist(ds_locs)
progress(ds_locs)

### Convert to dataframe
Conversion between an xarray DataArray into a pandas DataFrame (table) as time series data

In [None]:
df_f = ds_locs.to_dataframe()
df_f

In [None]:
df_f.describe()

In [None]:
df_f.info()

### Plot temperature timeseries

We'll first re-sample the data from hourly to daily maximums.  Note the number of entries / size of the dataset is reduced.

In [None]:
rs = df_f.resample('D').max()
rs.info()

In [None]:
matplotlib.rcParams['lines.linewidth'] = 1.0
matplotlib.rcParams['lines.linestyle'] = 'solid'
ax = rs.plot(figsize=(30, 15), title=f"ERA5 Daily Maximums {start_year} - {end_year}", grid=1)
ax.set(xlabel='Date', ylabel='2-m Air Temperature (deg C)')
plt.show()

## Dask Memory Management

Executing code in these cells can help you recover memory in the worker processes if things are getting tight.

Delete references to variables that might be held in Dask Worker memory

In [None]:
client.cancel(ds)
client.cancel(temp_mean)
client.cancel(subset_mean)
client.cancel(temp_std)
client.cancel(df_f)
client.cancel(ds_locs)

This snippet of code reduces the workers memory footprint, which can be useful in debugging memory use. It should get rid of most of the "unmanaged" memory reported in the dask dashboard.

In [None]:
import ctypes

def trim_memory() -> int:
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

client.run(trim_memory)

In extreme cases you might want to restart Dask Workers - this will take a couple of minutes

In [None]:
#client.restart()

## Cluster scale down

When we are temporarily done with the cluster we can scale it down to save on costs

In [None]:
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=0)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

Optional - stop the scheduler

In [None]:
client.close()

In [None]:
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=0)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])