In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
from tqdm.auto import tqdm
import optax
import matplotlib.pyplot as plt

# from functools import partial
print(jax.devices())

In [None]:
a = jnp.arange(10)
print(f"a: {a}")
print(f"a.shape: {a.shape}")

try:
    # cursor tried to add this line (try/except added by me): 
    print(f"a.device: {a.device()}")
except TypeError as x:
    print(TypeError(x))
    print("In jax, we generally don't have to worry about moving data between gpu and cpu!")

print("most vector/matrix operations are just like numpy/torch")
print(f"a.T @ a: {a.T @ a}")
print(f"a @ a: {a @ a}")

In [None]:
b = jnp.zeros(10)
print(f"b: {b}")

try:
    b[0] = 1
except Exception as e:
    print(e)

b = b.at[0].set(1)
print(f"b: {b}")

print("this can be a bit annoying, but restrictions like these allow for enourmous benefits (i.e. jit!)")

In [None]:
print("But before jit, lets learn how rng works")
key = jr.PRNGKey(0)

shape = (3,2)
print(f"random normal array:\n{jr.normal(key, shape)}")

print(f"lets do another one!\n{jr.normal(key, shape)}")

print(f"They're the same! This is because we used the same key, which is not a complicated object, it's just an array. The key: {key}")

In [None]:
print(f"so when we want to generate some randomness, and have the next call be different, 'split' the key")
key, key2 = jr.split(key)
print(f"our new keys are:{key}, and {key2}")
random_array = jr.normal(key, shape)
print(f"a random array:\n{random_array}")
key, key2 = jr.split(key)
random_array = jr.normal(key, shape)
print(f"and another one!\n{random_array}")
print("In general, its a good idea to explicitly keep track of rng keys, this helps with reproducibility, and to avoid subtle bugs")


In [None]:
print("Heres an implemtntation with {-1, 1}")
class ExampleSparseParity:
    def __init__(self, n, k, key):
        self.n = n
        self.inds = jr.choice(key, n, shape=(k,), replace = False)

    def get_example(self, key) -> tuple[jax.Array, jax.Array]:
        X = 1 - 2 * jr.bernoulli(key, shape=(self.n,)).astype(jnp.float32)
        y = jnp.prod(X[self.inds])
        return X, y
    
    # @partial(jax.jit, static_argnums=(0,1))
    @eqx.filter_jit
    def get_batch(self, batch_size, key) -> tuple[jax.Array, jax.Array]:
        X = 1 - 2 * jr.bernoulli(key, shape=(batch_size, self.n)).astype(jnp.float32)
        y = jnp.prod(X[..., self.inds], axis=-1)
        return X, y


key = jr.PRNGKey(0)
key, init_key = jr.split(key)
test = ExampleSparseParity(4, 2, init_key)
X, y = test.get_example(key)
print(f"X: {X}")
print(f"y: {y}")


X, y = test.get_batch(2, key)
print(f"X: {X}")
print(f"y: {y}")


In [None]:
print("Now you do it with {0, 1}!")

class SparseParity:
    def __init__(self, n, k, key):
        raise NotImplementedError("Not implemented")

    def get_batch(self, batch_size, key) -> tuple[jax.Array, jax.Array]:
        raise NotImplementedError("Not implemented")

Now lets train a model to learn the parity!
To make things more friendly to those familiar with torch, we will use Equinox, a nn package for jax

In [None]:
import equinox as eqx
key = jr.PRNGKey(0)
key, dataset_key, model_init_key = jr.split(key, 3)
n = 10
k = 2

dataset = SparseParity(n, k, dataset_key)

hidden_dim = 32
num_hidden_layers = 2


# we'll use the builtin MLP from Equinox for now
model = eqx.nn.MLP(in_size=n, out_size=1, width_size=hidden_dim, depth=num_hidden_layers, key=model_init_key)

In [None]:
X,y = dataset.get_batch(10, key)
try:
    print(model(X))
except Exception as e:
    print(e)

print("One quirk about equinox, is that it doesn't handle batching automatically, but this easy for us to fix with vmap")
print("when using equinox, sometimes you need to use filter versions of jax functions (for example: vmap, jit) to make things work, for now just always do so")

print(eqx.filter_vmap(model)(X))
print(y)

In [None]:
lr = 1e-3

print("we use optax for optimizers")
optimizer = optax.adam(learning_rate=lr)

# the core of the training loop is function that updates the model
# its helpful that this is its own function, for reasons we will discuss later!
def train_step(model, opt_state, batch):
    X,y = batch
    def loss(model):
        y_preds = jax.vmap(model)(X).flatten()
        return jnp.mean((y_preds - y)**2)

    loss, grads = eqx.filter_value_and_grad(loss)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

# write a function to record test metrics!
def eval_step(model, batch):
    raise NotImplementedError("Not implemented")
    return loss, accuracy

In [None]:
# from mlp import MLP
batch_size = 64
num_iters = 200
key = jr.PRNGKey(0)
key, test_data_key, train_data_key, model_init_key = jr.split(key,4)
test_batch = dataset.get_batch(100, test_data_key)


dataset = SparseParity(n, k, dataset_key)

hidden_dim = 32
num_hidden_layers = 2

model = eqx.nn.MLP(in_size=n, out_size=1, width_size=hidden_dim, depth=num_hidden_layers, key=model_init_key)

def dataloader(key):
    while True:
        raise NotImplementedError("Not implemented")
        yield batch

train_loader = dataloader(train_data_key)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
metrics = {"train_loss":[], "test_loss":[], "test_accuracy":[]}
pbar = tqdm(range(num_iters))
for i in pbar:
    # update the model
    raise NotImplementedError("Model update not implemented")
    
    # record train loss and other metrics
    raise NotImplementedError("Record metrics not implemented")
    pbar.set_postfix(test_loss=f"{metrics['test_loss'][-1]:.3f}", test_accuracy=f"{metrics['test_accuracy'][-1]:.3f}", train_loss=f"{metrics['train_loss'][-1]:.3f}")


In [None]:
plt.plot(metrics["train_loss"], label="train loss")
plt.legend()
plt.title(f"Training Loss for MLP on Sparse Parity n={n}, k={k}")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()
plt.plot(metrics["test_accuracy"], label="test accuracy")
plt.plot(metrics["test_loss"], label="test loss") # might have to adjust if you don't take test loss every iteration
plt.title(f"Test Metrics for MLP on Sparse Parity n={n}, k={k}")
plt.xlabel("Iteration")
plt.ylabel("Accuracy")
plt.legend()
plt.show()



In [None]:
class MLP(eqx.Module):
    layers: List[eqx.nn.Linear]

    def __init__(self, in_size, out_size, width_size, depth, key):
        raise NotImplementedError("Not implemented")

    def __call__(self, x, key):
        for layer in self.layers:
            x = layer(x)
        return x