In [1]:
import sys, os 
sys.path.insert(0, os.path.dirname(os.getcwd()))

# @title Imports
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional
from glob import glob

import cartopy.crs as ccrs
#from google.cloud import storage
from wofscast import autoregressive
from wofscast import casting
from wofscast import checkpoint
from wofscast import data_utils
from wofscast import my_graphcast as graphcast
from wofscast import normalization
from wofscast import rollout
from wofscast import xarray_jax
from wofscast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray #as xr

# For training the weights!
import optax
import jax
import numpy as np
import jax.numpy as jnp

from jax import device_put

from jax import pmap, device_put, local_device_count

# Check available devices
print("Available devices:", jax.devices())

from jax import tree_util


Available devices: [cuda(id=0), cuda(id=1)]


In [2]:
# Notes on GraphCast Changes to run with WoFS data.

# 1. Introduced time dimension with timedeltas dataset
# 2. Introduce level dimension to the dataset 
# 3. Added try/excepts for xarray_jax to avoid PyTree errors about registry 
# 4. Cuda-enabled jaxlib error; had to install jax with this command for Cuda 11.8 
# pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 5. Need to add the forcing variables (radiation) to the dataset
# 6. Xarray 2024.1.1 raised TracerArrayConversionError, downgraded to 2023.7.0, the version 
#    used in colab in the demo notebook.


""" usage: stdbuf -oL python -u train_graphcast_with_wofs.py > & log_trainer_models & """

' usage: stdbuf -oL python -u train_graphcast_with_wofs.py > & log_trainer_models & '

In [3]:
def count_total_parameters(params_dict):
    """
    Count the total number of parameters in a nested dictionary of parameters.
    Assumes that the dictionary contains `Array` objects that have a `size` attribute.
    
    Args:
    - params_dict (dict): A nested dictionary of parameters.
    
    Returns:
    - int: The total number of parameters.
    """
    total_params = 0

    # Define a helper function to recurse through the dictionary
    def recurse_through_dict(d):
        nonlocal total_params
        for k, v in d.items():
            if isinstance(v, dict):
                recurse_through_dict(v)  # Recurse if value is a dictionary
            else:
                # Assume that the object has a 'size' attribute
                total_params += v.size
    
    recurse_through_dict(params_dict)
    return total_params

def flatten_dict(d, parent_key='', sep='//'):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def save_model_params(d, file_path):
    flat_dict = flatten_dict(d)
    # Convert JAX arrays to NumPy for saving
    np_dict = {k: np.array(v) if isinstance(v, jnp.ndarray) else v for k, v in flat_dict.items()}
    np.savez(file_path, **np_dict)


base_path = '/work/mflora/wofs-cast-data/model'
params_path = os.path.join(base_path, 'params.npz')

###save_model_params(params, params_path)


def unflatten_dict(d, sep='//'):
    result_dict = {}
    for flat_key, value in d.items():
        keys = flat_key.split(sep)
        d = result_dict
        for key in keys[:-1]:
            if key not in d:
                d[key] = {}
            d = d[key]
        d[keys[-1]] = value
    return result_dict

def load_model_params(file_path):
    with np.load(file_path, allow_pickle=True) as npz_file:
        # Convert NumPy arrays back to JAX arrays
        jax_dict = {k: jnp.array(v) for k, v in npz_file.items()}
    return unflatten_dict(jax_dict)

###params = load_model_params(params_path)


In [4]:
mesh_size = 6
latent_size = 64
gnn_msg_steps = 8
hidden_layers = 2

In [5]:
input_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
target_variables = ['U', 'V', 'W', 'T']#, 'P', 'REFL_10CM', 'UP_HELI_MAX']
forcing_variables = ["XLAND"]

vars_2D = [] #['UP_HELI_MAX']

# Weights used in the loss equation.
VARIABLE_WEIGHTS = {v : 1.0 for v in target_variables}
#VARIABLE_WEIGHTS['REFL_10CM'] = 2.0
#VARIABLE_WEIGHTS['UP_HELI_MAX'] = 2.0

# Not pressure levels, but just vertical array indices at the moment. 
pressure_levels = np.arange(0, 40) #list(np.arange(0,40,2))
radius_query_fraction_edge_length=5

# Loads data from the past 10 minutes and 
# creates a target lead time 5-30, in 5 min intervals
input_duration = '20min'
train_lead_times = '10min' 

In [6]:
model_config = graphcast.ModelConfig(
      resolution=0,
      mesh_size=mesh_size,
      latent_size=latent_size,
      gnn_msg_steps=gnn_msg_steps,
      hidden_layers=hidden_layers,
      radius_query_fraction_edge_length=radius_query_fraction_edge_length)

task_config = graphcast.TaskConfig(
      input_variables=input_variables,
      target_variables=target_variables,
      forcing_variables=forcing_variables,
      pressure_levels=pressure_levels,
      input_duration=input_duration,
  )

In [7]:
# Load the data 
base_path = '/work/mflora/wofs-cast-data/train_datasets'

train_inputs = xarray.load_dataset(os.path.join(base_path, 'train_inputs.nc'))
train_targets = xarray.load_dataset(os.path.join(base_path, 'train_targets.nc'))
train_forcings = xarray.load_dataset(os.path.join(base_path, 'train_forcings.nc'))

train_targets = train_targets.isel(time = [0])
train_forcings= train_forcings.isel(time = [0])

train_inputs = train_inputs.transpose("batch", 'time', 'lat', 'lon', 'level')
train_targets = train_targets.transpose("batch", 'time', 'lat', 'lon', 'level')
train_forcings = train_forcings.transpose("batch", 'time', 'lat', 'lon')

eval_inputs = train_inputs.isel(batch=[0])
eval_targets = train_targets.isel(batch=[0])
eval_forcings = train_forcings.isel(batch=[0])

print("*"*80)
print("Train Inputs:  ", train_inputs.dims.mapping)
print("Train Targets: ", train_targets.dims.mapping)
print("Train Forcings:", train_forcings.dims.mapping)
print("*"*80)
print("Eval Inputs:   ", eval_inputs.dims.mapping)
print("Eval Targets:  ", eval_targets.dims.mapping)
print("Eval Forcings: ", eval_forcings.dims.mapping)

********************************************************************************
Train Inputs:   {'batch': 204, 'time': 2, 'level': 40, 'lat': 150, 'lon': 150}
Train Targets:  {'batch': 204, 'time': 1, 'level': 40, 'lat': 150, 'lon': 150}
Train Forcings: {'batch': 204, 'time': 1, 'lat': 150, 'lon': 150}
********************************************************************************
Eval Inputs:    {'batch': 1, 'time': 2, 'lat': 150, 'lon': 150, 'level': 40}
Eval Targets:   {'batch': 1, 'time': 1, 'lat': 150, 'lon': 150, 'level': 40}
Eval Forcings:  {'batch': 1, 'time': 1, 'lat': 150, 'lon': 150}


In [8]:
# Load the normalization datasets
base_path = '/work/mflora/wofs-cast-data/normalization_stats'
mean_by_level = xarray.load_dataset(os.path.join(base_path, 'mean_by_level.nc'))
stddev_by_level = xarray.load_dataset(os.path.join(base_path, 'stddev_by_level.nc'))
diffs_stddev_by_level = xarray.load_dataset(os.path.join(base_path, 'diffs_stddev_by_level.nc'))

In [9]:
# @title Build jitted functions, and possibly initialize random weights

# Load saved model params and test! 
#model_params = None #load_model_params(params_path)
#state = {}

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
    """Constructs and wraps the GraphCast Predictor."""

    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config, VARIABLE_WEIGHTS, vars_2D)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    
    return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
    def _aux(params, state, i, t, f):
        (loss, diagnostics), next_state = loss_fn.apply(
            params, state, jax.random.PRNGKey(0), model_config, task_config,
            i, t, f)
        return loss, (diagnostics, next_state)
      
        (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
          _aux, has_aux=True)(params, state, inputs, targets, forcings)
    return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=model_params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

model_params, state = init_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets,
    forcings=train_forcings)

#loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
#grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
#run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
#    run_forward.apply))))



n_mesh_node=array([8321]), node_features.shape=(8321, 3)


In [10]:
num = count_total_parameters(model_params)
print(f'Num of Parameters: {num}')

base_path = '/work/mflora/wofs-cast-data/model'
params_path = os.path.join(base_path, 'test_params.npz')
save_model_params(model_params, params_path)

Num of Parameters: 503264


### Training Function 

In [11]:
def shard_xarray_dataset(dataset, num_devices=None):
    """
    Shards an xarray.Dataset across multiple GPUs.

    Parameters:
    - dataset: xarray.Dataset to be sharded.
    - num_devices: Number of GPUs to shard the dataset across. If None, uses all available GPUs.

    Returns:
    A list of sharded xarray.Dataset, one for each GPU.
    """
    if num_devices is None:
        num_devices = jax.local_device_count()

    # Assuming the first dimension of each data variable is the batch dimension
    batch_size = next(iter(dataset.data_vars.values())).shape[0]
    shard_size = batch_size // num_devices

    if batch_size % num_devices != 0:
        raise ValueError(f"Batch size {batch_size} is not evenly divisible by the number of devices {num_devices}.")

    sharded_datasets = []
    for i in range(num_devices):
        start_idx = i * shard_size
        end_idx = start_idx + shard_size
        # Use dataset.isel to select a subset of the batch dimension for each shard
        shard = dataset.isel(indexers={'batch': slice(start_idx, end_idx)})
        sharded_datasets.append(shard)

    return xarray.concat(sharded_datasets, dim='devices')

In [12]:
def replicate_for_devices(params, num_devices=None):
    """Replicate parameters for each device using jax.device_put_replicated."""
    if num_devices is None:
        num_devices = jax.local_device_count()
    devices = jax.devices()[:num_devices]
    replicated = jax.device_put_replicated(params, devices)
    return replicated

In [13]:
# Check available devices
print("Available devices:", jax.devices())
num_devices = jax.local_device_count()

def _loss_fn(params, state, inputs, targets, forcings):
        (loss, diagnostics), next_state = loss_fn.apply(params, state, 
                                                        jax.random.PRNGKey(0), 
                                                        model_config, 
                                                        task_config, 
                                                        inputs, targets, forcings)
        return loss, (diagnostics, next_state)

# Define the gradient function
def train_step(params: dict, 
               state:dict, 
               inputs: xarray.Dataset, 
               targets : xarray.Dataset, 
               forcings : xarray.Dataset, 
               model_config, task_config):
    
    # Compute gradients and auxiliary outputs
    gradient_fn = jax.value_and_grad(_loss_fn, has_aux=True)
    (loss, (diagnostics, next_state)), grads = gradient_fn(params, state, inputs, targets, forcings)
    
    """
    # Compute the global norm of all gradients
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in tree_util.tree_leaves(grads)))

    # Clip gradients if the total norm exceeds the threshold
    def clip_grads(g, clip_norm=32):
        return jnp.where(total_norm > clip_norm, g * clip_norm / total_norm, g)

    clipped_grads = tree_util.tree_map(clip_grads, grads)
    """
    
    return grads, loss, diagnostics 
    
def update_step(optimiser, params, grads, opt_state):
    """Performs a single update step by applying gradients to parameters."""
    updates, opt_state = optimiser.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

def get_random_batches(inputs, targets, forcings, batch_size):
    total_samples = inputs.dims['batch']
    indices = np.arange(total_samples)
    np.random.shuffle(indices)  # Randomly shuffle the indices

    for start_idx in range(0, total_samples, batch_size):
        end_idx = min(start_idx + batch_size, total_samples)
        batch_indices = indices[start_idx:end_idx]
        yield (inputs.isel(batch=batch_indices), 
               targets.isel(batch=batch_indices), 
               forcings.isel(batch=batch_indices))

train_step_pmap = xarray_jax.pmap(with_configs(train_step), dim='devices')

Available devices: [cuda(id=0), cuda(id=1)]


In [14]:
# Training Parameters. 
TOTAL_LINEAR_EPOCHS = 50
TOTAL_COSINE_EPOCHS = 10
BATCH_SIZE = 128

checkpoint = False

# Setup the learning rate schedule
start_learning_rate = 1e-6  # Start from 0
end_learning_rate = 1e-3  # Increase to 1e-3
schedule = optax.linear_schedule(init_value=start_learning_rate, 
                                 end_value=end_learning_rate, 
                                 transition_steps=TOTAL_LINEAR_EPOCHS)

In [15]:
%%time

model_params, state = init_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets,
    forcings=train_forcings)


# For multiple GPUs!!
base_path = '/work/mflora/wofs-cast-data/model'
params_path = os.path.join(base_path, 'params.npz')

lr = 1e-3
optimiser = optax.adam(lr, b1=0.9, b2=0.95, eps=1e-8)

# Training loop with linearly increasing learning rate
for epoch in range(TOTAL_LINEAR_EPOCHS): 
    if epoch == 0:
        # Initialize the optimizer state only at the beginning
        opt_state = optimiser.init(model_params)
    
    # Create mini-batches for the current epoch and compute gradients. 
    losses_per_epoch = []
    for batch_inputs, batch_targets, batch_forcings in get_random_batches(train_inputs, 
                                                                          train_targets, 
                                                                          train_forcings, 
                                                                          BATCH_SIZE):
        
        batch_inputs_sharded = shard_xarray_dataset(batch_inputs)
        batch_targets_sharded = shard_xarray_dataset(batch_targets)
        batch_forcings_sharded = shard_xarray_dataset(batch_forcings)
        
        model_params_sharded = replicate_for_devices(model_params, num_devices)
        state_sharded = replicate_for_devices(state, num_devices)

        grads, loss, diagnostics = train_step_pmap(model_params_sharded, 
                                                   state_sharded, 
                                                   batch_inputs_sharded, 
                                                   batch_targets_sharded, 
                                                   batch_forcings_sharded, 
                                                  )
        
        # Aggregate gradients from all devices and apply them 
        # If your training logic requires aggregation like averaging
        grads_avg = jax.tree_map(lambda x: jnp.mean(x, axis=0), grads)

        model_params, opt_state = update_step(optimiser, model_params, grads_avg, opt_state)
        
        losses_per_epoch.append(np.mean(loss))
        
    print(f"Epoch: {epoch}.....Loss: {np.mean(losses_per_epoch):.5f}")

n_mesh_node=array([8321]), node_features.shape=(8321, 3)


2024-02-29 11:17:07.299288: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %compare.6959 = pred[638518,1]{1,0} compare(s32[638518,1]{1,0} %constant.271, s32[638518,1]{1,0} %broadcast.231), direction=GE, metadata={op_name="pmap(fn_passed_to_pmap)/jit(main)/transpose(jvp(grid2mesh_gnn))/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/ge" source_file="/home/monte.flora/python_packages/frdd-wofs-cast/wofscast/deep_typed_graph_net.py" source_line=219}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-02-29 11:17:07.696770: E external/xla/xla/service/slow_operation_alarm.cc:133] Th

2024-02-29 11:17:20.820015: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 6.91405e+06, expected 1.00925e+07
2024-02-29 11:17:20.820052: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 7.04512e+06, expected 1.02236e+07
2024-02-29 11:17:20.820056: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 6.91405e+06, expected 1.00925e+07
2024-02-29 11:17:20.820060: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 6.97958e+06, expected 1.01581e+07
2024-02-29 11:17:20.820063: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 6.94682e+06, expected 1.00925e+07
2024-02-29 11:17:20.820066: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 6.94682e+06, expected 1.01581e+07
2024-02-29 11:17:20.820069: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 6: 7.01235e+06, expected 1.01581e+07
2024-02-29 11:17:20.820072: E external/xla/xla/s

2024-02-29 11:17:25.024853: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 6.91405e+06, expected 1.00925e+07
2024-02-29 11:17:25.024889: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 7.04512e+06, expected 1.02236e+07
2024-02-29 11:17:25.024893: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 6.91405e+06, expected 1.00925e+07
2024-02-29 11:17:25.024896: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 6.97958e+06, expected 1.01581e+07
2024-02-29 11:17:25.024899: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 6.94682e+06, expected 1.00925e+07
2024-02-29 11:17:25.024901: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 6.94682e+06, expected 1.01581e+07
2024-02-29 11:17:25.024904: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 6: 7.01235e+06, expected 1.01581e+07
2024-02-29 11:17:25.024907: E external/xla/xla/s

2024-02-29 11:17:29.044622: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 8.38861e+06, expected 1.00925e+07
2024-02-29 11:17:29.044658: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 8.45414e+06, expected 1.02236e+07
2024-02-29 11:17:29.044662: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 8.38861e+06, expected 1.00925e+07
2024-02-29 11:17:29.044665: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 8.45414e+06, expected 1.01581e+07
2024-02-29 11:17:29.044668: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 8.38861e+06, expected 1.00925e+07
2024-02-29 11:17:29.044671: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 8.45414e+06, expected 1.01581e+07
2024-02-29 11:17:29.044673: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 6: 8.45414e+06, expected 1.01581e+07
2024-02-29 11:17:29.044676: E external/xla/xla/s

XlaRuntimeError: INTERNAL: All algorithms tried for %convert.411 = f32[192,64]{1,0} convert(bf16[192,64]{1,0} %dot.409), metadata={op_name="pmap(fn_passed_to_pmap)/jit(main)/transpose(jvp(grid2mesh_gnn))/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/sequential_3/processor_edges_0_grid2mesh_mlp/linear_0/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/home/monte.flora/python_packages/frdd-wofs-cast/wofscast/casting.py" source_line=197} failed. Falling back to default algorithm.  Per-algorithm errors:
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.

In [16]:
%%time

model_params, state = init_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets,
    forcings=train_forcings)

# remove `with_params` from jitted grads function
grads_fn_jitted = jax.jit(with_configs(train_step))

lr = 1e-3
optimiser = optax.adam(lr, b1=0.9, b2=0.95, eps=1e-8)

# Training loop with linearly increasing learning rate
for epoch in range(TOTAL_LINEAR_EPOCHS): 
    if epoch == 0:
        # Initialize the optimizer state only at the beginning
        opt_state = optimiser.init(model_params)
    
    # Create mini-batches for the current epoch and compute gradients. 
    losses_per_epoch = []
    for batch_inputs, batch_targets, batch_forcings in get_random_batches(train_inputs, train_targets, 
                                                                          train_forcings, 
                                                                          BATCH_SIZE):
        grads, loss, diagnostics = grads_fn_jitted(model_params, state, batch_inputs, 
                                                           batch_targets, batch_forcings)
        losses_per_epoch.append(loss)
        
        # Update parameters
        model_params, opt_state = update_step(optimiser, model_params, grads, opt_state)

    print(f"Epoch: {epoch}.....Loss: {np.mean(losses_per_epoch):.5f}")

n_mesh_node=array([8321]), node_features.shape=(8321, 3)


2024-02-29 11:17:39.822227: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 2s:

  %compare.6845 = pred[638518,1]{1,0} compare(s32[638518,1]{1,0} %constant.265, s32[638518,1]{1,0} %broadcast.225), direction=GE, metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(grid2mesh_gnn))/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/ge" source_file="/home/monte.flora/python_packages/frdd-wofs-cast/wofscast/deep_typed_graph_net.py" source_line=219}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2024-02-29 11:17:39.972324: E external/xla/xla/service/slow_operation_alarm.cc

2024-02-29 11:18:00.188214: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 9.89594e+06, expected 2.0054e+07
2024-02-29 11:18:00.188255: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:00.188259: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 9.76486e+06, expected 1.99229e+07
2024-02-29 11:18:00.188262: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 9.69933e+06, expected 1.96608e+07
2024-02-29 11:18:00.188264: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 9.89594e+06, expected 1.99229e+07
2024-02-29 11:18:00.188267: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:00.188270: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 6: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:00.188273: E external/xla/xla/servi

2024-02-29 11:18:09.071472: W external/xla/xla/service/gpu/triton_autotuner.cc:788] Slow kernel for triton_gemm_dot.247 took: 1.116436523s. config: {block_m:64,block_n:64,block_k:64,split_k:1,num_stages:2,num_warps:4}
2024-02-29 11:18:09.071953: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 9.89594e+06, expected 2.0054e+07
2024-02-29 11:18:09.071960: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:09.071963: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 9.76486e+06, expected 1.99229e+07
2024-02-29 11:18:09.071966: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 9.69933e+06, expected 1.96608e+07
2024-02-29 11:18:09.071969: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 9.89594e+06, expected 1.99229e+07
2024-02-29 11:18:09.071972: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 9.

2024-02-29 11:18:16.248283: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 0: 9.89594e+06, expected 2.0054e+07
2024-02-29 11:18:16.248325: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 1: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:16.248328: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 2: 9.76486e+06, expected 1.99229e+07
2024-02-29 11:18:16.248331: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 3: 9.69933e+06, expected 1.96608e+07
2024-02-29 11:18:16.248334: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 4: 9.89594e+06, expected 1.99229e+07
2024-02-29 11:18:16.248337: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 5: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:16.248339: E external/xla/xla/service/gpu/buffer_comparator.cc:149] Difference at 6: 9.8304e+06, expected 1.99229e+07
2024-02-29 11:18:16.248342: E external/xla/xla/servi

XlaRuntimeError: INTERNAL: All algorithms tried for %convert.411 = f32[192,64]{1,0} convert(bf16[192,64]{1,0} %dot.409), metadata={op_name="jit(<unnamed wrapped function>)/jit(main)/transpose(jvp(grid2mesh_gnn))/_process/grid2mesh_gnn/_process_step/grid2mesh_gnn/sequential_3/processor_edges_0_grid2mesh_mlp/linear_0/convert_element_type[new_dtype=float32 weak_type=False]" source_file="/home/monte.flora/python_packages/frdd-wofs-cast/wofscast/casting.py" source_line=197} failed. Falling back to default algorithm.  Per-algorithm errors:
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.
  Results do not match the reference. This is likely a bug/unexpected loss of precision.

In [17]:
diagnostics

NameError: name 'diagnostics' is not defined

In [None]:
save_model_params(model_params, params_path)

### Run the model forward based on the new model params

In [None]:
# Always pass params and state, so the usage below are simpler
def with_model_params(fn):
  return functools.partial(fn, params=model_params, state=state)

run_forward_jitted = drop_state(with_model_params(jax.jit(with_configs(
    run_forward.apply))))

# @title Autoregressive rollout (keep the loop in JAX)
print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)

predictions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable

preds =  predictions.squeeze(dim='batch', drop=True)
targets = eval_targets.squeeze(dim='batch', drop=True)

#print(np.max(preds['T'][1] - preds['T'][1]))
#print(np.max(targets['T'][0] - targets['T'][0]))

var = 'U'

def get_target_and_pred_pair(preds, targets, t, level=0):
    # (time, level, lat, lon )
    if level == 'max':
        zs = [targets[var][t].max(dim='level').values, preds[var][t].max(dim='level').values]
    elif level == 'min': 
        zs = [targets[var][t].min(dim='level').values, preds[var][t].min(dim='level').values]
    elif level == 'none':
        zs = [targets[var][t].values, preds[var][t].values]
    else:
        zs = [targets[var][t, level].values, preds[var][t, level].values]
    return zs 

fig, axes = plt.subplots(dpi=200, figsize=(10,6), ncols=2)
plt.tight_layout()

zs = get_target_and_pred_pair(preds, targets, t=0)

titles = ['Target', 'Prediction']
for i, (ax, z) in enumerate(zip(axes, zs)):
    div = make_axes_locatable(ax)
    cax = div.append_axes('right', '5%', '5%')
    if var in ['REFL_10CM', 'UP_HELI_MAX']:
        z = np.ma.masked_where(z<5,z)
    
    im = ax.imshow(z, origin='lower', aspect='equal', cmap='jet')
    cb = fig.colorbar(im, cax=cax)
    ax.set_title(titles[i])
    
# This function will update the content of the plot for each frame
def update(t):
    # Clear the current axes
    for ax in axes:
        ax.clear()

    zs = get_target_and_pred_pair(preds, targets, t=t)
    
    titles = ['Target', 'Prediction']
    for i, (ax, z) in enumerate(zip(axes, zs)):
        if var in ['REFL_10CM', 'UP_HELI_MAX']:
            z = np.ma.masked_where(z<5,z)
        im = ax.imshow(z, origin='lower', aspect='equal', cmap='jet')
        ax.set_title(titles[i])


# Total number of frames (adjust 'N' according to the size of your data along the 't' dimension)
N = targets[var].shape[0]  # Assuming the second dimension is 't'

# Create animation
anim = FuncAnimation(fig, update, frames=N, interval=200)  # Adjust interval for frame speed

# To display the animation in a Jupyter notebook
from IPython.display import HTML
HTML(anim.to_jshtml())