In [104]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
import os
import itertools
import random
import contextlib

import numpy as np
import torch
import lightning

from retnet import GPTR, GPTRConfig, GPTRClassifier
from lra import ListOps

In [2]:
dataset = ListOps("listops-1000")
dataset.setup()

In [3]:
np.random.seed(42)
np.random.choice(1000, 10, replace=False)

array([521, 737, 740, 660, 411, 678, 626, 513, 859, 136])

In [13]:
@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

class ParityDataset(torch.utils.data.Dataset):
    def __init__(self, maxsize=10, minsize=3, ndata=1000, seed=42,
                 train_split=0.8):
        # max_possible_data = 2**maxsize - 2**(minsize-1)
        max_possible_data = sum(2**i for i in range(minsize, maxsize+1))
        assert ndata <= max_possible_data
        self.ndata = ndata
        self.max_possible_data = max_possible_data
        self.maxsize = maxsize
        self.minsize = minsize
        self.seed = seed
        self.train_split = train_split
        self.make_full_dataset()

    def make_full_dataset(self):
        with temp_seed(self.seed):
            inds = np.random.choice(self.max_possible_data, self.ndata, replace=False)
            data = []
            counter = 0
            for n in range(self.minsize, self.maxsize+1):
                sequences, labels, lengths = self.list_of_binary_strings_n(n)
                for k in range(2**n):
                    if counter in inds:
                        data.append((sequences[k], labels[k], lengths[k]))
                    counter += 1
            ind = np.random.permutation(len(data))
        self.data = [data[i] for i in ind]
        self.train_ind = int(self.train_split * len(self.data))

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(self.data[:self.train_ind],
                                           *args, **kwargs,
                                           collate_fn=self.collate_fn)

    def val_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(self.data[self.train_ind:],
                                           *args, **kwargs,
                                           collate_fn=self.collate_fn)

    def collate_fn(self, data):
        """
        data: is a list of tuples with (example, label, length)
                where 'example' is a tensor of arbitrary shape
                and label/length are scalars
        """
        _, labels, lengths = zip(*data)
        max_len = max(lengths)
        features = torch.zeros((len(data), max_len), dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.long)
        lengths = torch.tensor(lengths, dtype=torch.long)
        for i, (example, _, _) in enumerate(data):
            features[i, :len(example)] = example
        return features, labels, {'lengths': lengths}

    def list_of_binary_strings_n(self, n):
        sequences = list(map(list, itertools.product(range(2), repeat=n)))
        sequences = torch.tensor(sequences, dtype=torch.long)
        lengths = n * torch.ones(len(sequences), dtype=torch.long)
        labels = sequences.sum(axis=-1) % 2
        return sequences, labels, lengths

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


            # data.append((sequences, labels, lengths))
        # for n in range(self.minsize, self.maxsize+1):
            # sequences, labels, lengths = list_of_binary_strings_n(n)
            # data.append((sequences, labels, lengths))
        # return data

In [14]:
dataset = ParityDataset()
train_dataloader = dataset.train_dataloader(batch_size=4, num_workers=1)
valid_dataloader = dataset.val_dataloader(batch_size=4, num_workers=1)

In [16]:
def cumulative_parity_function(x):
    return torch.cumsum(x, dim=-1) % 2

In [17]:
batch = next(iter(train_dataloader))

In [19]:
cumulative_parity_function(batch[0])

tensor([[0, 0, 1, 1, 0, 1, 1, 1, 1],
        [0, 0, 1, 0, 1, 0, 0, 1, 1],
        [0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0, 0, 0, 1, 1]])

In [20]:
batch[1]

tensor([1, 1, 0, 1])

In [21]:
batch[2]

{'lengths': tensor([6, 9, 9, 9])}