In [2]:
# @title Upgrade packages (kernel needs to be restarted after running this cell).

%pip install -U importlib_metadata

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
# @title Pip install repo and dependencies

%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

Collecting https://github.com/deepmind/graphcast/archive/master.zip
  Using cached https://github.com/deepmind/graphcast/archive/master.zip
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [6]:
# @title Reconfigure jax if running on TPU.

# This is required due to outdated jax and libtpu versions in Colab TPU images.
%pip uninstall -y libtpu libtpu-nightly
%pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Found existing installation: libtpu 0.0.2
Uninstalling libtpu-0.0.2:
  Successfully uninstalled libtpu-0.0.2
Found existing installation: libtpu_nightly 0.1.dev20241010+nightly.cleanup
Uninstalling libtpu_nightly-0.1.dev20241010+nightly.cleanup:
  Successfully uninstalled libtpu_nightly-0.1.dev20241010+nightly.cleanup
Note: you may need to restart the kernel to use updated packages.
Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting jax[tpu]
  Using cached jax-0.5.3-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[tpu])
  Using cached jaxlib-0.5.3-cp312-cp312-win_amd64.whl.metadata (1.2 kB)
INFO: pip is looking at multiple versions of jax[tpu] to determine which version is compatible with other requirements. This could take a while.
Collecting jax[tpu]
  Using cached jax-0.5.2-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[tpu])
  Using cached jaxlib-0.5.1-cp312-cp312-win_amd64.whl.meta


[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [8]:
# @title Imports
#!conda create -n graphcast_fix python=3.12 numpy=1.26.4 scikit-learn=1.5.0
#!conda activate graphcast_fix
#!pip install graphcast dinosaur
import dataclasses
import datetime
import math
from google.cloud import storage
from typing import Optional
import haiku as hk
from IPython.display import HTML
from IPython import display
import ipywidgets as widgets
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray

from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning



In [10]:
# @title Plotting functions

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())


In [12]:
# @title Authenticate with Google Cloud Storage

# Gives you an authenticated client, in case you want to use a private bucket.
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "gencast/"

In [14]:
def select_model_parameters():
    print("Select model parameters:")
    
    # Available options
    attention_types = ["splash_mha", "triblockdiag_mha", "mha"]
    latent_options = [int(2**i) for i in range(4, 10)]
    head_options = [int(2**i) for i in range(0, 3)]
    k_hop_options = [int(2**i) for i in range(2, 5)]
    resolution_options = ["1p0", "0p25"]
    
    # Get user input
    print("\nAttention types available:")
    for i, attn in enumerate(attention_types, 1):
        print(f"{i}. {attn}")
    attn_choice = int(input("Select attention type (1-3): ")) - 1
    attention_type = attention_types[attn_choice]
    
    mesh_size = int(input("Enter mesh size (4-6): "))
    
    print("\nNumber of heads options:")
    for i, heads in enumerate(head_options, 1):
        print(f"{i}. {heads}")
    heads_choice = int(input("Select number of heads (1-3): ")) - 1
    num_heads = head_options[heads_choice]
    
    # Filter latent sizes based on attention type and num heads
    def is_valid_latent(latent):
        head_dim, rem = divmod(latent, num_heads)
        if rem != 0:
            return False
        if head_dim % 128 != 0:
            return attention_type != "splash_mha"
        return True
    
    valid_latents = [latent for latent in latent_options if is_valid_latent(latent)]
    
    print("\nValid latent size options:")
    for i, latent in enumerate(valid_latents, 1):
        print(f"{i}. {latent}")
    latent_choice = int(input(f"Select latent size (1-{len(valid_latents)}): ")) - 1
    latent_size = valid_latents[latent_choice]
    
    print("\nAttention k-hop options:")
    for i, k_hop in enumerate(k_hop_options, 1):
        print(f"{i}. {k_hop}")
    k_hop_choice = int(input("Select attention k-hop (1-3): ")) - 1
    attention_k_hop = k_hop_options[k_hop_choice]
    
    print("\nResolution options:")
    for i, res in enumerate(resolution_options, 1):
        print(f"{i}. {res}")
    res_choice = int(input("Select resolution (1-2): ")) - 1
    resolution = resolution_options[res_choice]
    
    # Return selected parameters
    return {
        "attention_type": attention_type,
        "mesh_size": mesh_size,
        "num_heads": num_heads,
        "latent_size": latent_size,
        "attention_k_hop": attention_k_hop,
        "resolution": resolution
    }

# Example usage:
params = select_model_parameters()

Select model parameters:

Attention types available:
1. splash_mha
2. triblockdiag_mha
3. mha


Select attention type (1-3):  1
Enter mesh size (4-6):  4



Number of heads options:
1. 1
2. 2
3. 4


Select number of heads (1-3):  3



Valid latent size options:
1. 512


Select latent size (1-1):  1



Attention k-hop options:
1. 4
2. 8
3. 16


Select attention k-hop (1-3):  3



Resolution options:
1. 1p0
2. 0p25


Select resolution (1-2):  2


In [16]:
def load_model():
    # Get user selection for model source
    print("Select model source:")
    print("1. Checkpoint (pre-trained model)")
    print("2. Random (initialize new model)")
    source_choice = int(input("Enter your choice (1 or 2): "))
    
    if source_choice == 2:  # Random
        params = None
        state = {}
        task_config = gencast.TASK
        
        # Use default values for these configs
        sampler_config = gencast.SamplerConfig()
        noise_config = gencast.NoiseConfig()
        noise_encoder_config = denoiser.NoiseEncoderConfig()
        
        # Get user input for model architecture
        model_params = select_model_parameters()  # Using our previous function
        
        # Configure denoiser architecture
        denoiser_architecture_config = denoiser.DenoiserArchitectureConfig(
            sparse_transformer_config=denoiser.SparseTransformerConfig(
                attention_k_hop=model_params["attention_k_hop"],
                attention_type=model_params["attention_type"],
                d_model=model_params["latent_size"],
                num_heads=model_params["num_heads"]
            ),
            mesh_size=model_params["mesh_size"],
            latent_size=model_params["latent_size"]
        )
        
    else:  # Checkpoint
        assert source_choice == 1
        # List available checkpoints
        params_file_options = [
            name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix+"params/"))
            if (name := blob.name.removeprefix(dir_prefix+"params/"))]
        
        print("\nAvailable checkpoints:")
        for i, f in enumerate(params_file_options, 1):
            print(f"{i}. {f}")
        
        file_choice = int(input(f"Select checkpoint (1-{len(params_file_options)}): ")) - 1
        selected_file = params_file_options[file_choice]
        
        # Load the selected checkpoint
        with gcs_bucket.blob(dir_prefix + f"params/{selected_file}").open("rb") as f:
            ckpt = checkpoint.load(f, gencast.CheckPoint)
        
        params = ckpt.params
        state = {}
        task_config = ckpt.task_config
        sampler_config = ckpt.sampler_config
        noise_config = ckpt.noise_config
        noise_encoder_config = ckpt.noise_encoder_config
        denoiser_architecture_config = ckpt.denoiser_architecture_config
        
        print("\nModel description:\n", ckpt.description, "\n")
        print("Model license:\n", ckpt.license, "\n")
    
    return {
        "params": params,
        "state": state,
        "task_config": task_config,
        "sampler_config": sampler_config,
        "noise_config": noise_config,
        "noise_encoder_config": noise_encoder_config,
        "denoiser_architecture_config": denoiser_architecture_config
    }

# Load the model
model_config = load_model()

# Now you can access all the configuration through model_config dictionary
# For example:
# params = model_config["params"]
#task_config = model_config["task_config"]
# etc.

Select model source:
1. Checkpoint (pre-trained model)
2. Random (initialize new model)


Enter your choice (1 or 2):  1



Available checkpoints:
1. GenCast 0p25deg <2019.npz
2. GenCast 0p25deg Operational <2022.npz
3. GenCast 1p0deg <2019.npz
4. GenCast 1p0deg Mini <2019.npz


Select checkpoint (1-4):  1



Model description:
 
        GenCast model at 0.25deg resolution with 13 pressure levels and a 6 times
        refined icosahedral mesh. This model is trained on ERA5 data from 1979
        to 2018 (inclusive), and can be causally evaluated on 2019
        and later years. This model was described in the paper
        `GenCast: Diffusion-based ensemble forecasting for medium-range weather`
        (https://arxiv.org/abs/2312.15796).
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



In [18]:
def get_and_filter_datasets(model_config, params_file_name=None):
    """Get and filter available datasets based on model configuration."""
    
    # List all available datasets
    dataset_file_options = [
        name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix + "dataset/"))
        if (name := blob.name.removeprefix(dir_prefix+"dataset/"))
    ]
    
    def parse_file_parts(file_name):
        return dict(part.split("-", 1) for part in file_name.split("_"))
    
    def data_valid_for_model(file_name, model_config, params_file_name=None):
        """Check if data matches model requirements."""
        data_file_parts = parse_file_parts(file_name.removesuffix(".nc"))
        data_res = data_file_parts["res"].replace(".", "p")
        
        if model_config.get("source") == "Random":
            # For random models, check against the random_resolution parameter
            return model_config["resolution"] == data_res
        else:
            # For checkpoint models, check resolution and source
            res_matches = data_res in params_file_name.lower()
            source_matches = "Operational" in params_file_name
            if data_file_parts["source"] == "era5":
                source_matches = not source_matches
            return res_matches and source_matches
    
    # Filter datasets based on model configuration
    valid_datasets = [
        option for option in dataset_file_options
        if data_valid_for_model(option, model_config, params_file_name)
    ]
    
    if not valid_datasets:
        print("No datasets available that match your model configuration.")
        return None
    
    # Display available datasets to user
    print("\nAvailable datasets that match your model:")
    for i, dataset in enumerate(valid_datasets, 1):
        parts = parse_file_parts(dataset.removesuffix(".nc"))
        print(f"{i}. {', '.join([f'{k}: {v}' for k, v in parts.items()])}")
    
    # Get user selection
    selection = int(input(f"Select dataset (1-{len(valid_datasets)}): ")) - 1
    selected_dataset = valid_datasets[selection]
    
    print(f"\nSelected dataset: {selected_dataset}")
    return selected_dataset

# # Example usage:
# if __name__ == "__main__":
#     # Assuming you have your model configuration from the previous step
#     model_config = {
#         "source": "Random",  # or "Checkpoint"
#         "resolution": "0p25",  # example value
#         # ... other model config parameters
#     }
    
#     # For checkpoint models, you might have a params_file_name
#     params_file_name = "some_checkpoint_file_res0p25"  # example
    
#     selected_dataset = get_and_filter_datasets(model_config, params_file_name)
    
#     if selected_dataset:
#         # Load the dataset using the selected filename
#         print(f"\nProceeding to load dataset: {selected_dataset}")
#         # Your dataset loading code would go here

In [20]:
def parse_file_parts(file_name):
    """Parse dataset filename into components"""
    parts = file_name.removesuffix(".nc").split("_")
    return dict(part.split("-", 1) for part in parts)

def load_weather_data(selected_dataset):
    try:
        with gcs_bucket.blob(dir_prefix + f"dataset/{selected_dataset}").open("rb") as f:
            example_batch = xarray.load_dataset(f, decode_timedelta=False)
            
            try:
                example_batch = example_batch.compute()
            except ValueError as e:
                print(f"Shape mismatch error: {str(e)}")
                print("Attempting to load with modified chunks...")
                example_batch = xarray.load_dataset(f, chunks={'time': 1}, decode_timedelta=False).compute()

        if 'time' in example_batch.sizes and example_batch.sizes["time"] < 3:
            raise ValueError(f"Dataset must have at least 3 time steps, found {example_batch.sizes['time']}")

        file_parts = parse_file_parts(selected_dataset.removesuffix(".nc"))
        print("\nDataset metadata:")
        print(", ".join([f"{k}: {v}" for k, v in file_parts.items()]))
        print("\nDataset summary:")
        print(example_batch)

        return example_batch

    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return None


if __name__ == "__main__":
    model_config = {
        "source": "Random",
        "resolution": "0p25",
    }
    
    selected_dataset = get_and_filter_datasets(model_config)
    
    if selected_dataset:
        weather_data = load_weather_data(selected_dataset)
        
        if weather_data:
            print("\nSuccessfully loaded weather data!")


Available datasets that match your model:
1. source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 01
2. source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 04
3. source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 12
4. source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 20
5. source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 30
6. source: hres, date: 2022-03-29, res: 0.25, levels: 13, steps: 01
7. source: hres, date: 2022-03-29, res: 0.25, levels: 13, steps: 04
8. source: hres, date: 2022-03-29, res: 0.25, levels: 13, steps: 12
9. source: hres, date: 2022-03-29, res: 0.25, levels: 13, steps: 20
10. source: hres, date: 2022-03-29, res: 0.25, levels: 13, steps: 30


Select dataset (1-10):  1



Selected dataset: source-era5_date-2019-03-29_res-0.25_levels-13_steps-01.nc

Dataset metadata:
source: era5, date: 2019-03-29, res: 0.25, levels: 13, steps: 01

Dataset summary:
<xarray.Dataset> Size: 1GB
Dimensions:                   (lon: 1440, lat: 721, level: 13, time: 3, batch: 1)
Coordinates:
  * lon                       (lon) float32 6kB 0.0 0.25 0.5 ... 359.5 359.8
  * lat                       (lat) float32 3kB -90.0 -89.75 ... 89.75 90.0
  * level                     (level) int32 52B 50 100 150 200 ... 850 925 1000
  * time                      (time) int32 12B 0 12 24
    datetime                  (batch, time) datetime64[ns] 24B 2019-03-29 ......
Dimensions without coordinates: batch
Data variables: (12/18)
    land_sea_mask             (lat, lon) float32 4MB 1.0 1.0 1.0 ... 0.0 0.0 0.0
    geopotential_at_surface   (lat, lon) float32 4MB 2.735e+04 ... -0.07617
    day_progress_cos          (batch, time, lon) float32 17kB 1.0 1.0 ... 1.0
    day_progress_sin          (b

In [22]:
def select_plot_parameters(example_batch):
    """Alternative to widgets for selecting plot parameters"""
    print("\nAvailable variables:")
    for i, var in enumerate(example_batch.data_vars.keys(), 1):
        print(f"{i}. {var}")
    var_choice = int(input("Select variable (number): ")) - 1
    variable = list(example_batch.data_vars.keys())[var_choice]
    
    print("\nAvailable levels:")
    levels = example_batch.coords["level"].values
    for i, level in enumerate(levels, 1):
        print(f"{i}. {level}")
    level_choice = int(input("Select level (number): ")) - 1
    level = levels[level_choice]
    
    robust = input("Use robust scaling (True/False)? ").strip().lower() == 'true'
    
    max_time = example_batch.dims["time"]
    max_steps = int(input(f"Enter max time steps to plot (1-{max_time}): "))
    max_steps = max(1, min(max_steps, max_time))
    
    print("\nSelected parameters:")
    print(f"Variable: {variable}")
    print(f"Level: {level}")
    print(f"Robust scaling: {robust}")
    print(f"Max time steps: {max_steps}")
    
    return {
        'variable': variable,
        'level': level,
        'robust': robust,
        'max_steps': max_steps
    }

# Example usage:

    # Assuming example_batch is your loaded dataset
plot_params = select_plot_parameters(weather_data)
    
    # Then use these parameters with your plotting code
    # Example:
    # plot_data(example_batch, 
    #           variable=plot_params['variable'],
    #           level=plot_params['level'],
    #           robust=plot_params['robust'],
    #           max_steps=plot_params['max_steps'])


Available variables:
1. land_sea_mask
2. geopotential_at_surface
3. day_progress_cos
4. day_progress_sin
5. 2m_temperature
6. sea_surface_temperature
7. mean_sea_level_pressure
8. 10m_v_component_of_wind
9. total_precipitation_12hr
10. 10m_u_component_of_wind
11. u_component_of_wind
12. specific_humidity
13. temperature
14. vertical_velocity
15. v_component_of_wind
16. geopotential
17. year_progress_cos
18. year_progress_sin


Select variable (number):  1



Available levels:
1. 50
2. 100
3. 150
4. 200
5. 250
6. 300
7. 400
8. 500
9. 600
10. 700
11. 850
12. 925
13. 1000


Select level (number):  11
Use robust scaling (True/False)?  True


  max_time = example_batch.dims["time"]


Enter max time steps to plot (1-3):  3



Selected parameters:
Variable: land_sea_mask
Level: 850
Robust scaling: True
Max time steps: 3


In [24]:
# @title Plot example data
example_batch = weather_data
plot_example_variable,plot_example_level,plot_example_max_steps,plot_example_robust = tuple(plot_params.values())

plot_size = 7

data = {
    " ": scale(select(example_batch, plot_example_variable, plot_example_level, plot_example_max_steps),
              robust=plot_example_robust),
}
fig_title = plot_example_variable
if "level" in example_batch[plot_example_variable].coords:
  fig_title += f" at {plot_example_level} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust)


In [26]:
# @title Load normalization data

with gcs_bucket.blob(dir_prefix+"stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/min_by_level.nc").open("rb") as f:
  min_by_level = xarray.load_dataset(f).compute()

In [28]:
params_file_options = [
            name for blob in gcs_bucket.list_blobs(prefix=(dir_prefix+"params/"))
            if (name := blob.name.removeprefix(dir_prefix+"params/"))]
selected_file = params_file_options[0]
with gcs_bucket.blob(dir_prefix + f"params/{selected_file}").open("rb") as f:
    ckpt = checkpoint.load(f, gencast.CheckPoint)

In [34]:
from dataclasses import is_dataclass, asdict
import pandas as pd
import numpy as np

# Define default task_config_dict if not present in model_config
task_config1 = model_config.get('task_config')


if is_dataclass(task_config1):
    task_config_dict = asdict(task_config1)
elif isinstance(task_config1, dict):
    task_config_dict = task_config1
else:
    task_config_dict = {
        "input_variables": ['2m_temperature', 'mean_sea_level_pressure', '10m_v_component_of_wind', '10m_u_component_of_wind', 'sea_surface_temperature', 'temperature', 'geopotential', 'u_component_of_wind', 'v_component_of_wind', 'vertical_velocity', 'specific_humidity', 'year_progress_sin', 'year_progress_cos', 'day_progress_sin', 'day_progress_cos', 'geopotential_at_surface', 'land_sea_mask'],
        "target_variables": ['2m_temperature', 'mean_sea_level_pressure', '10m_v_component_of_wind', '10m_u_component_of_wind', 'total_precipitation_12hr', 'sea_surface_temperature', 'temperature', 'geopotential', 'u_component_of_wind', 'v_component_of_wind', 'vertical_velocity', 'specific_humidity'],
        "forcing_variables": ['year_progress_sin', 'year_progress_cos', 'day_progress_sin', 'day_progress_cos'],
        "pressure_levels": [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000],
        "input_duration": "24h"
    }

# Use Pandas timedelta for lead times
train_lead_time = pd.Timedelta("12h")
train_lead_times = slice(train_lead_time, train_lead_time)

# First ensure time coordinate is properly formatted
if 'time' in example_batch.coords:
    # Convert time to pandas Timedelta if it's numeric
    if np.issubdtype(example_batch.coords['time'].dtype, np.number):
        example_batch.coords['time'] = pd.to_timedelta(example_batch.coords['time'], unit='h')
    # Ensure time is in nanoseconds for arithmetic operations
    example_batch.coords['time'] = example_batch.coords['time'].astype('timedelta64[ns]')

# Create a modified version of the extraction function
def safe_extract_inputs_targets_forcings(dataset, **kwargs):
    """Wrapper to handle time coordinate arithmetic safely"""
    # Make a copy to avoid modifying the original
    dataset = dataset.copy()
    
    # Convert time to numpy timedelta64 if needed
    if 'time' in dataset.coords:
        time_values = dataset.coords['time'].values
        if not isinstance(time_values, np.ndarray):
            time_values = np.array(time_values)
        dataset.coords['time'] = time_values.astype('timedelta64[ns]')
    
    return data_utils.extract_inputs_targets_forcings(dataset, **kwargs)

# Extract training data
try:
    train_inputs, train_targets, train_forcings = safe_extract_inputs_targets_forcings(
        example_batch, 
        target_lead_times=train_lead_times,
        **task_config_dict
    )
except Exception as e:
    print(f"Error during extraction: {e}")
    # Fallback to manual extraction if needed
    print("Attempting manual extraction...")
    train_inputs = example_batch.isel(time=0)
    train_targets = example_batch.isel(time=1)
    train_forcings = xarray.Dataset()  # Empty dataset if no forcings

# Use remaining timesteps for evaluation
max_lead_time = (example_batch.dims['time'] - 2) * 12
eval_lead_times = slice(pd.Timedelta("12h"), pd.Timedelta(f"{max_lead_time}h"))

# Extract evaluation data
eval_inputs, eval_targets, eval_forcings = safe_extract_inputs_targets_forcings(
    example_batch,
    target_lead_times=eval_lead_times,
    **task_config_dict
)

# Print shapes/dims for verification
print("\nData Shapes:")
print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

  max_lead_time = (example_batch.dims['time'] - 2) * 12



Data Shapes:
All Examples:   {'lon': 1440, 'lat': 721, 'level': 13, 'batch': 1, 'time': 3}
Train Inputs:   {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 13}
Train Forcings: {'batch': 1, 'time': 1, 'lon': 1440}
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Targets:   {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 13}
Eval Forcings:  {'batch': 1, 'time': 1, 'lon': 1440}


In [36]:
ckpt

CheckPoint(description='\n        GenCast model at 0.25deg resolution with 13 pressure levels and a 6 times\n        refined icosahedral mesh. This model is trained on ERA5 data from 1979\n        to 2018 (inclusive), and can be causally evaluated on 2019\n        and later years. This model was described in the paper\n        `GenCast: Diffusion-based ensemble forecasting for medium-range weather`\n        (https://arxiv.org/abs/2312.15796).\n        ', license='\nThe model weights are licensed under the Creative Commons\nAttribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You\nmay obtain a copy of the License at:\nhttps://creativecommons.org/licenses/by-nc-sa/4.0/.\nThe weights were trained on ERA5 data, see README for attribution statement.\n', params={'fourier_features_mlp/~/mlp/~/linear_0': {'b': array([0.7249849 , 0.98678654, 0.87844795, 0.35235855, 0.563566  ,
       1.0055103 , 0.9701163 , 1.2390671 , 0.8369621 , 0.8546867 ,
       0.5343235 , 0.9168723 , 

In [40]:
# @title Build jitted functions with proper configuration handling




def construct_wrapped_gencast():
    """Constructs and wraps the GenCast Predictor with proper config handling."""
   
    
    predictor = gencast.GenCast(
        sampler_config=ckpt.sampler_config,
        task_config=ckpt.task_config,
        denoiser_architecture_config=ckpt.denoiser_architecture_config,
        noise_config=ckpt.noise_config,
        noise_encoder_config=ckpt.noise_encoder_config,
    )

    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level,
    )

    predictor = nan_cleaning.NaNCleaner(
        predictor=predictor,
        reintroduce_nans=True,
        fill_value=min_by_level,
        var_to_clean='sea_surface_temperature',
    )

    return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
    predictor = construct_wrapped_gencast()
    return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
    predictor = construct_wrapped_gencast()
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
        lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
        (loss, diagnostics),
    )

# Initialize model parameters and state
init_jitted = jax.jit(loss_fn.init)
params, state = init_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings,
)

# Create jitted functions
loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)

# Create pmapped version
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")
print(f"Number of local devices: {len(jax.local_devices())}")

# @title Autoregressive rollout with error handling

try:
    num_ensemble_members = 1
    rng = jax.random.PRNGKey(0)
    rngs = np.stack(
        [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)
    
    chunks = []
    for chunk in rollout.chunked_prediction_generator_multiple_runs(
        predictor_fn=run_forward_pmap,
        rngs=rngs,
        inputs=eval_inputs,
        targets_template=eval_targets * np.nan,
        forcings=eval_forcings,
        num_steps_per_chunk=1,
        num_samples=num_ensemble_members,
        pmap_devices=jax.local_devices()
    ):
        chunks.append(chunk)
    
    predictions = xarray.combine_by_coords(chunks)
    print("Rollout completed successfully!")
except Exception as e:
    print(f"Error during rollout: {str(e)}")
    predictions = None

KeyError: "No variable named slice(None, 4, None). Variables on the dataset include ['2m_temperature', 'mean_sea_level_pressure', '10m_v_component_of_wind', '10m_u_component_of_wind', 'sea_surface_temperature', ..., 'land_sea_mask', 'lon', 'lat', 'level', 'time']"