In [1]:
import torch 
from torch import nn
import datasets
from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, decoders, trainers
import tokenizers

import math

In [2]:
BATCH_SIZE = 64
CONTEXT_LENGTH = 512
VOCAB_SIZE = 16_000

## Load The Dataset

For the purposes of our simple next token prediction transformer we'll keep it to a simple and small dataset with an equally small tokenizer to keep models size small.

I sought out this dataset for a couple reasons:
1. It's english only removing the need to handle unicode (although the byte level bpe would handle this already)
2. It's small with a limit on the number of words which helps us to limit the number of tokens and keeps the transformer params small and quicker to train

In [3]:
dataset = datasets.load_dataset("wikitext", "wikitext-103-v1")
dataset

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [4]:
dataset['train']['text'][:10]

['',
 ' = Valkyria Chronicles III = \n',
 '',
 ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n',
 " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving

## Configure The Tokenizer

For this we'll be using a byte level byte pair encoding tokenizer.

In [5]:
tokenizer = Tokenizer(models.BPE())
tokenizer.normalizer = normalizers.NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(vocab_size=VOCAB_SIZE, min_frequency=2)

if tokenizer.get_vocab_size() < VOCAB_SIZE:
    pass  # We should do something here to prevent this.

In [6]:
def batch_iterator(batch_size: int = BATCH_SIZE) -> list[str]:
    for batch in dataset['train'].select_columns("text").iter(batch_size):
       yield batch['text'] 

In [7]:
tokenizer.train_from_iterator(batch_iterator(), trainer)






In [21]:
tokenizer.encode(dataset["train"][1]["text"]).tokens

['Ġ=', 'ĠV', 'alk', 'y', 'ria', 'ĠChronic', 'les', 'ĠIII', 'Ġ=', 'ĠĊ']

In [54]:
def iter_batches(
    dataset: datasets.Dataset, 
    tokenizer: tokenizers.Tokenizer, 
    batch_size: int = BATCH_SIZE,
    context_length: int = CONTEXT_LENGTH,
) -> torch.Tensor:
    buffer = torch.zeros(batch_size, context_length, VOCAB_SIZE, dtype=torch.int16) 

    write_ix = 0
    token_write_ix = 0
    
    for sample in dataset:
        token_ids = tokenizer.encode(sample["text"]).ids
        for token_id in token_ids:
            buffer[write_ix, token_write_ix, token_id] = 1
            token_write_ix += 1
        
            if token_write_ix == context_length:
                token_write_ix = 0
                write_ix += 1
    
            if write_ix == batch_size:
                write_ix = 0
                yield buffer

In [61]:
batch_iter = iter_batches(dataset['train'], tokenizer, batch_size=6, context_length=24)
batch = next(batch_iter)
batch

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0

In [66]:
def decode_batch(batch: torch.Tensor) -> list[str]:
    return [tokenizer.decode(torch.argmax(sample, dim=-1).tolist()) for sample in batch]