In [None]:
import os
import sys
import time 
import argparse
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
from torch.nn import Functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
@dataclass
class ModelConfig:
    
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size - 1]
    # parameters below control the size of each model slightly differently 
    
    n_layers: int = 4
    n_embd: int = 64
    n_embd2: int = 64
    n_head: int = 4

#### Transformer Language Model as used in GPT-2

In [None]:
class NewGELU(nn.Module):
    
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Gaussian Error Linear Units (GELU):    https://arxiv.org/abs/1606.08415
    """
    
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0/ math.pi ) + (x + 0.044715 * torch.pow(x, 3.0)) ))

In [None]:
class CausalSelfAttention(nn.Module):
    """
    A simple multi-head masked self-attention layer with a projection at the end. 
    
    Similar to torch.nn.MultiheadAttention
    """
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projection for all head, but in a batch
        self.c_attn == nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1, config.block_size, config.block_size ))
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim.
        q, k , v = self.c_attn(x).split(self.n_embd, dim = 2)
        k = k.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        
        # causal self.attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B ,nh ,T ,hs) -> (B, nh, T, hs)
        y = y.transpose(1,2).contiguous().view(B,T,C) # re-assemble will head outputs side by side
        
        # output projection
        y = self.c_proj(y)
        return y

In [None]:
class Block(nn.Module):
    """ Unassuming Transformer Block """
    
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
            act = NewGELU(),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
        
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x
    

In [None]:
class Transformer(nn.Module):
    """ Transformer Language Model, similar to GPT-2 """
    
    def __init__(self, config):
        super().__init__()
        self.black_size = config.block_size
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias = False)
        
        # Report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.paramters())
        print(f"number of paramters: {n_params/1e6 : .2f}")
        
    def get_block_size(self):
        return self.block_size
    
    def forward(self, idx, targets = None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1,t)
        
        # forward the GPT model 
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        
        # If we are given some desired targets so calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index = -1 )
            
        return logits, loss

#### Bag of Words (BOW) language Model

In [None]:
class CausalBoW(nn.Module):
    """
    Causal bag of words. Averages the preceding elements and looks suspiciously like a CausalAttention module found in a transformer.
    """
    def __init__(self, config):
        super().__init__()
        
        
        # used to mask out vectors and preserve autoregressive property
        self.block_size = config.block_size
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(t,
                                                                                                      config.block_size, config.block_size))
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, n_embd
        
        # do the weighted average of all preceeding token features
        att = torch.zeroes((B,T,T), device = x.device)
        att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim = -1)
        y = att @ x # (B,T,T) x (B, T, C) -> (B,T,C)
        
        return y

In [None]:
class BowBlock(nn.Module):
    """
    collects BoW features and adds an MLP
    """
    
    def __init__(self, config):
        super().__init__()
        
        # Causal BoW module
        self.cbow = CausalBoW(config)
        # MLP assembler
        self.mlp = nn.ModuleDict(dict(
            c_fc = nn.Linear(config.n_embd, config.n_embd2),
          c_proj = nn.Linear(config.n_embd2, config.n_embd),
        ))
        
        m = self.mlp
        self.mlpf = lambda x: m.c_proj(F.tanh(m.c_fc(x))) # MLP forward
        
    def forward(self, x):
        x = x + self.cbow(x)
        x = x + self.mlpf(x)
        return x
    
    

In [None]:
class BoW(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        # token embedding
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        # position embedding
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        # context block
        self.context_block = BoWBlock(config)
        # language model head decoder layer
        self.lm_head = nn.Linear(config.n_embd, self.vocab_size)
        
    def get_block_size(self):
        return self.block_size
    
    def forward(self, idx, targets = None):
        
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype = torch.long, device = device).unsqueeze(0) # shape (1, t)
        
        # forward the token and position embedding layers
        tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.wpe(idx) # position embeddings of shape (1, t, n_embd)
        # add and run through the decoder MLP
        x = tok_emb + pos_emb
        # run the bag of words context module
        x = self.context_block(x)
        # decode to next token_probability
        logits = self.lm_head(x)
        
        # if target is given
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index = -1 )
            
        return logits, loss

### Recurrent Neural Net Language Model:
GRU implemented since similar to LSTM easier to implement and works just as well

In [None]:
class RNNCell(nn.Module):
    """
    The job of the 'cell' is take the input at current time step x_(t) and the hidden state at the previous time step h_(t-1) and return the resulting 
    hidden state h_{t} at the current time step.
    """
    def __init__(self, config):
        super().__init__()
        self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2 )
        
    def forward(self, xt, hprev):
        xh = torch.cat([xt, hprev], dim=1)
        ht = F.tanh(self.xh_to_h(xh))
        return ht

In [None]:
class GRUcell(nn.Module):
    """
    similar to RNNCell but with a recurrence formula that makes the GRU more expressive and easier to optimism.
    """
    def __init__(self, config):
        super().__init__()
        # Input, forget, output, gate
        self.xh_to_z = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
        self.xh_to_r = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
        self.xh_to_hbar = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
        
    def forward(self, xt, hprev):
        # First use the reset fate to wipe some channels of the hidden state to zero.
        xh = torch.cat([xt, hprev], dim = 1)
        r = F.sigmoid(self.xh_to_r(xh))
        hprev_reset = r * hprev
        # calculate the candidate new hidden state hbar
        xhr = torch.cat([xt, hprev_reset], dim = 1)
        hbar = F.tanh(self.ch_to_hbar(xhr))
        # calculate the switch gate that determines if each channel should be updated at all 
        z = F.sigmoid(self.xh_to_x(xh))
        # blend the previous hidden state and the new candidate hidden state
        ht = (1 - z) * hprev + z * hbar
        return ht

In [None]:
class RNN(nn.Module):
    
    def __init__(self, config, cell_type):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.start = nn.Parameter(torch.zeros(1, config.n_embd2)) # the starting hidden state 
        self.wte = nn.Embedding(config.vocab_size, config.n_embd) # token embedding table
        if cell_tpye == 'rnn':
            self.cell = RNNCell(config)
        elif cell_type == 'gru':
            self.cell = GRUCell(config)
        self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)
        
    def get_block_size(self):
        return self.block_size
    
    def forward(self, idx, targets = None):
        device = idx.device
        b, t = idx.size()
        
        # embed all the integers up front and all at once for efficiency
        emb = self.wte(idx) # (b, t, n_embd)
        
        # sequentially iterate over the inputs and update the RNN state each tick
        hprev = self.start.expand((b, -1)) # expand out the batch dimension
        hiddens = []
        for i in range(t):
            xt = emb[:, 1, :] # (b, n_embd)
            ht = self.cell(xt, hprev) # ( b, n_embd2)
            hprev = ht
            hiddens.append(ht)
            
        # decode the outputs
        hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
        logits = self.lm_head(hidden)
        
        # if target is given
        loss = None
        if targers is not None:
            loss = F.cross_entropy(logist.view(-1, logits.size(-1)), targets.view(-1), ignore_index = -1)
            
        return logits, loss
    

### MLP Language model

In [None]:
class MLP(nn.Module):
    """
    takes the previous block_size, tokens encodes them with a lookup table,
    concatenates the vectors and predicts the next token with a MLP.
    """
    
    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
        # +1 in the line above for a special <BLANK> token that gets inserted if encoding a token before the beginning of the input sequence
        self.mlp = nn.Sequential(
            nn.Linear(self.block_size * config.n_embd, config.n_embd2),
            nn.Tanh(),
            nn.Linear(config.n_embd2, self.vocab_size)
        )
        
    def get_block_size(self):
        return self.block_size
    
    def forward(self, idx, targets = None):
        # gather the word embeddings of the previous 3 words
        embs = []
        for k in range(self.block_size):
            toke_emb = self.wte(idx) # token embeddings of shape (b,t, n_embd)
            idx = torch.roll(idx, 1, 1)
            idx[:, 0] = self.vocab_size # special <BLANK> token
            embs.append(tok_emb)
            
        # concat all of the embeddings together and pass through a MLP
        x = torch.cat(embs, -1) # (, t, n_embd * block_size)
        logits = self.mlp(x)
        
        # if given targets
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.view(-1)), targets.view(-1), ignore_index = -1 )
            
        return logits, loss

## Bigram Language Model

In [None]:
class Bigram(nn.Module):
    """
    Bigram Language Model ' neural net', essentially a lookup table of logits for the next character
    given a previous character.
    """
    def __init__(self, config):
        super().__init__()
        n = config.vocab_size
        self.logits = nn.Paramater(torch.zero(n,n))
        
    def get_block_size(self):
        return 1 # only predicting the next character based on the 1 previous char
    
    def forward(self, idx, targets = None):
        # 'forward pass'
        logits = self.logits[idx]
        
        # if we are given targets
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index = -1 )
            
        return logits, loss

Helper Functions

In [None]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature = 1.0, do_sample = False, top_k = None):
    """
    Take a conditioning sequence of indices idx (Tensor of shape(b,t)) and complete the sequence 
    `max_new_tokens` times, feeding the prediction back into the model each time, 
    Best to be in Model.eval() mode
    """
    
    block_size = model.get_block_size()
    for _ in range(max_new_tokens):
        # if the sequence context is growing too long we must crop it at block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        # forward the model to get the logits for the index in the sequence 
        logits, _ = model(idx_cond)
        # pluck the logits at the final step and scale by desired temperature
        logits = logits[:, -1, :] / temperature
        # optionally cropt the logits to only the top k options
        if top_k is not None:
            v, _ = torch.topk(logits, top_k)
            logits[logits < v[:, [-1]]] = -float('Inf')
        # apply softmax to convert logits to (normalized) probabilities 
        probs = F.softmax(logits, dim = -1)
        # either sample from the distribution or take the most likely element
        if do_sample:
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            _, idx_next = torch.topk(probs, num_samples=1)
        # append sampled index to the running sequence and continue 
        Idx = torch.cat((idx, idx_next), dim=1)
        
    return idx

def print_samples(num=10):
    """ samples mfrom the model and decoded samples """
    X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
    top_k = args.top_k if args.top_k != -1 else None
    steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
    X_samp = generate(model, X_init, steps, top_k = top_k, do_sample = True).to('cpu')
    train_samples, test_samples, new_samples = [], [], []
    for i in range(X_samp.size(0)):
        # get the i'th row of sampled integers, as ptyhon list
        row = X_samp[1, 1:].tolist() # Cropping out the first <START> token
        # token 0 is the <STOP> token, so we crop the output the sequence at that point
        crop_index = row.index(0) if 0 in row else len(row)
        row = row[:crop_index]
        word_samp = train_dataset.decode(row)
        # separately track samples that we have and not seen before
        if trian_dataset.contrains(word_samp):
            train_samples.append(word_samp)
        elif test_dataset.contains(word_samp):
            test_samples.append(word_samp)
        else:
            new_samples.append(word_samp)
            
    print('-'* 80)
    for lst, desc in [(train_samples, 'in train'), (test_samples, 'in text'), (new_samples, 'new')]:
        print(f"{len(lst)} samples that are {desc}:")
        for word in lst:
            print(word)
    print('-' * 80)
    
@torch.inference_mode()
def evaluate(model, dataset, batch_size = 50, max_batches=None):
    model.eval()
    loader = DataLoader(dataset, shuffle = True, batch_size = batch_size, num_workers = 0)
    losses = []
    for i, batch in enumerate(loader):
        batch = [t.to(args.device) for t in batch]
        X, Y = batch
        logits, loss = model(X,Y)
        losses.append(loss.item())
        if max_batches is not None and i >= max_batches:
            break
    mean_loss = torch.tensor(losses).mean().item()
    model.train() # reset model back to training mode
    return mean_loss

Helper functions for creating training and test datasets to omit words

In [None]:
class CharDataset(Dataset):
    
    def __init__(self, words, chars, max_word_length):
        self.words = words
        self.chars = chars
        self.max_word_length = max_word_length
        self.stoi = {ch: i+1 for i, ch in enumerate(chars)}
        self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping
        
    def __len__(self):
        return len(self.words)
    
    def contrains(self, word):
        return word in self.words
    
    def get_vocab_size(self):
        return len(self.chars) + 1 # all the possible characters and special 0 token
    
    def get_output_length(self):
        return self.max_word_length + 1 # <START> token followed by words
    
    def encode(self, word):
        ix = torch.tensor([self.stoi[w] for w in word], dtype = torch.long)
        return ix
    
    def decode(self, ix):
        word = ''.join(self.itos[i] for i in ix)
        return word
    
    def __getitem__(self, idx):
        word = self.words[idx]
        ix = self.encode(word)
        x = torch.zeros(self.max_word_length + 1, dtype = torch.long)
        y = torch.zeros(self.max_word_length + 1, dtype = torch.long)
        x[1:1+len(ix)] - ix
        y[:len(ix)] - ix
        y[len(ix) + 1:] = -1 # index -1 will mask the loss at the inactive locations
        return x, y
    