# Using jbar with Diffrax

This notebook demonstrates how `jbar` can replace the default `diffrax` progress meter.

In [1]:
import jax
import jax.numpy as jnp
import diffrax
from jbar import TqdmProgressMeter

W1223 00:08:24.884192   30319 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1223 00:08:24.891359   30225 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


## Setup ODE Problem

We set up a somewhat heavy ODE problem to simulate a long-running computation.

In [2]:
SIZE = 5000
key = jax.random.PRNGKey(0)
k1, k2 = jax.random.split(key)

W = jax.random.normal(k1, (SIZE, SIZE)) / jnp.sqrt(SIZE)
y0 = jax.random.normal(k2, (SIZE,))

def vector_field(t, y, args):
    matrix = args
    return jnp.tanh(jnp.dot(matrix, y))

term = diffrax.ODETerm(vector_field)
solver = diffrax.Euler()
stepsize_controller = diffrax.ConstantStepSize()

## Solve with Progress Meter

We use `TqdmProgressMeter` with `percent_progress=True` (Diffrax mode).

In [3]:
p_meter = TqdmProgressMeter(total=100, percent_progress=True, refresh_steps=10)

print("Solving...")
sol = diffrax.diffeqsolve(
    term, 
    solver, 
    t0=0.0, 
    t1=10.0, 
    dt0=0.01, 
    y0=y0, 
    args=W, 
    stepsize_controller=stepsize_controller,
    progress_meter=p_meter
)
p_meter.terminate()

print("Done.")
print(f"Final state norm: {jnp.linalg.norm(sol.ys[-1]):.2f}")

Solving...


0%|          | 0.0/100.0 [00:00<?]

Done.
Final state norm: 530.52
