In [4]:
import torch 
import torch.nn as nn 
import math 
import numpy as np 

In [5]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
device

'cpu'

In [6]:
batch_size = 2
seq_len = 10
vocab_size = 1000

input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
segment_ids = torch.zeros_like(input_ids)

In [7]:
print(input_ids)

tensor([[ 50, 716, 794, 493, 784, 572, 800, 875, 506, 124],
        [ 12,   2, 792, 652, 713, 960, 863, 369, 309, 218]])


In [8]:
print(segment_ids)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])


### **Positional Encoding**

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len = 512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        postion = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(postion * div_term)
        pe[:, 1::2] = torch.cos(postion * div_term)
        self.pe = pe.unsqueeze(0) # [1, max_len, d_model]
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(device)
        return x

In [10]:
embedding = nn.Embedding(vocab_size, 128)

In [11]:
embedded = embedding(input_ids)

In [12]:
embedded.shape

torch.Size([2, 10, 128])

In [13]:
pe = PositionalEncoding(128, 512)

In [14]:
emb_pe = pe(embedded)

In [15]:
emb_pe.shape

torch.Size([2, 10, 128])

### **Multi-Head Attention**

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads 

        self.qkv_linear = nn.Linear(d_model, d_model * 3)
        self.out_linear = nn.Linear(d_model, d_model)

    def forward(self, x, mask = None):
        # B = Batch size, T = Seq len (Time steps), C = Channel Dimension, Eg: (32, 512, 768)
        B, T, C = x.size()
        qkv = self.qkv_linear(x) # [B, T, 3*C]
        qkv = qkv.reshape(B, T, self.num_heads, 3 * self.d_k).permute(2, 0, 1, 3)
        q, k, v = qkv.chunk(3, dim = -1) # Each: [num_heads, B, T, d_k]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim = -1)

        attn_output = torch.matmul(attn_probs, v) # [num_heads, B, T, d_k]
        attn_output = attn_output.permute(1, 2, 0, 3).reshape(B, T, C)

        return self.out_linear(attn_output)

In [17]:
attention = MultiHeadAttention(128, 8)

In [18]:
attn_output = attention(emb_pe)

In [19]:
print(f"Input shape: {input_ids.shape}")
print(f"\nPositional Encoding output shape: {emb_pe.shape}")
print(f"\nAttention output shape: {attn_output.shape}")
print(f"\nAttention output: {attn_output}")

Input shape: torch.Size([2, 10])

Positional Encoding output shape: torch.Size([2, 10, 128])

Attention output shape: torch.Size([2, 10, 128])

Attention output: tensor([[[-0.2319,  0.0961,  0.0783,  ...,  0.8825,  0.2695, -0.0159],
         [-0.1978,  0.1329,  0.0320,  ...,  0.8739,  0.2227, -0.0390],
         [-0.3533,  0.0251,  0.1252,  ...,  0.8176,  0.3017,  0.0048],
         ...,
         [-0.3245, -0.0046,  0.1592,  ...,  0.8660,  0.3185, -0.0040],
         [-0.2440,  0.0314,  0.1336,  ...,  0.8777,  0.3325,  0.0070],
         [-0.3664, -0.0396,  0.1701,  ...,  0.8739,  0.3205,  0.0310]],

        [[-0.4096, -0.0070,  0.3474,  ...,  0.5412,  0.1486,  0.1452],
         [-0.4538,  0.0604,  0.3740,  ...,  0.5406,  0.0938,  0.2153],
         [-0.3544,  0.0782,  0.3751,  ...,  0.5758,  0.1517,  0.1273],
         ...,
         [-0.4722, -0.0419,  0.4355,  ...,  0.5345,  0.1429,  0.1463],
         [-0.4467,  0.0085,  0.3513,  ...,  0.6026,  0.1257,  0.1807],
         [-0.3962,  0.0030,

### **Feed Forward Network**

In [20]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(), 
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.ff(x)

In [21]:
ff = FeedForward(128, 512)
ff_output = ff(attn_output)

In [22]:
print(f"Feed Forward output shape: {ff_output.shape}")
print(f"\nFeed Forward output: {ff_output}")

Feed Forward output shape: torch.Size([2, 10, 128])

Feed Forward output: tensor([[[-0.0416, -0.0623,  0.0645,  ..., -0.1878,  0.0410,  0.0408],
         [-0.0318, -0.0546,  0.0479,  ..., -0.1852,  0.0341,  0.0345],
         [-0.0352, -0.0527,  0.0350,  ..., -0.1869,  0.0435,  0.0709],
         ...,
         [-0.0364, -0.0540,  0.0355,  ..., -0.1757,  0.0371,  0.0700],
         [-0.0190, -0.0608,  0.0447,  ..., -0.1836,  0.0386,  0.0676],
         [-0.0147, -0.0628,  0.0462,  ..., -0.1825,  0.0373,  0.0671]],

        [[-0.0023, -0.0301,  0.0290,  ..., -0.1556,  0.0416,  0.0933],
         [ 0.0009, -0.0422,  0.0375,  ..., -0.1356,  0.0376,  0.0953],
         [-0.0076, -0.0373,  0.0165,  ..., -0.1505,  0.0478,  0.0827],
         ...,
         [ 0.0216, -0.0225,  0.0266,  ..., -0.1263,  0.0303,  0.0976],
         [-0.0020, -0.0524,  0.0256,  ..., -0.1433,  0.0227,  0.0753],
         [-0.0191, -0.0401,  0.0355,  ..., -0.1493,  0.0266,  0.0894]]],
       grad_fn=<ViewBackward0>)


### **Transformer Encoder Block**

In [23]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask = None):
        attn_out = self.attn(x, mask)

        x = self.norm1(x + attn_out)

        ff_out = self.ff(x)

        x = self.norm2(x + ff_out)

        return x

In [24]:
t_out = TransformerBlock(128, 8, 512)
t_out(ff_output).shape

torch.Size([2, 10, 128])

## **BERT ENCODER**

In [44]:
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model = 128, num_heads = 4, d_ff = 512, num_layers = 2, max_len = 512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.segment_emb = nn.Embedding(2, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.encoder_layer = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, input_ids, segment_ids, mask = None):
        x = self.token_emb(input_ids) + self.segment_emb(segment_ids)

        x = self.pos_emb(x)

        for layer in self.encoder_layer:
            x = layer(x, mask)

        return self.norm(x) # Final Embeddings

In [45]:
model = BERT(vocab_size)

In [46]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                   Param #
BERT                                     --
├─Embedding: 1-1                         128,000
├─Embedding: 1-2                         256
├─PositionalEncoding: 1-3                --
├─ModuleList: 1-4                        --
│    └─TransformerBlock: 2-1             --
│    │    └─MultiHeadAttention: 3-1      66,048
│    │    └─FeedForward: 3-2             131,712
│    │    └─LayerNorm: 3-3               256
│    │    └─LayerNorm: 3-4               256
│    └─TransformerBlock: 2-2             --
│    │    └─MultiHeadAttention: 3-5      66,048
│    │    └─FeedForward: 3-6             131,712
│    │    └─LayerNorm: 3-7               256
│    │    └─LayerNorm: 3-8               256
├─LayerNorm: 1-5                         256
Total params: 525,056
Trainable params: 525,056
Non-trainable params: 0

In [47]:
output = model(input_ids, segment_ids)

In [48]:
print(output.shape)

torch.Size([2, 10, 128])


### **Masked Language Modeling (MLM) Training**

In [49]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from datasets import load_dataset
import random

  from .autonotebook import tqdm as notebook_tqdm


In [50]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [51]:
# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")  # small portion for demo

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 205728.56 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 1660160.56 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 231420.06 examples/s]


In [52]:
dataset

Dataset({
    features: ['text'],
    num_rows: 367
})

In [54]:
# Tokenize dataset
def tokenize(example):
    return tokenizer(example['text'], truncation=True, padding="max_length", max_length=128)

In [55]:
tokenized = dataset.map(tokenize, batched=True)
tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])

Map: 100%|██████████| 367/367 [00:00<00:00, 2016.49 examples/s]


In [79]:
# Mask tokens for MLM
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)
    
    # Handle both single sequences and batches
    if inputs.dim() == 1:
        # Single sequence - convert to list of lists
        labels_list = [labels.tolist()]
    else:
        # Batch of sequences
        labels_list = labels.tolist()
    
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels_list
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    
    # If it was a single sequence, squeeze the mask back to 1D
    if inputs.dim() == 1:
        special_tokens_mask = special_tokens_mask.squeeze(0)
    
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # only compute loss on masked tokens

    # Replace 80% with [MASK]
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # Replace 10% with random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    return inputs, labels

In [80]:
from torch.utils.data import Dataset, DataLoader

class MLMDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_dataset):
        self.dataset = tokenized_dataset
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        input_ids, labels = mask_tokens(item['input_ids'].clone(), tokenizer) 
        return {
            "input_ids": input_ids, 
            "attention_mask": item["attention_mask"], 
            "labels": labels
        }
    
    def __len__(self):
        return len(self.dataset)

In [81]:
mlm_dataset = MLMDataset(tokenized)
dataloader = DataLoader(mlm_dataset, batch_size=16, shuffle=True)

In [82]:
from transformers import BertConfig, BertForMaskedLM

# Define BERT model from scratch
config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=512,
    type_vocab_size=2
)

In [83]:
model = BertForMaskedLM(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

#### **Train the model**

In [89]:
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
epochs = 50

for epoch in range(epochs):
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

  0%|          | 0/23 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 23/23 [00:52<00:00,  2.26s/it, loss=6.82]
Epoch 1: 100%|██████████| 23/23 [00:52<00:00,  2.30s/it, loss=6.51]
Epoch 2: 100%|██████████| 23/23 [00:51<00:00,  2.25s/it, loss=6.61]
Epoch 3: 100%|██████████| 23/23 [00:52<00:00,  2.28s/it, loss=6.58]
Epoch 4: 100%|██████████| 23/23 [00:51<00:00,  2.26s/it, loss=6.15]
Epoch 5: 100%|██████████| 23/23 [00:54<00:00,  2.35s/it, loss=6.13]
Epoch 6: 100%|██████████| 23/23 [00:52<00:00,  2.29s/it, loss=6.36]
Epoch 7: 100%|██████████| 23/23 [00:54<00:00,  2.37s/it, loss=5.93]
Epoch 8: 100%|██████████| 23/23 [00:54<00:00,  2.37s/it, loss=6.32]
Epoch 9: 100%|██████████| 23/23 [00:51<00:00,  2.26s/it, loss=5.97]
Epoch 10: 100%|██████████| 23/23 [00:52<00:00,  2.30s/it, loss=7.01]
Epoch 11: 100%|██████████| 23/23 [00:53<00:00,  2.32s/it, loss=6.09]
Epoch 12: 100%|██████████| 23/23 [00:52<00:00,  2.28s/it, loss=6.96]
Epoch 13: 100%|██████████| 23/23 [00:52<00:00,  2.28s/it, loss=5.81]
Epoch 14: 100%|██████████| 23/23 [00:52<00:0

KeyboardInterrupt: 

In [None]:
def predict_masked(model, tokenizer, text):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    masked_index = torch.where(inputs["input_ids"][0] == tokenizer.mask_token_id)[0]
    predicted_index = predictions[0, masked_index].argmax(dim=-1).item()
    predicted_token = tokenizer.decode([predicted_index])
    return predicted_token

masked_text = "Paris is the [MASK] of France."
predicted = predict_masked(model, tokenizer, masked_text)
print(f"Prediction: {predicted}")

Prediction: the
