## Simple GPT2 Impelmentation

This notebook implements a simple GPT2 model based on the archicture from: 
    
Improving Language Understanding by Generative Pre-Training (OpenAI)

Also useful for tranformers: 

Attention Is All You Need (Google Brain)

Weight tying between the embeddings and softmax layer:

Using the Output Embedding to Improve Language Models


In [45]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from sklearn.model_selection import train_test_split
from typing import Tuple

torch.manual_seed(1)

<torch._C.Generator at 0x11294fd8290>

In [211]:
@dataclass
class GPTConfig:
    # Default GPT-2 hyperparameters
    context_length: int = 1024
    vocab_size: int = 50304
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.1
    bias: bool = False
        
class Text_Handler:
    
    def __init__(self, file_name, context_length):
        self.context_length = context_length 
        assert self.context_length != 0
        self.raw_text = self.load_text(file_name)
        self.vocab_size, self.stoi, self.itos = self.pre_process_text()
        self.encode = lambda s: [self.stoi[c] for c in s]
        self.decode = lambda l: "".join([self.itos[i] for i in l])
        
        # Load, split without shuffle and make context length exmaples.
        self.text = torch.tensor(self.encode(self.raw_text), dtype=torch.long)
        self.train_text, self.test_text = train_test_split(self.text, test_size=0.2, shuffle=False)
        self.X_train, self.Y_train = self.make_examples(self.train_text)
        self.X_test, self.Y_test = self.make_examples(self.test_text)
                
        
    def load_text(self, file_name): 
        try:
            with open(file_name, "r") as f:
                return f.read()
        except FileNotFoundError:
            raise FileNotFoundError(f"File {file_name} does not exist within the working directory.")
            
    def pre_process_text(self):
        chars = sorted(list(set(self.raw_text)))
        vocab_size = len(chars)
        stoi = {char: i for i, char in enumerate(chars)}
        itos = {i: char for char, i in stoi.items()}
        return vocab_size, stoi, itos
    
    def make_examples(self, text: torch.Tensor, verbose=False) -> Tuple[torch.Tensor, torch.Tensor]:
        n_examples = text.shape[0] - self.context_length
        X = torch.empty(n_examples, self.context_length, dtype=torch.long)
        Y = torch.empty(n_examples, dtype=torch.long)

        for i in range(n_examples):
            X[i] = text[i: i + self.context_length]
            Y[i] = text[i + self.context_length]
            if verbose: print(f"Example {i+1:2d}: {X[i].tolist()} --> {Y[i].item()}")

        return X, Y


In [267]:
# GPT2 implementation

class Causal_Self_Attention(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        # Make sure we can project the embeddings across the attention heads:
        assert config.n_embd % config.n_head == 0
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # Layer to generate keys, queries, values in one pass:
        self.attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.attn_dropout = nn.Dropout(config.dropout)
        
        # Make sure the causal mask is saved to the state dict and is moved with the model to the GPU. 
        self.register_buffer("causal_mask", torch.tril(torch.ones(config.context_length, config.context_length))
                                    .view(1, 1, config.context_length, config.context_length))
        
        # Final lienar transfromation after multi-head attention
        self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.proj_dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):

        # batch size = number of graphs processed at once 
        # context legnth = number of nodes in the graph
        # channels = embedding dimensionality of each node = attention_component_size
        n_graphs, n_nodes, attention_component_size = x.size() 
        
        # Calculate the keys, queries and values of all the nodes in each graph using the nodes' embeddings.
        # Each component q,k,v remains the same size as the embedding dimensionality: n_embd = size_components q,k,v.
        queries, keys, values  = self.attn(x).split(self.n_embd, dim=2) # Each have: (n_graphs, n_nodes, size_components=n_embd)
        
        # Split each node and its attention components across the heads (= project the The "key", "query", and "value"s across the heads):
    
        # (n_graphs, n_nodes, size_components) -> (n_graphs, n_nodes, n_heads, size_components/n_heads = split_information)^T(1,2) -> (n_graphs, n_heads, n_nodes, split_information)
        queries = queries.view(n_graphs, n_nodes, self.n_head, attention_component_size // self.n_head).transpose(1, 2)
        keys = keys.view(n_graphs, n_nodes, self.n_head, attention_component_size // self.n_head).transpose(1, 2)
        values = values.view(n_graphs, n_nodes, self.n_head, attention_component_size // self.n_head).transpose(1, 2)
        
        
        # Implement scaled dot product attention for graph in each attention head. 
        # i.e. how all the nodes talk to eachother when their info is split across the heads: (attention_component_size // self.n_head). 
        # This "split_information" is the dimension which is dot producted over in each graph, leaving: 
        # --> n_heads * n_nodes * n_nodes ! i.e. in each head you have n_nodes talking to n_nodes. 
        
    
        # For each graph in each head: (n_nodes, split_information) * (n_nodes, split_information)^T
        # (n_graphs, n_heads, n_nodes, split_information) * (n_graphs, n_heads, split_information, n_nodes) -> (n_graphs, n_heads, n_nodes, n_nodes)
        # Normalise by the size of the split_information.
        att = (queries @ keys.transpose(-2, -1)) * (1.0 / math.sqrt(keys.size(-1)))
        
        
        # Apply the causal mask to each graph.  
        # The indexing here allows us to handle sequences shorter than the full context length. 
        att = att.masked_fill(self.causal_mask[:,:,:n_nodes,:n_nodes] == 0, float('-inf'))
        
        # Softmax across how each node talks to the nodes causally connected to it.
        # For each node this leaves a "fraction of my attention I will pay to each other nodes' value vector" for each node in the graph.
        attention_fractions = self.attn_dropout(F.softmax(att, dim=-1))
        
        attention_weighted_value_vectors = attention_fractions @ values
        
        # Recombine all the heads' attention_weighted_value_vectors.
        # Ensure this is in contigious memory.
        attention_weighted_value_vectors = attention_weighted_value_vectors.transpose(1, 2).contiguous().view(n_graphs, n_nodes, attention_component_size)
        
        # Apply the final linear transforamtion to leave the output with the same dimensions as the input for a residual connection.
        out = self.proj_dropout(self.proj(attention_weighted_value_vectors))
        
        return out

In [None]:
class Multi_Layer_Perceptron(nn.Module):

    def __init__(self, config):
        super().__init__()
        # 4X step up as in the paper, and GeLu to stop dead neurons.
        self.fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.nonlin = nn.GELU()

    def forward(self, x):
        x = self.fc(x)
        x = self.nonlin(x)
        x = self.proj(x)
        return x

In [None]:
class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        # Normalise before and after the self attention as in the paper. 
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = Causal_Self_Attention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = Multi_Layer_Perceptron(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
class GPT(nn.Module):
    
    def __inti__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            token_embd = nn.Embedding(config.vocab_size, config.n_embd),
            positional_embd = nn.Embedding(config.n_embd, config.n_embd),
            attenetion_stack = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln = nn.LayerNorm(config.n_embd)))
        
        output_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # The output head and the token embeddings have the same dimension.
        # So both learn vector representaions of tokens.
        # Paper: "We study the topmost weight matrix of neural network language models. 
        # We show that this matrix constitutes a validword embedding."
        # PyTorch accumululates (adds) the gradients of both together in a backwards pass.
        self.transformer.positional_embd.weight = self.output_head.weight
        
        self.apply(self._init_weights)
        
        # Apply special scaled init to the residual projections, per GPT-2 paper
        for name, param in self.named_parameters():
            if name.endswith('proj.weight'):
                torch.nn.init.normal_(param, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
        
        # Print the number of parameters
        print("number of parameters: %d" % (sum(p.nelement() for p in self.parameters()),))
    
    def _init_weights(self):
        # "Since layernorm [2] is used extensively throughout the model, a simple weight initialization of 
        # N(0, 0.02) was sufficient."
        
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(moudle.bias)
                
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                
    def forward(self, t):
        pass
    

In [153]:
context_length = 5 # or number of nodes in the communication graph 
d = Text_Handler("input_shakespeare.txt", context_length)

In [268]:
config = GPTConfig(
    context_length = context_length,
    vocab_size = d.vocab_size,
    n_layer = 4,
    n_head = 3,
    n_embd = 6,
    bias = False)

In [269]:
c  = Causal_Self_Attention(config)
n_batch = 4
B, T, C = n_batch, context_length, config.n_embd # batch, time (context legnth), channel (embedding size)
x = torch.randn(B,T,C)