In [1]:
import sys
sys.path.append("..")

import itertools

import torch
import numpy as np

from lra import temp_seed, ParityDataset, BinaryMarkovDataset

In [4]:
dataset = BinaryMarkovDataset()
dataset.setup()
next(iter(dataset.train_dataloader(batch_size=8)))

(tensor([[1, 1, 0, 0, 1, 1, 1, 0, 0, 0],
         [1, 1, 0, 0, 0, 1, 1, 1, 0, 0],
         [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
         [1, 0, 0, 1, 0, 1, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
 tensor([1, 1, 1, 1, 0, 0, 1, 1]),
 {'lengths': tensor([ 7,  8,  9, 10,  6,  3,  3,  5])})

In [65]:
class CustomSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, ndata, train_split, vocab_size):
        self.ndata = ndata
        self.train_split = train_split
        self.vocab_size = vocab_size

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(self.data[:self.train_ind],
                                           *args, **kwargs,
                                           collate_fn=self.collate_fn)

    def val_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(self.data[self.train_ind:],
                                           *args, **kwargs,
                                           collate_fn=self.collate_fn)

    def collate_fn(self, data):
        """
        data: is a list of tuples with (example, label, length)
                where 'example' is a tensor of arbitrary shape
                and label/length are scalars
        """
        _, labels, lengths = zip(*data)
        max_len = max(lengths)
        features = torch.zeros((len(data), max_len), dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.long)
        lengths = torch.tensor(lengths, dtype=torch.long)
        for i, (example, _, _) in enumerate(data):
            features[i, :len(example)] = example
        return features, labels, {'lengths': lengths}

    @property
    def train_ind(self):
        train_ind = int(self.train_split * len(self.data))
        return train_ind

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def setup(self, seed=42):
        raise NotImplementedError
    


class BinarySequenceOpsDataset(CustomSequenceDataset):
    def __init__(self, maxsize=10, minsize=3, ndata=1000,
                 train_split=0.8):
        max_possible_data = sum(2**i for i in range(minsize, maxsize+1))
        assert ndata <= max_possible_data
        super().__init__(ndata, train_split, 2)
        self.max_possible_data = max_possible_data
        self.maxsize = maxsize
        self.minsize = minsize

    def setup(self, seed=42):
        with temp_seed(seed):
            inds = np.random.choice(self.max_possible_data, self.ndata, replace=False)
            data = []
            counter = 0
            for n in range(self.minsize, self.maxsize+1):
                sequences, labels, lengths = self.list_of_binary_strings_n(n)
                for k in range(2**n):
                    if counter in inds:
                        data.append((sequences[k], labels[k], lengths[k]))
                    counter += 1
            ind = np.random.permutation(len(data))
        self.data = [data[i] for i in ind]

    def list_of_binary_strings_n(self, n):
        sequences = list(map(list, itertools.product(range(2), repeat=n)))
        sequences = torch.tensor(sequences, dtype=torch.long)
        lengths = n * torch.ones(len(sequences), dtype=torch.long)
        labels = self.sequence_operation(sequences)
        # labels = sequences[:, 0]
        # labels = sequences.sum(dim=1) % 2
        # labels = ((sequences.sum(dim=1) - n//2) > 0).long()
        return sequences, labels, lengths


class ParityDataset(BinarySequenceOpsDataset):
    def sequence_operation(self, sequences):
        return sequences.sum(dim=1) % 2


class BinaryMarkovDataset(CustomSequenceDataset):
    def __init__(self, ndata=1000, probability_retain=[0.8, 0.2],
                 maxsize=10, minsize=3,
                 train_split=0.7):
        self.probability_retain = probability_retain
        vocab_size = len(probability_retain)
        super().__init__(ndata, train_split, vocab_size)
        max_possible_data = sum(self.vocab_size**i
                                for i in range(minsize, maxsize+1))
        assert ndata <= max_possible_data
        self.maxsize = maxsize
        self.minsize = minsize

    def setup(self, seed=42):
        with temp_seed(seed):
            self.data = []
            counter = 0
            for i in range(self.ndata):
                n = np.random.randint(self.minsize, self.maxsize+1)
                sequence, label, length = self.make_markov_chain(i, n)
                self.data.append((sequence, label, length))
        
    def make_markov_chain(self, i, n):
        which = i%self.vocab_size
        p = self.probability_retain[which]
        sequence = torch.zeros(n, dtype=torch.long)
        val = 0
        for i in range(n):
            change = np.random.rand() < p
            if change:
                val = 1 - val
            sequence[i] = val
        length = torch.tensor(n, dtype=torch.long)
        label = torch.tensor(which, dtype=torch.long)
        return sequence, label, length


In [66]:
dataset = ParityDataset()
dataset.setup()
next(iter(dataset.train_dataloader(batch_size=4)))

(tensor([[0, 0, 1, 0, 1, 1, 0, 0, 0],
         [0, 0, 1, 1, 1, 1, 0, 1, 0],
         [0, 1, 0, 1, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 1, 0, 0, 1, 0]]),
 tensor([1, 1, 0, 1]),
 {'lengths': tensor([6, 9, 9, 9])})

In [67]:
dataset = BinaryMarkovDataset()
dataset.setup()
next(iter(dataset.train_dataloader(batch_size=4)))

(tensor([[1, 0, 1, 0, 1, 0, 1, 0, 1],
         [1, 1, 1, 1, 0, 0, 0, 0, 0],
         [1, 0, 1, 0, 1, 0, 1, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0]]),
 tensor([0, 1, 0, 1]),
 {'lengths': tensor([9, 5, 7, 6])})