<h1>Use the "TinyStories" dataset, as we don't have the resources for a big and complicated dataset</h1>
https://www.kaggle.com/datasets/thedevastator/tinystories-narrative-classification?resource=download

<h3>This file produces the Byte Pair Encoding used; Our transformer does NOT produce individual characters, but rather fragments of sentences. Imagine monkeys on a type writer: If they're given buttons with words, you're much more likely to get a Shakespeare text eventually.</h3>

<h1>1. Load data</h1>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import time

class TextDataset(Dataset):
    def __init__(self, file_path):
        #read text file
        with open(file_path, 'r') as file:
            lines = file.readlines()
        #merge all lines together:
        text = "\n".join(lines)
        #replaced all "" with ` to make parsing the sequences easier (splitting them by " then yields individual stories)
        text = text.replace('""', '`') 
        #split text into chunks according to ", i.e. one story each:
        chunks = text.split('"')
        #clean: remove first chunk (csv name):
        chunks = chunks[1:]
        #remove every second chunk (empty stuff):
        chunks = chunks[::2]
        
        self.data = chunks
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

#we use the validation.csv file to produce the encoding;
#this is a bit easier than using the (much larger) train file
textfile = "validation.csv"

dataset = TextDataset(textfile)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

#give out something from the dataloader to test it:
for batch in dataloader:
    inputs = batch
    print(inputs)
    break

<h1>2. For ease of processing, just gather all data in one string, then take a subset of it</h1>

In [None]:
text_to_tokenise = ""
for i in dataset.data:
    text_to_tokenise += " " + i

#get unique characters:
chars = list(set(text_to_tokenise))

#create a mapping from unique characters to indices
char_to_index = {char: index for index, char in enumerate(chars)}
index_to_char = {index: char for index, char in enumerate(chars)}

#convert the text to a tensor
text_as_indices = [char_to_index[char] for char in text_to_tokenise]
text_subset_for_BPE = text_as_indices[:(1024*1024*4)]

<h3>Helpers to transcribe characters to their token and vice-versa</h3>

In [None]:
def transcribe_chars_to_index(chars):
    indices = []
    for char in chars:
        indices.append(char_to_index[char])
    return indices

def transcribe_indices_to_chars(indices):
    return [index_to_char[indices[i]] for i in range(0, len(indices))]

print(transcribe_indices_to_chars(transcribe_chars_to_index('abc')))

<h1>3. Find BPE:</h1>
-initially: tokens = characters; later: tokens = 1 to N many characters "in one"<br/>
-find most frequent pair of tokens next to each other; usually, something like "e " is very frequent<br/>
-replace most frequent pair with a new token everywhere<br/>
-repeat until we reach a desired dictionary size<br/>
<br/>
(usually, you'd want to do this with some C++ code, python takes forever for longer texts/more tokens!)

In [None]:
#simple BPE: use 2048 tokens, i.e. introduce 2048-EXISTING new tokens to the vocabulary
DICT_LENGTH = 512 * 4

#contains: [tokenIDA, tokenIDB] -> new index; we need these rules to later extract the tokens from the model
rules = []
if not os.path.exists("BPE"): #skip and just load existing if we cancelled before
    while len(index_to_char) < DICT_LENGTH:
        #1. count pairs:
        count = dict()
        for i in range(0, len(text_subset_for_BPE)-1):
            comb = (text_subset_for_BPE[i], text_subset_for_BPE[i+1])
            if comb in count:
                count[comb] += 1
            else:
                count[comb] = 1
        #2. find most common pair:
        max_key = max(count, key=count.get)
        print(max_key)
        print("BPE replaces |"+index_to_char[max_key[0]]+index_to_char[max_key[1]]+"| with a new token; it has "+str(count[max_key])+" occurences")
        #3. replace pair in text:
        new_index = len(index_to_char)
        index_to_char[new_index] = index_to_char[max_key[0]]+index_to_char[max_key[1]]
        rules.append((max_key, new_index))
        #replace all most common occurences with new the index:
        i = 0
        while i < len(text_subset_for_BPE)-1:
            if text_subset_for_BPE[i] == max_key[0] and text_subset_for_BPE[i+1] == max_key[1]:
                text_subset_for_BPE[i] = new_index
                text_subset_for_BPE.pop(i+1)
            i += 1
        #(repeat)

In [None]:
#to save:
#   -rules
#   -index_to_char
#   -char_to_index

#check if directory "BPE" exists:
if not os.path.exists("BPE"):
    os.makedirs("BPE")
    torch.save(rules, "BPE/rules.pt")
    torch.save(index_to_char, "BPE/index_to_char.pt")
    torch.save(char_to_index, "BPE/char_to_index.pt")
else: #just load existing if we cancelled before
    rules = torch.load("BPE/rules.pt")
    index_to_char = torch.load("BPE/index_to_char.pt")
    char_to_index = torch.load("BPE/char_to_index.pt")

<h3>Test with a demo from the dataset</h3>

In [None]:
test = "Spot. Spot saw the shiny car and said, `Wow, Kitty, your car is so bright and clean!` Kitty smiled and replied, `Thank you, Spot. I polish it every day.`"

def apply_BPE(text, rules):
    text_as_indices = transcribe_chars_to_index(text)
    for rule in rules:
        i = 0
        while i < len(text_as_indices)-1:
            if text_as_indices[i] == rule[0][0] and text_as_indices[i+1] == rule[0][1]:
                text_as_indices[i] = rule[1]
                text_as_indices.pop(i+1)
            i += 1
            
    return text_as_indices

def decode_BPE(tokens):
    #return "".join(transcribe_indices_to_chars(tokens))
    return transcribe_indices_to_chars(tokens) #output as list; that way, we can still see the individual tokens, which is useful for debugging

print("LENGTH BEFORE: ", len(test))
print("LENGTH AFTER: ", len(apply_BPE(test, rules)))
print("DECODED: ",decode_BPE(apply_BPE(test, rules))) #do we get the same sentence out?

<h1>4. Apply BPE to the training data and test data, then save as one big torch tensor</h1>

<h3>Write out training data (split up into different files)</h3>

In [None]:
try:
    os.makedirs("data")
except:
    pass

textfile = "train.csv"
dataset = TextDataset(textfile)
#split train data into multiple individual files to handle things better:
NO_CHUNKS = 128
for chunk in range(0, NO_CHUNKS): #<-- if you cancel this, it will continue from the last chunk
    #check if file already exists:
    if os.path.exists("data/train_BPE_"+str(chunk)+".dat"):
        continue
    fails = 0
    new_data = []
    start = time.time()
    for i in range(chunk, len(dataset.data), NO_CHUNKS):
        if i % 100 == 0 and i > 0:
            print("\tCHUNK ",chunk," - DONE WITH ", i/len(dataset.data)*100, "%, time left: ", (time.time()-start)/i*(len(dataset.data)-i)/60, " minutes")
        try:
            dat = torch.tensor(apply_BPE(dataset.data[i], rules)).to(torch.int32)
            new_data.append(dat)
        except:
            fails += 1
            if fails % 100 == 0:
                print("\tSo far, ", fails, "/", i, " failed")
            pass
    torch.save(new_data, "data/train_BPE_"+str(chunk)+".dat")

<h3>Write out validation data</h3>

In [None]:
if not os.path.exists("data/validation_BPE.dat"):
    textfile = "validation.csv"
    dataset = TextDataset(textfile)
    start = time.time()
    new_data = []
    for i in range(0, len(dataset.data)):
        if i % 100 == 0 and i > 0:
            print("DONE WITH ", i/len(dataset.data)*100, "%, time left: ", (time.time()-start)/i*(len(dataset.data)-i)/60, " minutes")
        dat = torch.tensor(apply_BPE(dataset.data[i], rules)).to(torch.int32)
        new_data.append(dat)

    torch.save(new_data, "data/validation_BPE.dat")