In [164]:
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

from torch.utils.data import DataLoader
import torchtext.datasets as datasets

import tiktoken

In [179]:
device = "cuda"

In [165]:
class MultiHeadAttention(nn.Module):

    def __init__(self, context_len, embedding_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embedding_dim % num_heads == 0
        
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.dropout = dropout
        
        self.kqv_proj = nn.Linear(embedding_dim, 3*embedding_dim)
        self.out_proj = nn.Linear(embedding_dim, embedding_dim)
        
        self.attn_dropout = nn.Dropout(self.dropout)
        self.resid_dropout = nn.Dropout(self.dropout)
        
        self.flash = hasattr(F, 'scaled_dot_product_attention')
        self.attn_mask = None
    
    def get_attn_mask(self):
        if self.attn_mask is None:
            # We register a buffer to not store this mask as a model
            # parameter and thus not update it while training!
            self.attn_mask = self.register_buffer(
                "mask",
                torch.ones(context_len, context_len)
                    .view(1, 1, context_len, context_len)
            )
        
        return self.attn_mask
    
    def forward(self, x):
        # batch size, sequence length, embedding dimensionality
        B, T, C = x.size()
        
        # we get the k, q, v projection of each embedding, each
        # matrix will have dimension (B, T, C)
        k, q, v = self.kqv_proj(x).split(self.embedding_dim, dim=2)
        
        # next we split the projected embeddings across the number
        # of heads we have, allowing each head to gain a different
        # interpretation.
        # (B, num_heads, T, head_size)
        head_size = C // self.num_heads
        k = k.view(B, T, self.num_heads, head_size).transpose(1, 2)
        q = q.view(B, T, self.num_heads, head_size).transpose(1, 2)
        v = v.view(B, T, self.num_heads, head_size).transpose(1, 2)
        
        mask = self.get_attn_mask()
        
        if self.flash:
            out = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask = mask,
                dropout_p = self.dropout if self.training else 0.0
            )
        else:
            # (B, num_heads, T, head_size) x (B, num_heads, head_size, T) -> (B, num_heads, T, T)
            attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            
            # attend to only past tokens by masking out future tokens
            attn = att.masked_fill(mask[:, :, :T, :T] == 0, float('-inf'))
            
            attn = F.softmax(attn, dim = -1)
            attn = self.attn_dropout(attn)
            
            # (B, num_heads, T, T) x (B, num_heads, T, head_size) -> (B, num_heads, T, head_size)
            out = attn @ v
        
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.resid_dropout(self.out_proj(out))
        
        return out


In [166]:
class CausalMultiHeadAttention(MultiHeadAttention):
    
    def __init__(self, context_len, embedding_dim, num_heads, dropout=0.1):
        super().__init__(context_len, embedding_dim, num_heads, dropout)
        
        # Causal attention allows tokens to attend to only
        # previous tokens, token t_i can also look at
        # tokens t_0:i-1
        self.causal_attn_mask = self.register_buffer(
            "causal_mask",
            torch.ones(context_len, context_len)
                .tril()
                .view(1, 1, context_len, context_len)
        )
    
    def get_attn_mask(self):
        return self.causal_attn_mask
    

In [167]:
class FFN(nn.Module):
    
    def __init__(self, embedding_dim, ff_dim, dropout=0.10):
        super().__init__()
        
        self.inter_rep = nn.Linear(embedding_dim, ff_dim)
        self.out_proj = nn.Linear(ff_dim, embedding_dim)
        
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.inter_rep(x)
        x = self.out_proj(x)
        x = self.gelu(x)
        x = self.dropout(x)
        
        return x


In [168]:
class Block(nn.Module):
    
    def __init__(self, embedding_dim, attn_mechanism, ffn, dropout=0.1):
        super().__init__()
        
        self.ln_1 = nn.LayerNorm(embedding_dim)
        self.attn = attn_mechanism
        
        self.ln_2 = nn.LayerNorm(embedding_dim)
        self.ffn = ffn
    
    def forward(self, x):
        x = self.ln_1(x)
        x = x + self.attn(x)
        
        x = self.ln_2(x)
        x = x + self.ffn(x)
        
        return x


In [222]:
class GPT(nn.Module):
    CONTEXT_LENGTH = 1024
    VOCAB_SIZE = 50304
    EMBEDDING_DIM = 768
    INTER_DIM = 2048
    NUM_HEADS = 12
    NUM_LAYERS = 12
    DROPOUT = 0.0
    
    def __init__(self):
        super().__init__()
        
        blocks = nn.ModuleList([
            Block(
                self.EMBEDDING_DIM,
                CausalMultiHeadAttention(
                    self.CONTEXT_LENGTH,
                    self.EMBEDDING_DIM,
                    self.NUM_HEADS,
                    self.DROPOUT
                ),
                FFN(self.EMBEDDING_DIM, self.INTER_DIM, self.DROPOUT),
                self.DROPOUT
            ) for _ in range(self.NUM_LAYERS)
        ])
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(self.VOCAB_SIZE, self.EMBEDDING_DIM),
            wpe = nn.Embedding(self.CONTEXT_LENGTH, self.EMBEDDING_DIM),
            dropout = nn.Dropout(self.DROPOUT),
            blocks = blocks,
            ln = nn.LayerNorm(self.EMBEDDING_DIM)
        ))
        self.lm_head = nn.Linear(self.EMBEDDING_DIM, self.VOCAB_SIZE, bias = False)
        
        self.transformer.wte.weight = self.lm_head.weight
        
        self.apply(self._init_weights)
        for param_name, params in self.named_parameters():
            if param_name.endswith('out_proj.weight'):
                torch.nn.init.normal_(
                    params,
                    mean = 0.0,
                    std = 0.02/math.sqrt(2 * self.NUM_LAYERS)
                )
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, input_ids, targets = None):
        device = input_ids.device
        
        B, T = input_ids.size()
        assert T <= self.CONTEXT_LENGTH
        
        pos = torch.arange(0, T, dtype=torch.long, device=device)
        
        token_embeddings = self.transformer.wte(input_ids)
        pos_embeddings = self.transformer.wpe(pos)
        
        # (B, T, embedding_dim)
        x = token_embeddings + pos_embeddings
        x = self.transformer.dropout(x)
        
        for block in self.transformer.blocks:
            x = block(x)
        x = self.transformer.ln(x)
        
        if targets is not None:
            # (B, T, vocab_size)
            logits = self.lm_head(x)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index = -1
            )
        else:
            # (B, 1, vocab_size)
            logits = self.lm_head(x[:, [-1], :])
            loss = None
        
        return logits, loss
    
    def configure_optimizer(self, weight_decay, learning_rate, betas, device_type):
        params = {param_name : param for param_name, param in self.named_parameters()}
        
        decay_params = []
        nodecay_params = []
        for param in params.values():
            if param.dim() >= 2:
                decay_params.append(param)
            else:
                nodecay_params.append(param)
                
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0}
        ]
        
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        
        return optimizer
    
    @torch.no_grad()
    def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
        
        for _ in range(max_new_tokens):
            context_window = input_ids
            if context_window.size(1) > self.CONTEXT_LENGTH:
                context_window = input_ids[:, -self.CONTEXT_LENGTH:]
            
            logits, _ = self(context_window)
            logits = logits[:, -1, :] / temperature
            
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            
            input_ids = torch.cat((input_ids, next_id), dim=1)
        
        return input_ids

In [223]:
model = GPT()

model.eval()
model.to(device)

enc = tiktoken.get_encoding("gpt2")

In [224]:
text = "hello my name is not"
input_ids = enc.encode(text)
input_ids

[31373, 616, 1438, 318, 407]

In [233]:
x = torch.tensor(input_ids, dtype=torch.long, device=device).view(1,-1)
out = model.generate(x, 2, 0.6, 10)

In [234]:
# gibberish since the model is not trained
enc.decode(out.flatten().tolist())

'hello my name is not not not'