In [None]:
# Copyright 2026 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#      https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# WeatherNext 2 Early Access Program
<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/weathernext/weathernext_2_early_access_program.ipynb">
      <img src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%weathernext%2Fweathernext_2_early_access_program.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/weathernext/weathernext_2_early_access_program.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/weathernext/weathernext_2_early_access_program.ipynb">
      <img width="32px"src="https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

## Overview

This notebook demonstrates running [WeatherNext 2 inference on Google Cloud Vertex AI](https://developers.google.com/weathernext/guides/access-vmg). WeatherNext 2 is Google's latest medium-range probabilistic forecasting model, principally an operational version the FGN model ([published June 2025](https://arxiv.org/abs/2506.10772)). More information is available in the [WeatherNext documentation](https://developers.google.com/weathernext).

### Objective

- Configure the model inputs for distributed, multi-host inference on H100 or A100 GPUs.
- Run WeatherNext 2 model forecasts in parallel.
- Visualize forecast results.

### Costs

This uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.


## Before you begin

### Request For GPU Quota

**WARNING:** Make sure you have sufficient GPU quota allocated for the inference configuration (i.e. `num_samples`) before running Vertex Jobs. Otherwise, some Vertex jobs may run while others will fail which would produce
incomplete results.


By default, the quota for GPUs is 0. You can request a higher quota by following the instructions at ["Request a higher quota"](https://cloud.google.com/docs/quota/view-manage#requesting_higher_quota).

You will need to request quota for either **NVIDIA H100 80GB GPUs** or **NVIDIA A100 80GB GPUs** in your selected region. The total number of GPUs you request must be sufficient for your largest planned forecast (i.e., `num_samples`).

You should request for the following quota:

- Service: `Vertex AI API`
- Name: `Custom model training preemptible Nvidia A100 80GB GPUs per region` OR `Custom model training preemptible Nvidia H100 GPUs per region`

In [None]:
# @title Install python packages

# Note that you may need to restart the kernel after this step.
# If so, continue to the next cell after restarting.

print("Installing python packages.")

! pip3 install \
    google-cloud-aiplatform==1.129.0 \
    xarray[complete]

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched.


BUCKET_URI = "gs://my-bucket"  # @param {type:"string"}

# @markdown 3. Select a region that has the required GPUs available.

REGION = "us-central1" # @param {type:"string"}

# Import the necessary packages
import datetime
import importlib
import os
import uuid
import glob
from google.cloud import aiplatform, storage

import json
import math
import re

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI API and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])

if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    raise ValueError("GCS Bucket URI is invalid!")
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(f"Bucket region {bucket_region} is different from notebook region {REGION}")
print(f"Using this GCS Bucket: {BUCKET_URI}")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
# Initialize Vertex AI API.
aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Utility functions
def get_job_name_with_datetime(prefix: str) -> str:
    return prefix + datetime.datetime.now().strftime("_%Y%m%d_%H%M%S")

In [None]:
# @title Configure Model Parameters
# @markdown Configure the hardware and input parameters for the WeatherNext 2 forecast.

# @markdown ### Hardware Configuration for Distributed Inference
# @markdown - **`machine_type`**: Select a valid machine type. `a3-highgpu` series use NVIDIA H100 80GB GPUs. `a2-ultragpu` series use NVIDIA A100 80GB GPUs.
# @markdown - **`num_samples`**: The total number of ensemble members to generate.
# @markdown The number of machine replicas will be calculated automatically (`num_samples` / GPUs per machine). **Therefore, `num_samples` must be a multiple of the number of GPUs in your selected `machine_type`.**
# @markdown - **`scheduling_strategy`**: The [strategy](https://cloud.google.com/vertex-ai/docs/reference/rest/v1beta1/CustomJobSpec#Strategy) used to acquire machines for the job. Defaults to [Dynamic Workload Scheduler](https://docs.cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws) (FLEX_START).
machine_type = "a3-highgpu-1g" #@param ["a3-highgpu-1g", "a3-highgpu-2g", "a3-highgpu-4g", "a3-highgpu-8g", "a2-ultragpu-1g", "a2-ultragpu-2g", "a2-ultragpu-4g", "a2-ultragpu-8g"]
num_samples = 8 #@param {type:"integer"}
scheduling_strategy = "FLEX_START" #@param ["FLEX_START", "SPOT", "STANDARD"]

# @markdown ### Forecast Configuration
# @markdown - **`forecast_init_time`**: The starting time for the forecast in ISO 8601 format (e.g., `2025-09-21T00:00:00Z`). Models are available for dates from 2024 onwards.
# @markdown - **`horizon_hrs`**: The desired length of the forecast in hours (e.g., 240 for a 10-day forecast).
# @markdown - **`model_seed`**: Choose a specific model seed (1-4) or select "all" to run inference with all four seeds in parallel for improved accuracy.
# @markdown - **`enable_hourly_prediction`**: If checked, the model will generate 1-hour predictions.
forecast_init_time = "2025-11-20T00:00:00Z" #@param {type:"string"}
horizon_hrs = 72 #@param {type:"integer"}
model_seed = "all" # @param ["1", "2", "3", "4", "all"]
enable_hourly_prediction = True # @param {type:"boolean"}

# --- Parameter Validation and Configuration ---

# Derive accelerator type and count from the chosen machine type
if machine_type.startswith('a3-highgpu'):
    accelerator_type = 'NVIDIA_H100_80GB'
elif machine_type.startswith('a2-ultragpu'):
    accelerator_type = 'NVIDIA_A100_80GB'
else:
    raise ValueError(f"Invalid machine type selected: {machine_type}.")

try:
    # Extract the number of GPUs from the machine type string, e.g., 'a3-highgpu-4g' -> 4
    accelerators_per_machine = int(re.search(r'-(\d+)g$', machine_type).group(1))
except (AttributeError, ValueError):
    raise ValueError(f"Could not determine accelerator count from machine type: {machine_type}")

seeds_to_run = [1, 2, 3, 4] if model_seed == "all" else [int(model_seed)]
num_seeds_to_run = len(seeds_to_run)

num_samples_per_seed = num_samples
if len(seeds_to_run) > 1:
  if num_samples % num_seeds_to_run != 0:
    raise ValueError(f"`num_samples` ({num_samples}) is not divisible by the number of seeds to run ({num_seeds_to_run}.")
  num_samples_per_seed = num_samples // num_seeds_to_run

# Validate that num_samples is a multiple of accelerators_per_machine
if num_samples_per_seed % accelerators_per_machine != 0:
    raise ValueError(f"`num_samples_per_seed` ({num_samples_per_seed}) must be a multiple of the GPUs per machine ({accelerators_per_machine} for {machine_type}).")

# Calculate the number of replicas per seed
replica_count_per_seed = num_samples_per_seed // accelerators_per_machine

# Ensure that there are enough samples
if num_samples_per_seed % accelerators_per_machine != 0:
    raise ValueError(f"`num_samples_per_seed` ({num_samples_per_seed}) must be a multiple of the GPUs per machine ({accelerators_per_machine} for {machine_type}).")

# Calculate total GPUs needed for all jobs
total_gpus_needed = num_samples * (4 if model_seed == "all" else 1)

# Set Docker URI
WEATHERNEXT2_DOCKER_URI = 'us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/weather-next-2-inference.gpu.0-1:latest'

print("--- Job Configuration Summary ---")
print(f"Total Samples: {num_samples}")
print(f"Machine Type: {machine_type}")
print(f"Accelerator Type: {accelerator_type}")
print(f"GPUs per Machine: {accelerators_per_machine}")
print(f"Total number seeds to run: {num_seeds_to_run}")
print(f"Total number samples per seed: {num_samples_per_seed}")
print(f"Calculated Machine Replicas Per Seed: {replica_count_per_seed}")
print(f"Total GPUs per Job: {num_samples}")
print(f"Total GPUs across all Jobs (ensure sufficient quota): {total_gpus_needed}")
print(f"Docker Image: {WEATHERNEXT2_DOCKER_URI}")
print("---------------------------------")

In [None]:
# @title Run Forecasts
# @markdown This section creates and runs one or more Vertex AI Custom Training Jobs to generate the forecasts.
# @markdown **This operation is asynchronous.** The jobs will be submitted and this cell will complete quickly.
# @markdown You must monitor the job progress in the Google Cloud Console (https://console.cloud.google.com/vertex-ai/training/custom-jobs).

from google.cloud.aiplatform.compat.types import \
    custom_job as gca_custom_job_compat

print(f"Submitting {len(seeds_to_run)} job(s) to run in parallel.")

launched_jobs = []
output_dirs = {}

if scheduling_strategy == "FLEX_START":
    SCHEDULLING_STRATEGY = gca_custom_job_compat.Scheduling.Strategy.FLEX_START
elif scheduling_strategy == "SPOT":
    SCHEDULLING_STRATEGY = gca_custom_job_compat.Scheduling.Strategy.SPOT
else:
    SCHEDULLING_STRATEGY = gca_custom_job_compat.Scheduling.Strategy.STANDARD

for seed in seeds_to_run:
    output_dir = os.path.join(BUCKET_URI, "weathernext2_outputs")
    output_dirs[seed] = output_dir

    docker_args_list = [
        f"--pred_root_dir={output_dir}",
        f"--num_samples={num_samples_per_seed}",
        f"--horizon_hrs={horizon_hrs}",
        f"--forecast_init_time={forecast_init_time}",
        f"--model_seed={seed}",
        f"--enable_hourly_prediction={enable_hourly_prediction}",
    ]

    JOB_NAME = get_job_name_with_datetime(prefix=f"wn2-forecast-s{seed}-n{num_samples_per_seed}")
    print(f"\n--- Submitting Job for Seed {seed} ---")
    print(f"JOB_NAME: {JOB_NAME}")

    job = aiplatform.CustomContainerTrainingJob(
        display_name=JOB_NAME,
        container_uri=WEATHERNEXT2_DOCKER_URI,
    )

    job.run(
        args=docker_args_list,
        replica_count=replica_count_per_seed,
        machine_type=machine_type,
        accelerator_type=accelerator_type,
        accelerator_count=accelerators_per_machine,
        scheduling_strategy=SCHEDULLING_STRATEGY,
        # Change this to True if you need to debug why the job hasn't started
        sync=False
    )
    launched_jobs.append(job)
    print(f"--> Job submitted successfully. Monitor it in the Google Cloud Console at https://console.cloud.google.com/vertex-ai/training/custom-jobs")

print("\nAll forecast jobs have been submitted.")

In [None]:
# @title Visualize Forecasts (Unified)
# @markdown Select which forecast output you want to visualize. This single component
# @markdown can handle both the standard 6-hourly predictions and the datasets
# @markdown with 1-hour model (which have a 'subtime' dimension).
# @markdown If you run into `Error loading Zarr store: unrecognized engine 'zarr'...` try restarting the runtime session and reruning this cell.

# @markdown ---
# @markdown ### Visualization Settings
# @markdown - **`model_seed_to_visualize`**: Choose a specific model seed (1-4) to visualize. This should be one of the model seeds selected in the **Forecast Configuration** above.
# @markdown - **`time_steps_to_visualize`**: Choose to visualize 1-hourly or 6-hourly forecasts. If 1-hourly is selected, ensure `enable_hourly_prediction` was selected in the **Forecast Configuration** above.
# @markdown - **`variable_to_visualize`**: Choose the weather variable to visualize. See the [WeatherNext documentation](https://developers.google.com/weathernext/guides/model-specs-vmg) for variable names and descriptions.
# @markdown - **`sample_to_visualize`**: Choose the sample (ensemble member) to visualize.
# @markdown - **`plot_size`**: Choose the size of the plot to generate.
model_seed_to_visualize = "4" # @param ["1", "2", "3", "4"]
time_steps_to_visualize = "6-Hourly" # @param ["6-Hourly", "1-Hourly"]
variable_to_visualize = "2m_temperature" # @param {type:"string"}
sample_to_visualize = 0 # @param {type:"integer"}
plot_size = 8 #@param {type:"number"}
level_to_visualize = None
# @markdown ---


# --- Helper Functions ---

def init_time_to_folder_path(init_time: str) -> str:
  """
  Convert init time to expected GCS folder path.
  """
  init_date, init_time = init_time.split("T")
  return f"{init_date.replace("-","")}_{init_time[0:2]}hr"


# override these if you'd like to visualize a different set of forecasts
visualize_bucket = BUCKET_URI
visualize_init_date = forecast_init_time

# set paths based on chosen model seed, bucket, and init date
path_to_6hr_zarr = f"{BUCKET_URI}/weathernext2_outputs/weathernext_2_seed_{model_seed_to_visualize}/{init_time_to_folder_path(visualize_init_date)}_01_preds/predictions.zarr/"
path_to_1hr_zarr = f"{BUCKET_URI}/weathernext2_outputs/weathernext_2_seed_{model_seed_to_visualize}_hourly/{init_time_to_folder_path(visualize_init_date)}_01_preds/predictions.zarr/"


import matplotlib
import xarray
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import math
from IPython.display import HTML
import numpy as np
import datetime

matplotlib.rcParams['animation.embed_limit'] = 500


# --- Helper Functions ---

def select_data(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    ) -> xarray.Dataset:
  """Selects a variable from the dataset and optionally a level."""
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale_data(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  """Scales the data for visualization."""
  vmin = np.nanpercentile(data.values, (2 if robust else 0))
  vmax = np.nanpercentile(data.values, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def create_forecast_animation(
    dataset: xarray.Dataset,
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    ) -> HTML:
  """
  Creates a forecast animation from an xarray Dataset.
  It intelligently handles datasets with or without a 'subtime' dimension.
  """
  # --- Data Preparation ---
  # Check if the data still has 'subtime'). If so, stack dimensions.
  # Otherwise, just rename the 'time' dimension for consistency.
  if 'subtime' in dataset.dims:
    print("Detected 'subtime' dimension. Stacking for hourly animation.")
    # Stack 'time' and 'subtime' into a single animation dimension
    plot_data = dataset.stack(
        animation_step=("time", "subtime")
    ).transpose("animation_step", "lat", "lon")
  else:
    print("No 'subtime' dimension found. Using 'time' for 6-hourly animation.")
    # Use 'time' as the animation dimension
    plot_data = dataset.rename({'time': 'animation_step'})

  # Now, the animation dimension is always called 'animation_step'
  max_steps = plot_data.sizes["animation_step"]
  init_time = plot_data.coords['init_time'].values

  # Scale the data for color mapping
  scaled_data, norm, cmap = scale_data(plot_data, robust=robust)

  # --- Plotting Setup ---
  figure = plt.figure(figsize=(plot_size * 2, plot_size))
  ax = figure.add_subplot(1, 1, 1)
  ax.set_xticks([])
  ax.set_yticks([])
  figure.suptitle(fig_title, fontsize=16)
  figure.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for title

  im = ax.imshow(
      scaled_data.isel(animation_step=0), norm=norm, origin="lower", cmap=cmap)

  plt.colorbar(
      mappable=im, ax=ax, orientation="vertical", pad=0.02,
      aspect=16, shrink=0.75, cmap=cmap,
      extend=("both" if robust else "neither"))

  # --- Animation Update Function ---
  def update(frame):
    # Get the coordinates for the current frame
    step_coords = plot_data['animation_step'][frame].coords

    # Calculate total offset and valid time based on available coordinates
    if 'subtime' in step_coords: # Hourly data
        total_offset = step_coords['time'].values + step_coords['subtime'].values
    else: # 6-hourly data
        total_offset = step_coords['animation_step'].values

    total_hours = total_offset / np.timedelta64(1, 'h')
    valid_time = init_time + total_offset
    valid_time_str = np.datetime_as_string(valid_time, unit='m').replace('T', ' ')

    new_title = (
        f"{fig_title}\n"
        f"Valid: {valid_time_str} UTC (Forecast: +{total_hours:.1f}h)"
    )
    figure.suptitle(new_title, fontsize=16)
    im.set_data(scaled_data.isel(animation_step=frame))

  # --- Create and Display Animation ---
  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_html5_video())


# --- Main Visualization Logic ---

# 1. Select the correct path based on the user's dropdown choice
if time_steps_to_visualize == "6-Hourly":
    path_to_zarr = path_to_6hr_zarr
elif time_steps_to_visualize == "1-Hourly":
    path_to_zarr = path_to_1hr_zarr
else:
    raise ValueError("Invalid visualization target selected.")

print(f"Loading data from: {path_to_zarr}")

# 2. Load the dataset
try:
    full_dataset = xarray.open_zarr(path_to_zarr)
except Exception as e:
    print(f"Error loading Zarr store: {e}")
    # This is a common point of failure, so we exit gracefully.
else:
    # 3. Select the specific data slice for visualization
    data_for_vis = full_dataset.isel(sample=sample_to_visualize)
    variable_data = select_data(data_for_vis, variable_to_visualize, level_to_visualize)

    # 4. Generate the title
    title = f"{variable_to_visualize} (Sample {sample_to_visualize})"
    if level_to_visualize:
      title += f" at {level_to_visualize} hPa"

    # 5. Create and display the animation
    display(create_forecast_animation(variable_data, title, plot_size, robust=True))