# **Notebook Description**

7/22/24

Authors:<br>
Jacob Radford<br>
Jebb Stewart
<br><br>
This notebook is designed for running a coarse (1°) version of GraphCast using free Google Colaboratory resources and plotting the output. You will need to log in with a Google acocunt to access these resources. Note that lines beginning with an exclamation point are calls to the command line rather than native Python code.

# **Step 0: Connect to a runtime**

At the top right of this page (next to "Connect") click on the dropdown arrow, then "change runtime type." Select "Python 3" and "T4 GPU"


# **Step 1: Install all of the required packages**

*   ai-models-gfs: Extension of ai-models, a package for easily running AIWP models
*   ai-models-graphcast-gfs: Extension of ai-models-graphcast, the GraphCast plug-in for ai-models
*   basemap: Package for plotting data on maps
*   git+https://github.com/deepmind/graphcast.git: The GraphCast repository
*   jax and jaxlib: Machine learning framework for running GraphCast with a GPU

In [2]:
#@title Install packages

%%capture
!pip install ai-models-gfs==0.0.8 ai-models-graphcast-gfs==0.0.7
!pip install basemap basemap-data-hires
!pip install git+https://github.com/deepmind/graphcast.git
!pip install jax==0.4.23 jaxlib==0.4.23+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# **Step 2: Run GraphCast (for Hurricane Beryl)**

Here is where we run GraphCast. Again note that this is just a call to the command line with the following arguments:


*   --input (cds, gfs, gdas): Input data source
*   --download-assets: Needed first time to download model weights
*   --date: Date of initialization
*   --time: Time of initialization (6 hour increments)
*   --lead-time: How long to run the model forward in hours
*   --onedeg: If included, run 1° version, else 0.25° version of GraphCast
*   --nc-or-grib: Whether to output as grib, netcdf4, or both (g, n, ng)
*   --path: Output file path
*   Model name (graphcast, fourcastnetv2-small, panguweather)

Assuming this runs without error, a netCDF or grib file will be produced.


In [94]:
!ai-models-gfs --input gfs --download-assets --date 20240702 --time 1200 --lead-time 240 --onedeg --nc-or-grib n --path 20240702_12 graphcast

2024-07-18 04:11:14,452 INFO NumExpr defaulting to 2 threads.
2024-07-18 04:11:17,168 INFO Writing results to 20240702_12
2024-07-18 04:11:17,169 INFO Loading surface fields from GFS
2024-07-18 04:11:17,281 INFO Downloading https://noaa-gfs-bdp-pds.s3.amazonaws.com/gfs.20240702/06/atmos/gfs.t06z.pgrb2.0p25.f000
2024-07-18 04:11:32,035 INFO Downloading https://noaa-gfs-bdp-pds.s3.amazonaws.com/gfs.20240702/12/atmos/gfs.t12z.pgrb2.0p25.f000
2024-07-18 04:11:46,723 INFO Loading pressure fields from GFS
2024-07-18 04:11:56,406 INFO Model description: 
Low resolution version of the GraphCast model (1deg, smaller mesh), with 37
pressure levels. This model is trained on ERA5 data from 1979 to 2015, and can
be causally evaluated on 2016 and later years. This model takes as inputs
`total_precipitation_6hr`. This model has much lower memory requirements.

2024-07-18 04:11:56,407 INFO Model license: 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.

# **Step 3: Plot the data**

Now that the model has completed we can plot the output. We will do so using a nice function that the GraphCast team has provided.

In [95]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
example_batch = xarray.open_dataset('20240702_12.nc')
example_batch

In [96]:
# @title Define plotting functions courtesy of Google DeepMind team
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import math
import datetime
from IPython.display import HTML
from matplotlib import animation
from mpl_toolkits.basemap import Basemap, shiftgrid

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
    vmin = np.nanpercentile(data, (0.1 if robust else 0))
    vmax = np.nanpercentile(data, (99.9 if robust else 100))
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return (data, matplotlib.colors.Normalize(vmin, vmax),
            ("RdBu_r" if center is not None else "viridis"))

def convert_longitudes(lon):
    lon = np.asarray(lon)
    lon = ((lon + 180) % 360) - 180
    return lon

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

    first_data = next(iter(data.values()))[0]
    max_steps = first_data.sizes.get("time", 1)
    assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

    cols = min(cols, len(data))
    rows = math.ceil(len(data) / cols)
    figure = plt.figure(figsize=(plot_size * 2 * cols,
                                 plot_size * rows))
    figure.suptitle(fig_title, fontsize=16)
    figure.subplots_adjust(wspace=0, hspace=0)
    figure.tight_layout()

    images = []
    for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
        ax = figure.add_subplot(rows, cols, i+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)

        # Convert longitudes
        lon = plot_data.coords['longitude'].values
        lat = plot_data.coords['latitude'].values
        plot_data_shifted, lon_shifted = shiftgrid(180, plot_data.values, lon, start=False)

        # Create basemap
        m = Basemap(projection='cyl', resolution='c', ax=ax)
        m.drawcoastlines()
        m.drawcountries()

        lon_shifted, lat_shifted = np.meshgrid(lon_shifted, lat)
        x, y = m(lon_shifted, lat_shifted)

        im = m.pcolormesh(x, y, plot_data_shifted[0], norm=norm, cmap=cmap)
        plt.colorbar(
            mappable=im,
            ax=ax,
            orientation="vertical",
            pad=0.02,
            aspect=16,
            shrink=0.75,
            cmap=cmap,
            extend=("both" if robust else "neither"))
        images.append(im)

    def update(frame):
        if "time" in first_data.dims:
            td = datetime.datetime.utcfromtimestamp(first_data["time"][frame].item() / 1000000000).strftime('%Y-%m-%d %H:%M:%S')

            figure.suptitle(f"{fig_title}, {td}", fontsize=16)
        else:
            figure.suptitle(fig_title, fontsize=16)
        for im, (plot_data, norm, cmap) in zip(images, data.values()):
            im.set_array(shiftgrid(180, plot_data.isel(time=frame, missing_dims="ignore").values, plot_data.coords['longitude'].values, start=False)[0])

    ani = animation.FuncAnimation(
        fig=figure, func=update, frames=max_steps, interval=250)
    plt.close(figure.number)
    return HTML(ani.to_jshtml())


In [97]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=example_batch.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=example_batch.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

VBox(children=(Dropdown(description='Variable', index=4, options=('u10', 'v10', 't2', 'msl', 't', 'u', 'v', 'z…

In [99]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center),
}
fig_title = plot_example_variable.value
if "level" in example_batch[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value)

Output hidden; open in https://colab.research.google.com to view.