<a href="https://colab.research.google.com/github/Alokcoder/ML-Basics/blob/main/gencast_working_file.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install -U importlib_metadata



In [2]:
%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

Collecting https://github.com/deepmind/graphcast/archive/master.zip
  Using cached https://github.com/deepmind/graphcast/archive/master.zip
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
# This is required due to outdated jax and libtpu versions in Colab TPU images.
%pip uninstall -y libtpu libtpu-nightly
%pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Found existing installation: libtpu 0.0.13
Uninstalling libtpu-0.0.13:
  Successfully uninstalled libtpu-0.0.13
[0mLooking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting libtpu==0.0.13.* (from jax[tpu])
  Using cached libtpu-0.0.13-py3-none-manylinux_2_31_x86_64.whl.metadata (500 bytes)
Using cached libtpu-0.0.13-py3-none-manylinux_2_31_x86_64.whl (132.9 MB)
Installing collected packages: libtpu
Successfully installed libtpu-0.0.13


In [4]:
!pip install netCDF4



In [5]:
! pip install h5netcdf



In [6]:
import os
# Disable Triton's scalar prefetch which isn’t implemented
os.environ["TRITON_DISABLE_PREFETCH"] = "1"

In [7]:

import dataclasses
import datetime
import math
from google.cloud import storage
from typing import Optional
import haiku as hk
from IPython.display import HTML
from IPython import display
import ipywidgets as widgets
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray

from graphcast import rollout
from graphcast import xarray_jax
from graphcast import normalization
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import xarray_tree
from graphcast import gencast
from graphcast import denoiser
from graphcast import nan_cleaning

In [8]:
diffs_stddev_by_level = xarray.load_dataset('/content/gencast_stats_diffs_stddev_by_level.nc').compute()
mean_by_level = xarray.load_dataset('/content/gencast_stats_mean_by_level.nc').compute()
stddev_by_level = xarray.load_dataset('/content/gencast_stats_stddev_by_level.nc').compute()
min_by_level = xarray.load_dataset('/content/gencast_stats_min_by_level.nc').compute()



In [9]:
diffs_stddev_by_level

In [10]:
eval_inputs=xarray.load_dataset('/content/dataset_input.nc').compute()
eval_targets=xarray.load_dataset('/content/dataset_target.nc').compute()
eval_forcings=xarray.load_dataset('/content/dataset_forcing.nc').compute()

In [11]:
  ckpt = checkpoint.load('/content/gencast_params_GenCast 1p0deg Mini _2019.npz', gencast.CheckPoint)
  params = ckpt.params
  state = {}

  task_config = ckpt.task_config
  sampler_config = ckpt.sampler_config
  noise_config = ckpt.noise_config
  noise_encoder_config = ckpt.noise_encoder_config
  denoiser_architecture_config = ckpt.denoiser_architecture_config
  print("Model description:\n", ckpt.description, "\n")
  print("Model license:\n", ckpt.license, "\n")

Model description:
 
        GenCast model at lower, 1deg, resolution, with 13 pressure levels and a
        4 times refined icosahedral mesh. This model is trained on ERA5 data
        from 1979 to 2018, and can be causally evaluated on 2019 and later years.
        This model has the smallest memory footprint of those provided and has been provided
        to enable low cost demonstrations. It is not representative of GenCast's performance.
         

Model license:
 
The model weights are licensed under the Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You
may obtain a copy of the License at:
https://creativecommons.org/licenses/by-nc-sa/4.0/.
The weights were trained on ERA5 data, see README for attribution statement.
 



In [12]:
# @title Build jitted functions, and possibly initialize random weights


def construct_wrapped_gencast():
  """Constructs and wraps the GenCast Predictor."""
  predictor = gencast.GenCast(
      sampler_config=sampler_config,
      task_config=task_config,
      denoiser_architecture_config=denoiser_architecture_config,
      noise_config=noise_config,
      noise_encoder_config=noise_encoder_config,
  )

  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level,
  )

  predictor = nan_cleaning.NaNCleaner(
      predictor=predictor,
      reintroduce_nans=True,
      fill_value=min_by_level,
      var_to_clean='sea_surface_temperature',
  )

  return predictor


@hk.transform_with_state
def run_forward(inputs, targets_template, forcings):
  predictor = construct_wrapped_gencast()
  return predictor(inputs, targets_template=targets_template, forcings=forcings)


@hk.transform_with_state
def loss_fn(inputs, targets, forcings):
  predictor = construct_wrapped_gencast()
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics),
  )


def grads_fn(params, state, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), i, t, f
    )
    return loss, (diagnostics, next_state)

  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True
  )(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads


if params is None:
  init_jitted = jax.jit(loss_fn.init)
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets=train_targets,
      forcings=train_forcings,
  )


loss_fn_jitted = jax.jit(
    lambda rng, i, t, f: loss_fn.apply(params, state, rng, i, t, f)[0]
)
grads_fn_jitted = jax.jit(grads_fn)
run_forward_jitted = jax.jit(
    lambda rng, i, t, f: run_forward.apply(params, state, rng, i, t, f)[0]
)
# We also produce a pmapped version for running in parallel.
run_forward_pmap = xarray_jax.pmap(run_forward_jitted, dim="sample")

In [14]:
# @title Autoregressive rollout (loop in python)

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

num_ensemble_members = 8 # @param int
rng = jax.random.PRNGKey(0)
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs, regardless of
# total ensemble size.
rngs = np.stack(
    [jax.random.fold_in(rng, i) for i in range(num_ensemble_members)], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
    # Use pmapped version to parallelise across devices.
    predictor_fn=run_forward_pmap,
    rngs=rngs,
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings,
    num_steps_per_chunk = 1,
    num_samples = num_ensemble_members,
    pmap_devices=jax.local_devices()
    ):
    chunks.append(chunk)
predictions = xarray.combine_by_coords(chunks)

Inputs:   {'time': 2, 'level': 13, 'lat': 181, 'lon': 360, 'batch': 1}
Targets:  {'lat': 181, 'lon': 360, 'level': 13, 'time': 4, 'batch': 1}
Forcings: {'time': 4, 'batch': 1, 'lon': 360}


  num_target_steps = targets_template.dims["time"]
  self._set_arrayXarray(i, j, x)
  self._set_arrayXarray(i, j, x)
  num_inputs = prev_inputs.dims["time"]


In [None]:
!pip install --upgrade jax jaxlib


In [17]:
predictions.to_netcdf('/content/predictions.nc')