<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/transformer_debugging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Crack these 10 transfromer debugging exercises and you shall pass any transformer debugging coding interview.
# Note: Each block, once ran will error out. Try not to look at the solution (commented out)

# Goal: Solve each problem in <10m

# Bug tier:
# Tier 1 - runtime bugs
# Tier 2 - structural omissions, not runtime but missing features in a GPT-like model. Includes best-practices


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# Problem 1 (Easy)
# Tier 1 errors: 2
# Tier 2 errors: 3
# Time taken: 5m

class SimpleSelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.q_lin = nn.Linear(embed_dim, embed_dim)
        self.k_lin = nn.Linear(embed_dim, embed_dim)
        self.v_lin = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x):
        Q = self.q_lin(x)
        K = self.k_lin(x)
        V = self.v_lin(x)
        # scores = torch.matmul(Q, K) # Error 1
        scores = torch.matmul(Q, K.transpose(-2, -1)) # Solution 1
        attn = F.softmax(scores / self.scale, dim=-1)
        context = torch.matmul(attn, V)
        return context

class SimpleDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.attn = SimpleSelfAttention(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = self.attn(x)
        return self.fc(x)

vocab_size = 2
embed_dim = 32
batch_size = 8
seq_len = 10
model = SimpleDecoder(vocab_size, embed_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(2):
    optimizer.zero_grad()
    output = model(inp)
    # loss = criterion(output.view(-1, vocab_size), inp.view(-1)) # Error 2
    loss = criterion(output[:, :-1, :].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Solution 2
    loss.backward()
    optimizer.step()

# Tier 2 errors:
# 1. Missing causal mask
# 2. Missing positional encodings
# 3. Missing Layernorm

In [None]:
# Problem 2 (Easy)
# Tier 1 errors: 2
# Tier 2 errors: 3
# Time taken: 8m 12s

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, f"{embed_dim} needs to be divisible by {num_heads}"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads
        self.q_lin = nn.Linear(embed_dim, embed_dim)
        self.k_lin = nn.Linear(embed_dim, embed_dim)
        self.v_lin = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, E = x.shape
        Q = self.q_lin(x).view(B, T, self.num_heads, self.d_k) # b, t, h, d_k
        K = self.k_lin(x).view(B, T, self.num_heads, self.d_k)
        V = self.v_lin(x).view(B, T, self.num_heads, self.d_k)

        Q = Q.transpose(1, 2)  # (B, H, T, d_k)
        K = K.transpose(1, 2)  # (B, H, T, d_k)
        V = V.transpose(1, 2)  # (B, H, T, d_k)

        scores = torch.matmul(Q, K.transpose(-2, -1)) # (b, h, t, d_k) x (b, h, d_k, t) = b, h, tq, tk
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, V) # b, h, tq, d_k

        context = context.transpose(1, 2).contiguous().view(B, T, E) # b, tq, h, d_k -> b, tq, E=h*d_k
        return self.fc(context)

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, heads):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, heads)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = self.attn(x)
        return self.fc(x)

vocab_size = 1000
embed_dim = 32
heads = 6 # this isn't even divisible! Change to 4 (Error 1)
# heads = 4 # Solution 1
model = Decoder(vocab_size, embed_dim, heads)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 4
seq_len = 10
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(3):
    optimizer.zero_grad()
    out = model(inp)
    loss = criterion(out.view(-1, vocab_size), inp.contiguous().view(-1)) # Error 2
    # loss = criterion(out[:, :-1, :].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Solution 2
    loss.backward()
    optimizer.step()

# Tier 2 Errors:
# 1. Missing Causal mask
# 2. Missing positional encodings
# 3. Missing Layernorm

In [None]:
# Problem 3 (Medium)
# Tier 1 bugs - 3
# Tier 2 bugs - 2
# Time taken: 10m

class PositionalDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Parameter(torch.zeros(1, max_len)) # Error 1
        # Okay, so the error is that positional embeddings is just zeros everywhere
        # [1, max_seq_len] = pos_emb
        # self.pos_emb = nn.Parameter(torch.zeros(1, max_len)) # Error 1
        # Solution #1 -> Learned positional encodings
        # self.pos_emb = nn.Embedding(max_len, embed_dim) # Learned positional embeddings

        # Solution #2 -> Fixed positional embeddings
        # div_term = 1/(10_000 ** (torch.arange(0, embed_dim // 2) / embed_dim)).unsqueeze(0) # [1, embed_dim/2]
        # positions = torch.zeros(max_len, embed_dim)
        # positions[:, 0::2] = torch.sin(torch.arange(max_len)[:, None] * div_term)
        # positions[:, 1::2] = torch.cos(torch.arange(max_len)[:, None] * div_term)
        # self.register_buffer('pos_emb', positions)

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = x + self.pos_emb[:, :x.size(1)] # Error 2
        # x = x + self.pos_emb(torch.arange(0, x.size(1))) # Solution 2 (using Learned positional encodings)
        # x = x + self.pos_emb[None, :x.size(1), :] # Solution 2 (using Fixed positional encodings)
        return self.fc(x)

vocab_size = 500
embed_dim = 32
max_len = 50
model = PositionalDecoder(vocab_size, embed_dim, max_len)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 5
seq_len = 20
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(3):
    optimizer.zero_grad()
    out = model(inp)
    loss = criterion(out.view(-1, vocab_size), inp.view(-1)) # Error 3
    # loss = criterion(out[:, :-1, :].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Solution 3
    loss.backward()
    optimizer.step()

# Tier 2 Bugs:
# 1. Missing causal mask
# 2. Missing Layernorm

In [None]:
# Problem 4 (Easy)
# Tier 1 bugs - 1
# Time taken: 4m 38s

class SimpleDecoder2(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        return self.fc(x)

vocab_size = 100
embed_dim = 16
model = SimpleDecoder2(vocab_size, embed_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 4
seq_len = 10
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(2):
    optimizer.zero_grad()
    # out = model(inp)
    # loss = criterion(out, inp) # Error 1
    # Solution 1
    # loss = criterion(out[:, 1:].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Error 1 -> target needs to be a 1-d tensor, flatten out batch_size x seq_len
    loss.backward()
    optimizer.step()


In [None]:
# Problem 5 (Easy)
# Tier 1 bugs - 1
# Time taken: 2m 56s

class TeacherForcingDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        return self.fc(x)

vocab_size = 200
embed_dim = 32
model = TeacherForcingDecoder(vocab_size, embed_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 3
seq_len = 7
inp = torch.randint(0, vocab_size, (batch_size, seq_len))

for epoch in range(3):
    optimizer.zero_grad()
    out = model(inp) # b, m, v
    loss = criterion(out.view(-1, vocab_size), inp[:,1:].view(-1)) # Error 1
    # loss = criterion(out[:, :-1, :].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Solution 1
    loss.backward()
    optimizer.step()


In [None]:
# Problem 7 (Easy)

# Tier 1 bugs - 2
# Time taken - 2m 18s

class MismatchDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        # self.embed = nn.Embedding(vocab_size, hidden_size - 2) # Error 1
        self.embed = nn.Embedding(vocab_size, hidden_size) # Solution
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        return self.fc(x)

vocab_size = 500
hidden_size = 20
model = MismatchDecoder(vocab_size, hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 2
seq_len = 5
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(2):
    optimizer.zero_grad()
    out = model(inp)
    loss = criterion(out.view(-1, vocab_size), inp.view(-1)) # Error 2
    # Solution 2
    # out = model(inp[:, :-1])
    # loss = criterion(out.view(-1, vocab_size), inp[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()


In [None]:
# Problem 8 (Medium)
# Tier 1 bugs - 2
# Tier 2 bugs - 2
# Time taken - 15m

class SlidingWindowDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, window_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)
        self.window_size = window_size

    def forward_buggy(self, x):
        B, T = x.shape
        x = self.embed(x)
        pad = self.window_size // 2
        padded = torch.cat([x.new_zeros(B, pad, x.size(2)), x, x.new_zeros(B, pad, x.size(2))], dim=1)
        windows = padded.unfold(1, self.window_size, 1)   # (B, T, window, D)

        Q = self.q(windows)
        K = self.k(windows)
        V = self.v(windows)

        attn_scores = torch.einsum('b i j d, b i k d -> b i j k', Q, K)
        attn = F.softmax(attn_scores, dim=-1)
        context = torch.einsum('b i j k, b i k d -> b i j d', attn, V)
        return self.fc(context)

    def forward_correct(self, x):
        # Correct implementation
        # The original buggy version had incorrectly applied causal mask with local attention. Much easier to just build a normal causal mask from scratch
        # that encodes for a fixed window
        B, T = x.shape
        x = self.embed(x) # b, m, d
        pad = self.window_size // 2 # 1 - causal attention with a window
        # 1- b, 1, d
        # 2- b, m, d
        # 2- b, 1, d
        # final = b, m+2, d = b, 7, d

        # x.shape = b, m, d = 2, 5, 16
        # windows.shape = b, 7
        # padded.shape = b
        # Error 2 - Should make this into a mask instead of directly passing into the model
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x)

        attn_scores = torch.einsum('b q d, b k d -> b q k', Q, K)
        mask = torch.triu(torch.ones(T, T), diagonal=1).bool() | torch.tril(torch.ones(T, T), diagonal=-(pad)).bool()
        attn_scores.masked_fill_(mask, float('-inf'))
        attn = F.softmax(attn_scores, dim=-1)
        context = torch.einsum('b q k, b k d -> b q d', attn, V)
        return self.fc(context)

    def forward(self, x):
        return self.forward_correct(x)

vocab_size = 1000
embed_dim = 16
window_size = 3
model = SlidingWindowDecoder(vocab_size, embed_dim, window_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 2
seq_len = 5
inp = torch.randint(0, vocab_size, (batch_size, seq_len))
for epoch in range(2):
    optimizer.zero_grad()
    out = model(inp)
    # loss = criterion(out.view(-1, vocab_size), inp.view(-1)) # Error 1
    loss = criterion(out[:, :-1].contiguous().view(-1, vocab_size), inp[:, 1:].contiguous().view(-1)) # Solution 1
    loss.backward()
    optimizer.step()

# Tier 2 bugs
# 1. No positional encoding
# 2. No LayerNorm

In [None]:
# Problem 9 (Easy)
# Tier 1 bugs - 1

# Time Taken: 2m 2s

class DetachDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        return self.fc(x)

vocab_size = 150
embed_dim = 16
model = DetachDecoder(vocab_size, embed_dim)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # fine to use SGD, but better to use AdamW to make use of modentum and velocity parameters
criterion = nn.CrossEntropyLoss()

batch_size = 4
seq_len = 6
inp = torch.randint(0, vocab_size, (batch_size, seq_len))

for epoch in range(3):
    optimizer.zero_grad()
    # out = model(inp)
    # loss = criterion(out.view(-1, vocab_size), inp.view(-1))
    out = model(inp[:, :-1])
    loss = criterion(out.view(-1, vocab_size), inp[:, :-1].contiguous().view(-1))
    # loss_det = loss.detach() # hmm this will get rid of gradients for the loss which we need for gradient flow. Error 1
    loss_det = loss # Solution 1. don't remove gradients from computational graph
    loss_det.backward()
    optimizer.step()


In [None]:
# Problem 10 (Medium)
# Tier 1 bugs - 2
# Tier 2 bugs - 3
# Time Taken: 5m

class GlobalQueryDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward_buggy(self, x):
        x = self.embed(x)
        global_q = self.q(x[:, :1, :])            # (B, 1, D)
        q = self.q(x)                             # (B, T, D)
        k = self.k(x)                             # (B, T, D)
        v = self.v(x)                             # (B, T, D)

        global_scores = torch.matmul(global_q, k.transpose(-2, -1))  # (B, 1, T)
        local_scores = torch.matmul(q, k.transpose(-2, -1))          # (B, T, T)

        scores = F.softmax(global_scores + local_scores, dim=-1)     # (??)
        context = torch.matmul(scores, v)
        return self.fc(context)

    def forward_correct(self, x):
        # Problem: we're broadcasting across the key dimension which means we're passing the same bias across the query dimension
        # Solution: Broadcast across the key dimension. Better to build a full-attention mask
        x = self.embed(x)
        b, m, d = x.shape

        global_q = self.q(x[:, :1, :]).expand(-1, m, -1) # (B, T, D)
        q = self.q(x)                             # (B, T, D)
        k = self.k(x)                             # (B, T, D)
        v = self.v(x)                             # (B, T, D)

        global_scores = torch.matmul(global_q, k.transpose(-2, -1))  # (B, T, T)
        local_scores = torch.matmul(q, k.transpose(-2, -1))          # (B, T, T)

        scores = F.softmax(global_scores/(d**0.5) + local_scores/(d**0.5), dim=-1)
        context = torch.matmul(scores, v)
        return self.fc(context)

    def forward(self, x):
      return self.forward_correct(x)

vocab_size = 300
embed_dim = 16
model = GlobalQueryDecoder(vocab_size, embed_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

batch_size = 2
seq_len = 5
inp = torch.randint(0, vocab_size, (batch_size, seq_len))

for epoch in range(2):
    optimizer.zero_grad()
    out = model(inp)
    # loss = criterion(out.view(-1, vocab_size), inp.view(-1)) # Error 2
    loss = criterion(out[:, 1:].contiguous().view(-1, vocab_size), inp[:, :-1].contiguous().view(-1)) # Solution 2
    loss.backward()
    optimizer.step()


# Tier 2 bugs:
# 1. Missing LayerNorm
# 2. Missing CausalMask
# 3. Missing positional encoding