In [1]:
import math

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax

In [2]:
def dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def get_data(dataset_size, *, key):
    t = jnp.linspace(0, 2 * math.pi, 16)
    offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
    x1 = jnp.sin(t + offset) / (1 + t)
    x2 = jnp.cos(t + offset) / (1 + t)
    y = jnp.ones((dataset_size, 1))

    half_dataset_size = dataset_size // 2
    x1 = x1.at[:half_dataset_size].multiply(-1)
    y = y.at[:half_dataset_size].set(0)
    x = jnp.stack([x1, x2], axis=-1)

    return x, y

In [3]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

In [13]:
dataset_size=10000
batch_size=32
learning_rate=3e-3
steps=20
hidden_size=16
depth=1
seed=5678

data_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 2)
xs, ys = get_data(dataset_size, key=data_key)
iter_data = dataloader((xs, ys), batch_size)

model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

@eqx.filter_value_and_grad
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    # Trains with respect to binary cross-entropy
    return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

# Important for efficiency whenever you use JAX: wrap everything into a single JIT
# region.
@eqx.filter_jit
def make_step(model, x, y, opt_state):
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

optim = optax.adam(learning_rate)
opt_state = optim.init(model)
for step, (x, y) in zip(range(steps), iter_data):
    loss, model, opt_state = make_step(model, x, y, opt_state)
    loss = loss.item()
    print(f"step={step}, loss={loss}")

pred_ys = jax.vmap(model)(xs)
num_correct = jnp.sum((pred_ys > 0.5) == ys)
final_accuracy = (num_correct / dataset_size).item()
print(f"final_accuracy={final_accuracy}")


step=0, loss=0.7216176986694336
step=1, loss=0.6902147531509399
step=2, loss=0.6979550123214722
step=3, loss=0.6814358234405518
step=4, loss=0.7044166326522827
step=5, loss=0.6944574117660522
step=6, loss=0.6910380125045776
step=7, loss=0.6976555585861206
step=8, loss=0.6890456080436707
step=9, loss=0.6951369047164917
step=10, loss=0.6963343620300293
step=11, loss=0.6905266046524048
step=12, loss=0.6929740309715271
step=13, loss=0.6984668970108032
step=14, loss=0.6953421235084534
step=15, loss=0.6912515163421631
step=16, loss=0.6906266212463379
step=17, loss=0.6928770542144775
step=18, loss=0.6971147060394287
step=19, loss=0.6919533014297485
final_accuracy=0.5


In [17]:
loss, grads = compute_loss(model, x, y)
updates, opt_state = optim.update(grads, opt_state)

In [23]:
compute_loss(model, x, y)[1]

RNN(
  hidden_size=None,
  cell=GRUCell(
    weight_ih=f32[48,2],
    weight_hh=f32[48,16],
    bias=f32[48],
    bias_n=f32[16],
    input_size=2,
    hidden_size=16,
    use_bias=True
  ),
  linear=Linear(
    weight=f32[1,16],
    bias=None,
    in_features=16,
    out_features=1,
    use_bias=False
  ),
  bias=f32[1]
)

In [28]:
model.cell.bias

Array([ 0.02099992, -0.22650257,  0.22024894, -0.14091924,  0.19867097,
        0.186878  ,  0.2590352 , -0.19021241,  0.10789654, -0.00654481,
        0.14347328, -0.12365455,  0.09235023,  0.01409564,  0.1799287 ,
       -0.18900716, -0.08414254, -0.00932486,  0.08162478, -0.1449526 ,
        0.20788766, -0.2012221 ,  0.05838855, -0.08605069, -0.15864289,
        0.2008853 ,  0.08769115,  0.09258799, -0.12722623,  0.01412126,
       -0.06657902,  0.03102791,  0.08074245,  0.21019915, -0.21532707,
        0.18304779, -0.11230482,  0.15396012,  0.13632606,  0.17642733,
       -0.00108753, -0.18435633,  0.0789031 ,  0.04088577,  0.04105607,
        0.14619192,  0.08829977,  0.0833812 ], dtype=float32)

In [27]:
eqx.apply_updates(model, updates).cell.bias

Array([ 0.02033575, -0.22588035,  0.22095573, -0.14154696,  0.1980837 ,
        0.18616064,  0.25889215, -0.1909398 ,  0.10857908, -0.00783069,
        0.14367156, -0.1237536 ,  0.09237452,  0.01348901,  0.18081193,
       -0.18829595, -0.08480729, -0.01026969,  0.08284629, -0.14506929,
        0.20644861, -0.19957067,  0.05704628, -0.084939  , -0.1598957 ,
        0.20243405,  0.08643414,  0.09179538, -0.12641516,  0.01319417,
       -0.06741141,  0.03020173,  0.08141845,  0.21092674, -0.21462524,
        0.18364033, -0.1130131 ,  0.15328297,  0.13565025,  0.17713504,
       -0.00179389, -0.18505491,  0.0782076 ,  0.04022322,  0.04039116,
        0.14551838,  0.08897749,  0.08271001], dtype=float32)