# WRF model wind speed calculations and visualization

- [High Resolution WRF Simulations of the Current and Future Climate of North America](https://rda.ucar.edu/datasets/ds612.0/index.html#sfol-wl-/data/ds612.0?g=33200406) (DOI: 10.5065/D6V40SXP)
- For example, Research Data Archive (RDA) data set https://rda.ucar.edu/data/ds612.0/PGW3D/2004/wrf3d_d01_PGW_U_20040602.nc (careful to click, big download!)
- Can take a long time to calculate due to volume of data. **How can we speed up calculations in this Jupyter Notebook?**

# Summary / Outline

- What is Dask?
- Set up our cluster
- Run "positive control experiment" to ensure cluster is running
- Run notebooks analyzing and visualizing wind speed data from UCAR RDA
- Future work

# Dask

- Dask is a Python library for parallel and distributed computing
- "Lazy" loading of larger-than-memory datasets
- "HPC in the client or HPC in the notebook"
- Tries to hide messy details of parallel computation (though may still have to think about chunking, etc.)
- Beyond Jupyter, we'll use the following technologies for our Dask cluster on the NSF Jetstream Cloud:

![image.png](attachment:1fa0a5a5-10e7-4f97-8dea-a631c9302f77.png) ![image.png](attachment:1c2c298c-07df-4cef-a230-5a35b3d65cc5.png) ![image.png](attachment:4813949f-8d49-468a-a90c-74d5d4dcc57a.png)

![image.png](attachment:66dd003a-7df1-4e66-b910-5de0aae83d2e.png)

# Set up 

- Xarray / Dask / Kubernetes allows for fast parallelization 
- [Must be very careful to coordinate worker / client / scheduler conda environments](https://github.com/julienchastang/jupyter-classroom/tree/master/rut-spring-2022)
- Setup for Dask cluster:
  - 4 workers
  - 4 cores per worker
  - 8 GBs per worker
  - 16 “task streams”

Our imports. `ProgressBar` will display the progress of our Dask calculations in the Dask dashboard

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import dask_gateway, dask, distributed, time
from dask.diagnostics import ProgressBar

Baseline time to determine how long it takes for this notebook to run.

In [None]:
time0 = time.time()

Lengthy timeouts are apparently important.

In [None]:
dask.config.set({"distributed.comm.timeouts.tcp": "300s"})
dask.config.set({"distributed.comm.timeouts.connect": "300s"})

Creating the cluster gateway that will give us access to our cluster.

In [None]:
from dask_gateway import Gateway
gateway = Gateway(
    address="http://traefik-dask-gateway/services/dask-gateway/",
    public_address="https://dtest-1.ees220002.projects.jetstream-cloud.org/services/dask-gateway/",
    auth="jupyterhub")
gateway

Let's examine the default cluster options:

In [None]:
options = gateway.cluster_options()
options.worker_cores = 2
options.worker_memory = 8
options

Launching cluster. This might take a few minutes. Make a note of the Dask Dashboard URL. Open it in a separate tab. The URL should look something like:

```
https://dtest-1.ees220002.projects.jetstream-cloud.org/services/dask-gateway/clusters/jhub.231b3db17f834c2e92130c5dde31c346/status
```

Make sure the first part of the URL matches what you see for the URL of this page. **IMPORTANT**: Only call this cell once or else you will have multiple clusters on your hands that you will have to sort through.

In [None]:
cluster = gateway.new_cluster(options)
cluster.scale(4)
cluster

You only have **one** cluster, right?

In [None]:
clusters = gateway.list_clusters()
clusters

Grab first, cluster. (There should not be more than one anyway.)

In [None]:
cluster = gateway.connect(clusters[0].name)

Don't forget to call or else cluster will apparently die. 

In [None]:
client = cluster.get_client()
client

Apparently important if you want to see Dask Dashboard update.

In [None]:
client = gateway.connect(cluster.name).get_client()

## Control "Experiment"

This is a "positive control experiment" to ensure the cluster is working with this "embarrassingly parallel" calculation. Make sure the dashboard "task stream" is also working.

In [None]:
import dask.array as da
a = da.random.normal(size=(40000, 40000), chunks=(500, 500))
a.mean().compute()

-------------

## Now run notebook in earnest

Now you are ready to run your notebook in earnest. Let's first define an `unstagger` function to unstagger the WRF grid.

In [None]:
def unstagger(ds, var, coord, new_coord):
    var1 = ds[var].isel({coord: slice(None, -1)})
    var2 = ds[var].isel({coord: slice(1, None)})
    return ((var1 + var2) / 2).rename({coord: new_coord})

Let's start grabbing data from an

Open U dataset from the RDA THREDDS catalog

In [None]:
ds = xr.open_dataset('https://thredds.rda.ucar.edu/thredds/dodsC/files/g/ds612.0/PGW3D/2004/wrf3d_d01_PGW_U_20040601.nc',
                      chunks={'bottom_top': 10})

Plot unstaggered surface **U** winds. Theoretically, should be able to examine the progress in the Dask Dashboard below the task stream.

In [None]:
with ProgressBar():
    ds.U.sel(Time='2004-06-01T00:00').isel(bottom_top=0).plot()

Unstagger U grid

In [None]:
with ProgressBar():
    ds['U_unstaggered'] = unstagger(ds, 'U', 'west_east_stag', 'west_east')

Open V dataset

In [None]:
ds2 = xr.open_dataset('https://thredds.rda.ucar.edu/thredds/dodsC/files/g/ds612.0/PGW3D/2004/wrf3d_d01_PGW_V_20040601.nc',
                      chunks={'bottom_top': 10})
ds2

In [None]:
with ProgressBar():
    ds2['V_unstaggered'] = unstagger(ds2, 'V', 'south_north_stag', 'south_north')

Merge U and V winds

In [None]:
ds = xr.merge((ds, ds2))
ds

Calculate wind speed.

In [None]:
ds['speed'] = np.sqrt(ds.U_unstaggered**2 + ds.V_unstaggered**2)
ds

Plot some wind speeds at a certain height and time.

In [None]:
ds.speed.sel(Time='2004-06-01T18:00').isel(bottom_top=10).plot()

Grab U and V data for a number of time steps from the THREDDS RDA.

In [None]:
prefix = 'https://thredds.rda.ucar.edu/thredds/dodsC/files/g/ds612.0/PGW3D/2004/wrf3d_d01_PGW_'
list_of_files = []
for var in ('U', 'V'):
#    for day in range(1,15):
    for day in range(1,4):
        filename = prefix + var + f'_200406{day:02g}.nc'
        list_of_files.append(filename)

Create a multifile dataset.

In [None]:
with ProgressBar():
    ds = xr.open_mfdataset(list_of_files, parallel=True, chunks={'bottom_top': 10})

Calculate wind speed for entire dataset.

In [None]:
with ProgressBar():
    ds['U_unstaggered'] = unstagger(ds, 'U', 'west_east_stag', 'west_east')
    ds['V_unstaggered'] = unstagger(ds, 'V', 'south_north_stag', 'south_north')
    ds['speed'] = np.sqrt(ds.U_unstaggered**2 + ds.V_unstaggered**2)

Again, plot some wind speeds at a certain height and time.

In [None]:
with ProgressBar():
    ds.speed.sel(Time='2004-06-03T18:00').isel(bottom_top=10).plot()

Mean winds over time.

In [None]:
mean_speed_lev5 = ds.speed.isel(bottom_top=5).mean(dim='Time')

In [None]:
with ProgressBar():
    mean_speed_lev5.plot()

Let's calculate the Jetstream for our time range.

In [None]:
zonal_avg_mean_wind = ds.speed.mean(dim='west_east').mean(dim='Time')
zonal_avg_mean_wind

Let's plot the Jetstream

In [None]:
fig, ax = plt.subplots(figsize=(15,12))
zonal_avg_mean_wind.plot.contourf(ax=ax, levels=10)
ax.set_title('Mean Wind Speed (Zonally Averaged)')
plt.show()

How long did it take for the notebook to run?

In [None]:
time = time.time() - time0 
time

Don't forget to shut down your cluster.

In [None]:
for cluster in gateway.list_clusters():
    try:
        c = gateway.connect(cluster.name)
        c.shutdown()
    except Exception as e:
        print(f"Failed to shut down cluster {cluster.name}: {e}")

Stop cell execution. Trying to programmatically shutdown kernel. This snippet does not seem work. If kernel is left running, you will see a number of error messages after a while. 

In [None]:
class StopExecution(Exception):
    def _render_traceback_(self):
        pass

raise StopExecution

## Future work

- Experiment with cores, chunking strategies, etc.
- Experiment with different data stores. Current notebook will probably quickly be I/O bound.