# GraphCast updated

In [1]:
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from graphcast import model_utils
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray
import warnings

warnings.filterwarnings('ignore')

In [2]:
def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))


gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")
dir_prefix = "graphcast/"

with gcs_bucket.blob(f"{dir_prefix}params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)

params = ckpt.params
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config

name = 'source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc'
with gcs_bucket.blob(f"{dir_prefix}dataset/{name}").open("rb") as f:
    example_batch = xarray.load_dataset(f).compute()
example_batch

In [3]:
eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, input_duration='24h', target_lead_times='6h',
    **{k: v for k, v in task_config.items() if k != "input_duration"})

with gcs_bucket.blob(dir_prefix+"stats/diffs_stddev_by_level.nc").open("rb") as f:
    diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/mean_by_level.nc").open("rb") as f:
    mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob(dir_prefix+"stats/stddev_by_level.nc").open("rb") as f:
    stddev_by_level = xarray.load_dataset(f).compute()

In [4]:
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor


@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  global predictor
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

In [5]:
predictions, embeddings, latent = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)



In [None]:
import torch
import torch.nn as nn
import numpy as np

from tqdm import tqdm


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [17]:
net = nn.Sequential(nn.Linear(1024, 2048),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(2048, 2048),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(2048, 1024),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(1024, 512),
                    nn.ReLU(),
                    nn.Dropout(0.3),
                    nn.Linear(512, 83)).to(device)
criterion = nn.MSELoss()
optim = torch.optim.Adam([p for p in net.parameters() if p.requires_grad], lr=1e-4, weight_decay=1e-3)

In [28]:
x[:,:512].shape

torch.Size([65160, 512])

In [3]:
for k, v in torch.load('data.pt').items():
    globals()[k] = v.to(device)
targets = targets.reshape(-1, 83)

In [18]:
bs = 512
net.train()
for i in tqdm(range(100)):
    indices = np.random.permutation(np.arange(targets.shape[0]))
    for j in range(0, len(targets), bs):
        idx = indices[i:i+bs]
        y = net(x[idx]) + old[idx]
        loss = criterion(y, targets[idx])
        loss.backward()
        optim.step()
        loss = loss.detach().cpu()
    if i % 5 == 0:
        print(loss)

  1%|▊                                                                                 | 1/100 [00:03<05:56,  3.60s/it]

tensor(7390.6572, device='cuda:0', grad_fn=<MseLossBackward0>)


  6%|████▉                                                                             | 6/100 [00:20<05:27,  3.49s/it]

tensor(4802.6069, device='cuda:0', grad_fn=<MseLossBackward0>)


 11%|████████▉                                                                        | 11/100 [00:38<05:08,  3.46s/it]

tensor(4094.9150, device='cuda:0', grad_fn=<MseLossBackward0>)


 16%|████████████▉                                                                    | 16/100 [00:55<04:50,  3.46s/it]

tensor(3357.9509, device='cuda:0', grad_fn=<MseLossBackward0>)


 21%|█████████████████                                                                | 21/100 [01:12<04:34,  3.47s/it]

tensor(3140.5391, device='cuda:0', grad_fn=<MseLossBackward0>)


 26%|█████████████████████                                                            | 26/100 [01:30<04:18,  3.49s/it]

tensor(3141.5195, device='cuda:0', grad_fn=<MseLossBackward0>)


 31%|█████████████████████████                                                        | 31/100 [01:47<04:01,  3.51s/it]

tensor(2975.2051, device='cuda:0', grad_fn=<MseLossBackward0>)


 36%|█████████████████████████████▏                                                   | 36/100 [02:05<03:45,  3.52s/it]

tensor(2734.1064, device='cuda:0', grad_fn=<MseLossBackward0>)


 41%|█████████████████████████████████▏                                               | 41/100 [02:23<03:28,  3.53s/it]

tensor(2896.1292, device='cuda:0', grad_fn=<MseLossBackward0>)


 46%|█████████████████████████████████████▎                                           | 46/100 [02:40<03:10,  3.53s/it]

tensor(2920.1614, device='cuda:0', grad_fn=<MseLossBackward0>)


 51%|█████████████████████████████████████████▎                                       | 51/100 [02:58<02:53,  3.54s/it]

tensor(3006.1431, device='cuda:0', grad_fn=<MseLossBackward0>)


 56%|█████████████████████████████████████████████▎                                   | 56/100 [03:15<02:35,  3.54s/it]

tensor(3315.0933, device='cuda:0', grad_fn=<MseLossBackward0>)


 61%|█████████████████████████████████████████████████▍                               | 61/100 [03:33<02:17,  3.53s/it]

tensor(3093.5342, device='cuda:0', grad_fn=<MseLossBackward0>)


 66%|█████████████████████████████████████████████████████▍                           | 66/100 [03:51<02:00,  3.54s/it]

tensor(3481.1863, device='cuda:0', grad_fn=<MseLossBackward0>)


 71%|█████████████████████████████████████████████████████████▌                       | 71/100 [04:08<01:42,  3.53s/it]

tensor(3494.3889, device='cuda:0', grad_fn=<MseLossBackward0>)


 76%|█████████████████████████████████████████████████████████████▌                   | 76/100 [04:26<01:24,  3.54s/it]

tensor(3446.3674, device='cuda:0', grad_fn=<MseLossBackward0>)


 81%|█████████████████████████████████████████████████████████████████▌               | 81/100 [04:44<01:07,  3.55s/it]

tensor(3229.7300, device='cuda:0', grad_fn=<MseLossBackward0>)


 86%|█████████████████████████████████████████████████████████████████████▋           | 86/100 [05:02<00:49,  3.55s/it]

tensor(3984.8906, device='cuda:0', grad_fn=<MseLossBackward0>)


 91%|█████████████████████████████████████████████████████████████████████████▋       | 91/100 [05:19<00:32,  3.58s/it]

tensor(4307.8457, device='cuda:0', grad_fn=<MseLossBackward0>)


 96%|█████████████████████████████████████████████████████████████████████████████▊   | 96/100 [05:37<00:14,  3.54s/it]

tensor(3604.7363, device='cuda:0', grad_fn=<MseLossBackward0>)


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [05:51<00:00,  3.52s/it]
