In [1]:
import torch
import torch.nn as nn
import math
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
batch_size = 2
seq_len = 10
vocab_size = 1000

input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
segment_ids = torch.zeros_like(input_ids)

In [4]:
print(input_ids)

tensor([[541, 491, 792, 511, 165, 300, 219, 116, 372,  47],
        [772,  23,  15, 847, 558,  44, 452, 171, 598, 578]])


In [5]:
print(segment_ids)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])


### **Positional Encoding**

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len = 512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        postion = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(postion * div_term)
        pe[:, 1::2] = torch.cos(postion * div_term)
        self.pe = pe.unsqueeze(0) # [1, max_len, d_model]

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(device)
        return x

In [7]:
embedding = nn.Embedding(vocab_size, 128)

In [13]:
embedded = embedding(input_ids).to(device)

In [14]:
embedded.shape

torch.Size([2, 10, 128])

In [15]:
pe = PositionalEncoding(128, 512)

In [16]:
emb_pe = pe(embedded)

In [17]:
emb_pe.shape

torch.Size([2, 10, 128])

### **Multi-Head Attention**

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.qkv_linear = nn.Linear(d_model, d_model * 3)
        self.out_linear = nn.Linear(d_model, d_model)

    def forward(self, x, mask = None):
        # B = Batch size, T = Seq len (Time steps), C = Channel Dimension, Eg: (32, 512, 768)
        B, T, C = x.size()
        qkv = self.qkv_linear(x) # [B, T, 3*C]
        qkv = qkv.reshape(B, T, self.num_heads, 3 * self.d_k).permute(2, 0, 1, 3)
        q, k, v = qkv.chunk(3, dim = -1) # Each: [num_heads, B, T, d_k]

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim = -1)

        attn_output = torch.matmul(attn_probs, v) # [num_heads, B, T, d_k]
        attn_output = attn_output.permute(1, 2, 0, 3).reshape(B, T, C)

        return self.out_linear(attn_output)

In [23]:
attention = MultiHeadAttention(128, 8).to(device)

In [24]:
attn_output = attention(emb_pe)

In [25]:
print(f"Input shape: {input_ids.shape}")
print(f"\nPositional Encoding output shape: {emb_pe.shape}")
print(f"\nAttention output shape: {attn_output.shape}")
print(f"\nAttention output: {attn_output}")

Input shape: torch.Size([2, 10])

Positional Encoding output shape: torch.Size([2, 10, 128])

Attention output shape: torch.Size([2, 10, 128])

Attention output: tensor([[[-6.5324e-02, -7.7998e-02,  2.5061e-01,  ...,  5.8333e-02,
           3.1619e-02,  2.1055e-01],
         [-2.0257e-03, -5.8401e-02,  2.5697e-01,  ...,  1.9608e-02,
           1.1479e-01,  1.5606e-01],
         [-9.2151e-02, -5.0112e-02,  1.8911e-01,  ...,  2.7027e-02,
           8.1677e-02,  1.5180e-01],
         ...,
         [-1.3789e-04, -3.2700e-02,  2.1351e-01,  ...,  9.3581e-02,
           7.5123e-02,  1.8809e-01],
         [-2.4918e-02, -4.4362e-02,  2.1374e-01,  ...,  1.0541e-01,
           6.4649e-02,  1.3687e-01],
         [-1.7939e-02, -2.8720e-02,  1.9165e-01,  ...,  5.9281e-02,
          -2.3819e-02,  1.7880e-01]],

        [[-5.7902e-02,  1.2457e-01,  1.4111e-01,  ...,  2.0201e-01,
          -2.6840e-01,  1.9771e-01],
         [-6.5448e-02,  9.0877e-02,  1.5637e-01,  ...,  2.4717e-01,
          -1.3947e-

### **Feed Forward Network**

In [26]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.ff(x)

In [28]:
ff = FeedForward(128, 512).to(device)
ff_output = ff(attn_output)

In [29]:
print(f"Feed Forward output shape: {ff_output.shape}")
print(f"\nFeed Forward output: {ff_output}")

Feed Forward output shape: torch.Size([2, 10, 128])

Feed Forward output: tensor([[[ 0.1877, -0.0428, -0.0164,  ..., -0.0086, -0.0311,  0.0375],
         [ 0.1833, -0.0252,  0.0047,  ...,  0.0034, -0.0245,  0.0504],
         [ 0.1900, -0.0169, -0.0023,  ..., -0.0194, -0.0372,  0.0259],
         ...,
         [ 0.1930, -0.0229, -0.0199,  ..., -0.0052, -0.0354,  0.0431],
         [ 0.1814, -0.0200, -0.0168,  ..., -0.0088, -0.0326,  0.0217],
         [ 0.1758, -0.0322, -0.0145,  ..., -0.0144, -0.0503,  0.0306]],

        [[ 0.1706, -0.0529, -0.0055,  ..., -0.0255, -0.1066,  0.0440],
         [ 0.1818, -0.0448, -0.0061,  ..., -0.0303, -0.0932,  0.0502],
         [ 0.1761, -0.0322, -0.0058,  ..., -0.0444, -0.1125,  0.0459],
         ...,
         [ 0.1649, -0.0473, -0.0067,  ..., -0.0308, -0.1094,  0.0317],
         [ 0.1559, -0.0470, -0.0181,  ..., -0.0253, -0.1057,  0.0284],
         [ 0.1610, -0.0459, -0.0148,  ..., -0.0230, -0.1073,  0.0448]]],
       device='cuda:0', grad_fn=<ViewBackw

### **Transformer Encoder Block**

In [30]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask = None):
        attn_out = self.attn(x, mask)

        x = self.norm1(x + attn_out)

        ff_out = self.ff(x)

        x = self.norm2(x + ff_out)

        return x

In [32]:
t_out = TransformerBlock(128, 8, 512).to(device)
t_out(ff_output).shape

torch.Size([2, 10, 128])

## **BERT ENCODER**

In [33]:
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model = 128, num_heads = 4, d_ff = 512, num_layers = 2, max_len = 512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.segment_emb = nn.Embedding(2, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.encoder_layer = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, input_ids, segment_ids, mask = None):
        x = self.token_emb(input_ids) + self.segment_emb(segment_ids)

        x = self.pos_emb(x)

        for layer in self.encoder_layer:
            x = layer(x, mask)

        return self.norm(x) # Final Embeddings

In [39]:
model = BERT(vocab_size).to(device)

In [36]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [40]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                   Param #
BERT                                     --
├─Embedding: 1-1                         128,000
├─Embedding: 1-2                         256
├─PositionalEncoding: 1-3                --
├─ModuleList: 1-4                        --
│    └─TransformerBlock: 2-1             --
│    │    └─MultiHeadAttention: 3-1      66,048
│    │    └─FeedForward: 3-2             131,712
│    │    └─LayerNorm: 3-3               256
│    │    └─LayerNorm: 3-4               256
│    └─TransformerBlock: 2-2             --
│    │    └─MultiHeadAttention: 3-5      66,048
│    │    └─FeedForward: 3-6             131,712
│    │    └─LayerNorm: 3-7               256
│    │    └─LayerNorm: 3-8               256
├─LayerNorm: 1-5                         256
Total params: 525,056
Trainable params: 525,056
Non-trainable params: 0

In [43]:
input_ids = input_ids.to(device)
segment_ids = segment_ids.to(device)

In [44]:
output = model(input_ids, segment_ids)

In [45]:
print(output.shape)

torch.Size([2, 10, 128])


### **Masked Language Modeling (MLM) Training**

In [46]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
from datasets import load_dataset
import random

In [47]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [48]:
# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")  # small portion for demo

README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [49]:
dataset

Dataset({
    features: ['text'],
    num_rows: 367
})

In [50]:
# Tokenize dataset
def tokenize(example):
    return tokenizer(example['text'], truncation=True, padding="max_length", max_length=128)

In [51]:
tokenized = dataset.map(tokenize, batched=True)
tokenized.set_format(type='torch', columns=['input_ids', 'attention_mask'])

Map:   0%|          | 0/367 [00:00<?, ? examples/s]

In [52]:
# Mask tokens for MLM
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)

    # Handle both single sequences and batches
    if inputs.dim() == 1:
        # Single sequence - convert to list of lists
        labels_list = [labels.tolist()]
    else:
        # Batch of sequences
        labels_list = labels.tolist()

    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels_list
    ]
    special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)

    # If it was a single sequence, squeeze the mask back to 1D
    if inputs.dim() == 1:
        special_tokens_mask = special_tokens_mask.squeeze(0)

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # only compute loss on masked tokens

    # Replace 80% with [MASK]
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    # Replace 10% with random token
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    return inputs, labels

In [53]:
from torch.utils.data import Dataset, DataLoader

class MLMDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_dataset):
        self.dataset = tokenized_dataset

    def __getitem__(self, idx):
        item = self.dataset[idx]
        input_ids, labels = mask_tokens(item['input_ids'].clone(), tokenizer)
        return {
            "input_ids": input_ids,
            "attention_mask": item["attention_mask"],
            "labels": labels
        }

    def __len__(self):
        return len(self.dataset)

In [54]:
mlm_dataset = MLMDataset(tokenized)
dataloader = DataLoader(mlm_dataset, batch_size=16, shuffle=True)

In [55]:
from transformers import BertConfig, BertForMaskedLM

# Define BERT model from scratch
config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=256,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=512,
    type_vocab_size=2
)

In [56]:
model = BertForMaskedLM(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

#### **Train the model**

In [58]:
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
epochs = 100

for epoch in range(epochs):
    loop = tqdm(dataloader, leave=True)
    for batch in loop:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

Epoch 0: 100%|██████████| 23/23 [00:04<00:00,  5.59it/s, loss=9.39]
Epoch 1: 100%|██████████| 23/23 [00:05<00:00,  3.89it/s, loss=8.9]
Epoch 2: 100%|██████████| 23/23 [00:03<00:00,  6.23it/s, loss=8.33]
Epoch 3: 100%|██████████| 23/23 [00:04<00:00,  5.63it/s, loss=8.03]
Epoch 4: 100%|██████████| 23/23 [00:03<00:00,  6.59it/s, loss=7.83]
Epoch 5: 100%|██████████| 23/23 [00:03<00:00,  6.56it/s, loss=7.74]
Epoch 6: 100%|██████████| 23/23 [00:03<00:00,  6.25it/s, loss=7.41]
Epoch 7: 100%|██████████| 23/23 [00:03<00:00,  6.22it/s, loss=7.02]
Epoch 8: 100%|██████████| 23/23 [00:03<00:00,  6.75it/s, loss=6.73]
Epoch 9: 100%|██████████| 23/23 [00:03<00:00,  6.79it/s, loss=6.91]
Epoch 10: 100%|██████████| 23/23 [00:03<00:00,  5.88it/s, loss=7.27]
Epoch 11: 100%|██████████| 23/23 [00:03<00:00,  6.78it/s, loss=6.93]
Epoch 12: 100%|██████████| 23/23 [00:03<00:00,  6.80it/s, loss=6.82]
Epoch 13: 100%|██████████| 23/23 [00:03<00:00,  5.94it/s, loss=6.79]
Epoch 14: 100%|██████████| 23/23 [00:03<00:00

In [59]:
def predict_masked(model, tokenizer, text):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits

    masked_index = torch.where(inputs["input_ids"][0] == tokenizer.mask_token_id)[0]
    predicted_index = predictions[0, masked_index].argmax(dim=-1).item()
    predicted_token = tokenizer.decode([predicted_index])
    return predicted_token

masked_text = "Paris is the [MASK] of France."
predicted = predict_masked(model, tokenizer, masked_text)
print(f"Prediction: {predicted}")

Prediction: the
