In [27]:
# @title Imports

import dataclasses
import datetime
import math
from typing import Optional

from IPython.display import HTML
from IPython import display
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1" 
import ipywidgets as widgets
import jax
# from jax.extend.core import JaxprEqn
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray
import haiku as hk
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning



In [26]:
# @title Reconfigure jax if running on TPU.

# This is required due to outdated jax and libtpu versions in Colab TPU images.
%pip uninstall -y libtpu libtpu-nightly
%pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

/bin/bash: /home/vatsal/miniconda3/envs/jaxcuda/lib/libtinfo.so.6: no version information available (required by /bin/bash)
[0mNote: you may need to restart the kernel to use updated packages.
/bin/bash: /home/vatsal/miniconda3/envs/jaxcuda/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting libtpu==0.0.13.* (from jax[tpu])
  Downloading libtpu-0.0.13-py3-none-manylinux_2_31_x86_64.whl.metadata (500 bytes)
Collecting requests (from jax[tpu])
  Using cached requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting charset-normalizer<4,>=2 (from requests->jax[tpu])
  Downloading charset_normalizer-3.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (35 kB)
Collecting idna<4,>=2.5 (from requests->jax[tpu])
  Using cached idna-3.10-py3-none-any.whl.metadata (10 kB)
Downloading libtpu-0.0.13-py3-none-manylinux_2_31_x86_64.whl (132.9 MB)
[2K   [90

# Plotting functions

In [28]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())




LOAD THE DATA AND INITIALIZE THE MODEL

In [29]:
# @title Authenticate with Google Cloud Storage

# Gives you an authenticated client, in case you want to use a private bucket.
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "gencast/"

NameError: name 'storage' is not defined

In [4]:
MODEL_PATH = ""  "gencast_params_GenCast 1p0deg Mini _2019.npz"
DATA_PATH = ""  "source-era5_date-2019-03-29_res-1.0_levels-13_steps-04.nc"
STATS_DIR = ""  "stats/"

Load the Model

In [5]:
# @title Load the model

with open(MODEL_PATH, "rb") as f:
  ckpt = checkpoint.load(f, gencast.CheckPoint)
params = ckpt.params
state = {}

task_config = ckpt.task_config
sampler_config = ckpt.sampler_config
noise_config = ckpt.noise_config
noise_encoder_config = ckpt.noise_encoder_config
denoiser_architecture_config = ckpt.denoiser_architecture_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

Model description:
 
        GenCast model at lower, 1deg, resolution, with 13 pressure levels and a
        4 times refined icosahedral mesh. This model is trained on ERA5 data
        from 1979 to 2018, and can be causally evaluated on 2019 and later years.
        This model has the smallest memory footprint of those provided and has been provided
        to enable low cost demonstrations. It is not representative of GenCast's performance.
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



Check example dataset matches model

In [6]:
# @title Check example dataset matches model

def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_")[:3])

def data_valid_for_model(file_name: str, params_file_name: str):
  """Check data type and resolution matches."""
  data_file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  res_matches = data_file_parts["res"].replace(".", "p") in params_file_name.lower()
  source_matches = "Operational" in params_file_name
  if data_file_parts["source"] == "era5":
    source_matches = not source_matches
  return res_matches and source_matches

assert data_valid_for_model(DATA_PATH, MODEL_PATH)


Load weather data

In [7]:
# @title Load weather data

with open(DATA_PATH, "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(DATA_PATH.removesuffix(".nc")).items()]))

example_batch

source: era5, date: 2019-03-29, res: 1.0


  example_batch = xarray.load_dataset(f).compute()
  assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets


Plot example data

In [8]:
# @title Plot example data

plot_size = 7
variable = "geopotential"
level = 500
steps = example_batch.dims["time"]


data = {
    " ": scale(select(example_batch, variable, level, steps), robust=True),
}
fig_title = variable
if "level" in example_batch[variable].coords:
  fig_title += f" at {level} hPa"

plot_data(data, fig_title, plot_size, robust=True)


  steps = example_batch.dims["time"]


Extract training and eval data

In [9]:
# @title Extract training and eval data

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", "12h"), # Only 1AR training.
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)


All Examples:   {'lon': 360, 'lat': 181, 'level': 13, 'time': 6, 'batch': 1}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 181, 'lon': 360, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lon': 360}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 4, 'lat': 181, 'lon': 360, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 4, 'lon': 360}


  example_batch, target_lead_times=slice("12h", f"{(example_batch.dims['time']-2)*12}h"), # All but 2 input frames.


Load normalization data

In [10]:
# @title Load normalization data

with open(STATS_DIR +"gencast_stats_diffs_stddev_by_level.nc", "rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"gencast_stats_mean_by_level.nc", "rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"gencast_stats_stddev_by_level.nc", "rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()
with open(STATS_DIR +"gencast_stats_min_by_level.nc", "rb") as f:
  min_by_level = xarray.load_dataset(f).compute()

Build jitted functions, and possibly initialize random weights

In [11]:
# @title Build jitted functions, and possibly initialize random weights


def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
  predictor = construct_wrapped_gencast()
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
  predictor = construct_wrapped_gencast()
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics),
  )


def grads_fn(params, state, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), i, t, f
    )
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True
  )(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads


if params is None:
  init_jitted = jax.jit(loss_fn.init)
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets=train_targets,
      forcings=train_forcings,
  )


loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

Run the model

In [12]:
# The number of ensemble members should be a multiple of the number of devices.
print(f"Number of local devices {len(jax.local_devices())}")

Number of local devices 1


In [22]:
# @title Autoregressive rollout (loop in python)

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

num_ensemble_members = 1 # @param int
rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs
# regardless of total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)

Inputs:   {'batch': 1, 'time': 2, 'lat': 181, 'lon': 360, 'level': 13}
Targets:  {'batch': 1, 'time': 4, 'lat': 181, 'lon': 360, 'level': 13}
Forcings: {'batch': 1, 'time': 4, 'lon': 360}


NotImplementedError: scalar prefetch not implemented in the Triton backend

In [None]:
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

In [None]:
rngs

array([[1797259609, 2579123966]], dtype=uint32)

In [None]:
for chunk in rollout.chunked_prediction_generator_multiple_runs:
    print(chunk)

TypeError: 'function' object is not iterable

Plot prediction samples and diffs

In [None]:
# @title Plot prediction samples and diffs

plot_size = 5
variable = "2m_temperature"
level = None
steps = predictions.dims["time"]

fig_title = variable
if "level" in predictions[variable].coords:
  fig_title += f" at {level} hPa"

for sample_idx in range(num_ensemble_members):
  data = {
      "Targets": scale(select(eval_targets, variable, level, steps), robust=True),
      "Predictions": scale(select(predictions.isel(sample=sample_idx), variable, level, steps), robust=True),
      "Diff": scale((select(eval_targets, variable, level, steps) -
                          select(predictions.isel(sample=sample_idx), variable, level, steps)),
                        robust=True, center=0),
  }
  display.display(plot_data(data, fig_title + f", Sample {sample_idx}", plot_size, robust=True))


NameError: name 'predictions' is not defined