# Processing ERA5 data in Zarr Format

This notebook demonstrates how to work with the ECMWF ERA5 reanalysis available as part of the AWS Public Dataset Program (https://registry.opendata.aws/ecmwf-era5/).

This notebook utilizes Amazon SageMaker & AWS Fargate for providing an environment with a Jupyter notebook and Dask cluster. There is an example AWS CloudFormation template available at https://github.com/awslabs/amazon-asdi/tree/main/examples/dask for quickly creating this environment in your own AWS account to run this notebook.

## Python Imports

In [None]:
%matplotlib inline
import boto3
import botocore
import datetime
import matplotlib.pyplot as plt
import matplotlib
import xarray as xr
import numpy as np
import s3fs
import fsspec
import dask
from dask.distributed import performance_report, Client, progress

font = {'family' : 'sans-serif',
        'weight' : 'normal',
        'size'   : 18}
matplotlib.rc('font', **font)

Install extra software here, if necessary

In [None]:
#import sys
#!{sys.executable} -m pip install graphviz
#import graphviz

## Set up the Dask Client to talk to our Fargate Dask Distributed Cluster

This notebook expects Dask to be running in an ECS cluster.  There is an example AWS CloudFormation template available at https://github.com/awslabs/amazon-asdi/tree/main/examples/dask for quickly creating this environment in your own AWS account to run this notebook.  The code in this notebook assumes you are running in this environment and will need adjusting if you are using a different Dask setup.

**Update the stackname variable below to identify the name of your CloudFormation stack**

In [None]:
stackname="dask-environment"

Retrieve details of the ECS cluster from the CloudFormation stack outputs

In [None]:
# Retrieve stack outputs
cfn = boto3.client('cloudformation')
resp = cfn.describe_stacks(StackName=stackname)
outputs = {}
for output in resp['Stacks'][0]['Outputs']:
    outputs[output['OutputKey']] = output['OutputValue']
cluster = outputs['DaskECSClusterName']
schedulerservice = outputs['DaskSchedulerServiceName']
workerservice = outputs['DaskWorkerServiceName']
outputs

Start the dask scheduler container through ECS and connect to it.  Note, the dashboard address displayed here is a private address that you won't be able to connect to - the public address is revealled in the following step.

In [None]:
ecs = boto3.client('ecs')
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=1)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])

The following will identify the public IP address of the Dask-Scheduler task (based on security group membership) and output the dashboard URL:

In [None]:
ec2 = boto3.client('ec2')
resp = ec2.describe_network_interfaces(
  Filters=[{
      'Name': 'group-id',
      'Values': [outputs['DaskSchedulerSecurityGroup']]
  }])
schedulerurl = 'http://' + resp['NetworkInterfaces'][0]['Association']['PublicDnsName'] + '/status'
from IPython.display import display,HTML
display(HTML('Dask scheduler URL: <a href=\'' + schedulerurl + '\'>' + schedulerurl + '</a>'))

### Scale out Dask Workers and connect
Start the dask worker tasks and connect to the scheduler

In [None]:
numWorkers=12
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=numWorkers)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

In [None]:
client = Client('Dask-Scheduler.local-dask:8786')
client

## Open 2-m air temperature as a single dataset
This is where the real work begins.  We start by defining the set of S3 objects that we are going to process, which is done using the dask s3fs module and a file pattern.

In [None]:
def fix_accum_var_dims(ds, var):
    # Some varibles like precip have extra time bounds varibles, we drop them here to allow merging with other variables

    # Select variable of interest (drops dims that are not linked to current variable)
    ds = ds[[var]]

    if var in ['air_temperature_at_2_metres',
               'dew_point_temperature_at_2_metres',
               'air_pressure_at_mean_sea_level',
               'northward_wind_at_10_metres',
               'eastward_wind_at_10_metres']:
        ds = ds.rename({'time0':'valid_time_end_utc'})

    elif var in ['precipitation_amount_1hour_Accumulation',
                 'integral_wrt_time_of_surface_direct_downwelling_shortwave_flux_in_air_1hour_Accumulation']:
        ds = ds.rename({'time1':'valid_time_end_utc'})

    else:
        print("Warning, Haven't seen {var} varible yet! Time renaming might not work.".format(var=var))

    return ds

@dask.delayed
def s3open(path):
    fs = s3fs.S3FileSystem(anon=True, default_fill_cache=False, 
                           config_kwargs = {'max_pool_connections': 20})
    return s3fs.S3Map(path, s3=fs)

def open_era5_range(start_year, end_year, variables):
    ''' Opens ERA5 monthly Zarr files in S3, given a start and end year (all months loaded) and a list of variables'''

    file_pattern = 'era5-pds/zarr/{year}/{month}/data/{var}.zarr/'
    years = list(np.arange(start_year, end_year+1, 1))
    months = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]

    l = []
    for var in variables:
        print(var)

        # Get files
        files_mapper = [s3open(file_pattern.format(year=year, month=month, var=var)) for year in years for month in months]

        # Look up correct time dimension by variable name
        if var in ['precipitation_amount_1hour_Accumulation']:
            concat_dim='time1'
        else:
            concat_dim='time0'

        # Lazy load
        ds = xr.open_mfdataset(files_mapper, engine='zarr',
                               concat_dim=concat_dim, combine='nested',
                               coords='minimal', compat='override', parallel=True)

        # Fix dimension names
        ds = fix_accum_var_dims(ds, var)
        l.append(ds)

    ds_out = xr.merge(l)
    return ds_out

Now initialise the xarray dataset

In [None]:
%%time

start_year = 2021
end_year = 2021
ds = open_era5_range(start_year, end_year, ["air_temperature_at_2_metres"])

In [None]:
print('ds size in GB {:0.2f}\n'.format(ds.nbytes / 1e9))
ds.info

The `ds.info` output above shows us that there are four dimensions to the data: lat, lon, and time0; and two data variables: air_temperature_at_2_metres, and air_pressure_at_mean_sea_level.

In [None]:
ds.air_temperature_at_2_metres

## Convert units to C from K
This performs a simple subtraction operation, to convert the temperature unit into Celcius. The operation will not actually be performed at this stage - not until we try to access the result or make the explicit call to `persist`, below.

In [None]:
ds['air_temperature_at_2_metres'] = (ds.air_temperature_at_2_metres - 273.15)
ds.air_temperature_at_2_metres.attrs['units'] = 'C'
ds.air_temperature_at_2_metres

## Read all data into dask worker memory
The following line reads the entire data set into worker memory.  This step makes subsequent calculations much faster and is a useful illustration of how dask works.  Otherwise, calculations are done without reading all data into worker memory at once, and data will need to be read back in for each calculation (taking much longer!).  

The subtraction calculation we queued up above will also be executed during this step.

In [None]:
ds = client.persist(ds)
progress(ds)

Sometimes data isn't evenly distributed, depending on the dataset and chunk size that we selected.  Here we rebalance the data across workers so that future tasks will make best use of cluster resources.

In [None]:
client.rebalance()

## Calculate the mean 2-m air temperature for all times

In [None]:
# calculates the mean along the time dimension
temp_mean = ds['air_temperature_at_2_metres'].mean(dim='valid_time_end_utc')

The expressions above didn’t actually compute anything. They just build the dask task graph. To do the computations, we call the `persist` method below.

In [None]:
temp_mean

In [None]:
temp_mean = temp_mean.persist()
progress(temp_mean)

## Plot Average Surface Temperature
To plot data, we need to read it back into the local notebook python environment.  This is done using the "compute" function.  Once the data is back in local memory, we can use matplotlib to display it visually.  For more information refer to: https://distributed.dask.org/en/latest/manage-computation.html

In [None]:
temp_mean.compute()
temp_mean.plot(figsize=(30, 15))
plt.title(f'Mean 2-m Air Temperature {start_year} - {end_year}')

Thats the mean of the hourly sample in the source dataset.  Let's down-sample the data by taking the daily maximum and re-calculating the mean based on that.  This is one line of code...

In [None]:
daily_max_mean = ds['air_temperature_at_2_metres'].resample(indexer={"valid_time_end_utc":'D'}).max().mean(dim='valid_time_end_utc')
daily_max_mean

We don't necessarily need to call `persist` here, the `compute` call below will trigger this for us - but this lets us see the progress in the notebook.

In [None]:
daily_max_mean = daily_max_mean.persist()
progress(daily_max_mean)

In [None]:
daily_max_mean.compute()
daily_max_mean.plot(figsize=(30, 15))
plt.title(f'Average daily maximum temperature {start_year} - {end_year}')

## Repeat for standard deviation
The data is in memory, so let's do another calculation!

In [None]:
temp_std = ds['air_temperature_at_2_metres'].std(dim='valid_time_end_utc')
temp_std

In [None]:
temp_std = temp_std.persist()
progress(temp_std)

In [None]:
temp_std.compute()
temp_std.plot(figsize=(30, 15), cmap='inferno')
plt.title(f'Standard Deviation 2-m Air Temperature {start_year} - {end_year}')

## Plot temperature time series for points
This example creates a dataframe table of data for some specific locations defined in the array below

In [None]:
# location coordinates
locs = [
    {'name': 'Wellington', 'lon': 172.78, 'lat': -41.28},
    {'name': 'Honolulu', 'lon': -157.84, 'lat': 21.29},
    {'name': 'Seattle', 'lon': -122.33, 'lat': 47.61},
    {'name': 'Melbourne', 'lon': 144.95, 'lat': -37.84}
]

# convert westward longitudes to degrees east
for l in locs:
    if l['lon'] < 0:
        l['lon'] = 360 + l['lon']
locs

In [None]:
ds_locs = xr.Dataset()
air_temp_ds = ds

# interate through the locations and create a dataset
# containing the temperature values for each location
for l in locs:
    name = l['name']
    lon = l['lon']
    lat = l['lat']
    var_name = name

    ds2 = air_temp_ds.sel(lon=lon, lat=lat, method='nearest')

    lon_attr = '%s_lon' % name
    lat_attr = '%s_lat' % name

    ds2.attrs[lon_attr] = ds2.lon.values.tolist()
    ds2.attrs[lat_attr] = ds2.lat.values.tolist()
    ds2 = ds2.rename({'air_temperature_at_2_metres' : var_name}).drop(('lat', 'lon'))

    ds_locs = xr.merge([ds_locs, ds2])

ds_locs.data_vars

In [None]:
ds_locs

In [None]:
ds_locs = client.persist(ds_locs)
progress(ds_locs)

### Convert to dataframe
Conversion between an xarray DataArray into a pandas DataFrame (table) as time series data

In [None]:
df_f = ds_locs.to_dataframe()
df_f

In [None]:
df_f.describe()

In [None]:
df_f.info()

### Plot temperature timeseries
We'll first re-sample the data from hourly to daily maximums.

In [None]:
rs = df_f.resample('D').max()
rs.info()

In [None]:
matplotlib.rcParams['lines.linewidth'] = 1.0
matplotlib.rcParams['lines.linestyle'] = 'solid'
ax = rs.plot(figsize=(30, 15), title=f"ERA5 Daily Maximums {start_year} - {end_year}", grid=1)
ax.set(xlabel='Date', ylabel='2-m Air Temperature (deg C)')
plt.show()

## Dask Memory management

Executing code in these cells can help you recover memory in the worker processes if things are getting tight.

First, clear up all known datasets.

In [None]:
client.cancel(ds)
client.cancel(temp_mean)
client.cancel(dssubset)
client.cancel(subset_mean)

This snippet of code reduces the workers memory footprint, which can be useful in debugging memory use.  It should get rid of most of the "unmanaged" memory reported in the dask dashboard.

In [None]:
import ctypes

def trim_memory() -> int:
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

client.run(trim_memory)

If memory still isn't coming down, this is a last resort. It will terminate all workers and restart them fresh.

In [None]:
client.restart()

## Cluster scale down

When we are temporarily done with the cluster we can scale it down to save on costs

In [None]:
numWorkers=0
ecs.update_service(cluster=cluster, service=workerservice, desiredCount=numWorkers)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[workerservice])

Optional - stop the scheduler

In [None]:
client.close()
ecs.update_service(cluster=cluster, service=schedulerservice, desiredCount=0)
ecs.get_waiter('services_stable').wait(cluster=cluster, services=[schedulerservice])