# **Notebook Description**

7/22/24

Authors:<br>
Jacob Radford<br>
Jebb Stewart
<br><br>
This notebook is designed for running a coarse (1°) version of GraphCast using free Google Colaboratory resources and plotting the output. You will need to log in with a Google acocunt to access these resources. Note that lines beginning with an exclamation point are calls to the command line rather than native Python code.

# **Step 0: Connect to a runtime**

At the top right of this page (next to "Connect") click on the dropdown arrow, then "change runtime type." Select "Python 3" and "TPU". The TPU runtime may or may not be available at any given time depending on demand (unless you purchase colaboratory compute units).


# **Step 1: Install all of the required packages**

*   ai-models-gfs: Extension of ai-models, a package for easily running AIWP models
*   ai-models-graphcast-gfs: Extension of ai-models-graphcast, the GraphCast plug-in for ai-models
*   basemap: Package for plotting data on maps
*   git+https://github.com/deepmind/graphcast.git: The GraphCast repository
*   jax and jaxlib: Machine learning framework for running GraphCast with a GPU.



In [1]:
#@title Install packages
!pip install ai-models-gfs==0.0.10 ai-models-graphcast-gfs==0.0.12
!pip install basemap basemap-data-hires
!pip install git+https://github.com/deepmind/graphcast.git
!pip install --upgrade "jax[tpu]<0.4.24" -f  https://storage.googleapis.com/jax-releases/libtpu_releases.html

Collecting git+https://github.com/deepmind/graphcast.git
  Cloning https://github.com/deepmind/graphcast.git to /tmp/pip-req-build-o0myeazu
  Running command git clone --filter=blob:none --quiet https://github.com/deepmind/graphcast.git /tmp/pip-req-build-o0myeazu
  Resolved https://github.com/deepmind/graphcast.git to commit 97d1ad50b0b7af4aaed7790167dffa769bae1f2c
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting cartopy (from graphcast==0.1.1)
  Using cached Cartopy-0.24.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Collecting colabtools (from graphcast==0.1.1)
  Using cached colabtools-0.0.1-py3-none-any.whl.metadata (511 bytes)
Collecting jraph (from graphcast==0.1.1)
  Using cached jraph-0.0.6.dev0-py3-none-any.whl.metadata (9.7 kB)
Collecting rtree (from graphcast==0.1.1)
  Using cached Rtree-1.3.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.1 kB)
Collecting trimesh (from graphcast==0.1.1)
  Using cached trim

# **Step 2: Run GraphCast for Hurricane Beryl**

Here is where we run GraphCast. Again note that this is just a call to the command line with the following arguments:


*   --input (cds, gfs, gdas): Input data source
*   --download-assets: Needed first time to download model weights
*   --date: Date of initialization
*   --time: Time of initialization (6 hour increments)
*   --lead-time: How long to run the model forward in hours
*   --onedeg: If included, run 1° version, else 0.25° version of GraphCast
*   --nc-or-grib: Whether to output as grib, netcdf4, or both (g, n, ng)
*   --path: Output file path
*   Model name (graphcast, fourcastnetv2-small, panguweather)

Assuming this runs without error, a netCDF or grib file will be produced.


In [2]:
#@title Run Graphcast with command line
!ai-models-gfs --input gfs --download-assets --date 20240702 --time 1200 --lead-time 240 --onedeg --nc-or-grib n --path 20240702_12_graphcast graphcast

2024-11-08 18:07:09,897 INFO Writing results to 20240702_12_graphcast
2024-11-08 18:07:09,898 INFO Downloading /content/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
2024-11-08 18:07:09,898 INFO Downloading https://storage.googleapis.com/dm_graphcast/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
2024-11-08 18:07:20,466 INFO Downloading /content/stats/diffs_stddev_by_level.nc
2024-11-08 18:07:20,467 INFO Downloading https://storage.googleapis.com/dm_graphcast/stats/diffs_stddev_by_level.nc
2024-11-08 18:07:20,603 INFO Downloading /content/stats/mean_by_level.nc
2024-11-08 18:07:20,603 INFO Downloading https://storage.googleapis.com/dm_graphcast/stats/mean_by_level.nc
2024-11-08 18:07:20,714 INFO Downloading /content/stats/stddev_by_level.nc
2024-11-08 18:07:20,714 INFO Downloading https://storage.googleapis.com/dm_graphcast/sta

# **Step 3: Plot the data**

Now that the model has completed we can plot the output. We will do so using a nice function that the GraphCast team has provided.

In [3]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
graphcast_20240702_12 = xarray.open_dataset('20240702_12_graphcast.nc')
graphcast_20240702_12

In [4]:
# @title Define plotting functions courtesy of Google DeepMind team
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import math
import datetime
from IPython.display import HTML
from matplotlib import animation
from mpl_toolkits.basemap import Basemap, shiftgrid

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    lat_bounds: Optional[tuple[float, float]] = None,
    lon_bounds: Optional[tuple[float, float]] = None,
    vminpercent: float = 5,
    vmaxpercent: float = 95
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

    if lat_bounds and lon_bounds:
        data = data.sel(latitude=slice(lat_bounds[0], lat_bounds[1]),
                        longitude=slice(lon_bounds[0], lon_bounds[1]))

    vmin = np.nanpercentile(data, (vminpercent if robust else 0))
    vmax = np.nanpercentile(data, (vmaxpercent if robust else 100))
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return (data, matplotlib.colors.Normalize(vmin, vmax),
            ("RdBu_r" if center is not None else "viridis"))

def convert_longitudes(lon):
    lon = np.asarray(lon)
    lon = ((lon + 180) % 360) - 180
    return lon

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4,
    lat_bounds: tuple[float, float] = (-90, 90),
    lon_bounds: tuple[float, float] = (-180, 180),
    nlevels=21
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

    first_data = next(iter(data.values()))[0]
    max_steps = first_data.sizes.get("time", 1)
    assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

    cols = min(cols, len(data))
    rows = math.ceil(len(data) / cols)
    figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
    figure.suptitle(fig_title, fontsize=16)
    figure.subplots_adjust(wspace=0.3, hspace=0.3)  # Adjust these values to control spacing

    images = []
    for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
        ax = figure.add_subplot(rows, cols, i+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)

        # Convert longitudes
        lon = plot_data.coords['longitude'].values
        lat = plot_data.coords['latitude'].values
        plot_data_shifted, lon_shifted = shiftgrid(180, plot_data.values, lon, start=False)

        # Create basemap with specified bounds
        m = Basemap(projection='cyl', resolution='c', ax=ax,
                    llcrnrlat=lat_bounds[0], urcrnrlat=lat_bounds[1],
                    llcrnrlon=lon_bounds[0], urcrnrlon=lon_bounds[1])
        m.drawcoastlines()
        m.drawcountries()

        lon_shifted, lat_shifted = np.meshgrid(lon_shifted, lat)
        x, y = m(lon_shifted, lat_shifted)

        # Define levels and BoundaryNorm
        levels = np.linspace(norm.vmin, norm.vmax, nlevels)
        norm = matplotlib.colors.BoundaryNorm(levels, ncolors=256)

        im = m.pcolormesh(x, y, plot_data_shifted[0], norm=norm, cmap=cmap)
        plt.colorbar(
            mappable=im,
            ax=ax,
            orientation="vertical",
            pad=0.02,
            aspect=16,
            shrink=0.75,
            cmap=cmap,
            extend=("both" if robust else "neither"))
        images.append(im)

    def update(frame):
        if "time" in first_data.dims:
            td = datetime.datetime.utcfromtimestamp(first_data["time"][frame].item() / 1000000000).strftime('%Y-%m-%d %H:%M:%S')
            figure.suptitle(f"{fig_title}, {td}", fontsize=16)
        else:
            figure.suptitle(fig_title, fontsize=16)
        for im, (plot_data, norm, cmap) in zip(images, data.values()):
            im.set_array(shiftgrid(180, plot_data.isel(time=frame, missing_dims="ignore").values, plot_data.coords['longitude'].values, start=False)[0])

    ani = animation.FuncAnimation(
        fig=figure, func=update, frames=max_steps, interval=250)
    plt.close(figure.number)
    return HTML(ani.to_jshtml())


In [5]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=graphcast_20240702_12.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=graphcast_20240702_12.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=graphcast_20240702_12.dims["time"], value=graphcast_20240702_12.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

  min=1, max=graphcast_20240702_12.dims["time"], value=graphcast_20240702_12.dims["time"],


VBox(children=(Dropdown(description='Variable', index=4, options=('u10', 'v10', 't2', 'msl', 't', 'u', 'v', 'z…

In [6]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(graphcast_20240702_12, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center,vminpercent=5,vmaxpercent=95),
}
fig_title = plot_example_variable.value
if "level" in graphcast_20240702_12[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value,lat_bounds=(0, 40),lon_bounds=(-110,-45),nlevels=20)

# **Bonus Application 1: Large ensembles**

Now that we've run GraphCast, let's see what else we can do with it. A commonly discussed application is developing large ensembles. In the interest of time and compute resources we won't make a *large* ensemble, but we'll at least try a small one.

In [None]:
#@title Run GraphCast 4 times with random perturbations
for i in range(0,4):
  command = f"ai-models-gfs --input gfs --download-assets --date 20240702 --time 1200 --lead-time 120 --onedeg --nc-or-grib n --path 20240702_12_mem{str(i).zfill(2)} --ensemble 0.0050 0.20 graphcast"
  !{command}

In [None]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
file_paths = ['20240702_12_mem00.nc', '20240702_12_mem01.nc', '20240702_12_mem02.nc', '20240702_12_mem03.nc']
example_batch = xarray.open_dataset(file_paths[0])

In [None]:
# @title Define paneled plotting functions courtesy of Google DeepMind team
from typing import Optional
import matplotlib
import matplotlib.pyplot as plt
import ipywidgets as widgets
import numpy as np
import math
import datetime
from IPython.display import HTML
from matplotlib import animation
from mpl_toolkits.basemap import Basemap, shiftgrid
import xarray as xr

def select_ensemble(
    data: xr.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xr.Dataset:
    data = data[variable]
    if "batch" in data.dims:
        data = data.isel(batch=0)
    if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
        data = data.isel(time=range(0, max_steps))
    if level is not None and "level" in data.coords:
        data = data.sel(level=level)
    return data

def scale_ensemble(
    data: xr.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    lat_bounds: Optional[tuple[float, float]] = None,
    lon_bounds: Optional[tuple[float, float]] = None,
    vminpercent: float = 5,
    vmaxpercent: float = 95
    ) -> tuple[xr.Dataset, float, float, matplotlib.colors.Normalize, str]:

    if lat_bounds and lon_bounds:
        data = data.sel(latitude=slice(lat_bounds[0], lat_bounds[1]),
                        longitude=slice(lon_bounds[0], lon_bounds[1]))

    vmin = np.nanpercentile(data, (vminpercent if robust else 0))
    vmax = np.nanpercentile(data, (vmaxpercent if robust else 100))
    if center is not None:
        diff = max(vmax - center, center - vmin)
        vmin = center - diff
        vmax = center + diff
    return (data, vmin, vmax, matplotlib.colors.Normalize(vmin, vmax),
            ("RdBu_r" if center is not None else "viridis"))

def convert_longitudes(lon):
    lon = np.asarray(lon)
    lon = ((lon + 180) % 360) - 180
    return lon

def plot_data_ensemble(
    data: dict[str, tuple[xr.Dataset, float, float, matplotlib.colors.Normalize, str]],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    lat_bounds: tuple[float, float] = (-90, 90),
    lon_bounds: tuple[float, float] = (-180, 180),
    nlevels=21
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:

    first_data = next(iter(data.values()))[0]
    max_steps = first_data.sizes.get("time", 1)
    assert all(max_steps == d[0].sizes.get("time", 1) for d in data.values())

    cols = 2
    rows = 2
    figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
    figure.suptitle(fig_title, fontsize=16)
    figure.subplots_adjust(wspace=0, hspace=0)
    figure.tight_layout()

    # Determine common vmin and vmax
    vmin = min(d[1] for d in data.values())
    vmax = max(d[2] for d in data.values())

    images = []
    for i, (title, (plot_data, _, _, norm, cmap)) in enumerate(data.items()):
        ax = figure.add_subplot(rows, cols, i+1)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(title)

        # Convert longitudes
        lon = plot_data.coords['longitude'].values
        lat = plot_data.coords['latitude'].values
        plot_data_shifted, lon_shifted = shiftgrid(180, plot_data.values, lon, start=False)

        # Create basemap with specified bounds
        m = Basemap(projection='cyl', resolution='c', ax=ax,
                    llcrnrlat=lat_bounds[0], urcrnrlat=lat_bounds[1],
                    llcrnrlon=lon_bounds[0], urcrnrlon=lon_bounds[1])
        m.drawcoastlines()
        m.drawcountries()

        lon_shifted, lat_shifted = np.meshgrid(lon_shifted, lat)
        x, y = m(lon_shifted, lat_shifted)

        # Define 20 levels and BoundaryNorm
        levels = np.linspace(vmin, vmax, nlevels)
        norm = matplotlib.colors.BoundaryNorm(levels, ncolors=256)

        im = m.pcolormesh(x, y, plot_data_shifted[0], norm=norm, cmap=cmap)
        plt.colorbar(
            mappable=im,
            ax=ax,
            orientation="vertical",
            pad=0.02,
            aspect=16,
            shrink=0.75,
            cmap=cmap,
            extend=("both" if robust else "neither"))
        images.append(im)

    def update(frame):
        if "time" in first_data.dims:
            td = datetime.datetime.utcfromtimestamp(first_data["time"][frame].item() / 1000000000).strftime('%Y-%m-%d %H:%M:%S')

            figure.suptitle(f"{fig_title}, {td}", fontsize=16)
        else:
            figure.suptitle(fig_title, fontsize=16)
        for im, (plot_data, _, _, _, _) in zip(images, data.values()):
            im.set_array(shiftgrid(180, plot_data.isel(time=frame, missing_dims="ignore").values, plot_data.coords['longitude'].values, start=False)[0])

    ani = animation.FuncAnimation(
        fig=figure, func=update, frames=max_steps, interval=250)
    plt.close(figure.number)
    return HTML(ani.to_jshtml())

In [None]:
#@title Choose data to plot
plot_example_variable = widgets.Dropdown(
    options=example_batch.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=example_batch.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

In [None]:
#@title Make the plot
variable_name = plot_example_variable.value
level = plot_example_level.value
robust = plot_example_robust.value
max_steps = plot_example_max_steps.value

data = {}
for i, file_path in enumerate(file_paths):
    ds = xr.open_dataset(file_path)
    scaled_data = scale_ensemble(select_ensemble(ds, variable_name, level, max_steps), robust=robust,vminpercent=5,vmaxpercent=95)
    data[f"File {i+1}"] = scaled_data

fig_title = f"{variable_name} at {level} hPa" if level else variable_name
plot_data_ensemble(data, fig_title=fig_title, plot_size=4, robust=robust, lat_bounds=(5, 35), lon_bounds=(-105, -50), nlevels=20)


# **Bonus Application 2: The What-If Machine**

![Alt Text](https://static0.gamerantimages.com/wordpress/wp-content/uploads/2024/03/futurama-professor-farnsworth.jpg?q=50&fit=crop&w=1500&dpr=1.5)



In [None]:
#@title What if we made the gulf 10 kelvin warmer?
!ai-models-gfs --input gfs --download-assets --date 20240702 --time 1200 --lead-time 240 --onedeg --nc-or-grib n --path 20240702_12_perturb --perturbation 2m_temperature 10 25 -90 600 0 graphcast

In [None]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
graphcast_20240702_12_perturb = xarray.open_dataset('20240702_12_perturb.nc')
graphcast_20240702_12_perturb

In [None]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=graphcast_20240702_12_perturb.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=graphcast_20240702_12_perturb.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=graphcast_20240702_12_perturb.dims["time"], value=graphcast_20240702_12_perturb.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

In [None]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(graphcast_20240702_12_perturb, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center,vminpercent=5,vmaxpercent=95),
}
fig_title = plot_example_variable.value
if "level" in graphcast_20240702_12_perturb[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value,lat_bounds=(0, 40),lon_bounds=(-110,-45),nlevels=20)

# **Bonus Application 3: Other AIWP models**

With the ai-models package, we can other AIWP models like FourCastNetv2-small and PanguWeather. We'll give FourcastNetv2 a try next.

PanguWeather (and 0.25° GraphCast) requires a more powerful GPU and so we won't be able to run them in the colab notebook.

In [7]:
#@title Install FourCastNetv2 (and ai-models-gfs again because we switched runtime type)
%%capture
!pip install ai-models-fourcastnetv2-gfs

In [8]:
#@title Run FourCastNetv2-small for Hurricane Beryl
!ai-models-gfs --input gfs --download-assets --date 20240702 --time 1200 --lead-time 120 --nc-or-grib n --path 20240702_12_fourcastnet fourcastnetv2-small

2024-11-08 18:12:17,972 INFO Writing results to 20240702_12_fourcastnet
2024-11-08 18:12:17,973 INFO Downloading /content/weights.tar
2024-11-08 18:12:17,973 INFO Downloading https://get.ecmwf.int/repository/test-data/ai-models/fourcastnetv2/small/weights.tar
2024-11-08 18:19:50,729 INFO Downloading /content/global_means.npy
2024-11-08 18:19:50,730 INFO Downloading https://get.ecmwf.int/repository/test-data/ai-models/fourcastnetv2/small/global_means.npy
2024-11-08 18:19:51,647 INFO Downloading /content/global_stds.npy
2024-11-08 18:19:51,647 INFO Downloading https://get.ecmwf.int/repository/test-data/ai-models/fourcastnetv2/small/global_stds.npy
2024-11-08 18:19:52,424 INFO Loading ./global_means.npy
2024-11-08 18:19:52,425 INFO Loading ./global_stds.npy
2024-11-08 18:19:52,425 INFO Loading surface fields from GFS
2024-11-08 18:19:53,777 INFO Loading pressure fields from GFS
2024-11-08 18:20:02,139 INFO Using device 'CPU'. The speed of inference depends greatly on the device.
  checkpo

In [9]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
fourcastnet_20240702_12 = xarray.open_dataset('20240702_12_fourcastnet.nc')

In [10]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=fourcastnet_20240702_12.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=fourcastnet_20240702_12.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=fourcastnet_20240702_12.dims["time"], value=fourcastnet_20240702_12.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

  min=1, max=fourcastnet_20240702_12.dims["time"], value=fourcastnet_20240702_12.dims["time"],


VBox(children=(Dropdown(description='Variable', index=8, options=('u10', 'v10', 'u100', 'v100', 't2', 'msl', '…

In [11]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(fourcastnet_20240702_12, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center,vminpercent=5,vmaxpercent=95),
}
fig_title = plot_example_variable.value
if "level" in fourcastnet_20240702_12[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value,lat_bounds=(0, 40),lon_bounds=(-110,-45),nlevels=100)

# **Bonus Application 4: Choose your own GraphCast adventure**

In [None]:
#@title Install packages again for T4 runtime
!pip install ai-models-gfs==0.0.10 ai-models-graphcast-gfs==0.0.12
!pip install basemap basemap-data-hires
!pip install git+https://github.com/deepmind/graphcast.git
!pip install jax==0.4.23 jaxlib==0.4.23+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
# @title Run Graphcast with custom configuration (connect to T4)
# Variable options:  temperature,
#                    geopotential,
#                    u_component_of_wind,
#                    v_component_of_wind,
#                    vertical_velocity,
#                    specific_humidity,
#                    2m_temperature,
#                    mean_sea_level_pressure,
#                    10m_v_component_of_wind,
#                    10m_u_component_of_wind,

date = "20240616"
time = "0600"
perturb = True
perturbation_variable = "2m_temperature"
perturbation_magnitude = "10"
perturbation_latitude = "25"
perturbation_longitude = "-90"
perturbation_radius = "600"
perturbation_level = "0"

if perturb:
    command_parts = [
        "ai-models-gfs",
        "--input gfs",
        "--download-assets",
        f"--date {date}",
        f"--time {time}",
        "--lead-time 240",
        "--onedeg",
        "--nc-or-grib n",
        f"--perturbation {perturbation_variable} {perturbation_magnitude} {perturbation_latitude} {perturbation_longitude} {perturbation_radius} {perturbation_level}",
        f"--path {date}_{time}_graphcast_cyoa",
        "graphcast"
    ]
else:
    command_parts = [
        "ai-models-gfs",
        "--input gfs",
        "--download-assets",
        f"--date {date}",
        f"--time {time}",
        "--lead-time 240",
        "--onedeg",
        "--nc-or-grib n",
        f"--path {date}_{time}_graphcast_cyoa",
        "graphcast"
    ]

command = " ".join(command_parts)
!{command}

In [None]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
graphcast_cyoa = xarray.open_dataset(f'{date}_{time}_graphcast_cyoa.nc')
graphcast_cyoa

In [None]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=graphcast_cyoa.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=graphcast_cyoa.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=graphcast_cyoa.dims["time"], value=graphcast_cyoa.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

In [None]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(graphcast_cyoa, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center,vminpercent=.1,vmaxpercent=99.9),
}
fig_title = plot_example_variable.value
if "level" in graphcast_cyoa[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value,lat_bounds=(-90, 90),lon_bounds=(-180,180),nlevels=100)

# **Bonus Application 5: Choose your own FourCastNetv2 adventure**

In [None]:
#@title Install FourCastNetv2 (and ai-models-gfs again because we switched runtime type)
!pip install basemap basemap-data-hires ai-models-gfs ai-models-fourcastnetv2-gfs ai-models-panguweather-gfs

In [None]:
#@title Run FourCastNetv2-small with custom configuration (connect to TPUv2)
date = "20240616"
time = "0600"
command_parts = [
    "ai-models-gfs",
    "--input gfs",
    "--download-assets",
    f"--date {date}",
    f"--time {time}",
    "--lead-time 120",
    "--nc-or-grib n",
    f"--path {date}_{time}_fourcastnet_cyoa",
    "fourcastnetv2-small"
]
command = " ".join(command_parts)
print(command)
!{command}

In [None]:
#@title Load the netCDF file data that we just produced into xarray dataset
import xarray
fourcastnet_cyoa = xarray.open_dataset(f'{date}_{time}_fourcastnet_cyoa.nc')
fourcastnet_cyoa

In [None]:
# @title Choose data to plot

plot_example_variable = widgets.Dropdown(
    options=fourcastnet_cyoa.data_vars.keys(),
    value="t",
    description="Variable")
plot_example_level = widgets.Dropdown(
    options=fourcastnet_cyoa.coords["level"].values,
    value=500,
    description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=fourcastnet_cyoa.dims["time"], value=fourcastnet_cyoa.dims["time"],
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

In [None]:
#@title Make the plot
plot_size = 6

#These are the variables that need diverging color map
if plot_example_variable.value in ['u10','v10','u','v','w']:
  center = 1
else:
  center = None
data = {
  " ": scale(select(fourcastnet_cyoa, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value,center=center,vminpercent=.1,vmaxpercent=99.9),
}
fig_title = plot_example_variable.value
if "level" in fourcastnet_cyoa[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value,lat_bounds=(-90, 90),lon_bounds=(-180,180),nlevels=100)