[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AndReGeist/wandb_cluster_neuralode/blob/main/colab/basic_example.ipynb#scrollTo=qXAp1blGadTl)

# Basic example: Use "Weights and biases" with JAX

In this notebook, we track the training process of a neural ODE regression using [weight and biases](https://wandb.ai/site).

⚠️ The NeuralODE code below is copied and then altered from the example provided in the [Diffrax  git repo](https://docs.kidger.site/diffrax/examples/neural_ode/). If you happen to use Diffrax in your academic research kindly consider [citing the library](https://docs.kidger.site/diffrax/).

The code below got inspired by the blogposts:
- [Weights and biases: Quickstart](https://docs.wandb.ai/quickstart)
- [A complete Weights and Biases tutorial](https://theaisummer.com/weights-and-biases-tutorial/)
- [Writing a training loop in JAX and Flax](https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-and-Flax--VmlldzoyMzA4ODEy)

In [206]:
!pip install -q equinox diffrax optax

In [207]:
import time

import diffrax  # https://docs.kidger.site/diffrax/
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

We use...
- JAX, providing linear algebra with automatic differentiation and GPU acceleration
- Equinox to build neural networks
- Optax for optimisers
- Diffrax for ODE solvers

Define a NN that models an ODE...

In [208]:
class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

Wrap the ODE solver into a model...

In [209]:
class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

Toy dataset of nonlinear oscillators. Sample paths look like deformed sines and cosines.

In [210]:
def _get_data(ts, *, key):
    y0 = jrandom.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jrandom.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

In [211]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

Before we setup data set creation and the main training loop, we install the weights and biases library and login with our account...



In [212]:
!pip install -q wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrene-geist[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [213]:
import wandb
import pickle
import os

Data sets (as well as models) can be stored in wandb as "artifacs". An [artifact](https://docs.wandb.ai/guides/data-and-model-versioning) is simply a folder with files on the wandb server. Wandb checksums the artifact to identify changes and track new versions.

In [214]:
def create_dataset(seed=42, dataset_size=256):

    # Initialize a new W&B run to track this job
    run = wandb.init(project="basic_example", job_type="dataset-creation")

    key = jrandom.PRNGKey(seed)
    ts, ys = get_data(dataset_size, key=key)
    _, length_size, data_size = ys.shape

    data = {"ts": ts, 
            "ys": ys, 
            "length_size": length_size, 
            "data_size": data_size}

    with open('data.pickle', 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

    dataset = wandb.Artifact('my-dataset', type='dataset')
    dataset.add_file("data.pickle")
    run.log_artifact(dataset)

    # ============= wandb - finish session ============= #
    wandb.finish()

In the **main training loop**, we use wandb to...

1. Load the dataset   
2. log the optimization configurations
3.   log optimization data
4.   log model gradients

In [215]:
def main(        
        batch_size=32,
        lr_strategy=(3e-3, 3e-3),
        steps_strategy=(500, 500),
        length_strategy=(0.1, 1),
        width_size=64,
        depth=2,
        seed=5678,
        plot=True,
        print_every=100
):

    # ============= wandb - log optimization configurations ============= #
    # Define a config dictionary object
    config = {
      "dataset_size": dataset_size,
      "batch_size": batch_size,
      "lr_strategy": lr_strategy,
      "steps_strategy": steps_strategy,
      "length_strategy": length_strategy,
      "width_size": width_size,
      "depth": depth,
      "seed": seed,
      "print_every": print_every,
      "load_data_from_wandb": load_data_from_wandb,
      "data_artifact_path": data_artifact_path
    }

    run = wandb.init(
              project="wandb_cluster_neuralode",
              job_type="basic_example",
              config=config
    ) 
    # You can explicitly state to which team wandb will save data by adding the option entity='<Team name>'

    key = jrandom.PRNGKey(seed)
    model_key, loader_key = jrandom.split(key, 2)

    # ============= wandb - store/load data set "artifact" ============= #
    artifact = run.use_artifact(data_artifact_path, type='dataset')
    artifact_dir = artifact.download()
    
    data_path = os.path.join(artifact_dir, 'my-dataset.pickle')
    with open(data_path, 'rb') as handle:
      data = pickle.load(handle)
      
      for key,val in data.items():
        exec(key + '=val')

    model = NeuralODE(data_size, width_size, depth, key=model_key)


    # ============= Training loop ============= #
    # Until step 500 we train on only the first 10% of each time series.
    # This is a standard trick to avoid getting caught in a local minimum.

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step, (yi,) in zip(
                range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()

            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

                # ============= wandb - log optimization data ============= #
                wandb.log({
                  'step': step, 
                  'loss': loss, 
                  'computation time': end - start
                })

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        model_y = model(ts, ys[0, 0])
        plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    # ============= wandb - finish session ============= #
    wandb.finish()

    return ts, ys, model

In [216]:
create_dataset()
#ts, ys, model = main(seed=42)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666893116665354, max=1.0)…

Problem at: <ipython-input-214-0700f9833e43> 4 create_dataset


Traceback (most recent call last):
  File "/usr/local/lib/python3.9/dist-packages/wandb/sdk/wandb_init.py", line 1144, in init
    run = wi.init()
  File "/usr/local/lib/python3.9/dist-packages/wandb/sdk/wandb_init.py", line 773, in init
    raise error
wandb.errors.CommError: Error communicating with wandb process, exiting...
For more info see: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-
[34m[1mwandb[0m: [32m[41mERROR[0m Abnormal program exit


Exception: ignored