In [2]:
# @title Build jitted functions, and possibly initialize random weights

import xarray as xr 
import numpy as np

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

import wofscast.graphcast_lam as graphcast 
from wofscast.data_generator import (load_wofscast_data, 
                                    wofscast_data_generator, 
                                    wofscast_batch_generator, 
                                    to_static_vars,
                                    add_local_solar_time
                                    
                                    )
from wofscast.wofscast_task_config import WOFS_TASK_CONFIG, train_lead_times, TARGET_VARS
from wofscast import (data_utils, 
                      casting, 
                      normalization,
                      autoregressive,
                      xarray_tree,
                      xarray_jax
                     )
import dataclasses


import haiku as hk
import jax
import functools



task_config = WOFS_TASK_CONFIG
model_config = graphcast.ModelConfig(
              resolution=0,
              mesh_size=5,
              latent_size=16,
              gnn_msg_steps=4,
              hidden_layers=1,
              grid_to_mesh_node_dist=5, 
              loss_weights = None,
              k_hop = 8,
              use_transformer = False,
              num_attn_heads = 4
        )

dataset = xr.load_dataset(
    '/work/mflora/wofs-cast-data/datasets/2019/wrfwof_2019-05-18_213000_to_2019-05-18_220000__10min__ens_mem_06.nc')

dataset = add_local_solar_time(dataset)

example_batch = dataset.expand_dims('batch', axis=0)


path = '/work/mflora/wofs-cast-data/normalization_stats'

mean_by_level = xr.load_dataset(os.path.join(path, 'mean_by_level.nc'))
stddev_by_level = xr.load_dataset(os.path.join(path, 'stddev_by_level.nc'))
diffs_stddev_by_level = xr.load_dataset(os.path.join(path, 'diffs_stddev_by_level.nc'))

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times="10min",
    **dataclasses.asdict(task_config))

train_inputs = to_static_vars(train_inputs)
train_inputs = train_inputs.transpose('batch', 'time', 'lat', 'lon', 'level')
train_targets = train_targets.transpose('batch', 'time', 'lat', 'lon', 'level')
train_forcings = train_forcings.transpose('batch', 'time', 'lat', 'lon')



def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)

  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  
  # Ensure loss is a concrete value
  concrete_loss = jax.device_get(loss)
  print(f"Concrete loss value: {concrete_loss}") 


  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

params = None

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))



# @title Loss computation (autoregressive loss over multiple steps)
loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

print("Loss:", float(loss))





per_variable_weights=None


Loss: 19.5703125


In [4]:
import optax

# modify the gradients function signature
def grads_fn(params, state, inputs, targets, forcings, model_config, task_config):
    def _aux(params, state, i, t, f):
        (loss, diagnostics), next_state = loss_fn.apply(params, state, jax.random.PRNGKey(0), 
                                                        model_config, task_config, i, t, f)
        return loss, (diagnostics, next_state)
    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(params, 
                                                                                      state, inputs, 
                                                                                      targets, forcings)
    
    # Ensure loss is a concrete value
    concrete_loss = jax.device_get(loss)
    print(f"Concrete loss value: {concrete_loss}") 
    
    return loss, diagnostics, next_state, grads

# remove `with_params` from jitted grads function
grads_fn_jitted = jax.jit(with_configs(grads_fn))

# setup optimiser
lr = 1e-3
optimiser = optax.adam(lr, b1=0.9, b2=0.999, eps=1e-8)
opt_state = optimiser.init(params)

# calculate loss and gradients
loss, diagnostics, next_state, grads = grads_fn_jitted(params, state, 
                                                       train_inputs, train_targets, train_forcings)

# update
updates, opt_state = optimiser.update(grads, opt_state)
params = optax.apply_updates(params, updates)


per_variable_weights=None


Concrete loss value: Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>


In [None]:
def update_fn(params, state, opt_state, inputs, targets, forcings, model_config, task_config, norm_stats):
    # Clip gradients if the total norm exceeds the threshold
    def clip_grads(g, clip_norm=32):
        return jnp.where(total_norm > clip_norm, g * clip_norm / total_norm, g)
    
    def compute_loss(params, state, inputs, targets, forcings):
        print('This step inside compute_loss happened....')
        (loss, diagnostics), next_state = loss_fn.apply(params, state, 
                                                        jax.random.PRNGKey(0), 
                                                        model_config, 
                                                        task_config, norm_stats, 
                                                        inputs, targets, forcings)
        return loss, (diagnostics, next_state)
    
    # Compute gradients and auxiliary outputs
    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(compute_loss, has_aux=True)(params, state, 
                                                                                              inputs, targets, 
                                                                                              forcings)
  
    # Combine the gradient across all devices (by taking their mean).
    grads = jax.lax.pmean(grads, axis_name='devices')

    # Compute the global norm of all gradients and clip them. 
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in tree_util.tree_leaves(grads)))
    clipped_grads = tree_util.tree_map(clip_grads, grads)
    
    #update 
    updates, new_opt_state = optimizer.update(clipped_grads, opt_state)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss, diagnostics