<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/debugging/nano_GPT_debugging_problems.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Problem 1
# Difficulty Medium

import torch
import torch.nn as nn
import torch.optim as optim

# MultiHeadSelfAttention
# Bugs: 3
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

        nn.init.zeros_(self.q_linear.weight)
        nn.init.zeros_(self.k_linear.weight)
        nn.init.zeros_(self.v_linear.weight)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('inf'))

        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        out = self.out_linear(out)
        return out

# FeedForward
# Bugs: 1
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

# DecoderLayer
# Bugs: 2
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_out = self.self_attn(x, mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ff_out = self.ff(x)
        x = x + self.dropout(ff_out)
        x = self.norm2(x)
        return x

# PositionalEncoding
# Bugs: 1
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

# TransformerLM
# Bugs: 2
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_len=5000, dropout=0.1):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_enc = PositionalEncoding(embed_dim, max_len)
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])
        self.out_linear = nn.Linear(embed_dim, vocab_size + 1)  # Off-by-one vocab size
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) == 0
        return mask.unsqueeze(0).unsqueeze(1)  # For batch and heads

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_enc(x)
        x = self.dropout(x)
        seq_len = x.size(1)
        mask = self.generate_mask(seq_len).to(x.device)
        for layer in self.layers:
            x = layer(x, mask)
        out = self.out_linear(x)
        return out

# Data preparation
vocab_size = 100
batch_size = 32
seq_len = 20
data = torch.randint(0, vocab_size, (1000, seq_len))  # Simple random data

# Model instantiation
embed_dim = 256
num_heads = 4
num_layers = 2
ff_dim = 512
model = TransformerLM(vocab_size, embed_dim, num_heads, num_layers, ff_dim)

# Train loop
# Bugs: 2
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        inputs = batch
        targets = batch  # No shift for next-token

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size + 1), targets.view(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

In [None]:
# Problem 2
# Difficulty Medium

import torch
import torch.nn as nn
import torch.optim as optim

# MultiHeadSelfAttention
# Bugs: 2
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv_linear = nn.Linear(embed_dim, embed_dim * 3)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        qkv = self.qkv_linear(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:,:,0].transpose(1, 2), qkv[:,:,1].transpose(1, 2), qkv[:,:,2].transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            scores += mask * -1e9  # Add instead of masked_fill for broadcasting test

        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        out = self.out_linear(out)
        return out

# FeedForward
# Bugs: 1
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.linear2(self.gelu(self.linear1(x)))

# DecoderLayer
# Bugs: 2
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_out = self.self_attn(x, mask)
        x = self.norm1(attn_out)  # No residual add
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

# LearnablePositionalEncoding
# Bugs: 1
class LearnablePositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(LearnablePositionalEncoding, self).__init__()
        self.pe = nn.Parameter(torch.randn(1, max_len, embed_dim))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# TransformerLM
# Bugs: 2
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_len=5000, dropout=0.1):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_enc = LearnablePositionalEncoding(embed_dim, max_len)
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])
        self.out_linear = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len))  # No ==0, broadcasts as is
        return mask.unsqueeze(0).unsqueeze(0).expand(-1, self.layers[0].self_attn.num_heads, -1, -1)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_enc(x)
        x = self.dropout(x)
        seq_len = x.size(1)
        mask = self.generate_mask(seq_len).to(x.device)
        for layer in self.layers:
            x = layer(x, mask)
        out = self.out_linear(x)
        return out

# Data preparation
vocab_size = 50
batch_size = 16
seq_len = 15
data = torch.randint(0, vocab_size, (500, seq_len))  # Random data

# Model instantiation
embed_dim = 128
num_heads = 4
num_layers = 3
ff_dim = 256
model = TransformerLM(vocab_size, embed_dim, num_heads, num_layers, ff_dim)

# Train loop
# Bugs: 3
optimizer = optim.AdamW(model.parameters(), lr=0.0005)
criterion = nn.CrossEntropyLoss(ignore_index=-1)  # Wrong ignore

for epoch in range(5):
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        inputs = batch[:, :-1]
        targets = batch  # No shift, wrong shape

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Generation loop (autoregressive)
# Bugs: 2
def generate(model, start_token, max_len=10, device='cpu'):
    model.eval()
    input = torch.tensor([[start_token]]).to(device)
    for _ in range(max_len):
        with torch.no_grad():
            output = model(input)
            next_token = output.argmax(dim=-1)  # No [:, -1]
            input = torch.cat([input, next_token], dim=1)
    return input.squeeze(0).tolist()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
generated = generate(model, 0, seq_len, device)
print("Generated:", generated)

In [None]:
# Problem 3
# Difficult Easy

import torch
import torch.nn as nn
import torch.optim as optim

# MultiHeadSelfAttention
# Bugs: 3
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(~mask, -1e9)  # ~mask inverts, potential broadcast issue

        attn = torch.softmax(scores, dim=-2)  # Wrong dim
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, self.embed_dim)
        out = self.out_linear(out)
        return out

# FeedForward
# Bugs: 1
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

# DecoderLayer
# Bugs: 2
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, ff_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = self.norm1(x)
        attn_out = self.self_attn(x, mask)
        x = x + self.dropout(attn_out)  # Residual after norm
        x = self.norm2(x)
        ff_out = self.ff(x)
        x = x + self.dropout(ff_out)
        return x

# SinusoidalPositionalEncoding
# Bugs: 1
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / embed_dim))
        pe[:, 0::2] = torch.sin(position / div_term)  # / instead of *
        pe[:, 1::2] = torch.cos(position / div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# TransformerLM
# Bugs: 2
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim, max_len=5000, dropout=0.1):
        super(TransformerLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len)
        self.layers = nn.ModuleList([DecoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])
        self.out_linear = nn.Linear(embed_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, seq_len):
        mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)  # tril instead of triu
        return mask

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_enc(x)
        x = self.dropout(x)
        seq_len = x.size(1)
        mask = self.generate_mask(seq_len).to(x.device)
        for layer in self.layers:
            x = layer(x, mask)
        out = self.out_linear(x)
        return out

# Data preparation
vocab_size = 200
batch_size = 64
seq_len = 30
data = torch.randint(0, vocab_size, (2000, seq_len))  # Random data

# Model instantiation
embed_dim = 512
num_heads = 8
num_layers = 4
ff_dim = 1024
model = TransformerLM(vocab_size, embed_dim, num_heads, num_layers, ff_dim)

# Train loop
# Bugs: 3
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(15):
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        inputs = batch
        targets = batch[:, 1:]  # Shift but no pad for last token

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs[:, :-1].reshape(-1, vocab_size), targets.view(-1))  # reshape instead of contiguous.view
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Generation loop (autoregressive)
# Bugs: 2
def generate(model, start_token, max_len=20, device='cpu'):
    model.eval()
    input = torch.tensor([start_token]).unsqueeze(0).to(device)  # No batch dim properly
    for _ in range(max_len):
        with torch.no_grad():
            output = model(input)
            next_token = output[:, -1, :].argmax(dim=-1).unsqueeze(1)
            input = torch.cat([input, next_token], dim=1)
    return input.squeeze(0).tolist()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
generated = generate(model, torch.tensor([0]), seq_len, device)
print("Generated:", generated)

In [None]:
# Problem 4
# Difficult Hard

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.qkv.weight)
        nn.init.constant_(self.out.bias, 0.)

    def forward(self, x, mask=None, return_attn=False):
        B, T, C = x.shape

        qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        if mask is not None:
            mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = attn @ v
        out = out.transpose(1, 2).reshape(B, T, C)
        out = self.out(out)

        if return_attn:
            return out, attn
        return out

class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.ff = FeedForward(embed_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        residual = x
        x = self.norm1(x)
        x = self.attn(x, mask)
        x = self.dropout(x) + residual

        residual = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.dropout(x) + residual

        return x

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000, base=10000):
        super().__init__()
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() *
                           (-math.log(base) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term[:pe[:, 1::2].shape[1]])

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x * math.sqrt(x.size(-1)) + self.pe[:, :x.size(1)]

class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, ff_dim,
                 max_len=512, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_enc = SinusoidalPositionalEncoding(embed_dim, max_len)
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size, bias=False)

        self.register_buffer('mask_cache', torch.empty(0))
        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.fc.weight, mean=0.0, std=0.02)

    def _generate_square_subsequent_mask(self, sz):
        if self.mask_cache.size(0) >= sz:
            return self.mask_cache[:sz, :sz]

        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        mask = mask.masked_fill(mask == 0, float(0.0))
        self.mask_cache = mask
        return mask

    def forward(self, x, mask=None):
        seq_len = x.size(1)

        if mask is None:
            device = x.device
            mask = self._generate_square_subsequent_mask(seq_len).to(device)

        x = self.embed(x) * math.sqrt(self.embed_dim)
        x = self.pos_enc(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        x = self.fc(x)

        return x

def create_batch_mask(lengths, max_len):
    batch_size = len(lengths)
    mask = torch.zeros(batch_size, max_len, dtype=torch.bool)
    for i, length in enumerate(lengths):
        mask[i, :length] = 1
    return mask

vocab_size = 1000
batch_size = 32
seq_len = 128
num_epochs = 10

data = torch.randint(0, vocab_size, (500, seq_len))
lengths = torch.randint(seq_len//2, seq_len, (500,))

model = TransformerLM(vocab_size, 256, 8, 4, 1024, max_len=256, dropout=0.1)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0

    for i in range(0, len(data), batch_size):
        batch_data = data[i:min(i+batch_size, len(data))]
        batch_lengths = lengths[i:min(i+batch_size, len(lengths))]

        inputs = batch_data.to(device)
        targets = inputs.clone()
        targets[:, :-1] = inputs[:, 1:]
        targets[:, -1] = -100

        for j, length in enumerate(batch_lengths):
            targets[j, length:] = -100

        optimizer.zero_grad()
        outputs = model(inputs)

        loss = criterion(outputs.reshape(-1, vocab_size), targets.reshape(-1))
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        num_batches += 1

    scheduler.step()
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

@torch.no_grad()
def generate(model, start_tokens, max_len=50, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device
    tokens = start_tokens.to(device)

    for _ in range(max_len):
        outputs = model(tokens)
        logits = outputs[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1)
        tokens = torch.cat([tokens, next_token], dim=1)

        if next_token.item() == 0:
            break

    return tokens

start = torch.randint(1, vocab_size, (1, 1))
generated = generate(model, start, max_len=30)
print(f"Generated: {generated.tolist()[0]}")

In [None]:
# Problem 5
# Difficulty Easy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# MHA
# Bugs: 3
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim=embed_dim; self.num_heads=num_heads
        self.head_dim=embed_dim//num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim*3)
        self.proj = nn.Linear(embed_dim, embed_dim)
    def forward(self,x,mask=None):
        B,T,C=x.shape
        qkv=self.qkv(x).view(B,T,3,self.num_heads,self.head_dim)
        q,k,v=qkv[:,:,0].transpose(1,2), qkv[:,:,1].transpose(1,2), qkv[:,:,2].transpose(1,2)
        scores=q@k.transpose(-2,-1)/(self.head_dim**0.5)
        if mask is not None:
            scores = scores.masked_fill(mask==0,0)   # should be -inf
        attn=F.softmax(scores,dim=-2)   # wrong dim
        out=attn@v
        out=out.transpose(1,2).reshape(B,T,C)
        return self.proj(out)

# FFN
# Bugs: 1
class FeedForward(nn.Module):
    def __init__(self,embed_dim,ff_dim):
        super().__init__()
        self.l1=nn.Linear(embed_dim,ff_dim)
        self.l2=nn.Linear(ff_dim,embed_dim)
    def forward(self,x):
        return self.l2(self.l1(x))   # missing nonlinearity

# Decoder
# Bugs: 2
class DecoderLayer(nn.Module):
    def __init__(self,embed_dim,num_heads,ff_dim,dropout=0.1):
        super().__init__()
        self.attn=MultiHeadSelfAttention(embed_dim,num_heads)
        self.ff=FeedForward(embed_dim,ff_dim)
        self.norm1=nn.LayerNorm(embed_dim)
        self.norm2=nn.LayerNorm(embed_dim)
        self.drop=nn.Dropout(dropout)
    def forward(self,x,mask):
        x=self.norm1(x)    # pre-norm but missing residual
        x=self.attn(x,mask)+x   # residual in wrong spot
        ff=self.ff(x)
        x=self.norm2(x)+self.drop(ff)   # wrong order
        return x

# Sinusoidal PE
# Bugs: 1
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self,embed_dim,max_len=5000):
        super().__init__()
        pe=torch.zeros(max_len,embed_dim)
        pos=torch.arange(0,max_len).unsqueeze(1).float()
        div=torch.exp(torch.arange(0,embed_dim,2).float()*-(torch.log(torch.tensor(10000.0))/embed_dim))
        pe[:,0::2]=torch.sin(pos/div)   # wrong formula (should be pos*div)
        pe[:,1::2]=torch.cos(pos/div)
        pe=pe.unsqueeze(0)
        self.register_buffer("pe",pe)
    def forward(self,x): return x+self.pe[:,:x.size(1)]

# TransformerLM
# Bugs: 2
class TransformerLM(nn.Module):
    def __init__(self,vocab_size,embed_dim,num_heads,num_layers,ff_dim,max_len=512):
        super().__init__()
        self.embed=nn.Embedding(vocab_size,embed_dim)
        self.pos=SinusoidalPositionalEncoding(embed_dim,max_len)
        self.layers=nn.ModuleList([DecoderLayer(embed_dim,num_heads,ff_dim) for _ in range(num_layers)])
        self.out=nn.Linear(embed_dim,vocab_size)
    def mask(self,T):
        return torch.tril(torch.ones(T,T))==1   # correct direction but bool only
    def forward(self,x):
        x=self.embed(x)
        x=self.pos(x)
        m=self.mask(x.size(1)).to(x.device)
        for l in self.layers: x=l(x,m)
        return self.out(x)

# Data & training
vocab=80; seq=25; batch=16
data=torch.randint(0,vocab,(400,seq))
model=TransformerLM(vocab,128,4,2,256)
opt=optim.Adam(model.parameters(),lr=1e-3)
loss_fn=nn.CrossEntropyLoss()

for e in range(3):
    for i in range(0,len(data),batch):
        inp=data[i:i+batch]
        tgt=inp[:,1:]   # shift but input not shifted
        out=model(inp)
        loss=loss_fn(out[:,:-1].reshape(-1,vocab),tgt.reshape(-1))
        opt.zero_grad(); loss.backward(); opt.step()
    print(e,loss.item())

# Generation
def generate(model,start,max_len=15):
    model.eval(); x=torch.tensor([[start]])
    for _ in range(max_len):
        with torch.no_grad():
            out=model(x)
            nxt=out.argmax(-1)   # no [:,-1]
            x=torch.cat([x,nxt],dim=1)
    return x
print("Sample:",generate(model,0,10))