[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AndReGeist/wandb_cluster_neuralode/blob/main/colab/basic_example.ipynb#scrollTo=qXAp1blGadTl)

# Basic example: Use 'Weights and biases' with JAX

In this notebook, we track the training process of a neural ODE regression using [weight and biases](https://wandb.ai/site).

⚠️ The NeuralODE code below is copied and then altered from the example provided in the [Diffrax  git repo](https://docs.kidger.site/diffrax/examples/neural_ode/). If you happen to use Diffrax in your academic research kindly consider [citing the library](https://docs.kidger.site/diffrax/).

The code below got inspired by the blogposts:
- [Weights and biases: Quickstart](https://docs.wandb.ai/quickstart)
- [A complete Weights and Biases tutorial](https://theaisummer.com/weights-and-biases-tutorial/)
- [Writing a training loop in JAX and Flax](https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-and-Flax--VmlldzoyMzA4ODEy)

In [79]:
!pip install -q equinox diffrax optax

In [80]:
import time

import diffrax  # https://docs.kidger.site/diffrax/
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import numpy as np
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

## Setting up dynamics and NeuralODE
We use...
- JAX, providing linear algebra with automatic differentiation and GPU acceleration
- Equinox to build neural networks
- Optax for optimisers
- Diffrax for ODE solvers

Define a NN that models an ODE...

In [81]:
class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

Wrap the ODE solver into a model...

In [82]:
class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

Toy dataset of nonlinear oscillators. Sample paths look like deformed sines and cosines.

In [83]:
def _get_data(ts, *, key):
    y0 = jrandom.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jrandom.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

In [84]:
def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

## Setup wandb
Before setting up the data set and training loop, we install the weights and biases library and login with our account...



In [85]:
!pip install -q wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrene-geist[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [86]:
import wandb
import pickle
import os

## Create data and log with wandb
Data sets (as well as models) can be stored in wandb as [artifact](https://docs.wandb.ai/guides/data-and-model-versioning) being simply folders with files on the wandb server. This [Google Colab](https://colab.research.google.com/drive/1GM22vkt1BXm3JVpTX8QeG1E2Rwl7xr_0?usp=sharing) illustrates working with artifacts. Wandb checksums the artifact to identify changes and track new versions.

Artifacts are useful for...
- [data versioning](https://docs.wandb.ai/guides/data-and-model-versioning/dataset-versioning)
- [model versioning](https://docs.wandb.ai/guides/data-and-model-versioning/model-versioning)

In [87]:
def create_dataset(seed=42, size_train=256, size_test=256):

    # ==== Initialize W&B run to track job ==== #
    config = {
        'size_train': size_train,
        'size_test': size_test,
        'seed': seed
    }
    run = wandb.init(project='wandb_cluster_neuralode', job_type='dataset-creation', config=config)

    key = jrandom.PRNGKey(seed)
    key_train, key_test = jrandom.split(key, 2)
    ts, ys = get_data(size_train, key=key_train)
    ts_test, ys_test = get_data(size_test, key=key_test)
    _, length_size, data_size = ys.shape

    data = {'ts': ts, 
            'ys': ys, 
            'length_size': length_size, 
            'data_size': data_size,
            'ts_test': ts_test, 
            'ys_test': ys_test}

    with open('data.pickle', 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

    dataset = wandb.Artifact('my-dataset', type='dataset')  # Create new artifact
    dataset.add_file('data.pickle')  # Add files to artifact
    run.log_artifact(dataset)  # Log artifact to save it as an output of this run

    wandb.finish()  # Finish session

## Training loop
In the training loop, we use wandb to...

1. Load the dataset   
2. log the optimization configurations
3. log optimization data
4. log model gradients

In [88]:
def main(        
        batch_size=32,
        lr_strategy=(3e-3, 3e-3),
        steps_strategy=(500, 500),
        length_strategy=(0.1, 1),
        width_size=64,
        depth=2,
        seed=5678,
        plot=False,
        print_every=100,
        watch_run = False
):

    # Define the config dictionary object
    config = {
        'batch_size': batch_size,
        'lr_strategy': lr_strategy,
        'steps_strategy': steps_strategy,
        'length_strategy': length_strategy,
        'width_size': width_size,
        'depth': depth,
        'seed': seed,
        'print_every': print_every
    }

    # ==== Initialize W&B run to track job ==== #
    run = wandb.init(
              project='wandb_cluster_neuralode',
              job_type='basic_example',
              config=config
    ) 
    # You can explicitly state to which team wandb will save data by adding the option entity='<Team name>'
    
    # When using sweep the default config gets overwritten
    config = wandb.config 


    # ==== W&B - load data artifact ==== #
    artifact = run.use_artifact('my-dataset' + ':latest')
    artifact_dir = artifact.download()
    
    data_path = os.path.join(artifact_dir, 'data.pickle')
    with open(data_path, 'rb') as handle:
      data = pickle.load(handle)

    # ==== JAX - init model ==== #
    key = jrandom.PRNGKey(config.seed)
    model_key, loader_key = jrandom.split(key, 2)
    model = NeuralODE(data['data_size'], config.width_size, config.depth, key=model_key)


    # === Training loop === #
    # Until step 500 we train on only the first 10% of each time series.
    # This is a standard trick to avoid getting caught in a local minimum.

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state, grads, updates

    def pytree_leaves_to_dict(pytree_obj, name='array', pick_ndarrays=True):
      # W&B - To log the ndarrays of a JAX pytree, we need to extract the ndarrays, 
      # transform them to numpy arrays and save these in a dict.
      # Rightnow, I did not figure out how to give these arrays an informative name
      leaves = jax.tree_util.tree_leaves(pytree_obj)
      for k, leaf in enumerate(leaves):
          if isinstance(leaf, jax.numpy.ndarray):
              key_name = name + str(k)
              log_dict[key_name] = np.array(leaf)
      return log_dict

    for lr, steps, length in zip(config.lr_strategy, config.steps_strategy, config.length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = data['ts'][: int(data['length_size'] * length)]
        _ys = data['ys'][:, : int(data['length_size'] * length)]
        for step, (yi,) in zip(
                range(steps), dataloader((_ys,), config.batch_size, key=loader_key)
        ):
            start = time.time()
            loss_train, model, opt_state, grads, updates = make_step(_ts, yi, model, opt_state)
            end = time.time()

            if (step % print_every) == 0 or step == steps - 1:
                print(f'Step: {step}, Loss: {loss_train}, Computation time: {end - start}')

                # Test model
                loss_test, _ = grad_loss(model, data['ts_test'], data['ys_test'])

                # === W&B - log optimization === #
                log_dict = {
                  'step': step, 
                  'loss_train': loss_train,
                  'loss_test': loss_test, 
                  'computation time': end - start
                }
              
                if watch_run:
                    log_dict.update( pytree_leaves_to_dict(model, name='model_array') )
                    #log_dict.update( pytree_leaves_to_dict(grads, name='grad_array') )
                    log_dict.update( pytree_leaves_to_dict(updates, name='updates_array') )

                wandb.log( log_dict )
          
    if plot:
        plt.plot(data['ts'], data['ys'][0, :, 0], c='dodgerblue', label='Real')
        plt.plot(data['ts'], data['ys'][0, :, 1], c='dodgerblue')
        model_y = model(data['ts'], data['ys'][0, 0])
        plt.plot(data['ts'], model_y[:, 0], c='crimson', label='Model')
        plt.plot(data['ts'], model_y[:, 1], c='crimson')
        plt.legend()
        plt.tight_layout()
        plt.savefig('neural_ode.png')
        plt.show()

    wandb.finish()  # Finish W&B session

In [89]:
#create_dataset(seed=42)
main(seed=99, watch_run=False)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.11333475261926651, Computation time: 26.88011407852173
Step: 100, Loss: 0.007018388248980045, Computation time: 0.004915952682495117
Step: 200, Loss: 0.007753062527626753, Computation time: 0.004452705383300781
Step: 300, Loss: 0.003183433087542653, Computation time: 0.006302833557128906
Step: 400, Loss: 0.0008427510038018227, Computation time: 0.005587577819824219
Step: 499, Loss: 0.0011529180919751525, Computation time: 0.005561113357543945
Step: 0, Loss: 0.03783680498600006, Computation time: 17.338109493255615
Step: 100, Loss: 0.007608955726027489, Computation time: 0.03306317329406738
Step: 200, Loss: 0.008614838123321533, Computation time: 0.034104347229003906
Step: 300, Loss: 0.0058342753909528255, Computation time: 0.03450608253479004
Step: 400, Loss: 0.003502139588817954, Computation time: 0.03551506996154785
Step: 499, Loss: 0.009178807027637959, Computation time: 0.036896705627441406


0,1
computation time,█▁▁▁▁▁▆▁▁▁▁▁
loss_test,▁▁▁▁▁▁█▁▁▁▁▁
loss_train,█▁▁▁▁▁▃▁▁▁▁▂
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.0369
loss_test,0.00371
loss_train,0.00918
step,499.0


## Automize parameter search using weights & biases sweeps
Use Weights & Biases [Sweeps](https://docs.wandb.ai/guides/sweeps) to automate hyperparameter search and explore the space of possible models. Pick from popular search methods such as Bayesian, grid search, and random to search the hyperparameter space. Scale and parallelize Sweep jobs across one or more machines.



In [90]:
# Define the search space
# You can specify...
# a range 'x': {'max': 0.1, 'min': 0.01},
# or values 'y': {'values': [1, 3, 7]},
sweep_configuration = {
    'method': 'random',
    'metric': {'goal': 'minimize', 'name': 'loss_test'},
    'parameters': 
    {       
        'batch_size': {'values': [32]}, # 32
        'lr_strategy': {'values': [(3e-3, 3e-3)]}, # (3e-3, 3e-3)
        'steps_strategy': {'values': [(500, 500)]}, # (500, 500)
        'length_strategy': {'values': [(0.1, 1)]}, # (0.1, 1)
        'width_size': {'values': [20, 64, 150]}, # 64
        'depth': {'values': [1, 2, 3]}, # 2
        'seed': {'values': [42]},
        'print_every': {'values': [100]}  # 100
     }
}

# Start the sweep
sweep_id = wandb.sweep(sweep=sweep_configuration, project='wandb_cluster_neuralode')
wandb.agent(sweep_id, function=main, count=10)

Create sweep with ID: us3tn8wt
Sweep URL: https://wandb.ai/rene-geist/wandb_cluster_neuralode/sweeps/us3tn8wt


[34m[1mwandb[0m: Agent Starting Run: z4tihwx3 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 3
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 20


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.07296563684940338, Computation time: 27.024245500564575
Step: 100, Loss: 0.013954832218587399, Computation time: 0.0031998157501220703
Step: 200, Loss: 0.004157245624810457, Computation time: 0.005562782287597656
Step: 300, Loss: 0.0006523383199237287, Computation time: 0.0037970542907714844
Step: 400, Loss: 0.00020883390970993787, Computation time: 0.003793954849243164
Step: 499, Loss: 0.0007606055005453527, Computation time: 0.0038187503814697266
Step: 0, Loss: 0.019614944234490395, Computation time: 17.28099513053894
Step: 100, Loss: 0.0057858191430568695, Computation time: 0.04332375526428223
Step: 200, Loss: 0.0027169983368366957, Computation time: 0.02740001678466797
Step: 300, Loss: 0.0010726527543738484, Computation time: 0.02580857276916504
Step: 400, Loss: 0.0014723350759595633, Computation time: 0.02711629867553711
Step: 499, Loss: 0.0006447380874305964, Computation time: 0.04514479637145996


0,1
computation time,█▁▁▁▁▁▅▁▁▁▁▁
loss_test,█▂▁▁▁▁▄▁▁▁▁▁
loss_train,█▂▁▁▁▁▃▂▁▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.04514
loss_test,0.00046
loss_train,0.00064
step,499.0


[34m[1mwandb[0m: Agent Starting Run: kj6n9g99 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 150


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.06524918228387833, Computation time: 18.394052982330322
Step: 100, Loss: 0.011897136457264423, Computation time: 0.012916803359985352
Step: 200, Loss: 0.006140392739325762, Computation time: 0.026012182235717773
Step: 300, Loss: 0.0009008643100969493, Computation time: 0.0168759822845459
Step: 400, Loss: 0.0006347507005557418, Computation time: 0.0164794921875
Step: 499, Loss: 0.0005457533989101648, Computation time: 0.016537904739379883
Step: 0, Loss: 0.025462059304118156, Computation time: 16.748396158218384
Step: 100, Loss: 0.003907904494553804, Computation time: 0.09642481803894043
Step: 200, Loss: 0.005904551595449448, Computation time: 0.10240936279296875
Step: 300, Loss: 0.004640038590878248, Computation time: 0.10671496391296387
Step: 400, Loss: 0.0004905553068965673, Computation time: 0.11063146591186523
Step: 499, Loss: 0.0007294952520169318, Computation time: 0.11672329902648926


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
computation time,█▁▁▁▁▁▇▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▄▁▂▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.11672
loss_test,0.00118
loss_train,0.00073
step,499.0


[34m[1mwandb[0m: Agent Starting Run: 384r7h32 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 64


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.059469472616910934, Computation time: 11.45964527130127
Step: 100, Loss: 0.010909945704042912, Computation time: 0.0043408870697021484
Step: 200, Loss: 0.006555596832185984, Computation time: 0.004612922668457031
Step: 300, Loss: 0.0019303553272038698, Computation time: 0.004456281661987305
Step: 400, Loss: 0.0011630762601271272, Computation time: 0.005420207977294922
Step: 499, Loss: 0.0012586781522259116, Computation time: 0.008536577224731445
Step: 0, Loss: 0.03719727322459221, Computation time: 11.532605648040771
Step: 100, Loss: 0.010364141315221786, Computation time: 0.03394365310668945
Step: 200, Loss: 0.006951243616640568, Computation time: 0.03400063514709473
Step: 300, Loss: 0.008381043560802937, Computation time: 0.03309345245361328
Step: 400, Loss: 0.0034200826194137335, Computation time: 0.033284664154052734
Step: 499, Loss: 0.006466153543442488, Computation time: 0.037371158599853516


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▅▂▂▂▁▂
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.03737
loss_test,0.00298
loss_train,0.00647
step,499.0


[34m[1mwandb[0m: Agent Starting Run: noc11mp6 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 1
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 150


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.10406722128391266, Computation time: 17.677746057510376
Step: 100, Loss: 0.014937231317162514, Computation time: 0.003751993179321289
Step: 200, Loss: 0.011628729291260242, Computation time: 0.0037336349487304688
Step: 300, Loss: 0.004836494103074074, Computation time: 0.003786325454711914
Step: 400, Loss: 0.002841165056452155, Computation time: 0.0038938522338867188
Step: 499, Loss: 0.0041902936063706875, Computation time: 0.004709959030151367
Step: 0, Loss: 0.024687549099326134, Computation time: 17.067394733428955
Step: 100, Loss: 0.03232789412140846, Computation time: 0.04089474678039551
Step: 200, Loss: 0.01704997941851616, Computation time: 0.028095483779907227
Step: 300, Loss: 0.01393081620335579, Computation time: 0.028810739517211914
Step: 400, Loss: 0.021362867206335068, Computation time: 0.029102087020874023
Step: 499, Loss: 0.016379041597247124, Computation time: 0.04916501045227051


0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▄▁▁▁▁▁
loss_train,█▂▂▁▁▁▃▃▂▂▂▂
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.04917
loss_test,0.01306
loss_train,0.01638
step,499.0


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: jxisy4zn with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 1
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 20


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.08794393390417099, Computation time: 17.028018712997437
Step: 100, Loss: 0.015755670145154, Computation time: 0.0017540454864501953
Step: 200, Loss: 0.01255581434816122, Computation time: 0.0017879009246826172
Step: 300, Loss: 0.005584015045315027, Computation time: 0.0018165111541748047
Step: 400, Loss: 0.0037837866693735123, Computation time: 0.0017547607421875
Step: 499, Loss: 0.0050829811953008175, Computation time: 0.001760721206665039
Step: 0, Loss: 0.04585922136902809, Computation time: 15.682780504226685
Step: 100, Loss: 0.022741448134183884, Computation time: 0.01436924934387207
Step: 200, Loss: 0.01646026037633419, Computation time: 0.024793148040771484
Step: 300, Loss: 0.010066828690469265, Computation time: 0.01489710807800293
Step: 400, Loss: 0.010081266984343529, Computation time: 0.014748573303222656
Step: 499, Loss: 0.01246845256537199, Computation time: 0.015107870101928711


0,1
computation time,█▁▁▁▁▁▇▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▄▃▂▂▂▂
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.01511
loss_test,0.01035
loss_train,0.01247
step,499.0


[34m[1mwandb[0m: Agent Starting Run: 6w4tdm04 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 150


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.06524918228387833, Computation time: 12.831528186798096
Step: 100, Loss: 0.011897136457264423, Computation time: 0.01277017593383789
Step: 200, Loss: 0.006140392739325762, Computation time: 0.01655292510986328
Step: 300, Loss: 0.0009008643100969493, Computation time: 0.02743077278137207
Step: 400, Loss: 0.0006347507005557418, Computation time: 0.016727924346923828
Step: 499, Loss: 0.0005457533989101648, Computation time: 0.016835689544677734
Step: 0, Loss: 0.025462059304118156, Computation time: 12.298889636993408
Step: 100, Loss: 0.003907904494553804, Computation time: 0.09573197364807129
Step: 200, Loss: 0.005904551595449448, Computation time: 0.16304850578308105
Step: 300, Loss: 0.004640038590878248, Computation time: 0.17379283905029297
Step: 400, Loss: 0.0004905553068965673, Computation time: 0.1780834197998047
Step: 499, Loss: 0.0007294952520169318, Computation time: 0.16985321044921875


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▄▁▂▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.16985
loss_test,0.00118
loss_train,0.00073
step,499.0


[34m[1mwandb[0m: Agent Starting Run: 1tjwdz06 with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 150


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.06524918228387833, Computation time: 12.615247249603271
Step: 100, Loss: 0.011897136457264423, Computation time: 0.01354217529296875
Step: 200, Loss: 0.006140392739325762, Computation time: 0.017592191696166992
Step: 300, Loss: 0.0009008643100969493, Computation time: 0.016311168670654297
Step: 400, Loss: 0.0006347507005557418, Computation time: 0.016498565673828125
Step: 499, Loss: 0.0005457533989101648, Computation time: 0.01659679412841797
Step: 0, Loss: 0.025462059304118156, Computation time: 12.184600830078125
Step: 100, Loss: 0.003907904494553804, Computation time: 0.12304902076721191
Step: 200, Loss: 0.005904551595449448, Computation time: 0.10210776329040527
Step: 300, Loss: 0.004640038590878248, Computation time: 0.10882139205932617
Step: 400, Loss: 0.0004905553068965673, Computation time: 0.10840916633605957
Step: 499, Loss: 0.0007294952520169318, Computation time: 0.10710859298706055


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▄▁▂▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.10711
loss_test,0.00118
loss_train,0.00073
step,499.0


[34m[1mwandb[0m: Agent Starting Run: nc6mgver with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 150


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.06524918228387833, Computation time: 12.56035041809082
Step: 100, Loss: 0.011897136457264423, Computation time: 0.013067483901977539
Step: 200, Loss: 0.006140392739325762, Computation time: 0.026799917221069336
Step: 300, Loss: 0.0009008643100969493, Computation time: 0.01639533042907715
Step: 400, Loss: 0.0006347507005557418, Computation time: 0.016874313354492188
Step: 499, Loss: 0.0005457533989101648, Computation time: 0.01636528968811035
Step: 0, Loss: 0.025462059304118156, Computation time: 12.1502046585083
Step: 100, Loss: 0.003907904494553804, Computation time: 0.09461712837219238
Step: 200, Loss: 0.005904551595449448, Computation time: 0.10222315788269043
Step: 300, Loss: 0.004640038590878248, Computation time: 0.10666632652282715
Step: 400, Loss: 0.0004905553068965673, Computation time: 0.1066126823425293
Step: 499, Loss: 0.0007294952520169318, Computation time: 0.17625951766967773


0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▂▁▁▁▄▁▂▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.17626
loss_test,0.00118
loss_train,0.00073
step,499.0


[34m[1mwandb[0m: Agent Starting Run: kv26na9l with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 3
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 64


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.06317592412233353, Computation time: 16.472474336624146
Step: 100, Loss: 0.007824818603694439, Computation time: 0.006460666656494141
Step: 200, Loss: 0.003035200061276555, Computation time: 0.012872695922851562
Step: 300, Loss: 0.0005349136190488935, Computation time: 0.008085489273071289
Step: 400, Loss: 0.00015244490350596607, Computation time: 0.008014440536499023
Step: 499, Loss: 0.00046238559298217297, Computation time: 0.008188009262084961
Step: 0, Loss: 0.04044495150446892, Computation time: 16.08171558380127
Step: 100, Loss: 0.00717288488522172, Computation time: 0.07592296600341797
Step: 200, Loss: 0.003122155088931322, Computation time: 0.05110931396484375
Step: 300, Loss: 0.006925210822373629, Computation time: 0.09041118621826172
Step: 400, Loss: 0.0005940109258517623, Computation time: 0.051848411560058594
Step: 499, Loss: 0.0009731792379170656, Computation time: 0.08678650856018066


0,1
computation time,█▁▁▁▁▁█▁▁▁▁▁
loss_test,█▁▁▁▁▁▂▁▁▁▁▁
loss_train,█▂▁▁▁▁▅▂▁▂▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.08679
loss_test,0.00205
loss_train,0.00097
step,499.0


[34m[1mwandb[0m: Agent Starting Run: ck0f9atq with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	depth: 2
[34m[1mwandb[0m: 	length_strategy: [0.1, 1]
[34m[1mwandb[0m: 	lr_strategy: [0.003, 0.003]
[34m[1mwandb[0m: 	print_every: 100
[34m[1mwandb[0m: 	seed: 42
[34m[1mwandb[0m: 	steps_strategy: [500, 500]
[34m[1mwandb[0m: 	width_size: 20


[34m[1mwandb[0m:   1 of 1 files downloaded.  


Step: 0, Loss: 0.15245793759822845, Computation time: 19.079686880111694
Step: 100, Loss: 0.01665269397199154, Computation time: 0.0030117034912109375
Step: 200, Loss: 0.007385714445263147, Computation time: 0.0023856163024902344
Step: 300, Loss: 0.0017961787525564432, Computation time: 0.002489805221557617
Step: 400, Loss: 0.0008132901857607067, Computation time: 0.003305673599243164
Step: 499, Loss: 0.0010794280096888542, Computation time: 0.0028791427612304688
Step: 0, Loss: 0.01662066951394081, Computation time: 16.153714418411255
Step: 100, Loss: 0.010402476415038109, Computation time: 0.020516633987426758
Step: 200, Loss: 0.005784120410680771, Computation time: 0.02380084991455078
Step: 300, Loss: 0.002810853300616145, Computation time: 0.02010798454284668
Step: 400, Loss: 0.002062638755887747, Computation time: 0.021400928497314453
Step: 499, Loss: 0.002252595964819193, Computation time: 0.03543734550476074


0,1
computation time,█▁▁▁▁▁▇▁▁▁▁▁
loss_test,█▁▁▁▁▁▁▁▁▁▁▁
loss_train,█▂▁▁▁▁▂▁▁▁▁▁
step,▁▂▄▅▇█▁▂▄▅▇█

0,1
computation time,0.03544
loss_test,0.00196
loss_train,0.00225
step,499.0
