In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - WeatherNext Forecasting
<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_weather_prediction_on_vertex.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_weather_prediction_on_vertex.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates running forecasts with [WeatherNext Graph](https://www.science.org/doi/10.1126/science.adi2336) and [WeatherNext Gen](https://arxiv.org/abs/2312.15796) models on TPU using Vertex Model Garden.

### Objective

- Config Gen/Graph models with example data
- Run Gen/Graph model forecasts
- Visualize forecasting results

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Before you begin

### Request For TPU Quota

By default, the quota for TPU training [Custom model training TPU v5e cores per region](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_tpu_v5e) is 0. TPU quota is only available in `us-west1`, `us-west4`, `us-central1`. You can request for higher TPU quota following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota). It is suggested to request at least 4 v5e to run this notebook.

In [None]:
# @title Setup Google Cloud project


# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

BUCKET_URI = "gs://"  # @param {type:"string"}

# @markdown 3. **[Optional]** TPU is only available in `us-west1`, `us-west4`, `us-central1`.

# REGION = ""  # @param {type:"string"}
REGION = "us-central1" # @param ["us-central1", "us-west1", "us-west4"]

# Import the necessary packages
import datetime
import importlib
import os
import uuid
from typing import Tuple, List
import glob
from google.cloud import aiplatform, storage

# Upgrade Vertex AI SDK.
! pip3 install --upgrade --quiet 'google-cloud-aiplatform>=1.64.0'
if not os.path.exists("./vertex-ai-samples"):
  ! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git
! pip3 uninstall --quiet -y xarray
! pip3 install --quiet xarray[complete]

# Import model garden utils.
common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
if not REGION:
    REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Gets the default SERVICE_ACCOUNT.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this default Service Account:", SERVICE_ACCOUNT)


# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME
! gcloud config set project $PROJECT_ID
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/storage.admin"
! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role="roles/aiplatform.user"


# Utility functions for vertex jobs.
def get_job_name_with_datetime(prefix: str) -> str:
    """Gets the job name with date time when triggering training or deployment
    jobs in Vertex AI.
    """
    return prefix + datetime.datetime.now().strftime("_%Y%m%d_%H%M%S")

def get_bucket_and_blob_name(filepath):
    # The gcs path is of the form gs:///
    gs_suffix = filepath.split("gs://", 1)[1]
    return tuple(gs_suffix.split("/", 1))

def upload_local_dir_to_gcs(local_dir_path, gcs_dir_path):
    """Uploads files in a local directory to a GCS directory."""
    client = storage.Client()
    bucket_name = gcs_dir_path.split("/")[2]
    bucket = client.get_bucket(bucket_name)
    for local_file in glob.glob(local_dir_path + "/**"):
        if not os.path.isfile(local_file):
            continue
        filename = local_file[1 + len(local_dir_path) :]
        gcs_file_path = os.path.join(gcs_dir_path, filename)
        _, blob_name = get_bucket_and_blob_name(gcs_file_path)
        blob = bucket.blob(blob_name)
        blob.upload_from_filename(local_file)
        print("Copied {} to {}.".format(local_file, gcs_file_path))

def download_gcs_blob_as_json(gcs_file_path):
    """Download GCS blob and convert it to json."""
    client = storage.Client()
    bucket_name, blob_name = get_bucket_and_blob_name(gcs_file_path)
    bucket = client.get_bucket(bucket_name)
    blob = bucket.blob(blob_name)

    return json.loads(blob.download_as_bytes())

# Utility functions for prediction visualization.
import matplotlib
import xarray
from typing import Optional
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import math
from IPython.display import HTML
import json
import numpy as np

print(xarray.backends.list_engines())

def get_existing_demo_step(
    num_forecast_steps: int,
    model_type: str = "gen_small",
    ) -> str:
  # The demo data only supports some steps, and we can only run the predictions
  # if num_forecast_steps is smaller than or equal to the maximum supported setps.
  if model_type == "gen_small":
    # The max supported steps are obtained from gs://dm_graphcast/gencast/dataset.
    supported_demo_steps = [1, 4, 12, 20, 30]
  elif model_type == "graph_small":
    # The max supported steps are obtained from gs://dm_graphcast/graphcast/dataset.
    supported_demo_steps = [1, 4, 12, 20, 40]
  elif model_type == "graph_operational":
    # The max supported steps are obtained from gs://dm_graphcast/graphcast/dataset.
    supported_demo_steps = [1, 4, 12]
  else:
    raise ValueError("Invalid model_type.")

  # Find the proper demo data for forecasting.
  found_supported_step = supported_demo_steps[-1]
  for i, supported_step in enumerate(supported_demo_steps):
    if num_forecast_steps <= supported_step:
      found_supported_step = supported_step
      break
  if num_forecast_steps > found_supported_step:
    raise ValueError(f"Supported demo steps for {model_type} in gs://dm_graphcast are {supported_demo_steps}. {num_forecast_steps} is too large, and could not find proper demo data.")
  return found_supported_step

def get_suggested_machines(
    num_forecast_steps: int,
    model_type: str = "gen_small",) -> Tuple[str, str, int]:
    if model_type == "gen_small":
      if num_forecast_steps <= 16:
        machine_type = "ct5lp-hightpu-4t"
        tpu_topology = "2x2"
        accelerator_count = 4
      else:
        machine_type = "ct5lp-hightpu-8t"
        tpu_topology = "2x4"
        accelerator_count = 8
    else:
      if num_forecast_steps <= 16:
        machine_type = "ct5lp-hightpu-1t"
        tpu_topology = "1x1"
        accelerator_count = 1
      else:
        machine_type = "ct5lp-hightpu-4t"
        tpu_topology = "2x2"
        accelerator_count = 4
    return machine_type, tpu_topology, accelerator_count

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())


## Forecasts and Visualizations

In [None]:
# @title Configure Models
# @markdown You can config WeatherNext models with *data_type*, *model_type* and *num_forecast_steps*.

output_dir = f"{BUCKET_URI}/science"
accelerator_type = "TPU_V5e"

# @markdown All demo data are from the public gcs bucket `gs://dm_graphcast`.
# @markdown You can prepare your own data similarly to demo data for forecasting.
# @markdown The demo data for gen_small is from the date 2019-03-29 with resolution 1.0.
# @markdown The demo data for graph_small is from the date 2022-01-01 with resolution 1.0.
# @markdown The demo data for graph_operational is from the date 2022-01-01 with resolution 0.25.

model_type = "graph_operational" # @param ["gen_small", "graph_small", "graph_operational"]


num_forecast_steps = 10 # @param {type:"integer"}
# @markdown *num_forecast_steps* will specific the number of forecast steps, which will indicate the forcasting time combined with model leading time.
# @markdown Assuming num_forecast_steps=4, and the leading time is 6 hours, then the results will contain forecasts with 6 hours, 12 hours, 18 hours and 24 hours.
# @markdown WeatherNext Gen and Graph models support leading time as 12 hours and 6 hours separately.
# @markdown num_forecast_steps will be truncated to the maximum of allowed values if it is beyond. Maximum of num_forecast_steps for Weather Gen and Graph models are 30 and 40 separately.
machine_type, tpu_topology, accelerator_count = get_suggested_machines(num_forecast_steps, model_type)

SCIENCE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/science-serve.tpu.0-1.debian12.py310:20250331_0715_RC03"
data_storage_dir = "gs://dm_graphcast"
existing_demo_step = get_existing_demo_step(num_forecast_steps, model_type)
if model_type == "gen_small":
  input_file = f"{data_storage_dir}/gencast/dataset/source-era5_date-2019-03-29_res-1.0_levels-13_steps-{existing_demo_step:02d}.nc"
  # num_ensemble_samples (WeatherNext Gen models only) specified the number of ensembling samples per step.
  num_ensemble_samples = 8
  parameters = {"num_forecast_steps": num_forecast_steps, "num_ensemble_samples": num_ensemble_samples}
elif model_type == "graph_small":
  input_file = f"{data_storage_dir}/graphcast/dataset/source-era5_date-2022-01-01_res-1.0_levels-13_steps-{existing_demo_step:02d}.nc"
  parameters = {"num_forecast_steps": num_forecast_steps}
elif model_type == "graph_operational":
  input_file = f"{data_storage_dir}/graphcast/dataset/source-hres_date-2022-01-01_res-0.25_levels-13_steps-{existing_demo_step:02d}.nc"
  parameters = {"num_forecast_steps": num_forecast_steps}
else:
  raise ValueError("Invalid example_type.")

instances = [
    {
        "input_file": input_file,
        "output_dir": output_dir,
        "parameters": parameters,
    }
]

# @markdown Refer to more details in https://github.com/google-deepmind/graphcast.

print(f"machine_type is {machine_type}.")
print(f"tpu_topology is {tpu_topology}.")
print(f"SCIENCE_DOCKER_URI is {SCIENCE_DOCKER_URI}.")
print(f"The prediction instances: {instances}")


In [None]:
# @title Run Forecasts
# @markdown This section will create vertex jobs to run forecasts.
# @markdown It usually takes a couple of minutes to finish.
# @markdown Click on the generated link in the output to see your run in the Cloud Console.

print("Check if there are enough quota.")
common_util.check_quota(
    project_id=PROJECT_ID,
    region=REGION,
    accelerator_type=accelerator_type,
    accelerator_count=accelerator_count,
    is_for_training=True,
)

print("Generate custom job inputs and outputs.")
input_jsonl_name = f"custom_{model_type}_input.jsonl"
output_jsonl_name = f"custom_{model_type}_output.jsonl"

# Convert and write JSON object to file.
os.makedirs("bath_prediction_input", exist_ok=True)

with open(f"bath_prediction_input/{input_jsonl_name}", "w") as outfile:
    for item in instances:
        json_str = json.dumps(item)
        outfile.write(json_str)
        outfile.write("\n")

upload_local_dir_to_gcs(
    "bath_prediction_input", output_dir
)

JOB_NAME = get_job_name_with_datetime(prefix=f"jax_{model_type}")

input_jsonl = f"{output_dir}/{input_jsonl_name}"
output_jsonl = f"{output_dir}/{output_jsonl_name}"

docker_args_list = [
    "python3",
    "./gdm_science/batch_prediction.py",
    f"--model_type={model_type}",
    f"--input_jsonl={input_jsonl}",
    f"--output_jsonl={output_jsonl}"
]

print(f"The input json file will be {input_jsonl}.")
print(f"The output json file will be {output_jsonl}.")
print(f"The docker args list is {docker_args_list}.")
print(f"JOB_NAME is {JOB_NAME}.")

labels = {
    "mg-source": "notebook",
    "mg-notebook-name": "model_garden_weather_prediction_on_vertex.ipynb".split(".")[0],
    "mg-tune": f"publishers-google-models-{model_type}".lower(),
    "versioned-mg-tune": f"publishers-google-models-{model_type}".lower(),
}

job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=SCIENCE_DOCKER_URI,
    labels=labels,
)

job.run(
    args=docker_args_list,
    base_output_dir=f"{BUCKET_URI}",
    replica_count=1,
    machine_type=machine_type,
    tpu_topology=tpu_topology,
    service_account=SERVICE_ACCOUNT,
)

In [None]:
# @title  Visualize Forecasts
# @markdown For simplicity, we only pick the one forecast result for visualization.
# @markdown You can also visualize previous forecasts by setting `prediction_data_path`.
prediction_json = download_gcs_blob_as_json(output_jsonl)
prediction_data_path = json.loads(prediction_json[0])["predictions"] # @param

# @markdown  (Optional) The sample to visualize if there are multiple samples.
sample = 0 # @param

print(prediction_data_path)
predictions = xarray.open_zarr(prediction_data_path)

plot_size = 7
variable = "2m_temperature"
level = None
steps = predictions.dims["time"]
print("steps=", steps)
if "sample" in predictions:
  print("sample=", len(predictions["sample"]))
  # Visualize one sample if there are many.
  visualized_data = predictions.isel(sample=sample)
else:
  visualized_data = predictions

data = {
    " ": scale(select(visualized_data, variable, level, steps), robust=True),
}

fig_title = variable
if "level" in predictions[variable].coords:
  fig_title += f" at {level} hPa"

plot_data(data, fig_title, plot_size, robust=True)

## Clean Up Resources

In [None]:
# @title Delete Temporal Buckets

delete_bucket = True  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_URI

