In [1]:
import notebooks_path
notebooks_path.include_packages()
import settings


In [13]:
import numpy as np

from jax import numpy as jnp
from jax import random 
from torch.utils import data
from functools import partial

In [14]:
prng = random.PRNGKey(42)

In [15]:
class ReverseDataset(data.Dataset):

    def __init__(self, num_categories, seq_len, size, np_rng):
        super().__init__()
        self.num_categories = num_categories
        self.seq_len = seq_len
        self.size = size
        self.np_rng = np_rng

        self.data = self.np_rng.integers(self.num_categories, size=(self.size, self.seq_len))

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        inp_data = self.data[idx]
        labels = np.flip(inp_data, axis=0)
        return inp_data, labels

In [16]:
# Combine batch elements (all numpy) by stacking
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

dataset = partial(ReverseDataset, 10, 16)
rev_train_loader = data.DataLoader(dataset(50000, np_rng=np.random.default_rng(42)),
                                   batch_size=128,
                                   shuffle=True,
                                   drop_last=True,
                                   collate_fn=numpy_collate)
rev_val_loader   = data.DataLoader(dataset(1000, np_rng=np.random.default_rng(43)),
                                   batch_size=128,
                                   collate_fn=numpy_collate)
rev_test_loader  = data.DataLoader(dataset(10000, np_rng=np.random.default_rng(44)),
                                   batch_size=128,
                                   collate_fn=numpy_collate)

In [19]:
for i, row in enumerate(rev_val_loader):
    print(row)
    print('-' * 10)
    if i > 10:
        break

[array([[5, 6, 4, ..., 7, 3, 2],
       [2, 4, 6, ..., 4, 4, 0],
       [4, 3, 5, ..., 5, 7, 9],
       ...,
       [1, 1, 4, ..., 9, 5, 1],
       [7, 9, 3, ..., 8, 9, 4],
       [9, 8, 8, ..., 6, 5, 0]]), array([[2, 3, 7, ..., 4, 6, 5],
       [0, 4, 4, ..., 6, 4, 2],
       [9, 7, 5, ..., 5, 3, 4],
       ...,
       [1, 5, 9, ..., 4, 1, 1],
       [4, 9, 8, ..., 3, 9, 7],
       [0, 5, 6, ..., 8, 8, 9]])]
----------
[array([[8, 3, 3, ..., 9, 0, 0],
       [4, 4, 0, ..., 5, 7, 4],
       [8, 2, 0, ..., 3, 3, 1],
       ...,
       [6, 4, 7, ..., 2, 5, 3],
       [5, 0, 8, ..., 4, 0, 2],
       [0, 3, 7, ..., 0, 3, 3]]), array([[0, 0, 9, ..., 3, 3, 8],
       [4, 7, 5, ..., 0, 4, 4],
       [1, 3, 3, ..., 0, 2, 8],
       ...,
       [3, 5, 2, ..., 7, 4, 6],
       [2, 0, 4, ..., 8, 0, 5],
       [3, 3, 0, ..., 7, 3, 0]])]
----------
[array([[9, 9, 0, ..., 8, 9, 9],
       [0, 4, 0, ..., 0, 8, 6],
       [7, 2, 6, ..., 5, 8, 3],
       ...,
       [3, 8, 8, ..., 6, 2, 9],
       [8, 

In [29]:
print(row[0][0])
print(np.array(list(reversed(row[1][0]))))
print(len(row[0][0]))

[7 2 9 7 5 6 0 5 5 7 6 4 8 3 2 8]
[7 2 9 7 5 6 0 5 5 7 6 4 8 3 2 8]
16


In [3]:
train_dataloader = dataloader.get_train_dataloader(8)

Files already downloaded and verified


In [1]:
import jax

In [10]:
import jax
import equinox as eqx
from jax import numpy as jnp
from equinox import nn

In [6]:
norm = nn.LayerNorm(10)

In [8]:
jax.vmap(norm)(jnp.ones((8, 10, 10)))

ValueError: `LayerNorm(shape)(x)` must satisfy the invariant `shape == x.shape`Received `shape=(10,) and `x.shape=(10, 10)`. You might need to replace `layer_norm(x)` with `jax.vmap(layer_norm)(x)`.

If this is a new error for you, it might be because this became stricter in Equinox v0.11.0. Previously all that was required is that `x.shape` ended with `shape`. However, this turned out to be a frequent source of bugs, so we made the check stricter!

In [11]:
eqx.__version__

'0.11.2'