In [40]:
# fetching data from hf datasets: Salesforce/wikitext
# same as the one used by Graphcore/gpt2-wikitext-103

from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
print(type(ds), len(ds))

<class 'datasets.arrow_dataset.Dataset'> 1801350


In [42]:
ds[100]

{'text': ' 96 ammunition packing boxes \n'}

In [32]:
import tiktoken

# yes, i'm using gpt2's tokenizer here 
# it is bpe, but not the version llama used, but should be okay ig
enc = tiktoken.get_encoding("gpt2")

encoded = enc.encode("hello world, let's goo")
decoded = enc.decode(encoded)
print(encoded, '\n', decoded)

[31373, 995, 11, 1309, 338, 467, 78] 
 hello world, let's goo


In [46]:
len(enc.encode(ds[3]['text']))

166

In [102]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import tiktoken

class WikiText103Dataset(Dataset):
    def __init__(self, split='train', seq_len=128):
        dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split=split)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.seq_len = seq_len
        
        # Filter out empty entries
        self.dataset = [item for item in dataset if item['text'].strip()]

    def __len__(self):
        return len(self.dataset)

    # TODO: not the most optimized, but simple and works, can use pad_sequence
    # TODO: also needs attention mask
    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = self.tokenizer.encode(text)
        
        # we need to pull 1 more token above seq_len 
        # reason: [a,b,c] -> [b,c,d] for seq [a,b,c,d]

        # clip tokens if exceeding seq_lim + 1
        if len(tokens) > (self.seq_len + 1):
            tokens = tokens[:(self.seq_len + 1)] 
        # pad with zeros if short of seq_len + 1
        else:
            tokens = tokens + [0] * ((self.seq_len + 1) - len(tokens)) 

        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

# create dataset and dataloader
dataset = WikiText103Dataset(split='train')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [111]:
# sample
for batch in dataloader:
    x, y = batch
    print(x.shape, y.shape)
    print()
    for input, output in zip(x.tolist()[:5], y.tolist()[:5]):
        print("[INPUT]", enc.decode(input[:10]))
        print("[OUTPUT]", enc.decode(output[:10]))
        print("-"*10)
    break

torch.Size([16, 128]) torch.Size([16, 128])

[INPUT]  The rivalry between Borden and Angier dominates the
[OUTPUT]  rivalry between Borden and Angier dominates the film
----------
[INPUT]  This was followed by Operation Drežnica
[OUTPUT]  was followed by Operation Drežnica ,
----------
[INPUT]  Victor Yarros and Rachelle Yarros 

[OUTPUT]  Yarros and Rachelle Yarros 
!
----------
[INPUT]  As a result , he and his producers spent the
[OUTPUT]  a result , he and his producers spent the time
----------
[INPUT]  Several countries have produced postage stamps which depict the northern
[OUTPUT]  countries have produced postage stamps which depict the northern bald
----------
