In [None]:
import os
import sys
import time 
import argparse
from dataclasses import dataclass
from typing import List

import torch
import torch.nn as nn
from torch.nn import Functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
@dataclass
class ModelConfig:
    
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size - 1]
    # parameters below control the size of each model slightly differently 
    
    n_layers: int = 4
    n_embd: int = 64
    n_embd2: int = 64
    n_head: int = 4

#### Transformer Language Model as used in GPT-2

In [None]:
class NewGELU(nn.Module):
    
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Gaussian Error Linear Units (GELU):    https://arxiv.org/abs/1606.08415
    """
    
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0/ math.pi ) + (x + 0.044715 * torch.pow(x, 3.0)) ))

In [None]:
class CausalSelfAttention(nn.Module):
    """
    A simple multi-head masked self-attention layer with a projection at the end. 
    
    Similar to torch.nn.MultiheadAttention
    """
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projection for all head, but in a batch
        self.c_attn == nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1,1, config.block_size, config.block_size ))
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim.
        q, k , v = self.c_attn(x).split(self.n_embd, dim = 2)
        k = k.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head ).transpose(1,2) # (B, nh, T, hs)
        
        # causal self.attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v # (B, nh, T, T) x (B ,nh ,T ,hs) -> (B, nh, T, hs)
        y = y.transpose(1,2).contiguous().view(B,T,C) # re-assemble will head outputs side by side
        
        # output projection
        y = self.c_proj(y)
        return y