In [5]:
import dataclasses
import functools
import haiku as hk
import matplotlib
from typing import Optional
import jax
import numpy as np
import xarray
import random

from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree

In [9]:
# Load existing model params
params_file = ("params_graphCast-ERA5_1979-2017-resolution_"
               "0.25-pressure_levels_37-mesh_2to6-precipitation_input_and_output.npz")
with open(f"params/{params_file}", "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
model_config = ckpt.model_config
task_config = ckpt.task_config

In [10]:
# Model Parameters
GNN_MSG_STEPS = 16
HIDDEN_LAYERS = 1
LATENT_SIZE = 512
MESH2GRID_NORMALIZATION_FACTOR = 0.6180
MESH_SIZE = 6
RADIUS_QUERY_FRACTION_EDGE_LENGTH = 0.5999
RESOLUTION = 0.25

# Weather Variables
PRESSURE_LEVELS = (1000,)
ALL_ATMOSPHERIC_VARS = (
    "potential_vorticity",
    "specific_rain_water_content",
    "specific_snow_water_content",
    "geopotential",
    "temperature",
    "u_component_of_wind",
    "v_component_of_wind",
    "specific_humidity",
    "vertical_velocity",
    "vorticity",
    "divergence",
    "relative_humidity",
    "ozone_mass_mixing_ratio",
    "specific_cloud_liquid_water_content",
    "specific_cloud_ice_water_content",
    "fraction_of_cloud_cover",
)
INPUT_SURFACE_VARS = (
    "temperature",
    "specific_humidity",
    "total_precipitation_6hr",
)
TARGET_SURFACE_VARS = (
    "total_precipitation_6hr",
)
TARGET_ATMOSPHERIC_VARS = (
    "specific_humidity",
    "temperature",
)
EXTERNAL_FORCING_VARS = (
    "toa_incident_solar_radiation",
)
GENERATED_FORCING_VARS = (
    "year_progress_sin",
    "year_progress_cos",
    "day_progress_sin",
    "day_progress_cos",
)
FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS
STATIC_VARS = (
    "geopotential_at_surface",
    "land_sea_mask",
)

# Create a new model and task config
params = None
state = {}

# noinspection PyArgumentList
task_config = graphcast.TaskConfig(
    input_variables=(
            INPUT_SURFACE_VARS + FORCING_VARS + STATIC_VARS),
    target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
    forcing_variables=FORCING_VARS,
    pressure_levels=PRESSURE_LEVELS,
    input_duration="12h"
)

# noinspection PyArgumentList
model_config = graphcast.ModelConfig(
    gnn_msg_steps=GNN_MSG_STEPS,
    hidden_layers=HIDDEN_LAYERS,
    latent_size=LATENT_SIZE,
    mesh2grid_edge_normalization_factor=MESH2GRID_NORMALIZATION_FACTOR,
    mesh_size=MESH_SIZE,
    radius_query_fraction_edge_length=RADIUS_QUERY_FRACTION_EDGE_LENGTH,
    resolution=RESOLUTION,
)

In [None]:
# Load the dataset
def parse_file_parts(file_name):
    return dict(part.split("-", 1) for part in file_name.split("_"))

dataset_file = "source-era5_date-2022-01-01_res-0.25_levels-37_steps-01.nc"
with open(f"data/{dataset_file}", "rb") as f:
    example_batch = xarray.load_dataset(f).compute()

assert example_batch.sizes["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file.removesuffix(".nc")).items()]))

In [12]:
# Extract inputs, targets, and forcings
train_steps = 1
eval_steps = 1

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps * 6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps * 6}h"),
    **dataclasses.asdict(task_config))

print("All Examples:  ", example_batch.dims.mapping)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

All Examples:   {'lon': 1440, 'lat': 721, 'level': 37, 'time': 3, 'batch': 1}
Train Inputs:   {'batch': 1, 'time': 2, 'level': 1, 'lat': 721, 'lon': 1440}
Train Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 1}
Train Forcings: {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440}
Eval Inputs:    {'batch': 1, 'time': 2, 'level': 1, 'lat': 721, 'lon': 1440}
Eval Targets:   {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 1}
Eval Forcings:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440}


In [13]:
# Load normalization data
with open("stats/diffs_stddev_by_level.nc", "rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open("stats/mean_by_level.nc", "rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with open("stats/stddev_by_level.nc", "rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()

In [14]:
# Build jitted functions, and possibly initialize random weights
def construct_wrapped_graphcast(
        model_config: graphcast.ModelConfig,
        task_config: graphcast.TaskConfig):
    """Constructs and wraps the GraphCast Predictor."""
    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level)

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config)
    return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
        lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
        (loss, diagnostics))


def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
    def _aux(params, state, i, t, f):
        (loss, diagnostics), next_state = loss_fn.apply(
            params, state, jax.random.PRNGKey(0), model_config, task_config,
            i, t, f)
        return loss, (diagnostics, next_state)

    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
        _aux, has_aux=True)(params, state, inputs, targets, forcings)
    return loss, diagnostics, next_state, grads


# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
    return functools.partial(
        fn, model_config=model_config, task_config=task_config)


# Always pass params and state, so the usage below are simpler
def with_params(fn):
    return functools.partial(fn, params=params, state=state)


# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
    return lambda **kw: fn(**kw)[0]

In [15]:
init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
    print("\nInitializing the model params...")
    params, state = init_jitted(
        rng=jax.random.PRNGKey(0),
        inputs=train_inputs,
        targets_template=train_targets,
        forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))


Initializing the model params...


  scan_length = targets_template.dims['time']


  Converting all input data into flat vectors
  Transfer data for the grid to the mesh
  Run message passing in the multimesh
  Transfer data from the mesh to the grid
  Convert output flat vectors for the grid nodes to the format of the output


  num_inputs = inputs.dims['time']


In [None]:
# Train the model
# @title Loss computation (autoregressive loss over multiple steps)
print("\nTraining the model...")
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
print("Loss:", float(loss))

# @title Gradient computation (backprop through time)
print("\nGradient computation (backprop through time)...")
loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

# @title Autoregressive rollout (keep the loop in JAX)
print("\nAutoregressive rollout...")
print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)

In [None]:
print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)

In [None]:
def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

plot_pred_variable = "total_precipitation_6hr"
plot_pred_level = 1000
plot_pred_robust = False
plot_pred_max_steps = 1
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps),
                     robust=plot_pred_robust),
    "Predictions": scale(select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps),
                         robust=plot_pred_robust),
    "Diff": scale((select(eval_targets, plot_pred_variable, plot_pred_level, plot_max_steps) -
                   select(predictions, plot_pred_variable, plot_pred_level, plot_max_steps)),
                  robust=plot_pred_robust, center=0),
}

# Extract the DataArray from the tuples
targets = data["Targets"][0]
predictions = data["Predictions"][0]
diff = data["Diff"][0]

# Number of sets and values per set
num_sets = 3
values_per_set = 10

# Generate random starting indices
random_indices = [
    (random.randint(0, len(targets.lat) - values_per_set), random.randint(0, len(targets.lon) - values_per_set)) for _
    in range(num_sets)]

# Initialize the output string
fig_title = plot_pred_variable
output = fig_title + "\n\n"

# Create header
header = f"{'Lat':>8} {'Lon':>8} {'Target':>10} {'Prediction':>12} {'Diff':>10}\n"
output += header

# Iterate through the random starting indices and print the values
for start_i, start_j in random_indices:
    output += f"\nStarting at index (lat, lon): ({start_i}, {start_j})\n"
    for k in range(values_per_set):
        i = start_i + k
        j = start_j + k
        if i >= len(targets.lat) or j >= len(targets.lon):
            break
        lat = float(targets.lat[i])
        lon = float(targets.lon[j])
        target_value = float(targets[0, i, j])
        prediction_value = float(predictions[0, i, j])
        diff_value = float(diff[0, i, j])

        output += f"{lat:8.2f} {lon:8.2f} {target_value:10.2f} {prediction_value:12.2f} {diff_value:10.2f}\n"

output += "\n ----------------------------------------------------- \n"
print(output)
# Save to a README file
with open("/predictions.txt", "a") as file:
    file.write(output)

print("Output saved to predictions.txt")

Inputs:   {'batch': 1, 'time': 2, 'level': 1, 'lat': 721, 'lon': 1440}
Targets:  {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440, 'level': 1}
Forcings: {'batch': 1, 'time': 1, 'lat': 721, 'lon': 1440}


  num_target_steps = targets_template.dims["time"]
  scan_length = targets_template.dims['time']


  Converting all input data into flat vectors
  Transfer data for the grid to the mesh
  Run message passing in the multimesh
  Transfer data from the mesh to the grid
  Convert output flat vectors for the grid nodes to the format of the output


  num_inputs = inputs.dims['time']
