# Ch 4 - Simple GPT
This chapter focuses on implementing a mini-GPT NN.

## GPT-2 skeleton

In [19]:
import torch
import torch.nn as nn

In [20]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,     # Vocabulary size
    "context_length": 1024,  # Context length
    "emb_dim": 768, # Number of embedding dimensions
    "n_heads": 12, # Number of attention heads
    "n_layers": 12, # Number of layers
    "drop_rate": 0.1, # Dropout rate
    "qkv_bias": False# Query-Key-Value bias
}

In [21]:
class MultiHeadAttention(nn.Module):
    """Implementation of multihead attention w/ parallel matrix processing"""

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # validate input dimensions
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.attention_dim = d_out // num_heads

        # setup attention matrices
        self.W_q = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_k = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_v = nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

        # setup dropout
        self.dropout = nn.Dropout(dropout)

    
    def forward(self, x):
        n, seq_length, _ = x.shape

        # compute Q, K, V matrices
        x_query = self.W_q(x)
        x_key = self.W_k(x)
        x_value = self.W_v(x)

        # reshape to separate into Q = [Q1, Q2, ...], K = [K1, K2, ...]
        x_query = x_query.view(n, seq_length, self.num_heads, self.attention_dim)
        x_query = x_query.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)

        x_key = x_key.view(n, seq_length, self.num_heads, self.attention_dim)
        x_key = x_key.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)
        x_key = x_key.transpose(2, 3) # (n, num_heads, attention_dim, seq_length)

        x_value = x_value.view(n, seq_length, self.num_heads, self.attention_dim)
        x_value = x_value.transpose(1, 2) # (n, num_heads, seq_length, attention_dim)

        # compute attention scores (per-head)
        dk_constant = x_key.shape[-1] ** -0.5
        mask_context = self.mask.bool()[:seq_length, :seq_length] 
        attention_scores = (x_query @ x_key)
        attention_scores.masked_fill_(mask_context, -torch.inf)

        # compute attention weights 
        # note : no dropout on scores (b/c dropout on -inf is not well-defined)
        attention_weights = torch.softmax(attention_scores * dk_constant, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # compute context
        context = attention_weights @ x_value

        # reshape back to (n, seq_length, d_out)
        context = context.contiguous().view(n, seq_length, self.d_out)

        return context


In [22]:
import torch
import torch.nn as nn

class DummyGPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(
            *[DummyTransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)

    def forward(self, in_idx):

        # read input X
        batch_size, seq_len = in_idx.shape

        # map X to embedding space
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds

        # apply dropout
        x = self.drop_emb(x)

        # apply transformer blocks
        x = self.trf_blocks(x)

        # apply layer norm
        x = self.final_norm(x)

        # apply FF network 
        logits = self.out_head(x)
        return logits

class DummyTransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
    def forward(self, x):
        return x

class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift
    
class GELU(nn.Module):
    """Activation function that's smoother than RELU"""
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))
        ))
    
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
    def forward(self, x):
        return self.layers(x)
    
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        return x


# Examples - Scratchpad 
A list of tiny examples to demonstrate how each class works

In [23]:
# example : tokenizing text
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")

batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"
batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)
print("shape", batch.shape)
print(batch)

shape torch.Size([2, 4])
tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])


In [24]:
# example : running the dummy network
torch.manual_seed(123)
model = DummyGPTModel(GPT_CONFIG_124M)
logits = model(batch)
print("Output shape:", logits.shape)
print(logits)

Output shape: torch.Size([2, 4, 50257])
tensor([[[-0.5998,  0.1788, -0.4855,  ..., -1.0338,  0.1756, -0.3791],
         [-0.3048,  0.1200,  0.3618,  ..., -0.2633,  1.0311,  0.5753],
         [ 0.4025,  1.0698, -0.1543,  ...,  0.7827,  0.0786,  0.5037],
         [ 0.0938,  1.6024, -0.6251,  ...,  0.8712, -0.1122, -0.3924]],

        [[-0.9718, -0.0670, -0.6994,  ..., -1.1791, -0.3173, -0.4363],
         [-0.6117,  0.6015, -0.0475,  ..., -0.0795,  0.1833,  0.7967],
         [ 0.1158,  0.7813, -0.1099,  ...,  0.0522, -0.1281,  0.1061],
         [ 0.1734, -0.6275,  0.1344,  ...,  1.7331, -0.2619, -0.2139]]],
       grad_fn=<UnsafeViewBackward0>)


In [25]:
class ExampleDeepNeuralNetwork(nn.Module):
    def __init__(self, layer_sizes, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())
        ])

    def forward(self, x):
        for layer in self.layers:

            # run layer 
            layer_output = layer(x)

            # if skip connection, then you sum the original (x) + layer-output and feed to next network
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output

            # if no skip connection, then just layer_output = x, the input to next layer
            else:
                x = layer_output

        return x


In [26]:
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1., 0., -1.]])
torch.manual_seed(123)

# Model without shortcut
model_without_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=False)
model_with_shortcut = ExampleDeepNeuralNetwork(layer_sizes, use_shortcut=True)

# Function to compute gradients in the model's backward pass
def print_gradients(model, x):
    output = model(x)  # Forward pass
    target = torch.tensor([[0.]])
    loss_fn = nn.MSELoss()
    loss = loss_fn(output, target)  # Calculates loss based on how close the target and output are
    loss.backward()  # Backward pass to calculate the gradients

    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

# Call the function to print gradients
print("--model w/out shortcut--")
print_gradients(model_without_shortcut, sample_input)
print("\n\n--model w/ shortcut--")
print_gradients(model_with_shortcut, sample_input)


--model w/out shortcut--
layers.0.0.weight has gradient mean of 0.00020173587836325169
layers.1.0.weight has gradient mean of 0.0001201116101583466
layers.2.0.weight has gradient mean of 0.0007152041653171182
layers.3.0.weight has gradient mean of 0.001398873864673078
layers.4.0.weight has gradient mean of 0.005049646366387606


--model w/ shortcut--
layers.0.0.weight has gradient mean of 0.0014432319439947605
layers.1.0.weight has gradient mean of 0.004846962168812752
layers.2.0.weight has gradient mean of 0.0041389018297195435
layers.3.0.weight has gradient mean of 0.00591512955725193
layers.4.0.weight has gradient mean of 0.03265950828790665


In [28]:
torch.manual_seed(123)
x = torch.rand(2, 4, 768)
block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])
