In [128]:
import torch
from torch import Tensor
from torch.utils.data import dataset
from torchtext import datasets
from grok.transformer import Transformer

from torchtext.datasets import PennTreebank
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer

from typing import Tuple, Dict
import math

In [129]:
train_iter = datasets.PennTreebank(root="../data", split="train")
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
vocab.set_default_index(vocab['<unk>'])

def data_process(raw_text_iter: dataset.IterableDataset) -> torch.Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


In [130]:
train_data = data_process(train_iter)
train_data.train = True
train_data.shape

torch.Size([924412])

In [131]:
def data_process(raw_text_iter: dataset.IterableDataset) -> torch.Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))


def get_ptb_dataset(train_pct: float, split:str = "train", data_dir:str ="../data"):
        data_iter = datasets.PennTreebank(root=data_dir, split=split)
        tokenizer = get_tokenizer('basic_english')
        
        if split != "train":
            train_iter = datasets.PennTreebank(root=data_dir, split="train")
            vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
        else:
            vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=['<unk>'])
        
        vocab.set_default_index(vocab['<unk>'])
        
        data = data_process(data_iter)
        data = data[:int(train_pct*len(data))]
        if split =="train":
            data.train = True
        data.tokenizer = tokenizer
        data.vocab = vocab
        return data


        

In [132]:
def batchify(data: torch.Tensor, bsz: int) -> Tensor:
    """Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    Args:
        data: Tensor, shape [N]
        bsz: int, batch size

    Returns:
        Tensor of shape [N // bsz, bsz]
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).contiguous()
    return data

bptt = 5
def get_batch(source: torch.Tensor, i: int) -> Tuple[Tensor, Tensor]:
    """
    Args:
        source: Tensor, shape [full_seq_len, batch_size]
        i: int

    Returns:
        tuple (data, target), where data has shape [seq_len, batch_size] and
        target has shape [seq_len * batch_size]
    """
    print(source.shape)
    seq_len = min(bptt, source.size(1) - 1 - i)
    data = source[:,i:i+seq_len]
    target = source[:,i+1:i+1+seq_len]
    return data, target


class PTBIterator(torch.utils.data.IterableDataset):
    """
    An iterator over batches of data in an ArithmeticDataset
    """

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        device: torch.device,
        batchsize: float = 2,
        shuffle: bool = True,
    ) -> None:
        """
        :param dataset: the dataset to iterate over
        :param device: the torch device to send batches to
        :param batchsize_hint: * 0 means we use a default batchsize
                               * -1 means the entire dataset
                               * float between 0 and 1 means each batch is
                                 that fraction of the DS
                               * int > 1 means that specific batch size
        :param shuffle: whether or not to randomly shuffle the dataset
        """
        self.dataset = dataset
        self.data = batchify(self.dataset, batchsize)
        self.device = device
        self.reset_iteration(shuffle=shuffle)
        self.batchsize = batchsize

    def reset_iteration(self, shuffle=True):
        self.index = 0
        if shuffle and self.dataset.train:
            self.permutation = torch.randperm(len(self.dataset))
        else:
            self.permutation = torch.arange(len(self.dataset))

    
    def __iter__(self):
        """
        :returns: this iterator
        """
        return self

    def __next__(self) -> Dict[str, Tensor]:
        """
        Returns one batch of data.

        :raises: StopIteration when we're out of data
        :returns: batch tensor of shape (self.batchsize, tokens_per_eq)
        """
        self.index += 1
        return get_batch(self.data, self.index)
        
        
    def __len__(self) -> int:
        """
        :returns: the total number of batches
        """
        return math.ceil(len(self.dataset) / self.batchsize)

    

In [133]:
train_iter = PTBIterator(get_ptb_dataset(0.2), torch.device("cpu"))

In [134]:
len(train_iter)

92441

In [135]:
data,target = next(iter(train_iter))
data, target

torch.Size([2, 92441])


(tensor([[9893, 9894, 9896, 9897, 9898],
         [   0, 1192,    2,    2,    3]]),
 tensor([[9894, 9896, 9897, 9898, 9902],
         [1192,    2,    2,    3,    1]]))