# BERT using PyTorch

Note: The foundation/background modules are implemented in the transformer here: [transformer-from-scratch.ipynb](https://github.com/aayush4vedi/ml-papers-to-code/blob/main/transformer-from-scratch.ipynb)

In [101]:
# 🧠 BERT from Scratch (Encoder-only Transformer + MLM Training)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from torch.utils.data import Dataset, DataLoader

# ========== Vocabulary ==========

PAD_TOKEN = 0
MASK_TOKEN = 1
CLS_TOKEN = 2
SEP_TOKEN = 3
SPECIAL_TOKENS = ["<pad>", "<mask>", "<cls>", "<sep>"]

class Vocab:
    def __init__(self):
        self.token2id = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
        self.id2token = {i: tok for tok, i in self.token2id.items()}
        self.next_id = len(self.token2id)

    def add_sentence(self, sentence):
        for word in sentence.split():
            if word not in self.token2id:
                self.token2id[word] = self.next_id
                self.id2token[self.next_id] = word
                self.next_id += 1

    def encode(self, sentence):
        tokens = [self.token2id.get(w, PAD_TOKEN) for w in sentence.split()]
        return [CLS_TOKEN] + tokens + [SEP_TOKEN]

    def decode(self, ids):
        return " ".join([self.id2token.get(i, "<unk>") for i in ids])

    def __len__(self):
        return self.next_id

# ========== MLM Dataset ==========

def mask_tokens(tokens, mask_token=MASK_TOKEN, pad_token=PAD_TOKEN, vocab_size=None, mask_prob=0.15):
    input_ids = tokens.clone()
    labels = tokens.clone()
    mask = (torch.rand(tokens.shape) < mask_prob) & (tokens != pad_token)

    for i in range(tokens.size(0)):
        for j in range(tokens.size(1)):
            if mask[i, j]:
                prob = random.random()
                if prob < 0.8:
                    input_ids[i, j] = mask_token
                elif prob < 0.9 and vocab_size is not None:
                    input_ids[i, j] = random.randint(4, vocab_size - 1)
                # else: keep the original token
            else:
                labels[i, j] = -100  # ignore index

    return input_ids, labels

class MLMDataset(Dataset):
    def __init__(self, sentences, vocab):
        self.vocab = vocab
        self.data = [torch.tensor(vocab.encode(sent)) for sent in sentences]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_mlm_batch(batch):
    max_len = max(len(x) for x in batch)
    padded = [F.pad(x, (0, max_len - len(x)), value=PAD_TOKEN) for x in batch]
    batch_tensor = torch.stack(padded)
    return batch_tensor

# ========== Positional Encoding ==========

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ========== BERT Encoder Block ==========

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        def split(x):
            return x.view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        Q, K, V = split(self.W_q(Q)), split(self.W_k(K)), split(self.W_v(V))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        x = torch.matmul(attn, V)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
        return self.W_o(x)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ffn(x))
        return x

# ========== BERT Model ==========

class BERT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.encoder = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x != PAD_TOKEN).unsqueeze(1).unsqueeze(2)
        x = self.pos_enc(self.embed(x))
        for block in self.encoder:
            x = block(x, mask)
        return self.fc(x)

# ========== Train Loop ==========

def train_mlm_epoch(model, loader, loss_fn, optimizer, vocab_size, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        masked_input, labels = mask_tokens(batch.clone(), vocab_size=vocab_size)
        masked_input, labels = masked_input.to(device), labels.to(device)
        logits = model(masked_input)
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# ========== Run Training ==========

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
texts = ["i like to eat", "you are good", "hello world test"] * 300
vocab = Vocab()
for s in texts: vocab.add_sentence(s)
dataset = MLMDataset(texts, vocab)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_mlm_batch)

model = BERT(vocab_size=len(vocab), d_model=128, num_heads=4, d_ff=256, num_layers=2, max_len=64).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(1, 6):
    loss = train_mlm_epoch(model, loader, loss_fn, optimizer, len(vocab), device)
    print(f"Epoch {epoch} | Loss: {loss:.4f}")

Epoch 1 | Loss: 2.0747
Epoch 2 | Loss: 0.9880
Epoch 3 | Loss: 0.4279
Epoch 4 | Loss: 0.2949
Epoch 5 | Loss: nan
