In [3]:
import math

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import optax  # https://github.com/deepmind/optax

In [4]:
def dataloader(arrays, batch_size):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def get_data(dataset_size, *, key):
    t = jnp.linspace(0, 2 * math.pi, 16)
    offset = jrandom.uniform(key, (dataset_size, 1), minval=0, maxval=2 * math.pi)
    x1 = jnp.sin(t + offset) / (1 + t)
    x2 = jnp.cos(t + offset) / (1 + t)
    y = jnp.ones((dataset_size, 1))

    half_dataset_size = dataset_size // 2
    x1 = x1.at[:half_dataset_size].multiply(-1)
    y = y.at[:half_dataset_size].set(0)
    x = jnp.stack([x1, x2], axis=-1)

    return x, y

In [5]:
class RNN(eqx.Module):
    hidden_size: int
    cell: eqx.Module
    linear: eqx.nn.Linear
    bias: jax.Array

    def __init__(self, in_size, out_size, hidden_size, *, key):
        ckey, lkey = jrandom.split(key)
        self.hidden_size = hidden_size
        self.cell = eqx.nn.GRUCell(in_size, hidden_size, key=ckey)
        self.linear = eqx.nn.Linear(hidden_size, out_size, use_bias=False, key=lkey)
        self.bias = jnp.zeros(out_size)

    def __call__(self, input):
        hidden = jnp.zeros((self.hidden_size,))

        def f(carry, inp):
            return self.cell(inp, carry), None

        out, _ = lax.scan(f, hidden, input)
        # sigmoid because we're performing binary classification
        return jax.nn.sigmoid(self.linear(out) + self.bias)

In [8]:
def main(
    dataset_size=10000,
    batch_size=32,
    learning_rate=3e-3,
    steps=200,
    hidden_size=16,
    depth=1,
    seed=5678,
):
    data_key, model_key = jrandom.split(jrandom.PRNGKey(seed), 2)
    xs, ys = get_data(dataset_size, key=data_key)
    iter_data = dataloader((xs, ys), batch_size)

    model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

    @eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

    # Important for efficiency whenever you use JAX: wrap everything into a single JIT
    # region.
    @eqx.filter_jit
    def make_step(model, x, y, opt_state):
        loss, grads = compute_loss(model, x, y)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    optim = optax.adam(learning_rate)
    opt_state = optim.init(model)
    print(f'{"Step":<6}|{"Loss":<10}')
    for step, (x, y) in zip(range(steps), iter_data):
        loss, model, opt_state = make_step(model, x, y, opt_state)
        loss = loss.item()
        print(f"{step:<6}|{loss:<10.4f}")

    pred_ys = jax.vmap(model)(xs)
    num_correct = jnp.sum((pred_ys > 0.5) == ys)
    final_accuracy = (num_correct / dataset_size).item()
    print(f"final_accuracy={final_accuracy:.4f}")

In [9]:
main()  # All right, let's run the code.


Step  |Loss      
0     |0.7095    
1     |0.6952    
2     |0.6903    
3     |0.6986    
4     |0.6939    
5     |0.6977    
6     |0.6939    
7     |0.6912    
8     |0.6903    
9     |0.6912    
10    |0.6931    
11    |0.6920    
12    |0.6938    
13    |0.6929    
14    |0.6946    
15    |0.6903    
16    |0.6923    
17    |0.6930    
18    |0.6939    
19    |0.6901    
20    |0.6939    
21    |0.6920    
22    |0.6932    
23    |0.6924    
24    |0.6924    
25    |0.6943    
26    |0.6972    
27    |0.6920    
28    |0.6942    
29    |0.6915    
30    |0.6893    
31    |0.6895    
32    |0.6939    
33    |0.6882    
34    |0.6889    
35    |0.7056    
36    |0.6865    
37    |0.6842    
38    |0.6802    
39    |0.6990    
40    |0.6912    
41    |0.7013    
42    |0.6822    
43    |0.6898    
44    |0.6933    
45    |0.7014    
46    |0.6925    
47    |0.6903    
48    |0.6869    
49    |0.6979    
50    |0.6831    
51    |0.7050    
52    |0.6897    
53    |0.6932    
54    |0.7

In [10]:
hidden_size = 16
key = jrandom.PRNGKey(0)
data_key, model_key = jrandom.split(key, 2)
model = RNN(in_size=2, out_size=1, hidden_size=hidden_size, key=model_key)

In [13]:
print(model)
print(model.cell.weight_hh.shape)
print(model.cell.weight_ih.shape)

RNN(
  hidden_size=16,
  cell=GRUCell(
    weight_ih=f32[48,2],
    weight_hh=f32[48,16],
    bias=f32[48],
    bias_n=f32[16],
    input_size=2,
    hidden_size=16,
    use_bias=True
  ),
  linear=Linear(
    weight=f32[1,16],
    bias=None,
    in_features=16,
    out_features=1,
    use_bias=False
  ),
  bias=f32[1]
)
(48, 16)
(48, 2)


```python
    @jax.named_scope("eqx.nn.GRUCell")
    def __call__(
        self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None
    ):
        """**Arguments:**

        - `input`: The input, which should be a JAX array of shape `(input_size,)`.
        - `hidden`: The hidden state, which should be a JAX array of shape
            `(hidden_size,)`.
        - `key`: Ignored; provided for compatibility with the rest of the Equinox API.
            (Keyword only argument.)

        **Returns:**

        The updated hidden state, which is a JAX array of shape `(hidden_size,)`.
        """
        if self.use_bias:
            bias = self.bias
            bias_n = self.bias_n
        else:
            bias = 0
            bias_n = 0
        igates = jnp.split(self.weight_ih @ input + bias, 3) # [Wzx @ x, Wrx @ x, Wnx @ x]
        hgates = jnp.split(self.weight_hh @ hidden, 3) # [Wzh @ h, Wrh @ h, Wn @ h]
        reset = jnn.sigmoid(igates[0] + hgates[0]) # r = σ(Wzx @ x + Wzh @ h)
        inp = jnn.sigmoid(igates[1] + hgates[1]) # z = σ(Wrx @ x + Wrh @ h)
        new = jnn.tanh(igates[2] + reset * (hgates[2] + bias_n)) # n = tanh(Wnx @ x + r * (Wn @ h))
        return new + inp * (hidden - new)
```