In [None]:
import os
import sys
import time 
import argparse
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
from torch.nn import Functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
@dataclass
class ModelConfig:
    
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size - 1]
    # parameters below control the size of each model slightly differently 
    
    n_layers: int = 4
    n_embd: int = 64
    n_embd2: int = 64
    n_head: int = 4

#### Transformer Language Model as used in GPT-2

In [None]:
class NewGELU(nn.Module):
    
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Gaussian Error Linear Units (GELU):    https://arxiv.org/abs/1606.08415
    """
    
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0/ math.pi ) + (x + 0.044715 * torch.pow(x, 3.0)) ))

In [None]:
class CausalSelfAttention(nn.Module):
    """
    A simple multi-head masked self-attention layer with a projection at the end. 
    
    Similar to torch.nn.MultiheadAttention
    """
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projection for all head, but in a batch
        self.c_attn == nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1, config.block_size, config.block_size ))
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim.
        q, k , v = self.c_attn(x).split(self.n_embd, dim = 2)
        k = k.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        
        # causal self.attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B ,nh ,T ,hs) -> (B, nh, T, hs)
        y = y.transpose(1,2).contiguous().view(B,T,C) # re-assemble will head outputs side by side
        
        # output projection
        y = self.c_proj(y)
        return y

In [None]:
class Block(nn.Module):
    """ Unassuming Transformer Block """
    
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.ModuleDict(dict(
            c_fc = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
            act = NewGELU(),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))
        
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x
    

In [None]:
class Transformer(nn.Module):
    """ Transformer Language Model, similar to GPT-2 """
    
    def __init__(self, config):
        super().__init__()
        self.black_size = config.block_size
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias = False)
        
        # Report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.paramters())
        print(f"number of paramters: {n_params/1e6 : .2f}")
        
    def get_block_size(self):
        return self.block_size
    
    def forward(self, idx, targets = None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1,t)
        
        # forward the GPT model 
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = tok_emb + pos_emb
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        
        # If we are given some desired targets so calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index = -1 )
            
        return logits, loss

#### Bag of Words (BOW) language Model

In [None]:
class CausalBoW(nn.Module):
    """
    Causal bag of words. Averages the preceding elements and looks suspiciously like a CausalAttention module found in a transformer.
    """
    def __init__(self, config):
        super().__init__()
        
        
        # used to mask out vectors and preserve autoregressive property
        self.block_size = config.block_size
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(t,
                                                                                                      config.block_size, config.block_size))
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, n_embd
        
        # do the weighted average of all preceeding token features
        att = torch.zeroes((B,T,T), device = x.device)
        att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim = -1)
        y = att @ x # (B,T,T) x (B, T, C) -> (B,T,C)
        
        return y