# ⌚ Let's be quick.

1. Define the functions
2. Roll in!

> Why?
>
> We need to re-define the functions as these are hte pickle files, and they will need to lookup to the definitions. For the mature saving, we would need to do some *official* stuff, but for us, for now the things will work just fine.

So below, I will be re-defining the methods. No changes have been made.

# `0.` Imports

In [1]:
import torch, pickle
import torch.nn as nn # for layers and stuff
from torch.nn import functional as F # for the loss function and softmax
torch.manual_seed(1337) # same as in the lecture

<torch._C.Generator at 0x17fe866d8b0>

# `1.` Re-defining everything 

In [2]:
class BigramLM(nn.Module):
    """
    The Final BigramLM model does the following:
    
    ## Has:
    1. Token embedding layer
    2. Position embedding layer
    3. Nx Blocks which has multihead attentions and feed-forward
    4. Finally the LM-head
    5. The shapes written in comments
    
    ## Does:
    1. Takes the input which will be in the B, T format
    2. Converts them into B, T, C (starting with the Token embedding layer)
    3. The rest is the history... you really want me to talk much!? 
    """
    
    
    def __init__(self):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, n_embd)       
        self.positions_embeddings = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential( 
            *[Block(n_embd, n_head=n_head) for _ in range(n_layers)]
        )
        self.ln_f = nn.LayerNorm(n_embd)          
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
        
    def forward(self, idx, targets=None):
        B, T = idx.shape 
        tok_emb = self.embedding_table(idx) 
        positions_emb = self.positions_embeddings(torch.arange(T, device=device))
        x = tok_emb + positions_emb         
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)             
    
        if targets is None: 
            loss=None
        else:               
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)        
        return logits, loss
        
        
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)

        return idx

In [3]:
class Block(nn.Module):
    """
    The block basically is the collection of self attention layers (multi) and 
    the feed forward layers with residual connections and the layer norm layers.
    
    All we want to do is to isolate them so that we can make as many as we want
    and get better results!
    """
    
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa_heads = MultiHeadAttention(n_head, head_size) 
        self.add_norm_1 = nn.LayerNorm(n_embd)
        self.ffwd = FeedForward(n_embd)
        self.add_norm_2 = nn.LayerNorm(n_embd)

        
    def forward(self, x):
        x = x + self.sa_heads(self.add_norm_1(x))  # B, T, head_size
        x = x + self.ffwd(self.add_norm_2(x)) 
        return x

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout) ###
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out)) ###
        return out 

In [5]:
class Head(nn.Module):
    """
    This class will simply create the Q, K, V vectors
    and also the reguster_buffer to create the mask.
    
    Then on the `forward` it will pass the vectors in the 
    Q, K, V and give the `out`.
    """
    
    def __init__(self, head_size):
        super().__init__()
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))
        
        self.dropout = nn.Dropout(dropout) ###

    def forward(self, x):
        '''
        Take the `x` input which will be the positions.
        The shape will be B, T, C meaning:
        "For each batch, there will be T tokens which will have positions encoded in C
        space"
        
        We will use that and work oursalves forward.
        '''
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        wei = q @ k.transpose(-2, -1) * C**-0.5 # the C**-0.5 is used to control the variance
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # the mask
        wei = F.softmax(wei, dim=-1) # the final wei
        
        wei = self.dropout(wei) ###
        out = wei @ v # this is what we will use further
        return out

In [6]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout) ###
        )
        
    def forward(self, x):
        return self.net(x)

In [7]:
import regex as re
from typing import List, Tuple, Dict
from collections import defaultdict
from tqdm import tqdm

class RegexTokenizer:
    '''
    This is supposed to get a little crazy.
    
    Step 1: Split the text based on the regex pattern.
    Step 2: Now, we have the cleaned words.
    Step 3: Get their raw tokens individually.
    Step 4: Don't merge them yet, because it will nullify the step 1-3. 
    Step 4: Find pairs (stats) for each of the words - while keeping "common" stats across each.
    Step 5: Find the max repetative pair.
    Step 6: Replace that pair in each token group.
    '''
    def __init__(self):
        # initialize the defaut vocab
        self.vocab = {idx:bytes([idx]) for idx in range(256)}
        self.trained=False
        self.GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.GPT4_PATTERN_COMPILED = re.compile(self.GPT4_SPLIT_PATTERN)
    
    def find_most_repeated_pair(self, tokens, counter=None) -> Tuple[Tuple, int, Dict]:
        '''
        Now, this function is changed slightly as we will calculcate the 
        max when needed after this function call.
        
        Also, the `counter` can be passed and updated, and returned.
        Doing this will ensure, the global counter.
        '''
        counter = counter if counter is not None else defaultdict(int)
        for pair in zip(tokens, tokens[1:]):
            counter[pair] += 1
        return counter # will be useful when the counter=None passed.

    def replace_pair_with_new_token(self, tokens, pair, new_idx) -> List:
        new_tokens = [] # this will hold the copy for the new tokens
        idx = 0
        while idx < len(tokens):
            if idx < len(tokens) - 1 and (tokens[idx] == pair[0]) and (tokens[idx + 1] == pair[1]): # this is a match!
                new_tokens.append(new_idx)
                idx += 2
            else: # this is not a match
                new_tokens.append(tokens[idx])
                idx += 1
        return new_tokens
        
    def train(self, blob, vocab_size=None) -> None:
        '''
        This function will train the tokenizer based on the 
        training data given as text.
        
        1. blob: The data in text format that will be used as training
            of the tokenizer.
        
        2. vocab_size: This is "how many new tokens you want to generate"
            - `None` means indefinite; generate all combinations.
            - `int` means the number of merges.
        '''
        self.vocab_size = vocab_size
        
        # First split
        cleaned_text = self.GPT4_PATTERN_COMPILED.findall(blob)
        # Then create the tokens
        self.tokens = [list(map(int, word.encode("utf-8"))) for word in cleaned_text]
        
        
        new_idx = 255
        merges = {}
        for i in tqdm(range(vocab_size)):
            stats = defaultdict(int)
            for token_group in self.tokens:
                # pass the stats, which will be updated in place
                self.find_most_repeated_pair(token_group, stats)
            
            max_pair = max(stats, key=stats.get)
            max_count = stats[max_pair]
            
            if max_count > 1:
                new_idx += 1
                self.tokens = [self.replace_pair_with_new_token(token_group, max_pair, new_idx) for token_group in self.tokens]
                merges[max_pair] = new_idx
            else: # every pair is occuring for once only
                break
        self.total_merges = i+1
        
        ## The training is done now merge the stuff
        for pair, idx in merges.items():
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]
        self.merges = merges   
        self.trained = True
        
    def encode(self, text):
        '''
        The goal of this function is to encode the given text into the 
        tokens that are acceptable by our `vocab`.
        
        So, we will need to keep encoding the tokens form the start (top)
        to the bottom.
        
        The `order` of the vocab **is not guerenteed** in the older versions
        of python, so we wil need to rely on the `idx`. The lower the idx
        is, the older that token is!
        '''
        
        if not self.trained:
            raise NotImplementedError("Please first train the tokenizer!")
        
        # tokens = text.encode("utf-8")
        split_words = self.GPT4_PATTERN_COMPILED.findall(text)
        split_tokens = [list(word.encode("utf-8")) for word in split_words]
        
        final_tokens = []
        for chunk in split_tokens:
            while len(chunk) >= 2:
                stats = self.find_most_repeated_pair(chunk)
                # now the goal is to get all pairs of the new tokens
                # we are not interested in the count, just the pairs
                # then check for each pair, if 
                pair_replace = min(self.merges, key=lambda x: stats.get(x, float("inf")))
                if pair_replace in stats:
                    chunk = self.replace_pair_with_new_token(chunk, 
                                                     pair_replace,
                                                     self.merges[pair_replace])
                else:
                    break
            final_tokens.extend(chunk)
        return final_tokens
    
    def decode(self, tokens):
        decoded_stream = [self.vocab[idx] for idx in tokens]
        text = b"".join(decoded_stream)
        return text.decode("utf-8")

# `2.` Also the training parameters

In [8]:
batch_size = 64      # samples we will use for the single forward pass
block_size = 256     # the context window (significantly bigger than our toy examples)
max_iters = 5000     # total forward-backward passes

eval_interval = 500  # after how many steps we want to print the loss?
learning_rate = 3e-4 # learning rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'

eval_iters = 200    # when printing the loss, how many samples to consider for validation?
n_embd = 384        # embedding size of each token
n_head = 6          # `n` multi heads for the self-attention
n_layers = 6        # `n` for `Nx` which shows how many blocks to use
dropout = 0.2       # randomly drop % percentage of waights from getting trained for that single pass

# `3.` Finally, loading the model 🎉  

In [9]:
model = torch.load("./model/ShakeGPT_BPE.pt")
model = model.to(device)

# `4.` Loading the tokenizer 🎉🎉  

In [10]:
with open("./model/tokenizer.pkl", "rb") as file:
    tokenizer = pickle.load(file)

# `5.` Generating stuff!!! 🎉🎉🎉 

In [16]:
print(tokenizer.decode(model.generate(idx = torch.zeros((1, 1), 
                          dtype=torch.long,
                          device=device),  ### 🗽 Transfer to device 🗽 ###
        max_new_tokens=2048)[0].tolist()))

 ed friend
Till health for her by his brother:, who, forgive her
Shall rebellion too.

GLOUCESTER:
There straight have made us, you'll show me wail the city.

GLOUCESTER:
But I spit, my brother is careless to you,
If he stand for being all, that he is no less
Abid our kindred title at Bontona.

HAMBENBROKE:
O, that I uncle Your brother, your humour be spirit of sorrow;
Who stands but in the mouth as free,
It were well deserved with Mowbray's forefellow.

QUEEN MARGARET:
Pithee of him, Lord Hastings, and you come.
What, will you beg of this, that ill, or you are of gone?

Boy:
Ay, but not up for yourself; for, my good lord.

GLOUCESTER:
And we will not hold your subject.

PRINCE EDWARD:
Thanks, farewell: O, what news!
See that will not despair of that?

GLOUCESTER:
Now, lords, Edward, after man,
Grant us ye not our friends; Sir Wednesee
King Edward's love, and not as of pause
As if that we should second yield them both;
Away to rcourage it rather
Than those targed as I have my CARGARET:

## Here's the output image, if Github doesn't render above cell 👆🏻 
<img src="./images/bpe-output.png">

# Looks great! Isn't it?
Let'a take a pause, and meet in the next adventure! 🥦