In [1]:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


In [2]:
#!/usr/bin/env python
# coding: utf-8

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))

# Set JAX_TRACEBACK_FILTERING to off for detailed traceback
#os.environ['JAX_TRACEBACK_FILTERING'] = 'on'


# @title Imports
import dataclasses
import datetime
import functools
import math
import re
from typing import Optional
from glob import glob
import gc
import datetime

import cartopy.crs as ccrs
#from google.cloud import storage
from wofscast import autoregressive #_lam as autoregressive
from wofscast import casting
from wofscast import checkpoint
from wofscast import data_utils
from wofscast import graphcast_lam as graphcast
from wofscast import normalization
from wofscast import rollout
from wofscast import xarray_jax
from wofscast import xarray_tree
from wofscast.data_generator import (ZarrDataGenerator, 
                                     add_local_solar_time, 
                                     to_static_vars, 
                                     open_zarr,
                                     dataset_to_input
                                    )

from wofscast.wofscast_task_config import DBZ_TASK_CONFIG, WOFS_TASK_CONFIG
from wofscast.model import shard_xarray_dataset, replicate_for_devices

import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import xarray #as xr

from wofscast.utils import count_total_parameters, save_model_params, load_model_params 

# For training the weights!
import optax
import jax
import numpy as np
import jax.numpy as jnp

from jax import device_put
from jax import pmap, device_put, local_device_count
# Check available devices
print("Available GPU devices:", jax.devices())
from jax import tree_util

import time 
import wandb


Available GPU devices: [cuda(id=0), cuda(id=1)]


In [3]:
# Load the 36.7M GraphCast weights. 

name = 'params_GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz'
graphcast_path = os.path.join('/work/mflora/wofs-cast-data/graphcast_models', name)

with open(graphcast_path, 'rb') as f:
    data = checkpoint.load(graphcast_path, dict)


In [4]:
graphcast_params = data['params']
model_config = data['model_config']
#task_config = ckpt.task_config

In [5]:
def print_params(params, indent=0):
    for key, value in params.items():
        if isinstance(value, dict):
            print(' ' * indent + f"{key}:")
            print_params(value, indent + 2)
        else:
            print(' ' * indent + f"{key}: {value.shape}")     

In [6]:
# Initial params: mesh = 5, latent
mesh_size = 5
latent_size = 512
gnn_msg_steps = 16
hidden_layers = 1
grid_to_mesh_node_dist=5

task_config = WOFS_TASK_CONFIG

model_config = graphcast.ModelConfig(
      resolution=0,
      mesh_size=mesh_size,
      latent_size=latent_size,
      gnn_msg_steps=gnn_msg_steps,
      hidden_layers=hidden_layers,
      grid_to_mesh_node_dist=grid_to_mesh_node_dist)

In [7]:
path = '/work/mflora/wofs-cast-data/full_normalization_stats'

mean_by_level = xarray.load_dataset(os.path.join(path, 'mean_by_level.nc'))
stddev_by_level = xarray.load_dataset(os.path.join(path, 'stddev_by_level.nc'))
diffs_stddev_by_level = xarray.load_dataset(os.path.join(path, 'diffs_stddev_by_level.nc'))

norm_stats = {'mean_by_level': mean_by_level, 
                      'stddev_by_level' : stddev_by_level,
                      'diffs_stddev_by_level' : diffs_stddev_by_level
                     }

In [8]:
%%time
import os
from os.path import join
from concurrent.futures import ThreadPoolExecutor

base_path = '/work/mflora/wofs-cast-data/datasets_zarr'#_zarr'
years = ['2019']#, '2020']

def get_files_for_year(year):
    year_path = join(base_path, year)
    with os.scandir(year_path) as it:
        return [join(year_path, entry.name) for entry in it if entry.is_dir() and entry.name.endswith('.zarr')]
        #return [join(year_path, entry.name) for entry in it if entry.is_file()]
    
with ThreadPoolExecutor() as executor:
    paths = []
    for files in executor.map(get_files_for_year, years):
        paths.extend(files)

print(len(paths))

8730
CPU times: user 4.05 ms, sys: 7.25 ms, total: 11.3 ms
Wall time: 10.6 ms


In [9]:
%%time 

batch_size = 2

generator = ZarrDataGenerator(task_config, 
                              cpu_batch_size=batch_size, 
                              gpu_batch_size=batch_size, n_workers=16)

j=0
for inputs, targets, forcings in generator(paths[:batch_size]):
    print(f'Batch : {j}')
    j+=1
    break 

  self.pid = os.fork()
  self.pid = os.fork()


Batch : 0
CPU times: user 208 ms, sys: 336 ms, total: 544 ms
Wall time: 636 ms


In [10]:
import subprocess as sp
import os

def get_gpu_memory():
    print(sp.check_output("nvidia-smi").decode('ascii'))

get_gpu_memory()

Fri Jun  7 16:23:47 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:21:00.0 Off |                    0 |
| N/A   37C    P0             67W /  300W |   62134MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  |   00

In [11]:
def construct_wrapped_graphcast(model_config: graphcast.ModelConfig, 
                                task_config: graphcast.TaskConfig,
                                norm_stats: dict
                               ):
    """Constructs and wraps the GraphCast Predictor."""
    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=norm_stats['diffs_stddev_by_level'],
      mean_by_level=norm_stats['mean_by_level'],
      stddev_by_level=norm_stats['stddev_by_level']
    )

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    
    return predictor

def train_step(params, 
             state, 
             opt_state, 
             optimizer, 
             inputs, 
             targets, 
             forcings, 
             model_config, 
             task_config, 
             norm_stats):
    
    def compute_loss(params, state, inputs, targets, forcings):
        (loss, diagnostics), next_state = loss_fn.apply(params, state, 
                                                        jax.random.PRNGKey(0), 
                                                        model_config, 
                                                        task_config, norm_stats, 
                                                        inputs, targets, forcings)
        return loss, (diagnostics, next_state)
    
    # Compute gradients and auxiliary outputs
    (loss, (diagnostics, next_state)), grads = jax.value_and_grad(compute_loss, has_aux=True)(params, state, 
                                                                                              inputs, targets, 
                                                                                              forcings)
    
    # Combine the gradient across all devices (by taking their mean).
    #grads = jax.lax.pmean(grads, axis_name='devices')

    # Compute the global norm of all gradients
    total_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in tree_util.tree_leaves(grads)))

    # Clip gradients if the total norm exceeds the threshold
    def clip_grads(g, clip_norm=32):
        return jnp.where(total_norm > clip_norm, g * clip_norm / total_norm, g)

    clipped_grads = tree_util.tree_map(clip_grads, grads)

    updates, opt_state = optimizer.update(grads, opt_state, params=params)
    new_params = optax.apply_updates(params, updates)
    
    return new_params, opt_state

# Function for deployment. Used to make predictions on new data and rollout. 
@hk.transform_with_state
def run_forward(model_config, task_config, norm_stats, inputs, targets_template, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config, norm_stats)
    return predictor(inputs, targets_template, forcings)

@hk.transform_with_state
def loss_fn(model_config, task_config, norm_stats, inputs, targets_template, forcings):
    predictor = construct_wrapped_graphcast(model_config, task_config, norm_stats)
    loss, diagnostics = predictor.loss(inputs, targets, forcings)
    return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

def with_configs(fn):
    return functools.partial(
      fn, model_config=model_config, task_config=task_config, norm_stats=norm_stats)

def with_optimizer(fn, optimizer):
    return functools.partial(
      fn, optimizer=optimizer)

# Always pass params and state, so the usage below are simpler
def with_model_params(fn):
    return functools.partial(fn, params=model_params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
    return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

model_params, state = init_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=inputs,
    targets_template=targets,
    forcings=forcings)

#run_forward_jitted = drop_state(jax.jit(with_configs(run_forward.apply)))

optimizer = optax.adamw(1e-4, b1=0.9, b2=0.95, eps=1e-8, weight_decay=0.1)
opt_state = optimizer.init(model_params)

train_step_jitted = jax.jit(with_optimizer(with_configs(train_step), optimizer))

count = count_total_parameters(model_params)
print(f'Number of Model Parameters: {count}')



Number of Model Parameters: 36015209


In [12]:
def update_params_with_graphcast(model_params, graphcast_params):
    def update_recursive(model_dict, graphcast_dict):
        for key, value in model_dict.items():
            if isinstance(value, dict) and key in graphcast_dict and isinstance(graphcast_dict[key], dict):
                # If both are dictionaries, recurse
                update_recursive(model_dict[key], graphcast_dict[key])
            elif key in graphcast_dict and model_dict[key].shape == graphcast_dict[key].shape:
                # If both shapes are identical, update the value
                model_dict[key] = graphcast_dict[key]
    
    update_recursive(model_params, graphcast_params)
    return model_params

# Example usage:
updated_model_params = update_params_with_graphcast(model_params, graphcast_params)


In [13]:
get_gpu_memory()

Fri Jun  7 16:23:54 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:21:00.0 Off |                    0 |
| N/A   37C    P0             69W /  300W |   62642MiB /  81920MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  |   00

In [14]:
#%load_ext memory_profiler

In [15]:
print(f'| {batch_size}  |  {count:.3e}  |  {latent_size}  | {gnn_msg_steps}    |   {hidden_layers}  ')

| 2  |  3.602e+07  |  512  | 16    |   1  


In [16]:
%%time

model_params, opt_state = train_step_jitted(params=updated_model_params, 
                                            state=state, 
                                            opt_state=opt_state, 
                                            inputs=inputs, 
                                            targets=targets, 
                                            forcings=forcings)
get_gpu_memory()


Fri Jun  7 16:24:15 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:21:00.0 Off |                    0 |
| N/A   38C    P0             67W /  300W |   74478MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          On  |   00

# GenCast 

In [17]:
def add_noise(values: xarray.Dataset,
              noise: xarray.Dataset,
              ) -> xarray.Dataset:
  """Normalize variables using the given scales and (optionally) locations."""
  def add_noise_to_array(array):
    if array.name is None:
      raise ValueError(
          "Can't look up normalization constants because array has no name.")
    if array.name in noise:
        array = array + noise[array.name].astype(array.dtype)
    else:
        logging.warning('No normalization location found for %s', array.name)
        
    return array

  return xarray_tree.map_structure(add_noise_to_array, values)

In [18]:
import jax
import jax.numpy as jnp
import numpy as np
import xarray as xr

def get_random_dataset(dataset, scale):
    # Create a copy of the original dataset
    new_dataset = dataset.copy()
    
    # Iterate over all variables in the dataset
    for var_name in new_dataset.data_vars:
        var = new_dataset[var_name]
        # Check if the variable has a 'time' dimension
        if 'time' in var.dims:
            # Generate random data with the same shape as the variable
            random_data = np.random.normal(loc=0.0, scale=scale, size=var.shape)
            # Assign the random data to the variable
            new_dataset[var_name].data = random_data
    
    return new_dataset



class GenCastSampler:
    def __init__(self, gencast, num_steps=18, 
                 sigma_min=0.03, sigma_max=80, rho=7,
                 S_churn=2.5, S_min=0.75, S_max=80, S_noise=1.05):
        self.gencast = gencast 
        self.num_steps = num_steps
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.rho = rho 
        self.S_churn = S_churn 
        self.S_min = S_min 
        self.S_max = S_max 
        self.S_noise = S_noise 

    def inverse_cdf(self, u):
        exp = 1/self.rho
        return (self.sigma_max**(exp) + u * (self.sigma_min**(exp) - self.sigma_max**(exp)))**self.rho    
        
    def sample(self, 
               inputs, 
               targets_template, 
               forcings
              ):
        
        # Why is GenCast using a random u? that makes the t_steps random?
        
        t_steps = [self.inverse_cdf(np.random.uniform(low=0.0, high=1.0, size=1))[0] 
                   for i in np.arange(self.num_steps) ]
        
        #step_indices = np.arange(self.num_steps, dtype=jnp.bfloat16)
        #t_steps1 = (self.sigma_max ** (1 / self.rho) + step_indices / (self.num_steps - 1))
        #t_steps = (t_steps1 * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))) ** self.rho
        #t_steps = np.concatenate([np.round(t_steps), np.zeros_like(t_steps[:1])])  # t_N = 0

        print(t_steps)
        
        #x_next = latents.astype(jnp.float64) * t_steps[0]
        
        # Initialize a xarray.Dataset with random data matching the target dataset?. 
        x_next = get_random_dataset(targets, scale=t_steps[0])
        
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1
            x_cur = x_next

            gamma = min(self.S_churn / self.num_steps, np.sqrt(2) - 1) if self.S_min <= t_cur <= self.S_max else 0
            t_hat = np.round(t_cur + gamma * t_cur)
            err = get_random_dataset(targets, scale=self.S_noise)
            x_hat = x_cur + np.sqrt(t_hat ** 2 - t_cur ** 2) * err

            #model_input_images = jnp.concatenate([x_hat, condition_images], axis=1)
            denoised = self.gencast(params=model_params, 
                                    state=state, 
                                    rng=jax.random.PRNGKey(0),
                                    inputs=inputs, targets_template=targets, 
                                    forcings=forcings, sigma=50.) 
                    
            d_cur = (x_hat - denoised) / t_hat
            x_next = x_hat + (t_next - t_hat) * d_cur

            if i < self.num_steps - 1:
                #model_input_images = jnp.concatenate([x_next, condition_images], axis=1)
                #denoised = self.gencast(model_input_images, t_next).astype(jnp.float64)
                denoised = self.gencast(params=model_params, 
                                    state=state, 
                                    rng=jax.random.PRNGKey(0),
                                    inputs=inputs, targets_template=targets, 
                                    forcings=forcings, sigma=50.) 
                
                d_prime = (x_next - denoised) / t_next
                x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        return x_next


In [19]:
import gc

from numba import cuda
cuda.select_device(0)
cuda.close()
gc.collect()

14