# A toy implementation of BERT from scratch to build understanding of encoder transformer architecture and it's comparison to GPT-1.
### Papers referenced: BERT, Devlin et. al & Attention is all you need Vaswani et. al

In [35]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizerFast, BertForPreTraining

**Data Manipulation for Model Training**

Here we create the textual and positional embeddings of the data we need to train our BERT on.

In [36]:
#Use cuda if available or else stay on CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [37]:
#A pre-trained tokenizer to tokenize the input given to the model
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [38]:

"""
Total Model Parameters: 33.67M
"""
n_layers = 6
n_heads = 6
block_size = 32
batch_size = 8
n_embd = 384
dropout = 0.1
max_iters = 4500
MAX_SEQUENCE_LENGTH = 128
vocab_size = len(tokenizer.get_vocab())
print(vocab_size)

30522


In [39]:
#Dataset preparation
PATH = "/kaggle/input/shakespeare6/shakespeare.txt"
WIKI_PATH = "/kaggle/input/wikitext2-data/train.txt"
with open(WIKI_PATH, 'r', encoding='utf-8') as f:
    text = f.read()

In [40]:
import torch
import random

def create_pretraining_batch(sentences, tokenizer, batch_size=32, max_length=64, mask_prob=0.15, device="cpu"):
    """
    Create a single batch for BERT pretraining using BertTokenizerFast + BertForPreTraining.
    
    Returns a dict with:
        - input_ids
        - attention_mask
        - token_type_ids
        - labels_mlm
        - labels_nsp
    """
    
    sentence_a_batch = []
    sentence_b_batch = []
    nsp_labels = []
    i = 0
    while len(sentence_a_batch) < batch_size and i < len(sentences) - 1:
        a = sentences[i]

        if random.random() < 0.5 and i < len(sentences) - 1:
            # IsNext
            b = sentences[i + 1]
            label = 0
        else:
            # NotNext
            rand_idx = random.randint(0, len(sentences) - 1)
            while rand_idx in [i, i + 1]:
                rand_idx = random.randint(0, len(sentences) - 1)
            b = sentences[rand_idx]
            label = 1

        sentence_a_batch.append(a)
        sentence_b_batch.append(b)
        nsp_labels.append(label)
        i += 1

    labels_nsp = torch.tensor(nsp_labels, dtype=torch.long, device=device)

    # Step 2: Tokenize batch
    encoding = tokenizer(
        list(zip(sentence_a_batch, sentence_b_batch)),
        padding="max_length",
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )

    input_ids = encoding["input_ids"].to(device)
    token_type_ids = encoding["token_type_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    # Step 3: Create MLM labels
    labels_mlm = input_ids.clone()
    masked_input = input_ids.clone()

    vocab_size = tokenizer.vocab_size
    mask_token_id = tokenizer.mask_token_id
    cls_id = tokenizer.cls_token_id
    sep_id = tokenizer.sep_token_id
    pad_id = tokenizer.pad_token_id

    # Initialize labels: -100 for all positions first
    labels_mlm[:] = -100

    for i in range(masked_input.size(0)):
        tokens = masked_input[i]

        # Candidate positions for masking (exclude special tokens)
        candidate_mask = (tokens != cls_id) & (tokens != sep_id) & (tokens != pad_id)
        candidate_indices = candidate_mask.nonzero(as_tuple=True)[0]

        if len(candidate_indices) == 0:
            continue

        # Number of tokens to mask
        num_to_mask = max(1, int(mask_prob * len(candidate_indices)))
        mask_indices = candidate_indices[torch.randperm(len(candidate_indices))[:num_to_mask]]

        # Fill MLM labels at masked positions
        labels_mlm[i, mask_indices] = input_ids[i, mask_indices]

        # Apply masking 80% [MASK], 10% random token, 10% original
        for idx in mask_indices:
            prob = random.random()
            if prob < 0.8:
                tokens[idx] = mask_token_id
            elif prob < 0.9:
                tokens[idx] = random.randint(0, vocab_size - 1)
            # else: leave original

        masked_input[i] = tokens

    return {
        "input_ids": masked_input,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
        "labels_mlm": labels_mlm,
        "labels_nsp": labels_nsp
    }


**Creating the Attention Mechanism for the BERT Model**

In [41]:
class AttentionHead(nn.Module):
        """
        The Self-Attention Head:
        Pass query, key and value through the weights initialized.
        self-attention = softmax([Q @ K^T/root(d_k)]) @ V
        Multiply Query and Key values, normalize them by the size of key, apply softmax
        to get attention weights and then perform matmul
        between the obtained attn_weights and the values.
        The output is a weighted sum of values.
        """
        def __init__(self, head_size):
            super(AttentionHead, self).__init__()
            self.w_q = nn.Linear(n_embd, head_size)
            self.w_k = nn.Linear(n_embd, head_size)
            self.w_v = nn.Linear(n_embd, head_size)
            self.dropout = nn.Dropout(dropout)
        def forward(self,x):
            q = self.w_q(x) 
            k = self.w_k(x) 
            v = self.w_v(x) 

            scores = q @ k.transpose(-2,-1)
            scores = scores * (k.size(-1)**-0.5)
            attention_weights = F.softmax(scores, dim=-1) 
            attention_weights = self.dropout(attention_weights)
            output = attention_weights @ v
            return output

In [42]:
class MultiHeadAttention(nn.Module):
    """
    Apply Self-Attention using multiple heads over the input x.
    """
    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [43]:
class FeedForward(nn.Module):
    """
    The FeedForward Network takes the input x, it scales it upto 4 times the embedding dimensions.
    Applies reLU activation
    Downscale the activated input back to its original embedding.
    """
    def __init__(self, n_embd):
        super(FeedForward, self).__init__()
        self.FF = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        out = self.FF(x)
        return out

In [44]:
class EncoderBlock(nn.Module):
    """
    
    """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd) 
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(x) 
        x = self.ln1(x)    

        x = x + self.ffwd(x) 
        x = self.ln2(x)    
        
        return x

In [45]:
class BERT(nn.Module):
    def __init__(self, vocab_size, n_embd, block_size, n_head, n_layer):
        super(BERT, self).__init__()
        self.token_embeddings = nn.Embedding(vocab_size, n_embd)
        self.pos_embeddings = nn.Embedding(block_size, n_embd)
        self.seg_embeddings = nn.Embedding(2, n_embd)
        self.enc_layers = nn.Sequential(
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            nn.LayerNorm(n_embd)
        )
        self.mlm_head = nn.Linear(n_embd, vocab_size)
        self.nsp_head = nn.Linear(n_embd, 2)
    def forward(self,idx, segment_ids, targets=None, nsp_labels=None):
        B,T=idx.shape
        token_emb = self.token_embeddings(idx)
        pos_emb = self.pos_embeddings(torch.arange(T, device=device))
        seg_emb = self.seg_embeddings(segment_ids)
        x = token_emb+pos_emb+seg_emb
        x = self.enc_layers(x)
        mlm_logits = self.mlm_head(x)
        cls_token_hidden = x[:, 0] 
        nsp_logits = self.nsp_head(cls_token_hidden)
        total_loss, mlm_loss, nsp_loss = None, None, None
        if targets is not None:
            mlm_loss = F.cross_entropy(
                mlm_logits.view(-1, mlm_logits.size(-1)),
                targets.view(-1),
                ignore_index=-100 
            )
        if nsp_labels is not None:
            nsp_loss = F.cross_entropy(nsp_logits, nsp_labels)

        if mlm_loss is not None and nsp_loss is not None:
            total_loss = mlm_loss + nsp_loss

        return mlm_logits, nsp_logits, total_loss
        

In [46]:
model = BERT(vocab_size=vocab_size, n_embd=n_embd, block_size=MAX_SEQUENCE_LENGTH,n_head=n_heads, n_layer=n_layers)
m = model.to(device)

In [47]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

In [48]:
device = "cuda" if torch.cuda.is_available() else "cpu"

for step in range(max_iters):
    batch = create_pretraining_batch(text.splitlines(), tokenizer, batch_size=32, max_length=MAX_SEQUENCE_LENGTH)

    input_ids = batch["input_ids"].to(device)
    token_type_ids = batch["token_type_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels_mlm = batch["labels_mlm"].to(device)
    labels_nsp = batch["labels_nsp"].to(device)

    mlm_logits, nsp_logits, loss = model(
        idx=input_ids,
        segment_ids=token_type_ids,
        targets=labels_mlm,
        nsp_labels=labels_nsp
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(f"Step {step:04d} | Loss: {loss.item():.4f}")


Step 0000 | Loss: 11.1309
Step 0200 | Loss: 5.8234
Step 0400 | Loss: 4.6646
Step 0600 | Loss: 3.8559
Step 0800 | Loss: 3.0014
Step 1000 | Loss: 2.2054
Step 1200 | Loss: 1.4703
Step 1400 | Loss: 2.0974
Step 1600 | Loss: 2.5161
Step 1800 | Loss: 2.5515
Step 2000 | Loss: 2.0625
Step 2200 | Loss: 2.9385
Step 2400 | Loss: 1.3653
Step 2600 | Loss: 2.0359
Step 2800 | Loss: 1.5512
Step 3000 | Loss: 2.1846
Step 3200 | Loss: 2.4339
Step 3400 | Loss: 1.9960
Step 3600 | Loss: 1.2103
Step 3800 | Loss: 1.6484
Step 4000 | Loss: 1.9599
Step 4200 | Loss: 1.4479
Step 4400 | Loss: 1.7401


In [49]:
"""
n_layers = 4
n_heads = 4
block_size = 8
batch_size = 8
n_embd = 384
dropout = 0.1
max_iters = 4500
MAX_SEQUENCE_LENGTH = 64
fINAL LOSS: Step 4400 | Loss: 1.4863
## Trained on : Shakespeare
lr = 5e-4

PART 2:
n_layers = 6
n_heads = 6
block_size = 32
batch_size = 8
n_embd = 384
dropout = 0.1
max_iters = 4500
MAX_SEQUENCE_LENGTH = 128
vocab_size = len(tokenizer.get_vocab())
print(vocab_size)
Step 4400 | Loss: 2.2908
LR = 1e-4
## Trained on : Shakespeare

PART 3:
n_layers = 6
n_heads = 6
block_size = 32
batch_size = 8
n_embd = 384
dropout = 0.1
max_iters = 4500
MAX_SEQUENCE_LENGTH = 128
vocab_size = len(tokenizer.get_vocab())
print(vocab_size)
Step 4400 | Loss: 2.2908
LR = 1e-4
## Trained on: wiki-text2

"""

'\nn_layers = 4\nn_heads = 4\nblock_size = 8\nbatch_size = 8\nn_embd = 384\ndropout = 0.1\nmax_iters = 4500\nMAX_SEQUENCE_LENGTH = 64\nfINAL LOSS: Step 4400 | Loss: 1.4863\n## Trained on : Shakespeare\nlr = 5e-4\n\nPART 2:\nn_layers = 6\nn_heads = 6\nblock_size = 32\nbatch_size = 8\nn_embd = 384\ndropout = 0.1\nmax_iters = 4500\nMAX_SEQUENCE_LENGTH = 128\nvocab_size = len(tokenizer.get_vocab())\nprint(vocab_size)\nStep 4400 | Loss: 2.2908\nLR = 1e-4\n'

In [52]:
torch.save(model.state_dict(), "bert__wiki_weights.pth")