<h1>1. Build a loader to get the data</h1>
<h3>Note that we use the "Schedule Free" Adam from Facebook research: https://github.com/facebookresearch/schedule_free<br/>
This relieves us from the annoying task of providing a learning rate schedule for our transformer (i.e. adapt the learning rate in some pattern throughout training);<br/>
However, it still needs warmup.
</h3>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import time
import schedulefree
import matplotlib.pyplot as plt

#some global variables to make our life easier
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #move model to GPU if available
STEPS_WARMUP        = 2500 #warmup steps for the model
STEPS_WARMUP_ROUTER =  500 #warmup steps for the router
STEPS_ROUTER        = STEPS_WARMUP_ROUTER * 2 #how long do we train the router at all; after that, we just use the most likely cluster (=better for training)
BATCH_SIZE = 64+16
TEXTFILES_TO_USE = 1 #####max: 128, then the whole "tiny stories" dataset is used; 8/128 data chunks is a good value for trying stuff out
NO_CLUSTERS = 10 #number of clusters for our single model mixture of experts

In [2]:
rules = torch.load("BPE/rules.pt")
index_to_char = torch.load("BPE/index_to_char.pt")
char_to_index = torch.load("BPE/char_to_index.pt")

PADDING_TOKEN = len(index_to_char) #padding token is the last token in the index_to_char dictionary; we use it to shorter pad sequences to the max length
LARGEST_SEQUENCE_LENGTH = 250

#helper functions - the same as for our BPE
def apply_BPE(text, rules):
    text_as_indices = transcribe_chars_to_index(text)
    for rule in rules:
        i = 0
        while i < len(text_as_indices)-1:
            if text_as_indices[i] == rule[0][0] and text_as_indices[i+1] == rule[0][1]:
                text_as_indices[i] = rule[1]
                text_as_indices.pop(i+1)
            i += 1
            
    return text_as_indices

def transcribe_indices_to_chars(indices):
    return [index_to_char[indices[i]] for i in range(0, len(indices))]

def decode_BPE(tokens):
    #EXPECTS input to be ONE item, not a batch!
    #cut off every token >= PADDING_TOKEN!
    tokens_ = tokens
    tokens = []
    for token in tokens_:
        if token < PADDING_TOKEN:
            #if token is a tensor, append item:
            if isinstance(token, torch.Tensor):
                tokens.append(token.item())
            else:
                tokens.append(token)
        else:
            break
    #return "".join(transcribe_indices_to_chars(tokens)) #show full text
    return str(transcribe_indices_to_chars(tokens)) #show individual text fragments, in an array

def transcribe_chars_to_index(chars):
    indices = []
    for char in chars:
        indices.append(char_to_index[char])
    return indices

  rules = torch.load("BPE/rules.pt")
  index_to_char = torch.load("BPE/index_to_char.pt")
  char_to_index = torch.load("BPE/char_to_index.pt")


<h3>Load from multiple files, pad to a fixed length / filter out longer ones</h3>

In [3]:
class EncodedDataset(Dataset):
    def __init__(self, file_paths):
        self.data = []
        for file_path in file_paths:
            loaded = torch.load(file_path)
            self.data.extend(loaded)
        print("Loaded ", len(self.data), " samples from ", len(file_paths), " files.")
        
        if False: #print some statsitics; usually, throwing away samples with 250+ doesn't hurt much, but speeds up computation considerably
            largest = 0 #find the largest sequence length
            largest_sequence = None
            lengths = []
            for sample in self.data:
                lengths.append(len(sample))
                if len(sample) > largest:
                    largest = len(sample)
                    largest_sequence = sample
            
            print("Average sequence length: ", sum(lengths)/len(lengths))
            print("Median sequence length: ", sorted(lengths)[len(lengths)//2])
            print("90 percent length: ", sorted(lengths)[int(len(lengths)/10*8)])

            print("Largest sequence length: ", largest)
            print("Example of largest sequence: ", decode_BPE(largest_sequence))
        
        #throw away samples with LARGEST_SEQUENCE_LENGTH+ tokens
        #   -> without this, attention computation is very slow (quadratic scaling) with little benefit (very few samples are actually that long)
        self.data = [torch.cat((sample, torch.ones((LARGEST_SEQUENCE_LENGTH-len(sample),), dtype=torch.long) * PADDING_TOKEN)) for sample in self.data if len(sample) < LARGEST_SEQUENCE_LENGTH]
        #stack into one big tensor; for tinystories, we can fit everything into memory at the same time!
        self.data = torch.stack(self.data)
        print("Stored ", len(self.data), " tensors.")
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

#list all textfiles we want to load
textfiles = []
for i in range(0, TEXTFILES_TO_USE):
    textfiles.append("data/train_BPE_"+str(i)+".dat")

train_dataset = EncodedDataset(textfiles)
test_dataset = EncodedDataset(["data/validation_BPE.dat"])

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

  loaded = torch.load(file_path)


Loaded  132269  samples from  8  files.
Stored  107065  tensors.
Loaded  21953  samples from  1  files.
Stored  18280  tensors.


<h1>2. Transformer helpers</h2>

In [4]:
def positional_encoding(embed_dims, SEQUENCE_LENGTH):
    enc = torch.arange(SEQUENCE_LENGTH).unsqueeze(1).float()  # Use arange instead of ones
    denominator = torch.pow(10000, torch.arange(0, embed_dims, 2).float() / embed_dims)
    
    angle_rads = enc / denominator
    sin_vals = torch.sin(angle_rads)
    cos_vals = torch.cos(angle_rads)
    
    pos_enc = torch.zeros(SEQUENCE_LENGTH, embed_dims)
    pos_enc[:, 0::2] = sin_vals
    pos_enc[:, 1::2] = cos_vals
    return pos_enc

<h1>3. Build actual Transformer</h2>

<h3>Rough rundown of what a transformer does:</h3>
https://arxiv.org/abs/1706.03762 is the original idea; https://jalammar.github.io/illustrated-transformer/ explains it somewhat nicely:<br/>
For attention, you compute pairwise scores between all tokens, then use these scores to<br/>
mix your tokens together to new tokens. Exemplary, for "A black cat sat on the wall", the word "black" will "attend" to "cat", i.e. have a lot of attention on black;<br/>
meaning the tokens will be mixed such that we have a hybrid token thingy that says "black cat" (very crude).<br/><br/>
Transformers are just build out of stacked blocks ("layers"); each block consists of:<br/>
-an attention layer that computes pairwise scores, then re-mixes tokens accordingly<br/>
-normalisations & residuals<br/>
-a fully connected network part that is applied to EACH token after attention; meaning this does the heavy lifting,<br/>
while the attention is the only part where tokens get to know each other. The fully connected network is also the part where mixture of experts (MoE)<br/>
usually happens, i.e. where we apply a different network according to some routing process ("for math, we use net A, for french, net D, for german, net F, [...]")

<h3>Our MoE idea:</h3>
<b><u>Observation A:</u></b><br/>
Wisdom of the crowd effects benefit most tasks; If you ask 1000 people to estimate something, the average will come pretty close;<br/>
 an ensemble ofestimators will usually be better than a single one<br/>
<b><u>Observation B:</u></b><br/>
Applying dropout (randomly disabling some neurons) essentially brings such an ensemble effect: the network learns to work with many different combinations of neurons<br/>
<b><u>Observation C:</u></b><br/>
Mixture of Experts is, in a way, a selective ensemble, but fails to use knowledge that everyone has; general knowledge has to be in each expert (e.g. how grammar works, even if it's the history node or the math node, they need to individually learn grammar)<br/><br/>
We bring these observations together in ONE idea:
<h1>Single Mixture-of-Experts</h1>
We apply a simple routing to determine which expert gets asked; However, instead of physically different netowrks that, individually, don't know any information from the other ones, <br/>
we instead use dropout to determine an expert: For a given transformer, we apply an additional, deterministic dropout procedure. This dropout ("which neurons are turned off") are determined<br/>
by the router, i.e. are content-dependant.<br/>
In result, we don't have multiple experts, but a mixture of experts <b>in a single network</b>, as a result of selectively dropping out neurons according to which "expert", we'd want.<br/>
We still get to keep "one" network to not lose any knowledge when branching into an expert, but still allow some degree of specialisation.

<h3>Code blocks here:</h3>
<b>AutoregressiveDecoderTransformer</b> is a decoder-only transformer (i.e. predicts away token after token); it stacks multiple transformer decoder blocks,<br/>
then applies some head that does classification (i.e. gives us a pseudo probability distributuon over tokens) that we sample from.<br/><br/>

<b>TransformerDecoderBlock</b> is a transformer block: attention computation between all tokens, then apply the fully connected network to each token. Also contains normalisation and residuals<br/><br/>

<b>FeedForward</b> is the fully connected network that processes each token after "mixing" it in the attention layer<br/><br/>

<b>CausalSelfAttention</b> is there to speed up training: When we process a full sentence, we can just "cover up" some part of the attention matrix to get a subset that describes part of a sentence;<br/>
if we have "The| |black| |cat", we can compute the whole 5-by-5 attention matrix (whole sentence) and just cover up pieces of it for e.g. "The| |black". This is what makes transformers so fast -<br/>
they somewhat learn on each prefix in parallel.<br/><br/>

<b>LinearPlusDropout</b> is a regular linear layer, but can also receive a class label; if it receives a class label, it performs a deterministic dropout.<br/>
Meaning: If we give it e.g. label "5" multiple times, it will always zero out the same neurons, but a different dropout than if we give it the label "3"<br/><br/>

<b>AutoregressiveRouterTransformer</b> is our router; also stacks transformer blocks, but then proceeds to output a label to which "expert" we branch out for the current part of the sentence (=which dropout we perform)<br/><br/>

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

#set to 0.0 for no dropout
DROPOUT_RATE = 0.1
sMoE_DROPOUT_RATE = 0.2 #make it large enough so each expert is significantly different from the others

class LinearPlusDropout(nn.Module):
    #our special way of applying dropout is built in here:
    #a regular linear layer, but with a dropout mask that is applied to the output
    #   (i.e. we don't just zero out some weights, but a deterministic set of neurons)
    def __init__(self, in_features, out_features, no_classes=NO_CLUSTERS):
        super(LinearPlusDropout, self).__init__()
        self.linear = nn.Linear(in_features, out_features)

        self.out = out_features

        #make mask a parameter:
        self.mask = nn.Parameter((torch.rand(no_classes, self.out) > sMoE_DROPOUT_RATE).float(), requires_grad=False)
        
    def forward(self, x, expert_labels=None):
        x = self.linear(x)
        #should be [b]-sized
        if expert_labels != None:
            shape_b4 = x.size()
            #flatten to [-1, out_features]
            expert_labels = expert_labels.view(-1)
            x = x.view(-1, x.size()[-1])
            #find individual mask for each class
            mask = self.mask[expert_labels]
            x = x * mask

            scale = mask.sum(dim=-1) / self.out
            x = x / scale[:,None]
            #reshape back to original shape
            x = x.view(shape_b4)
            
        return x

class FeedForward(nn.Module):
    #feed forward with ReLU; is applied to each token individually
    #   (i.e. we don't care about the order of tokens here, just apply the same operation to each token)
    #   this is basically where the knowledge is at
    def __init__(self, dim):
        super().__init__()
        #two linear layers with ReLU in between
        #note how size expands and contracts:
        #   imagine a puzzle you try to solve - you want the table you do it on
        #   to be big enough to lay out all the pieces instead of just big enough 
        #   to hold the final result
        self.lin_1 = LinearPlusDropout(dim, dim * 4)
        self.lin_2 = LinearPlusDropout(dim * 4, dim)
        #I'd suggest leaky ReLU, but ReLU is the standard used in TFs
        self.relu = nn.ReLU()
        
    def forward(self, x, expert_labels=None):
        x = self.relu(self.lin_1(x, expert_labels))
        x = self.lin_2(x, expert_labels)
        return x

class CausalSelfAttention(nn.Module):
    #>>Causal<< self-attention means that the model can only look at previous tokens:
    #   this is important for autoregressive models, as we can then recycle a lot
    #   of the computation for attention & train all prefixes at once
    #   (e.g. for ABCDEFG, we train the next token after A, after AB, after ABC, ... in one go)
    def __init__(self, dim, num_heads):
        super().__init__()
        #ensure that we can split the dimension into num_heads
        #each head will then have dim/num_heads dimensions, i.e.
        #we divide the input into num_heads parts to 
        #   a) keep matrix sizes tame
        #   b) ensure that different heads can focus on different tasks
        #      (softmax focuses the attention largely on one part of the input,
        #       then different heads can focus on different parts)
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        #scale factor for attention - as the dot product grows with the dimension, 
        #   we scale it down to prevent the softmax from getting too extreme / sharp
        self.scale = 1.0 / (math.sqrt(self.head_dim))

        #linear layer for query, key, value (=apply query, key, value matrix to input)
        #here, we make life easy for us and apply one linear layer (=also multiplies the input by a matrix)
        #   (instead of having separate matrices for each)
        #   --> same number of parameters, but less code / all in one go
        #   (don't forget the bias=False, as we don't want to add a bias here; we just want a matrix multiplication, essentially)
        #   (linear layer is just W * x + b)

        self.qkv_proj = nn.Linear(dim, 3 * dim, bias=False)
        
        self.out_proj = nn.Linear(dim, dim, bias=False)

    def forward(self, x, mask=None):
        batch_size, T, C = x.shape
        #produce key, query, value from input
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        #divide into num_heads parts (=split the dimension up)
        q, k, v = map(lambda t: t.view(batch_size, T, self.num_heads, self.head_dim).transpose(1, 2), qkv)

        #compute scaled dot product attention:
        #   dot product of query and key, then scale it down to prevent softmax from getting too extreme
        attn_weights = (q @ k.transpose(-2, -1)) * self.scale
        #to prevent the model from looking at future tokens (for autoregressive training),
        #   we mask the attention weights for tokens that are in the future;
        #   this is done by setting the attention weights for future tokens to -inf
        #   (as softmax(-inf) = 0, i.e. the model will ignore these tokens)
        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_weights, dim=-1)
        
        out = (attn_weights @ v).transpose(1, 2).contiguous().view(batch_size, T, C)
        return self.out_proj(out)

class TransformerDecoderBlock(nn.Module):
    #transformer decoder block:
    #   self attention, dropout
    #   residual & layer norm
    #   feed forward, dropout
    #   residual & layer norm

    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.attn = CausalSelfAttention(dim, num_heads)
        self.ffn = FeedForward(dim)

        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        #dropout just to prevent overfitting, i.e. memorising stuff:
        #   we want the model to learn the structure of the data, not the data itself!
        #   dropout is a simple way to prevent the model from memorising the data
        #   by randomly setting some weights to zero;
        #   i.e. the model can't rely on just memorising individual aspects,
        #   but has to learn the structure of the data in a general and redundant (=robust) way
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None, expert_labels=None):
        #self attention and dropout, then residual
        x = x + self.dropout(self.attn(x, mask))
        #layer norm
        x = self.norm1(x)
        
        #feed forward and dropout, then residual
        x = x + self.dropout(self.ffn(x, expert_labels))
        #layer norm
        x = self.norm2(x)
        return x

class AutoregressiveDecoderTransformer(nn.Module):
    def __init__(self, vocab_size, max_seq_len, dim, num_layers, num_heads, dropout=DROPOUT_RATE):
        super().__init__()
        #embed tokens with something learnable
        self.token_embedding = nn.Embedding(vocab_size, dim)
        #embed positions with something computed
        self.pos_embedding = torch.nn.Parameter(positional_encoding(dim, LARGEST_SEQUENCE_LENGTH+1)[None], requires_grad=False)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.decoder = nn.Linear(dim, vocab_size, bias=False)

        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)).unsqueeze(0).unsqueeze(0))

    def forward(self, x, expert_labels=None):
        batch_size, tokens = x.shape
        token_emb = self.token_embedding(x)
        pos_emb = self.pos_embedding[:,:x.size()[1]]
        
        x = token_emb + pos_emb
        
        mask = self.mask[:, :, :tokens, :tokens]
        for layer in self.layers:
            x = layer(x, mask, expert_labels)

        return self.decoder(x)

    #maximum sampling - useful for debugging, but will always generate the same sequence ("always pick the most likely token as next token")
    def generate_max(self, tokens, max_new_tokens, router):
        for _ in range(max_new_tokens):
            indices = router(tokens).argmax(dim=-1)
            logits = self.forward(tokens, indices)
            next_token = torch.argmax(logits, dim=-1, keepdim=False)[:, -1:]
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens
    #sample just randomly according to probability - has a chance to pick some really messed up
    def generate_mul(self, tokens, max_new_tokens, router):
        for _ in range(max_new_tokens):
            indices = router(tokens).argmax(dim=-1)
            logits = self.forward(tokens, indices)
            next_token = torch.multinomial(F.softmax(logits, dim=-1)[:, -1], 1)
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens
    def generate_nuc(self, tokens, max_new_tokens, router):
        for _ in range(max_new_tokens):
            #1. get logits
            indices = router(tokens).argmax(dim=-1)
            logits = self.forward(tokens, indices)
            #2. turn into probabilities
            probs = F.softmax(logits, dim=-1)[:, -1]
            #3. sort & cumsum to get the cumulative probability to cut off everything beyond 90%
            sorted, indices = torch.sort(probs, descending=True)
            cumulative = torch.cumsum(sorted, dim=-1)
            #find the first index where the cumulative probability is larger than 0.9
            cutoff = torch.argmax((cumulative > 0.9).long(), dim=-1)
            #4. null out everything beyond 90%
            for b in range(0, probs.size()[0]):
                cutoff_index = cutoff[b] + 1
                probs[b, indices[b, cutoff_index:]] = 0.0
            #5. sample from the modified probabilities
            next_token = torch.multinomial(probs, 1)
            tokens = torch.cat([tokens, next_token], dim=1)
        return tokens
    
class AutoregressiveRouterTransformer(nn.Module):
    def __init__(self, vocab_size_in, max_seq_len, dim, num_layers, num_heads, dropout=DROPOUT_RATE):
        super().__init__()
        #embed tokens with something learnable
        self.token_embedding = nn.Embedding(vocab_size_in, dim)
        #embed positions with something computed
        self.pos_embedding = torch.nn.Parameter(positional_encoding(dim, LARGEST_SEQUENCE_LENGTH+1)[None], requires_grad=False)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(dim, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.decoder = nn.Linear(dim, NO_CLUSTERS, bias=False)

        self.register_buffer("mask", torch.tril(torch.ones(max_seq_len, max_seq_len)).unsqueeze(0).unsqueeze(0))

    def forward(self, x):
        batch_size, tokens = x.shape
        token_emb = self.token_embedding(x)
        pos_emb = self.pos_embedding[:,:x.size()[1]]
        
        x = token_emb + pos_emb
        
        mask = self.mask[:, :, :tokens, :tokens]
        for layer in self.layers:
            x = layer(x, mask)

        return self.decoder(x)

<h3>Training</h3>
We always train the router ALONGSIDE the full transformer<br/>
Rough idea: Router produces a probability distribution from which we sample;<br/>
We then find the tokens that produces higher-than-average error; Those are the ones our router<br/>
should assign a new probability to. We do so by reducing the probability of those chosen tokens.

In [None]:
CE_LOSS = nn.CrossEntropyLoss(reduction="none")

def test_model(optimiser, model, optimiser_router, router):
    optimiser.eval()
    model.eval()
    optimiser_router.eval()
    router.eval()

    its = 0
    start = time.time()
    last = start
    
    losses_test = []
    for batch in test_dataloader:
        #pre-pad initial empty token
        batch = torch.cat((torch.ones(batch.size()[0], 1).long() * PADDING_TOKEN, batch), dim=1)
        batch = batch.to(DEVICE)
        #for evaluation, we don't need to sample from the router; sampling is only a trick to train the router, so we just pick the most likely one!
        indices = router(batch).argmax(dim=-1)
        logits = model(batch, indices)
        
        target = batch[:,1:] #shifted left by one; we want to predict the next token
        output = logits[:,:-1] #remove the last prediction; we don't want to predict anything at the last token
        
        loss = CE_LOSS(output.reshape(-1, output.size(-1)), target.reshape(-1)).mean()
        losses_test.append(loss.item())
        its += 1
        if time.time() - last > 30: 
            print("\t\tTime left for TEST epoch: ", (time.time()-start)/its*(len(train_dataloader)-its), " seconds.")
            last = time.time()
    return losses_test

def train_router(optimiser_router, routing, indices, loss_terms, losses_router_correctness, losses_router_distribution):
    #train router:
    optimiser_router.zero_grad()

    bsize = routing.size()[0]
    routing = routing.reshape(-1, NO_CLUSTERS)
    indices = indices.reshape(-1)
    probabilities_of_chosen = routing[torch.arange(indices.size()[0]), indices].view(bsize, -1)
    
    #FIRST : identify which losses are performing BELOW average,
    #       i.e. error is HIGHER than average error (loss - avg is GREATER 0)
    #       (ignore all losses that are performing better than average, i.e. are BELOW zero)
    #       To further improve: Find the average loss PER POSITION; some positions are naturally more difficult than others, e.g. first vs last token in a story!
    losses_below_average = (loss_terms.detach() - loss_terms.detach().mean(dim=0)[None]).clamp(min=0.0)
    
    #SECOND: punish those by punishing the probability of choosing this cluster; low probability samples don't matter as much here!
    loss_router_correctness = ((losses_below_average * probabilities_of_chosen)).square().mean() * 10.0
    
    loss_router_distribution = 0.0
    max_indices = routing.argmax(dim=-1)
    for index in range(0, NO_CLUSTERS):
        #get all indices where the model chose this cluster
        indices_for_cluster = (indices == index).nonzero().squeeze()
        
        desired = max_indices.size()[0] / NO_CLUSTERS
        actual = indices_for_cluster.size()[0]

        edge_index = (desired - actual) / (max_indices.size()[0] / NO_CLUSTERS)

        loss_router_distribution = loss_router_distribution + (routing[indices_for_cluster, index] - (routing[indices_for_cluster, index].detach() + edge_index).clamp(0.0, 1.0)).square().sum()
    loss_router_distribution = loss_router_distribution / max_indices.size()[0] * 5000.0
    
    loss_router = loss_router_distribution + loss_router_correctness
    
    losses_router_distribution.append(loss_router_distribution.item())
    losses_router_correctness.append(loss_router_correctness.item())

    loss_router.backward()
    optimiser_router.step()

    return losses_router_correctness, losses_router_distribution

def train_model(optimiser, model, optimiser_router, router, steps):
    its = 0
    start = time.time()
    last = start
    losses_train, losses_router_correctness, losses_router_distribution = [], [], []
    
    optimiser.train()
    model.train()
    optimiser_router.train()
    router.train()
    
    if steps >= STEPS_ROUTER:
        router.eval()
        optimiser_router.eval()
    
    for batch in train_dataloader:
        steps += 1
        #pre-pad initial empty token
        batch = torch.cat((torch.ones(batch.size()[0], 1).long() * PADDING_TOKEN, batch), dim=1)
        batch = batch.to(DEVICE)

        routing = F.softmax(router(batch), dim=-1) #get routing probabilities
        routing_probabilities = routing + 0.2 / NO_CLUSTERS #add some epsilon noise to prevent 0 probabilities
        routing_probabilities = routing_probabilities / routing_probabilities.sum(dim=-1, keepdim=True) #normalise routing probabilities

        if steps < STEPS_ROUTER:
            #during training, do SOFT PROBABILISTIC routing
            #sample from routing probabilities
            indices = torch.multinomial(routing_probabilities.reshape(routing_probabilities.size()[0] * routing_probabilities.size()[1], -1), 1).view(routing_probabilities.size()[0], routing_probabilities.size()[1])
            max_indices = routing.argmax(dim=-1) #just to sanity check how we WOULD distribute
        else:
            #after router initialisation (first few thousand steps or so), do HARD routing: pick the best, we don't train it anymore
            indices = routing.argmax(dim=-1)
            max_indices = indices
        
        #remove last token from routing probabilities and routing --> we don't need to make a prediction after obtaining the last token
        routing_probabilities = routing_probabilities[:,:-1,:]
        routing = routing[:,:-1,:]

        optimiser.zero_grad()

        logits = model(batch, expert_labels=indices)
        
        indices = indices[:,:-1]
        target = batch[:,1:] #shifted left by one; we want to predict the next token
        output = logits[:,:-1] #remove the last prediction; we don't want to predict anything at the last token
        
        loss = CE_LOSS(output.reshape(-1, output.size(-1)), target.reshape(-1))
        loss_terms = loss.clone().detach().view(output.size()[0], -1) #test if we need the clone even; we do need the detach
        loss = loss.mean()
        loss.backward()
        
        optimiser.step()
        losses_train.append(loss.item())

        #For training the router (only first few steps), we use two objectives:
        #       a) distribute the samples evenly among the clusters
        #       b) punish the router for choosing clusters that perform worse than average, i.e. get "smart" in choosing the right cluster
        if steps < STEPS_ROUTER:
            losses_router_correctness, losses_router_distribution = train_router(optimiser_router, routing, indices, loss_terms, losses_router_correctness, losses_router_distribution)
        
        its += 1
        if time.time() - last > 30: 
            #print moving average over last 20 items:
            print("Losses Tokens: ", sum(losses_train[-20:])/len(losses_train[-20:]))
            if steps < STEPS_ROUTER:
                #output loss terms for router to track it's progress
                print("Router  Correctness: ", sum(losses_router_correctness[-20:])/len(losses_router_correctness[-20:]))
                print("Router Distribution: ", sum(losses_router_distribution[-20:])/len(losses_router_distribution[-20:]))
                print("Router Distribution LOCAL: ", losses_router_distribution[-1])
            #output indices - how many of which ones? helps us to see if the router is doing his job
            print("Index distribution: ", torch.unique(indices, return_counts=True)[1])
            print("Index distribution MAX: ", torch.unique(max_indices, return_counts=True)[1])
            print("Index distribution MAX LOSS: ", (torch.unique(max_indices, return_counts=True)[1]-2000.0).square().mean())
            print("\n")
            print("\t\tTime left for TRAIN epoch: ", (time.time()-start)/its*(len(train_dataloader)-its)/60, " minutes; Currently, ",steps," steps in.")
            last = time.time()
    return losses_train, steps

def plot_store_losses(total_losses_train, total_losses_test, identifier, epoch):
    #save losses:
    torch.save(total_losses_train, "stored/"+identifier+"_total_losses_train_"+str(epoch)+".pt")
    torch.save(total_losses_test, "stored/"+identifier+"_total_losses_test_"+str(epoch)+".pt")
    
    plt.plot(total_losses_train, label="train")
    plt.plot(total_losses_test, label="test")
    plt.title("Losses")
    plt.legend()
    #save plot & show:
    plt.savefig("stored/"+identifier+"_losses_"+str(epoch)+".png")
    plt.show()

def train():
    identifier = "sMoE"

    model = AutoregressiveDecoderTransformer(PADDING_TOKEN + 1, LARGEST_SEQUENCE_LENGTH+1, dim=512, num_layers=8, num_heads=8, dropout=DROPOUT_RATE).to(DEVICE)
    print("Model has ", sum(p.numel() for p in model.parameters()), " parameters.")
    #router model is about a quarter of the parameters, should suffice
    router = AutoregressiveRouterTransformer(PADDING_TOKEN + 1, LARGEST_SEQUENCE_LENGTH + 1, dim=256, num_layers=8, num_heads=8, dropout=DROPOUT_RATE).to(DEVICE)
    print("Router has ", sum(p.numel() for p in router.parameters()), " parameters.")

    LR = 0.001 #works best for this model & dataset
    SAMPLES_TO_GENERATE = 4

    optimiser = schedulefree.AdamWScheduleFree(model.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01, warmup_steps=STEPS_WARMUP) #also check 0.1; before: 0.01
    optimiser_router = schedulefree.AdamWScheduleFree(router.parameters(), lr=LR, betas=(0.9, 0.999), weight_decay=0.01, warmup_steps=STEPS_WARMUP_ROUTER) #also check 0.1; before: 0.01

    steps = 0 #how many steps we've done so far; used for switching off router training at some point

    total_losses_train = []
    total_losses_test  = []

    for epoch in range(0, 100):
        #1. Train:
        losses_train, steps = train_model(optimiser, model, optimiser_router, router, steps)
        #2. Evaluate:
        with torch.no_grad():
            losses_test = test_model(optimiser, model, optimiser_router, router)
            
            print("*** DONE WITH EPOCH ", epoch, " - TRAIN LOSS: ", sum(losses_train)/len(losses_train), " - TEST LOSS: ", sum(losses_test)/len(losses_test)," ***")
            total_losses_train.append(sum(losses_train)/len(losses_train))
            total_losses_test.append(sum(losses_test)/len(losses_test))

            plot_store_losses(total_losses_train, total_losses_test, identifier, epoch)

        #3. Inference / Generate:
        with torch.no_grad():
            #remember to always skip the first character befpre outputting, that's just the empty token (serving as the start of sequence token):
            try:
                sampled_max = model.generate_max((torch.ones(1, 1).long() * PADDING_TOKEN).to(DEVICE), 251, router)
                sampled_mul = model.generate_mul((torch.ones(4, 1).long() * PADDING_TOKEN).to(DEVICE), 251, router)
                sampled_nuc = model.generate_nuc((torch.ones(4, 1).long() * PADDING_TOKEN).to(DEVICE), 251, router)
                print("\tGENERATED SENTENCE  MAX: ")
                print("\t\tStory 0: ",decode_BPE(sampled_max[0, 1:]))
                print("\tGENERATED SENTENCE  MUL: ")
                for i in range(0, SAMPLES_TO_GENERATE):
                    print("\t\tStory "+str(i)+": ",decode_BPE(sampled_mul[i, 1:]))
                print("\tGENERATED SENTENCE  NUC: ")
                for i in range(0, SAMPLES_TO_GENERATE):
                    print("\t\tStory "+str(i)+": ",decode_BPE(sampled_nuc[i, 1:]))
                
            except:
                print("FAILED TO GENERATE")
                continue

        #4. Store model & optimiser:
        if True:
            #(important for SF AdamW: only store stuff when in eval mode!)
            #save model:
            torch.save(model.state_dict(), "stored/"+identifier+"_model_"+str(epoch)+".pt")
            #save optimiser:
            torch.save(optimiser.state_dict(), "stored/"+identifier+"_optimiser_"+str(epoch)+".pt")

train()

Model has  27634176  parameters.
Router has  7003648  parameters.
Losses Tokens:  5.009021067619324
Router  Correctness:  0.3907765462994576
Router Distribution:  23.067965030670166
Router Distribution LOCAL:  33.40509033203125
Index distribution:  tensor([2005, 2016, 1763, 1704, 1842, 2372, 2024, 2045, 2078, 2151],
       device='cuda:0')
Index distribution MAX:  tensor([1739, 1548,  652,  405, 1128, 4824, 2448, 1853, 2815, 2668],
       device='cuda:0')
Index distribution MAX LOSS:  tensor(1470167.6250, device='cuda:0')


		Time left for TRAIN epoch:  8.732993670359049  minutes; Currently,  73  steps in.
Losses Tokens:  4.7302128076553345
Router  Correctness:  0.30810931921005247
Router Distribution:  15.431269145011902
Router Distribution LOCAL:  10.69532585144043
Index distribution:  tensor([1927, 1968, 1991, 1910, 2160, 2019, 2026, 2116, 2052, 1831],
       device='cuda:0')
Index distribution MAX:  tensor([1475, 1463, 1906, 1383, 3553, 1466, 1406, 2916, 3490, 1022],
       device=