In [40]:
# fetching data from hf datasets: Salesforce/wikitext
# same as the one used by Graphcore/gpt2-wikitext-103

from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
print(type(ds), len(ds))

<class 'datasets.arrow_dataset.Dataset'> 1801350


In [42]:
ds[100]

{'text': ' 96 ammunition packing boxes \n'}

In [32]:
import tiktoken

# yes, i'm using gpt2's tokenizer here 
# it is bpe, but not the version llama used, but should be okay ig
enc = tiktoken.get_encoding("gpt2")

encoded = enc.encode("hello world, let's goo")
decoded = enc.decode(encoded)
print(encoded, '\n', decoded)

[31373, 995, 11, 1309, 338, 467, 78] 
 hello world, let's goo


In [46]:
len(enc.encode(ds[3]['text']))

166

In [55]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import tiktoken

class WikiText103Dataset(Dataset):
    def __init__(self, split='train', max_length=1024):
        self.dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split=split)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    # TODO: not the most optimized, but simple and works
    # can use pad_sequence, do later
    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = self.tokenizer.encode(text)
        
        # clip tokens if exceeding seq_lim
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length] 
        # pad with zeros if short of seq_len
        else:
            tokens = tokens + [0] * (self.max_length - len(tokens)) 
        
        return torch.tensor(tokens, dtype=torch.long)

# create dataset and dataloader
dataset = WikiText103Dataset(split='train')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# example usage
for batch in dataloader:
    print(batch.shape)
    print(batch[2])
    break

torch.Size([16, 1024])
tensor([ 311, 1076,  786,  ...,    0,    0,    0])
