In [1]:
import jax
import jax.numpy as jnp
import jax.random as jr
from tqdm.auto import tqdm
import optax
from functools import partial
print(jax.devices())

[CudaDevice(id=0)]


In [2]:
a = jnp.arange(10)
print(f"a: {a}")
print(f"a.shape: {a.shape}")

try:
    # cursor tried to add this line (try/except added by me): 
    print(f"a.device: {a.device()}")
except TypeError as x:
    print(TypeError(x))
    print("In jax, we generally don't have to worry about moving data between gpu and cpu!")

print("most vector/matrix operations are just like numpy/torch")
print(f"a.T @ a: {a.T @ a}")
print(f"a @ a: {a @ a}")

a: [0 1 2 3 4 5 6 7 8 9]
a.shape: (10,)
'jaxlib._jax.Device' object is not callable
In jax, we generally don't have to worry about moving data between gpu and cpu!
most vector/matrix operations are just like numpy/torch
a.T @ a: 285
a @ a: 285


In [3]:
b = jnp.zeros(10)
print(f"b: {b}")

try:
    b[0] = 1
except Exception as e:
    print(e)

b = b.at[0].set(1)
print(f"b: {b}")

print("this can be a bit annoying, but restrictions like these allow for enourmous benefits (i.e. jit!)")

b: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html


b: [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
this can be a bit annoying, but restrictions like these allow for enourmous benefits (i.e. jit!)


In [4]:
print("But before jit, lets learn how rng works")
key = jr.PRNGKey(0)

shape = (3,2)
print(f"random normal array:\n{jr.normal(key, shape)}")

print(f"lets do another one!\n{jr.normal(key, shape)}")

print(f"They're the same! This is because we used the same key, which is not a complicated object, it's just an array. The key: {key}")

But before jit, lets learn how rng works
random normal array:
[[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]
lets do another one!
[[ 1.6226422   2.0252647 ]
 [-0.43359444 -0.07861735]
 [ 0.1760909  -0.97208923]]
They're the same! This is because we used the same key, which is not a complicated object, it's just an array. The key: [0 0]


In [5]:
print(f"so when we want to generate some randomness, and have the next call be different, 'split' the key")
key, key2 = jr.split(key)
print(f"our new keys are:{key}, and {key2}")
random_array = jr.normal(key, shape)
print(f"a random array:\n{random_array}")
key, key2 = jr.split(key)
random_array = jr.normal(key, shape)
print(f"and another one!\n{random_array}")
print("In general, its a good idea to explicitly keep track of rng keys, this helps with reproducibility, and to avoid subtle bugs")


so when we want to generate some randomness, and have the next call be different, 'split' the key
our new keys are:[1797259609 2579123966], and [ 928981903 3453687069]
a random array:
[[ 1.0040143  -0.9063372 ]
 [-0.7481722  -1.1713669 ]
 [-0.8712328   0.58883816]]
and another one!
[[-0.57478017  0.79983664]
 [-0.25960687  1.429873  ]
 [-0.52380246 -1.7450135 ]]
In general, its a good idea to explicitly keep track of rng keys, this helps with reproducibility, and to avoid subtle bugs


In [None]:
print("Heres an implemtntation with {0,1}")
class SparseParity:
    def __init__(self, n, k, key):
        self.n = n
        self.inds = jr.choice(key, n, shape=(k,), replace = False)
    
    # @eqx.filter_jit
    def get_batch(self, batch_size, key) -> tuple[jax.Array, jax.Array]:
        X = jr.bernoulli(key, shape=(batch_size, self.n)).astype(jnp.float32)
        y = jnp.sum(X[..., self.inds], axis=-1) % 2
        return X, y


key = jr.PRNGKey(0)
key, init_key = jr.split(key)
dataset = SparseParityExample(4, 2, init_key)
X, y = dataset.get_example(key)
print(f"X: {X}")
print(f"y: {y}")


X, y = dataset.get_batch(2, key)
print(f"X: {X}")
print(f"y: {y}")


Heres an implemtntation with {0,1}
X: [0. 1. 1. 1.]
y: 1.0
X: [[0. 1. 1. 1.]
 [1. 0. 0. 1.]]
y: [1. 1.]


In [None]:
# Now you do it with {1,-1} parity!

class SparseParity:
    def __init__(self, n, k, key):
        raise NotImplementedError("Not implemented")

    def get_example(self, key) -> tuple[jax.Array, jax.Array]:
        raise NotImplementedError("Not implemented")

Now lets train a model to learn the parity!
To make things more friendly to those familiar with torch, we will use Equinox, a nn package for jax

In [51]:
import equinox as eqx
key = jr.PRNGKey(0)
key, dataset_key, model_init_key = jr.split(key, 3)
n = 10
k = 1

dataset = SparseParity(n, k, dataset_key)

hidden_dim = 32
num_hidden_layers = 2

model = eqx.nn.MLP(in_size=n, out_size=1, width_size=hidden_dim, depth=num_hidden_layers, key=model_init_key)

In [None]:
X,y = dataset.get_batch(10, key)
try:
    print(model(X))
except Exception as e:
    print(e)

print("One quirk about equinox, is that it doesn't handle batching automatically, but this easy for us to fix with vmap")
print("when using equinox, sometimes you need to use filter versions of jax functions (for example: filter_vmap, filter_jit)")

print(eqx.filter_vmap(model)(X))
print(y)

Incompatible shapes for broadcasting: shapes=[(32, 10), (32,)]
One quirk about equinox, is that it doesn't handle batching automatically, but this easy for us to fix with vmap
when using equinox, sometimes you need to use filter versions of jax functions (for example: vmap, jit) to make things work, for now just always do so
[[0.05863924]
 [0.08096123]
 [0.04649485]
 [0.06503659]
 [0.07639153]
 [0.06271922]
 [0.05976832]
 [0.04839244]
 [0.07761116]
 [0.05253061]]
[0. 1. 1. 0. 0. 1. 0. 1. 1. 1.]


In [None]:
lr = 1e-3
optimizer = optax.adam(learning_rate=lr)

@eqx.filter_jit
def train_step(model, opt_state, batch):
    X,y = batch
    def loss(model):
        y_preds = jax.vmap(model)(X).flatten()
        return jnp.mean((y_preds - y)**2)

    loss, grads = eqx.filter_value_and_grad(loss)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

@eqx.filter_jit
def eval_step(model, batch):
    X,y = batch
    y_preds = eqx.filter_vmap(model)(X).flatten()
    loss = jnp.mean((y_preds - y)**2)

    accuracy = jnp.mean((y_preds > .5) == (y > .5))
    return loss, accuracy

we use optax for optimizers


In [None]:
# from mlp import MLP
batch_size = 64
num_iters = 1000
key = jr.PRNGKey(0)
key, test_data_key, model_init_key = jr.split(key,3)
test_batch = dataset.get_batch(64, test_data_key)
key, datakey = jr.split(key)
keys = jr.split(datakey, num_iters)
batches = (dataset.get_batch(batch_size, key) for key in keys)
dataset = SparseParity(n, k, dataset_key)

hidden_dim = 16
num_hidden_layers = 2

model = eqx.nn.MLP(in_size=n, out_size=1, width_size=hidden_dim, depth=num_hidden_layers, key=model_init_key)
def dataloader(key):
    while True:
        yield from batches
        # key, datakey = jr.split(key)
        # yield dataset.get_batch(batch_size, datakey)

train_loader = dataloader(key)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

pbar = tqdm(range(num_iters))
for i in pbar:
    model, opt_state, train_loss = train_step(model, opt_state, next(train_loader))
    if i % 10 == 0:
        test_loss, test_accuracy = eval_step(model, test_batch)
        pbar.set_postfix(test_loss=f"{test_loss:.3f}", test_accuracy=f"{test_accuracy:.3f}", train_loss=f"{train_loss:.3f}")


  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
class MLP(eqx.Module):
    layers: List[eqx.nn.Linear]

    def __init__(self, in_size, out_size, width_size, depth, key):
        raise NotImplementedError("Not implemented")

    def __call__(self, x, key):
        for layer in self.layers:
            x = layer(x)
        return x