<h1>1. Build a loader to get the data</h1>
<h3>Note that we use the "Schedule Free" Adam from Facebook research: https://github.com/facebookresearch/schedule_free<br/>
This relieves us from the annoying task of providing a learning rate schedule for our transformer (i.e. adapt the learning rate in some pattern throughout training);<br/>
However, it still needs warmup.
</h3>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import time
import schedulefree

rules = torch.load("BPE/rules.pt")
index_to_char = torch.load("BPE/index_to_char.pt")
char_to_index = torch.load("BPE/char_to_index.pt")

PADDING_TOKEN = len(index_to_char) #padding token is the last token in the index_to_char dictionary; we use it to shorter pad sequences to the max length
LARGEST_SEQUENCE_LENGTH = 250
BATCH_SIZE = 16 #increase me - this is the number of sequences that are processed in parallel; should run on any decent GPU, I used 80 for my NVidia 4080
TEXTFILES_TO_USE = 2 # if you want to use more than 2 textfiles, you need to run the first file to produce them from the original dataset

#helper functions - the same as for our BPE
def apply_BPE(text, rules):
    text_as_indices = transcribe_chars_to_index(text)
    for rule in rules:
        i = 0
        while i < len(text_as_indices)-1:
            if text_as_indices[i] == rule[0][0] and text_as_indices[i+1] == rule[0][1]:
                text_as_indices[i] = rule[1]
                text_as_indices.pop(i+1)
            i += 1
            
    return text_as_indices

def transcribe_indices_to_chars(indices):
    return [index_to_char[indices[i]] for i in range(0, len(indices))]

def decode_BPE(tokens):
    #EXPECTS input to be ONE item, not a batch!
    #cut off every token >= PADDING_TOKEN!
    tokens_ = tokens
    tokens = []
    for token in tokens_:
        if token < PADDING_TOKEN:
            #if token is a tensor, append item:
            if isinstance(token, torch.Tensor):
                tokens.append(token.item())
            else:
                tokens.append(token)
        else:
            break
    #return "".join(transcribe_indices_to_chars(tokens)) #show full text
    return str(transcribe_indices_to_chars(tokens)) #show individual text fragments, in an array

def transcribe_chars_to_index(chars):
    indices = []
    for char in chars:
        indices.append(char_to_index[char])
    return indices

<h3>Load from multiple files, pad to a fixed length / filter out longer ones</h3>

In [None]:
class EncodedDataset(Dataset):
    def __init__(self, file_paths):
        self.data = []
        for file_path in file_paths:
            loaded = torch.load(file_path)
            self.data.extend(loaded)
        print("Loaded ", len(self.data), " samples from ", len(file_paths), " files.")
        
        if False: #print some statsitics; usually, throwing away samples with 250+ doesn't hurt much, but speeds up computation considerably
            largest = 0 #find the largest sequence length
            largest_sequence = None
            lengths = []
            for sample in self.data:
                lengths.append(len(sample))
                if len(sample) > largest:
                    largest = len(sample)
                    largest_sequence = sample
            
            print("Average sequence length: ", sum(lengths)/len(lengths))
            print("Median sequence length: ", sorted(lengths)[len(lengths)//2])
            print("90 percent length: ", sorted(lengths)[int(len(lengths)/10*8)])

            print("Largest sequence length: ", largest)
            print("Example of largest sequence: ", decode_BPE(largest_sequence))
        
        #throw away samples with LARGEST_SEQUENCE_LENGTH+ tokens
        #   -> without this, attention computation is very slow (quadratic scaling) with little benefit (very few samples are actually that long)
        self.data = [torch.cat((sample, torch.ones((LARGEST_SEQUENCE_LENGTH-len(sample),), dtype=torch.long) * PADDING_TOKEN)) for sample in self.data if len(sample) < LARGEST_SEQUENCE_LENGTH]
        #stack:
        self.data = torch.stack(self.data)
        print("Stored ", len(self.data), " tensors.")
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

#list all textfiles we want to load
textfiles = []
for i in range(0, TEXTFILES_TO_USE):
    textfiles.append("data/train_BPE_"+str(i)+".dat")

train_dataset = EncodedDataset(textfiles)
test_dataset = EncodedDataset(["data/validation_BPE.dat"])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

<h1>2. Transformer helpers</h2>

In [None]:
def positional_encoding(embed_dims, SEQUENCE_LENGTH):
    enc = torch.arange(SEQUENCE_LENGTH).unsqueeze(1).float()  # Use arange instead of ones
    denominator = torch.pow(10000, torch.arange(0, embed_dims, 2).float() / embed_dims)
    
    angle_rads = enc / denominator  # Element-wise division
    sin_vals = torch.sin(angle_rads)
    cos_vals = torch.cos(angle_rads)

    # Interleave sin and cos values to match expected shape
    pos_enc = torch.zeros(SEQUENCE_LENGTH, embed_dims)
    pos_enc[:, 0::2] = sin_vals  # Even indices
    pos_enc[:, 1::2] = cos_vals  # Odd indices

    return pos_enc

<h1>3. Build actual Transformer</h2>

<h3>Rough rundown of what a transformer does:</h3>
https://arxiv.org/abs/1706.03762 is the original idea; https://jalammar.github.io/illustrated-transformer/ explains it somewhat nicely:<br/>
For attention, you compute pairwise scores between all tokens, then use these scores to<br/>
mix your tokens together to new tokens. Exemplary, for "A black cat sat on the wall", the word "black" will "attend" to "cat", i.e. have a lot of attention on black;<br/>
meaning the tokens will be mixed such that we have a hybrid token thingy that says "black cat" (very crude).<br/><br/>
Transformers are just build out of stacked blocks ("layers"); each block consists of:<br/>
-an attention layer that computes pairwise scores, then re-mixes tokens accordingly<br/>
-normalisations & residuals<br/>
-a fully connected network part that is applied to EACH token after attention; meaning this does the heavy lifting,<br/>
while the attention is the only part where tokens get to know each other. The fully connected network is also the part where mixture of experts (MoE)<br/>
usually happens, i.e. where we apply a different network according to some routing process ("for math, we use net A, for french, net D, for german, net F, [...]")

<h3>Code blocks here:</h3>
<b>AutoregressiveDecoderTransformer</b> is a decoder-only transformer (i.e. predicts away token after token); it stacks multiple transformer decoder blocks,<br/>
then applies some head that does classification (i.e. gives us a pseudo probability distributuon over tokens) that we sample from.<br/><br/>

<b>TransformerDecoderBlock</b> is a transformer block: attention computation between all tokens, then apply the fully connected network to each token. Also contains normalisation and residuals<br/><br/>

<b>FeedForward</b> is the fully connected network that processes each token after "mixing" it in the attention layer<br/><br/>

<b>CausalSelfAttention</b> is there to speed up training: When we process a full sentence, we can just "cover up" some part of the attention matrix to get a subset that describes part of a sentence;<br/>
if we have "The| |black| |cat", we can compute the whole 5-by-5 attention matrix (whole sentence) and just cover up pieces of it for e.g. "The| |black". This is what makes transformers so fast -<br/>
they somewhat learn on each prefix in parallel.<br/><br/>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

#set to 0.0 for no dropout
DROPOUT_RATE = 0.1

class FeedForward(nn.Module):
    #feed forward with ReLU
    def __init__(self, dim):
        super().__init__()
        #two linear layers with ReLU in between
        #note how size expands and contracts:
        #   imagine a puzzle you try to solve - you want the table you do it on
        #   to be big enough to lay out all the pieces instead of just big enough 
        #   to hold the final result
        self.lin_1 = torch.nn.Linear(dim, dim * 4)
        self.lin_2 = torch.nn.Linear(dim * 4, dim)
        #I'd suggest leaky ReLU, but ReLU is the standard used in TFs
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.lin_1(x))
        x = self.lin_2(x)
        return x

class CausalSelfAttention(nn.Module):
    #>>Causal<< self-attention means that the model can only look at previous tokens:
    #   this is important for autoregressive models, as we can then recycle a lot
    #   of the computation for attention & train all prefixes at once
    #   (e.g. for ABCDEFG, we train the next token after A, after AB, after ABC, ... in one go)
    def __init__(self, dim, num_heads):
        super().__init__()
        #ensure that we can split the dimension into num_heads
        #each head will then have dim/num_heads dimensions, i.e.
        #we divide the input into num_heads parts to 
        #   a) keep matrix sizes tame
        #   b) ensure that different heads can focus on different tasks
        #      (softmax focuses the attention largely on one part of the input,
        #       then different heads can focus on different parts)
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        #scale factor for attention - as the dot product grows with the dimension, 
        #   we scale it down to prevent the softmax from getting too extreme / sharp
        self.scale = 1.0 / (math.sqrt(self.head_dim))

        #linear layer for query, key, value (=apply query, key, value matrix to input)
        #here, we make life easy for us and apply one linear layer (=also multiplies the input by a matrix)
        #   (instead of having separate matrices for each)
        #   --> same number of parameters, but less code / all in one go
        #   (don't forget the bias=False, as we don't want to add a bias here; we just want a matrix multiplication, essentially)
        #   (linear layer is just W * x + b)

        self.qkv_proj = nn.Linear(dim, 3 * dim, bias=False)
        
        self.out_proj = nn.Linear(dim, dim, bias=False)

    def forward(self, x, mask=None):
        batch_size, T, C = x.shape
        #produce key, query, value from input
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        #divide into num_heads parts (=split the dimension up)
        q, k, v = map(lambda t: t.view(batch_size, T, self.num_heads, self.head_dim).transpose(1, 2), qkv)

        #compute scaled dot product attention:
        #   dot product of query and key, then scale it down to prevent softmax from getting too extreme
        attn_weights = (q @ k.transpose(-2, -1)) * self.scale
        #to prevent the model from looking at future tokens (for autoregressive training),
        #   we mask the attention weights for tokens that are in the future;
        #   this is done by setting the attention weights for future tokens to -inf
        #   (as softmax(-inf) = 0, i.e. the model will ignore these tokens)
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_weights, dim=-1)
        
        out = (attn_weights @ v).transpose(1, 2).contiguous().view(batch_size, T, C)
        return self.out_proj(out)

class TransformerDecoderBlock(nn.Module):
    #transformer decoder block:
    #   self attention, dropout
    #   residual & layer norm
    #   feed forward, dropout
    #   residual & layer norm

    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.attn = CausalSelfAttention(dim, num_heads)
        self.ffn = FeedForward(dim)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        #dropout just to prevent overfitting, i.e. memorising stuff:
        #   we want the model to learn the structure of the data, not the data itself!
        #   dropout is a simple way to prevent the model from memorising the data
        #   by randomly setting some weights to zero;
        #   i.e. the model can't rely on just memorising individual aspects,
        #   but has to learn the structure of the data in a general and redundant (=robust) way
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        #self attention and dropout, then residual
        x = x + self.dropout(self.attn(x, mask))
        #layer norm
        x = self.norm1(x)
        
        #feed forward and dropout, then residual
        x = x + self.dropout(self.ffn(x))
        #layer norm
        x = self.norm2(x)
        return x

class AutoregressiveDecoderTransformer(nn.Module):
    def __init__(self, vocab_size, max_seq_len, dim, num_layers, num_heads, dropout=DROPOUT_RATE):
        super().__init__()
        #embed tokens with something learnable
        self.token_embedding = nn.Embedding(vocab_size, dim)
        #embed positions with something computed
        self.pos_embedding = torch.nn.Parameter(positional_encoding(dim, LARGEST_SEQUENCE_LENGTH+1)[None], requires_grad=False)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.decoder = nn.Linear(dim, vocab_size, bias=False)

        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)).unsqueeze(0).unsqueeze(0))

    def forward(self, x):
        batch_size, tokens = x.shape
        token_emb = self.token_embedding(x)
        pos_emb = self.pos_embedding[:,:x.size()[1]]
        
        x = token_emb + pos_emb
        
        mask = self.mask[:, :, :tokens, :tokens]
        for layer in self.layers:
            x = layer(x, mask)

        return self.decoder(x)

    #maximum sampling - useful for debugging, but will always generate the same sequence ("always pick the most likely token as next token")
    def generate_max(self, tokens, max_new_tokens):
        for _ in range(max_new_tokens):
            logits = self.forward(tokens)
            next_token = torch.argmax(logits, dim=-1, keepdim=False)[:, -1:]
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens
    #sample just randomly according to probability - has a chance to pick some really messed up
    def generate_mul(self, tokens, max_new_tokens):
        for _ in range(max_new_tokens):
            logits = self.forward(tokens)
            next_token = torch.multinomial(F.softmax(logits, dim=-1)[:, -1], 1)
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens
    #nucleus sampling: only sample from the top 90% of the probability distribution
    #                  we do so by taking only the largest probabilities up to gaining 90%,
    #                  then sampling from it
    #                  in result, this will a) prevent unlikely stuff from being sampled
    #                                   and b) still allows for some randomness, IMHO a bit nicer than top-k
    def generate_nuc(self, tokens, max_new_tokens):
        for _ in range(max_new_tokens):
            #1. get logits
            logits = self.forward(tokens)
            #2. turn into probabilities
            probs = F.softmax(logits, dim=-1)[:, -1]
            #3. sort & cumsum to get the cumulative probability to cut off everything beyond 90%
            sorted, indices = torch.sort(probs, descending=True)
            cumulative = torch.cumsum(sorted, dim=-1)
            #find the first index where the cumulative probability is larger than 0.9
            cutoff = torch.argmax((cumulative > 0.9).long(), dim=-1)
            #4. null out everything beyond 90%
            for b in range(0, probs.size()[0]):
                cutoff_index = cutoff[b] + 1
                probs[b, indices[b, cutoff_index:]] = 0.0
            #5. sample from the modified probabilities
            next_token = torch.multinomial(probs, 1)
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens


In [None]:
#move model to GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoregressiveDecoderTransformer(PADDING_TOKEN + 1, LARGEST_SEQUENCE_LENGTH+1, dim=512, num_layers=8, num_heads=8, dropout=DROPOUT_RATE).to(DEVICE)
print("Model has ", sum(p.numel() for p in model.parameters()), " parameters.")

LR = 0.001 #works best for this model & dataset

#linear warumup of learning rate; if you use AdamW instead of SFAdam, just apply that manually
WARMUP_STEPS = 2500

#how many (text) samples to generate after each epoch
SAMPLES_TO_GENERATE = 4

optimiser = schedulefree.AdamWScheduleFree(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01, warmup_steps=WARMUP_STEPS)

steps = 0

total_losses_train = []
total_losses_test  = []

import matplotlib.pyplot as plt

for epoch in range(0, 100):
    losses_train = []
    losses_test = []

    #1. Train:
    optimiser.train()
    model.train()
    
    its = 0
    start = time.time()
    last = start
    loss_fn = nn.CrossEntropyLoss()

    for batch in train_dataloader:
        steps += 1
        #pre-pad initial empty token (i.e. "start of sequence" token to also learn what the first "real" token should be)
        batch = torch.cat((torch.ones(batch.size()[0], 1).long() * PADDING_TOKEN, batch), dim=1)
        batch = batch.to(DEVICE)

        optimiser.zero_grad()
        logits = model(batch)
        
        target = batch[:,1:] #shifted left by one; we want to predict the next token
        output = logits[:,:-1] #remove the last prediction; we don't want to predict anything at the last token
        
        loss = loss_fn(output.reshape(-1, output.size(-1)), target.reshape(-1))
        loss.backward()
        
        optimiser.step()
        losses_train.append(loss.item())
        its += 1
        if time.time() - last > 30: 
            print("\t\tTime left for TRAIN epoch: ", (time.time()-start)/its*(len(train_dataloader)-its)/60, " minutes; Currently, ",steps," steps in.")
            last = time.time()
    
    #2. Evaluate:
    optimiser.eval()
    model.eval()

    #(important for SF AdamW: only store stuff when in eval mode!)
    #save model:
    if not os.path.exists("stored"):
        os.makedirs("stored")
    torch.save(model.state_dict(), "stored/model_"+str(epoch)+".pt")
    #save optimiser:
    torch.save(optimiser.state_dict(), "stored/optimiser_"+str(epoch)+".pt")

    with torch.no_grad():
        its = 0
        start = time.time()
        last = start
        for batch in test_dataloader:
            #pre-pad initial empty token
            batch = torch.cat((torch.ones(batch.size()[0], 1).long() * PADDING_TOKEN, batch), dim=1)
            batch = batch.to(DEVICE)
            logits = model(batch)
            
            target = batch[:,1:] #shifted left by one; we want to predict the next token
            output = logits[:,:-1] #remove the last prediction; we don't want to predict anything at the last token
            
            loss = loss_fn(output.reshape(-1, output.size(-1)), target.reshape(-1))
            losses_test.append(loss.item())
            its += 1
            if time.time() - last > 30: 
                print("\t\tTime left for TEST epoch: ", (time.time()-start)/its*(len(train_dataloader)-its), " seconds.")
                last = time.time()
        
        print("*** DONE WITH EPOCH ", epoch, " - TRAIN LOSS: ", sum(losses_train)/len(losses_train), " - TEST LOSS: ", sum(losses_test)/len(losses_test)," ***")
        total_losses_train.append(sum(losses_train)/len(losses_train))
        total_losses_test.append(sum(losses_test)/len(losses_test))

        #save losses:
        torch.save(total_losses_train, "stored/total_losses_train_"+str(epoch)+".pt")
        torch.save(total_losses_test, "stored/total_losses_test_"+str(epoch)+".pt")
        
        plt.plot(total_losses_train, label="train")
        plt.plot(total_losses_test, label="test")
        plt.title("Losses at LR="+str(LR))
        plt.legend()
        #save plot & show:
        plt.savefig("losses_"+str(LR)+"_"+str(epoch)+".png")
        plt.show()

        #3. Inference / Generate:
        #skip first character, that's just the empty token:
        try:
            sampled_max = model.generate_max((torch.ones(1, 1).long() * PADDING_TOKEN).to(DEVICE), 251)
            sampled_mul = model.generate_mul((torch.ones(4, 1).long() * PADDING_TOKEN).to(DEVICE), 251)
            sampled_nuc = model.generate_nuc((torch.ones(4, 1).long() * PADDING_TOKEN).to(DEVICE), 251)
            print("\tGENERATED SENTENCE  MAX: ")
            print("\t\tMost likely story: ",decode_BPE(sampled_max[0, 1:]))
            print("\tGENERATED SENTENCE  MUL: ")
            for i in range(0, SAMPLES_TO_GENERATE):
                print("\t\tMultinomial sampling story "+str(i)+": ",decode_BPE(sampled_mul[i, 1:]))
            print("\tGENERATED SENTENCE  NUC: ")
            for i in range(0, SAMPLES_TO_GENERATE):
                print("\t\tNucleus (probably best) story "+str(i)+": ",decode_BPE(sampled_nuc[i, 1:]))
        except:
            print("FAILED TO GENERATE")
            continue