In [2]:
import torch
from torch import nn
import os
import random
import math
import collections
import time
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

  cpu = _conversion_method_template(device=torch.device("cpu"))


ModuleNotFoundError: No module named 'matplotlib'

Code block below handles the **Tokenizer** and **Vocabulary** logic. This is the first step in the NLP pipeline: converting raw human-readable text (strings) into machine-readable integers (indices).

#### **1. Purpose and Function**

1. **Tokenization:** The `tokenize` function splits a raw string ("The cat sat.") into a list of discrete units or tokens (`["the", "cat", "sat", "."]`). This defines the fundamental unit of meaning for the model.
2. **Vocabulary Building:** The `Vocab` class constructs a bijection (two-way mapping) between tokens and integer IDs.
* **Forward Mapping (`token_to_idx`):** Converts "apple"  452. This is used to encode input text.
* **Reverse Mapping (`idx_to_token`):** Converts 452  "apple". This is used to decode model predictions back into text.


3. **Special Token Handling:** It automatically manages reserved tokens that are critical for BERT and other Transformer models:
* `<pad>`: Used to make all sequences the same length.
* `<mask>`: Used for the Masked Language Modeling task.
* `<cls>`: The "Classification" token, used as the aggregate representation of the sequence.
* `<sep>`: The "Separator" token, used to delimit sentences in pairs.

---

#### **2. Detailed Theoretical Breakdown**

**A. Tokenization (`tokenize` function)**
The code uses a simple whitespace-based tokenizer: `text.strip().lower().split()`.

* **Lowercasing:** `lower()` converts "The" and "the" to the same token. This reduces the vocabulary size significantly, making the model more efficient, though it loses the distinction between proper nouns (e.g., "Apple" the company vs. "apple" the fruit).
* **Splitting:** `split()` breaks the text on spaces.
* *Limitation:* This simple tokenizer treats punctuation attached to words as part of the word (e.g., "end." is different from "end"). Real-world BERT uses **WordPiece** tokenization, which splits "playing" into "play" + "##ing" to handle rare words and punctuation better.


**B. The Vocabulary (`Vocab` class)**
The vocabulary is the set of all unique tokens the model knows.

* **Frequency Filtering (`min_freq=5`):** In any corpus, many words appear only once or twice (Zipf's Law). Learning embeddings for these extremely rare words is difficult because there are too few examples to update their weights meaningfully. Furthermore, they bloat the embedding matrix size.
* **Theory:** By filtering out words with frequency , we reduce the parameter count and force the model to focus on generalizable patterns.
* **The `<unk>` Token:** Words that are filtered out (or unseen during testing) are mapped to the special `<unk>` (Unknown) token. This ensures the model doesn't crash when it encounters a word it hasn't seen before; it just treats it as "some generic unknown word."


**C. The Padding Token (`<pad>`)**
Deep learning models process data in batches (e.g., 32 sentences at a time). To stack 32 sentences into a single matrix, they must all have the same length.

* **Theory:** If Sentence A has 10 words and Sentence B has 15, we append 5 `<pad>` tokens to Sentence A.
* **Masking:** During training, we use a "padding mask" to tell the attention mechanism to *ignore* these `<pad>` tokens so they don't affect the meaning of the sentence.

---

#### **3. Key Code Lines and Their Roles**

**1. Reserved Tokens Initialization**

```python
self.idx_to_token = list(reserved_tokens)
self.token_to_idx = {token: idx for idx, token in enumerate(reserved_tokens)}

```

* **Role:** This ensures that the special tokens (`<pad>`, `<mask>`, etc.) are always assigned the first few indices (0, 1, 2, 3). This is crucial because many models hardcode the index `0` for padding. If `<pad>` were assigned index 543, it would complicate the masking logic later.

**2. Counting Token Frequencies**

```python
counter = collections.Counter()
# ...
counter.update(tokenize(sentence))

```

* **Role:** `collections.Counter` is a highly optimized hash map that counts unique elements. It iterates through the entire dataset to build the frequency distribution needed for filtering.

**3. Frequency Filtering Logic**

```python
if freq >= min_freq and token not in self.token_to_idx:
    self.idx_to_token.append(token)
    self.token_to_idx[token] = len(self.idx_to_token) - 1

```

* **Role:** This is the gatekeeper. It checks two conditions:
1. Is the word common enough? (`freq >= min_freq`)
2. Is it already in the vocab? (To prevent duplicates if it was in `reserved_tokens`)
Only if both pass is the word added to the official vocabulary.



**4. Handling Unknown Words (`__getitem__`)**

```python
return self.token_to_idx.get(tokens, self.unk)

```

* **Role:** This is the runtime lookup. When converting text to indices, `dict.get(key, default)` tries to find the word. If the word isn't found (because it was rare and filtered out), it returns `self.unk` (the index for `<unk>`). This makes the tokenizer robust to out-of-vocabulary (OOV) words.

In [None]:
def tokenize(text):
    """Simple word-level tokenization."""
    return text.strip().lower().split()

class Vocab:
    """Vocabulary class for token-to-index mapping."""
    def __init__(self, sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>']):
        self.idx_to_token = list(reserved_tokens)
        self.token_to_idx = {token: idx for idx, token in enumerate(reserved_tokens)}
        
        # Build counter for all tokens
        counter = collections.Counter()
        for sentence in sentences:
            if isinstance(sentence, list): # Handle pre-tokenized input
                counter.update(sentence)
            else:
                counter.update(tokenize(sentence))
        
        # Add tokens above min_freq
        for token, freq in counter.items():
            if freq >= min_freq and token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
        
        self.pad = self.token_to_idx['<pad>']
        self.unk = self.token_to_idx.get('<unk>', 0)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def __len__(self):
        return len(self.idx_to_token)

Below code block implements the **BERT Architecture**, which is fundamentally a bidirectional Transformer Encoder. It translates the theoretical design of BERT (Input Embeddings + Stacked Encoder Layers) into a PyTorch model.

#### **1. Purpose and Function**

1. **`EncoderBlock`:** Represents a single layer of the Transformer. It contains the **Self-Attention** mechanism (to mix information between tokens) and a **Feed-Forward Network** (to process information individually), wrapped in residual connections and layer normalization.
2. **`BERTEncoder`:** Stacks  of these blocks on top of each other. It also handles the complex input representation (Token + Segment + Position Embeddings) that is unique to BERT.

---

#### **2. Detailed Theoretical Breakdown**

**A. The Encoder Block (`EncoderBlock`)**
This is the standard building block of the Transformer (Vaswani et al., 2017).

1. **Multi-Head Self-Attention:** Allows every token to look at every other token in the sequence. "Multi-head" means it runs  attention mechanisms in parallel, allowing the model to focus on different types of relationships (e.g., Head 1 focuses on syntax, Head 2 on semantics).
2. **Add & Norm:**
* **Residual Connection (Add):** . This allows gradients to flow through the network easily, solving the vanishing gradient problem in deep networks.
* **Layer Normalization (Norm):** Normalizes the features for each token to have mean 0 and variance 1. This stabilizes training.


3. **Position-wise Feed-Forward Network (FFN):** A simple MLP applied to every token independently. It acts as a key-value memory that stores knowledge learned from the pre-training data.

**B. BERT's Input Embeddings**
BERT's power comes from its ability to handle pairs of sentences.

* **Token Embedding:** The standard vector lookup for words.
* **Segment Embedding:** A learned vector that signals "This token belongs to Sentence A" vs "Sentence B". This allows BERT to distinguish between the premise and hypothesis in NLI tasks.
* **Positional Embedding:** Unlike the original Transformer (which used fixed sine/cosine waves), BERT uses **learnable** positional embeddings. The model *learns* the best vector representation for "Position 1", "Position 2", etc.

**C. The Padding Mask**
Self-attention is . If we have a batch with sentences of length 10 and 50, we pad the short one with 40 zeros.

* If we don't mask these zeros, the attention mechanism will treat them as valid context (like a period or a stop word).
* **Masking Logic:** We create a boolean mask where `True` indicates padding. The attention mechanism sets the attention score for these positions to  (negative infinity). When Softmax is applied, , effectively forcing the model to ignore the padding.

---

#### **3. Key Code Lines and Their Roles**

**1. Multi-Head Attention**

```python
self.attention = nn.MultiheadAttention(..., batch_first=True)

```

* **Role:** This is the core engine. `batch_first=True` ensures the input tensors are `(Batch, Seq_Len, Dim)` rather than `(Seq_Len, Batch, Dim)`, which is more intuitive for NLP.

**2. Learnable Positional Embeddings**

```python
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

```

* **Role:** `nn.Parameter` tells PyTorch "treat this tensor as a trainable weight." Unlike `nn.Embedding`, this isn't a lookup table; it's a raw tensor added directly to the input. The `1` in the shape allows broadcasting across the batch dimension.

**3. Summing Embeddings**

```python
X = self.token_embedding(tokens) + self.segment_embedding(segments)
X = X + self.pos_embedding[:, :X.shape[1], :]

```

* **Role:** This implements BERT's input formula: . The slicing `[:X.shape[1], :]` ensures that if the current batch has sequence length 20, we only add the first 20 positional vectors.

**4. Padding Mask Construction**

```python
key_padding_mask[i, length:] = True

```

* **Role:** This loop builds the attention mask dynamically.
* `valid_lens` tells us the real length of each sentence (e.g., `[5, 8]`).
* For sentence `i`, indices from `length` to the end are set to `True` (ignore).
* This mask is passed to `self.attention`, preventing the model from cheating or getting confused by empty padding tokens.

In [3]:
class EncoderBlock(nn.Module):
    """Transformer Encoder Block implemented from scratch."""
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        # Multi-head attention mechanism
        # batch_first=True ensures input shape is (batch, seq, feature)
        self.attention = nn.MultiheadAttention(embed_dim=num_hiddens, 
                                               num_heads=num_heads, 
                                               dropout=dropout, 
                                               batch_first=True)
        # Add & Norm layer 1
        self.addnorm1 = nn.LayerNorm(num_hiddens)
        self.dropout1 = nn.Dropout(dropout)
        
        # Position-wise Feed-Forward Network
        self.ffn = nn.Sequential(
            nn.Linear(num_hiddens, ffn_num_hiddens),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_num_hiddens, num_hiddens)
        )
        # Add & Norm layer 2
        self.addnorm2 = nn.LayerNorm(num_hiddens)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, X, key_padding_mask=None):
        # Self-Attention
        # key_padding_mask: (Batch, Seq_Len) where True indicates padding (ignore)
        attn_output, _ = self.attention(X, X, X, key_padding_mask=key_padding_mask)
        X = self.addnorm1(X + self.dropout1(attn_output))
        
        # Feed Forward
        ffn_output = self.ffn(X)
        X = self.addnorm2(X + self.dropout2(ffn_output))
        return X

class BERTEncoder(nn.Module):
    """BERT Encoder: Embeddings + Stack of Transformer Blocks."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
                 num_blks, dropout, max_len=1000, **kwargs):
        super(BERTEncoder, self).__init__(**kwargs)
        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)
        self.blks = nn.ModuleList()
        for _ in range(num_blks):
            self.blks.append(EncoderBlock(num_hiddens, ffn_num_hiddens, num_heads, dropout))
        
        # Positional embedding is learnable in BERT
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        # Add embeddings
        X = self.token_embedding(tokens) + self.segment_embedding(segments)
        X = X + self.pos_embedding[:, :X.shape[1], :]
        
        # Create mask for padding tokens
        # 1 means ignore (pad), 0 means keep
        key_padding_mask = torch.zeros(tokens.shape, dtype=torch.bool, device=tokens.device)
        if valid_lens is not None:
            for i, length in enumerate(valid_lens):
                key_padding_mask[i, length:] = True
                
        for blk in self.blks:
            X = blk(X, key_padding_mask=key_padding_mask)
        return X

Code block below implements the **Prediction Heads** that are attached to the core BERT Encoder. These heads are responsible for transforming the encoder's hidden states into actual predictions for the two self-supervised learning tasks: **Masked Language Modeling (MLM)** and **Next Sentence Prediction (NSP)**.

#### **1. Purpose and Function**

1. **`MaskLM`:** Solves the "Cloze" task. It takes the hidden vectors corresponding to the masked positions and projects them back to the vocabulary size to predict the original word.
2. **`NextSentencePred`:** Solves the binary classification task. It takes the hidden vector of the special `<cls>` token and predicts whether the second sentence follows the first.
3. **`BERTModel`:** The master container. It combines the `BERTEncoder` (from the previous block) with these two task-specific heads into a single trainable `nn.Module`.

---

#### **2. Detailed Theoretical Breakdown**

**A. Masked Language Modeling (MLM) Head**
The MLM head is essentially a classifier.

* **Input:** Hidden state  for a masked token at position .
* **Transformation:** In the original BERT paper, the authors apply a non-linear transformation before the final classifier:

* **Output:** A probability distribution over the vocabulary  (e.g., size 30,000).


* **Note:** The weights  are often tied (shared) with the input embedding matrix to save parameters, though in this simplified implementation, they are separate.

**B. Next Sentence Prediction (NSP) Head**
The NSP head is a binary classifier trained on the `[CLS]` token.

* **The `[CLS]` Token:** BERT is designed such that the hidden state of the very first token () aggregates information from the entire input sequence.
* **Pooling:** We often apply a `Tanh` activation to this pooled representation before the classifier.
* **Output:** A binary score (IsNext vs. NotNext).

**C. Gathering Strategy**
BERT processes a batch of sequences, each of length .

* The output of the encoder is a tensor of shape `(Batch, L, Hidden)`.
* We only need to make predictions for the 15% of tokens that were masked.
* Instead of running the classifier on *all*  tokens (which is wasteful), we use index selection (`masked_X = X[batch_idx, pred_positions]`) to extract only the vectors we care about.

---

#### **3. Key Code Lines and Their Roles**

**1. Advanced Indexing for MLM**

```python
batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
masked_X = X[batch_idx, pred_positions]

```
* **Role:** This extracts the relevant hidden states.
* `pred_positions` contains the indices of the masked tokens (e.g., `[3, 8]` for row 0).
* `X` is the full sequence output.
* This advanced indexing grabs `X[0, 3]` and `X[0, 8]`, collapsing the 3D tensor into a 2D batch of features to be fed into the MLP.

**2. The MLM MLP Architecture**

```python
self.mlp = nn.Sequential(nn.Linear(...), nn.ReLU(), nn.LayerNorm(...), nn.Linear(...))

```
* **Role:** This implements the specific architecture described in the BERT paper. The `LayerNorm` here is crucial for training stability. Note that while the original paper used GELU, this implementation uses `ReLU` for simplicity (standard PyTorch didn't have GELU until later versions).

**3. The NSP Pooling Layer**

```python
self.hidden = nn.Sequential(nn.Linear(..., ...), nn.Tanh())

```
* **Role:** This is the "pooler" layer. It takes the raw `<cls>` vector and processes it. The `Tanh` activation is a specific design choice from the original BERT implementation, likely inherited from older LSTM-based classification heads.

**4. Forward Pass Logic**

```python
if pred_positions is not None:
    mlm_Y_hat = self.mlm(encoded_X, pred_positions)

```
* **Role:** This conditional logic allows the model to be flexible. During **Pretraining**, we provide `pred_positions` to compute the MLM loss. During **Fine-tuning** (downstream tasks like Sentiment Analysis), we might only care about the NSP head (or a new custom head) and can skip the MLM computation entirely.

In [4]:
class MaskLM(nn.Module):
    """MLM Head: Predicts masked tokens."""
    def __init__(self, vocab_size, num_hiddens, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        self.mlp = nn.Sequential(
            nn.Linear(num_hiddens, num_hiddens),
            nn.ReLU(),
            nn.LayerNorm(num_hiddens),
            nn.Linear(num_hiddens, vocab_size)
        )

    def forward(self, X, pred_positions):
        # Extract features of tokens to be predicted
        num_pred_positions = pred_positions.shape[1]
        pred_positions = pred_positions.reshape(-1)
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size, device=X.device)
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)
        
        # Gather the hidden states at the masked positions
        masked_X = X[batch_idx, pred_positions]
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        return mlm_Y_hat

class NextSentencePred(nn.Module):
    """NSP Head: Binary classification (IsNext vs NotNext)."""
    def __init__(self, num_hiddens, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_hiddens, 2)

    def forward(self, X):
        # X is the output of the <cls> token
        return self.output(X)

class BERTModel(nn.Module):
    """Full BERT Model."""
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
                 num_heads, num_blks, dropout, max_len=1000):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, ffn_num_hiddens,
                                   num_heads, num_blks, dropout, max_len)
        self.hidden = nn.Sequential(nn.Linear(num_hiddens, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens)
        self.nsp = NextSentencePred(num_hiddens)

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
            
        # Use <cls> token (index 0) for NSP
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat

Below Code block implements the complex logic required to generate training data for BERT's two pretraining tasks: **Next Sentence Prediction (NSP)** and **Masked Language Modeling (MLM)**. Unlike Word2Vec, which simply slides a window over text, BERT requires structured sentence pairs and a carefully orchestrated masking strategy.

#### **1. Purpose and Function**

1. **NSP Data Generation:** Creates pairs of sentences `(A, B)` where `B` is either the actual next sentence (IsNext) or a random sentence (NotNext). This teaches BERT to understand long-range discourse and relationships between segments.
2. **MLM Data Generation:** Takes a sequence of tokens and applies the "Cloze" task logic: randomly masking 15% of tokens so the model can learn bidirectional context.
3. **Special Token Management:** Inserts the critical `<cls>` (start) and `<sep>` (separator) tokens that define BERT's input structure.

---

#### **2. Detailed Theoretical Breakdown**

**A. The NSP Task (Inter-Sentence Coherence)**
Standard language models (like GPT) treat text as a continuous stream. BERT, however, is often used for tasks like Question Answering (SQuAD) or Natural Language Inference (MNLI), where understanding the relationship between *two* pieces of text is crucial.

* **Positive Example:** "The man went to the store."  "He bought a gallon of milk." (Label: IsNext)
* **Negative Example:** "The man went to the store."  "Penguins are flightless birds." (Label: NotNext)
* **Sampling:** The code ensures a 50/50 balance between these two types to prevent the model from biased learning.

**B. The MLM Task (Bidirectional Context)**
The `_replace_mlm_tokens` function implements the core innovation of BERT.

* **The Problem:** In a deep bidirectional network, word  can "see itself" in later layers if we just train it to predict the next word.
* **The Solution:** We physically remove or corrupt the input token so the model *must* rely on context.
* **The 80-10-10 Rule:**
* **80% `<mask>`:** The standard Cloze task.
* **10% Random:** Forces the model to mistrust the input embedding at any position, ensuring it always checks the context to verify if the word "makes sense".
* **10% Original:** Biases the model toward the correct word. Without this, the model might learn that "everything masked is garbage" and drift away from meaningful representations for non-masked words.



**C. Structural Formatting**
BERT inputs require a rigid structure: `[CLS] Sentence A [SEP] Sentence B [SEP]`.

* **Segment IDs:** A vector of 0s and 1s `[0, 0, ..., 0, 1, 1, ..., 1]` is generated alongside the tokens. This tells the self-attention mechanism which sentence a token belongs to, allowing it to differentiate between the "Premise" and "Hypothesis".

---

#### **3. Key Code Lines and Their Roles**

**1. Paragraph-Level Sampling**

```python
if random.random() < 0.5:
    is_next = True
else:
    next_sentence = random.choice(random.choice(paragraphs))
    is_next = False

```

* **Role:** This is the NSP logic. It flips a coin. If Heads, it takes `paragraph[i+1]` (the true next sentence). If Tails, it jumps to a random paragraph entirely to fetch a disconnected sentence. This ensures the "NotNext" examples are semantically distinct.

**2. The 80-10-10 Masking Implementation**

```python
if random.random() < 0.8:
    masked_token = '<mask>'
else:
    if random.random() < 0.5:
        masked_token = tokens[mlm_pred_position] # 10% Original
    else:
        masked_token = random.choice(vocab.idx_to_token) # 10% Random

```

* **Role:** This nested if-else block strictly enforces the probability distribution required for robust BERT training. Note that `0.5` in the `else` block corresponds to 10% because it is 50% of the remaining 20%.

**3. Special Token Exclusion**

```python
if token in ['<cls>', '<sep>']:
    continue

```

* **Role:** We must never mask the special structural tokens. If we masked `<sep>`, the model would lose track of where sentence A ends and B begins, destroying the NSP task.

**4. Filtering Short Sentences**

```python
if len(line) < 5: continue

```

* **Role:** Very short sentences ("Yes.", "No.") provide almost no context for learning. Training on them is inefficient and potentially noisy. This line acts as a simple heuristic data cleaner.

In [5]:
def get_tokens_and_segments(tokens_a, tokens_b=None):
    """Helper to add <cls>, <sep> and generate segment ids."""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments

def _read_wiki(data_dir):
    """Reads wikitext.train.tokens and splits into paragraphs/sentences."""
    file_name = os.path.join(data_dir, 'wikitext.train.tokens') 
    # NOTE: If your file is named 'wiki.train.tokens', change the line above.
    
    if not os.path.exists(file_name):
        print(f"File not found: {file_name}")
        return []

    with open(file_name, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    paragraphs = []
    for line in lines:
        line = line.strip()
        if len(line) < 5: 
            continue
        # Split by ' . ' is specific to WikiText formatting
        sentences = line.split(' . ')
        paragraph = [s.strip().lower().split() for s in sentences if len(s.strip()) > 0]
        if len(paragraph) >= 1:
            paragraphs.append(paragraph)
            
    random.shuffle(paragraphs)
    return paragraphs

def _get_next_sentence(sentence, next_sentence, paragraphs):
    """Generates NSP label and pair."""
    if random.random() < 0.5:
        is_next = True
    else:
        # Randomly select a different paragraph and sentence
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    """Process one paragraph to get NSP samples."""
    nsp_data_from_paragraph = []
    for i in range(len(paragraph) - 1):
        tokens_a, tokens_b, is_next = _get_next_sentence(
            paragraph[i], paragraph[i + 1], paragraphs)
        
        # +3 for <cls>, <sep>, <sep>
        if len(tokens_a) + len(tokens_b) + 3 > max_len:
            continue
            
        tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
        nsp_data_from_paragraph.append((tokens, segments, is_next))
    return nsp_data_from_paragraph

def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
    """Masking logic for MLM."""
    mlm_input_tokens = [token for token in tokens]
    pred_positions_and_labels = []
    random.shuffle(candidate_pred_positions)
    
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        
        masked_token = None
        # 80%: replace with <mask>
        if random.random() < 0.8:
            masked_token = '<mask>'
        else:
            # 10%: keep original
            if random.random() < 0.5:
                masked_token = tokens[mlm_pred_position]
            # 10%: replace with random word
            else:
                masked_token = random.choice(vocab.idx_to_token)
                
        mlm_input_tokens[mlm_pred_position] = masked_token
        pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))
        
    return mlm_input_tokens, pred_positions_and_labels

def _get_mlm_data_from_tokens(tokens, vocab):
    """Generate MLM inputs and labels for a token sequence."""
    candidate_pred_positions = []
    for i, token in enumerate(tokens):
        if token in ['<cls>', '<sep>']:
            continue
        candidate_pred_positions.append(i)
        
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
        
    pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [v[1] for v in pred_positions_and_labels]
    
    return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

Code block below implements the `WikiTextDataset` class, which is a PyTorch `Dataset` subclass. This class serves as the final bridge between the raw processed data (NSP pairs and MLM masking) and the model training loop. It ensures data is properly formatted, padded, and converted into tensors for batch processing.

#### **1. Purpose and Function**

1. **Orchestration:** It coordinates the entire pipeline. It calls the sentence splitting logic (`_read_wiki`), builds the vocabulary (`Vocab`), generates NSP pairs (`_get_nsp_data_from_paragraph`), and applies MLM masking (`_get_mlm_data_from_tokens`).
2. **Standardization (Padding):** Neural networks require fixed-size inputs. This class ensures every sequence is padded to `max_len`, adjusting all corresponding label and segment arrays to match.
3. **Tensor Conversion:** It converts raw Python lists (integers) into PyTorch Tensors (`torch.long` or `torch.float32`), making them ready for GPU acceleration.

---

#### **2. Detailed Theoretical Breakdown**

**A. Flattening for Vocabulary Construction**
Before we can process any examples, we need a complete vocabulary of the dataset.

* **Logic:** The dataset is initially a list of paragraphs, where each paragraph is a list of sentences. The list comprehension `[s for p in paragraphs for s in p]` flattens this nested structure into a single long list of sentences.
* **Why:** The `Vocab` class needs to see every word in the corpus to calculate frequencies and assign indices correctly.

**B. The Data Processing Pipeline**
The `__init__` method runs the full preprocessing pipeline once during initialization:

1. **NSP Generation:** It iterates through paragraphs to create `(Sentence A, Sentence B)` pairs with IsNext/NotNext labels.
2. **MLM Masking:** For each pair, it tokenizes the text into integers and applies the 80-10-10 masking rule.
3. **Result:** A list of `examples` where each example contains all the necessary inputs and labels for one training step.

**C. Padding Strategy**
The `max_len` parameter dictates the static size of the input tensors (e.g., 64 or 512).

* **Token IDs:** Padded with `<pad>`.
* **Segment IDs:** Padded with `0`.
* **MLM Positions:** Padded with `0`. This is safe because...
* **MLM Weights:** We create a weight vector `[1, 1, ..., 0, 0]`. During loss calculation, we multiply the loss by this weight vector. The `0` weights ensure that the model is not penalized for predicting the padded "fake" masks.

---

#### **3. Key Code Lines and Their Roles**

**1. Flattening Logic**

```python
sentences = [s for p in paragraphs for s in p]
self.vocab = Vocab(sentences)

```
* **Role:** This prepares the raw material for the vocabulary builder. Without flattening, the `Vocab` class (if not designed to handle nested lists) might fail or produce incorrect counts.

**2. Dynamic Padding Calculation**

```python
valid_len = len(token_ids)
pad_len = max_len - valid_len
token_ids = token_ids + [self.vocab.pad] * pad_len

```
* **Role:** This ensures structural integrity. If a sentence has 10 tokens and `max_len` is 64, it appends 54 padding tokens. This allows us to stack thousands of diverse sentences into a single rectangular matrix.

**3. MLM Weighting for Padding**

```python
mlm_weights = [1.0] * num_mlm + [0.0] * mlm_pad

```
* **Role:** This is critical for the loss function.
* We padded the `pred_positions` and `mlm_pred_label_ids` arrays to a fixed size (e.g., 10 predictions per sequence).
* If a short sentence only has 3 masked tokens, the remaining 7 slots are garbage.
* The `mlm_weights` vector tells the loss function: "Count the error for the first 3 predictions, but multiply the error for the last 7 by zero."

**4. Tensor Conversion**

```python
self.all_data.append((torch.tensor(...), ...))

```
* **Role:** By converting to tensors immediately and storing them in RAM (`self.all_data`), we save CPU time during training. When the `DataLoader` requests batch 5, the tensors are already ready to be shipped to the GPU; we don't need to convert them on the fly every millisecond.

In [6]:
class WikiTextDataset(Dataset):
    def __init__(self, paragraphs, max_len):
        # Flatten for vocab building
        sentences = [s for p in paragraphs for s in p]
        self.vocab = Vocab(sentences)
        self.max_len = max_len
        
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(
                paragraph, paragraphs, self.vocab, max_len))
        
        # Process MLM
        self.all_data = []
        for tokens, segments, is_next in examples:
            token_ids, pred_positions, mlm_pred_label_ids = _get_mlm_data_from_tokens(tokens, self.vocab)
            
            # Padding
            valid_len = len(token_ids)
            pad_len = max_len - valid_len
            token_ids = token_ids + [self.vocab.pad] * pad_len
            segments = segments + [0] * pad_len
            
            # MLM padding
            num_mlm = len(pred_positions)
            max_mlm = round(max_len * 0.15)
            mlm_pad = max_mlm - num_mlm
            pred_positions = pred_positions + [0] * mlm_pad
            mlm_weights = [1.0] * num_mlm + [0.0] * mlm_pad
            mlm_pred_label_ids = mlm_pred_label_ids + [0] * mlm_pad
            
            self.all_data.append((
                torch.tensor(token_ids, dtype=torch.long),
                torch.tensor(segments, dtype=torch.long),
                torch.tensor(valid_len, dtype=torch.long),
                torch.tensor(pred_positions, dtype=torch.long),
                torch.tensor(mlm_weights, dtype=torch.float32),
                torch.tensor(mlm_pred_label_ids, dtype=torch.long),
                torch.tensor(is_next, dtype=torch.long)
            ))

    def __getitem__(self, idx):
        return self.all_data[idx]

    def __len__(self):
        return len(self.all_data)

### **Training Loop** 

Code block below demonstates the actual training of the BERT model. It combines all the components defined previously—data loading, model architecture, loss calculation—and executes the optimization process over multiple steps. 

#### **1. Purpose and Function**

1. **Orchestration:** It loads the data into memory, initializes the model with specific hyperparameters (like hidden size and dropout), and sets up the optimizer.
2. **Optimization Loop:** It iterates through the dataset in batches. For each batch, it computes the forward pass (predictions), calculates the composite loss (MLM + NSP), and performs backpropagation to update the model weights.
3. **Visualization:** It tracks the loss over time and plots the learning curve, allowing us to verify if the model is actually learning or diverging.

---

#### **2. Detailed Theoretical Breakdown**

**A. Hyperparameters and Architecture Config**
The code initializes a "Small BERT".

* **`num_hiddens=128`:** The vector dimension . Standard BERT-Base is 768.
* **`num_heads=4`:** The number of attention heads. Standard BERT-Base is 12.
* **`num_blks=2`:** The depth of the network (layers). Standard BERT-Base is 12.
* **Theory:** We use a smaller model here so it can train on a CPU or a single GPU in a reasonable time (seconds vs days) for demonstration purposes. The underlying math remains identical to the full-scale BERT.

**B. The Optimization Step**
The core of learning happens inside the `while step < NUM_STEPS` loop.

1. **Zero Grad:** `optimizer.zero_grad()` clears old gradients.
2. **Forward:** The model processes the batch.
3. **Backward:** `total_loss.backward()` computes gradients .
4. **Step:** `optimizer.step()` updates weights .

**C. The Composite Loss Function**
BERT optimizes two objectives simultaneously:


* **NSP Loss:** Simple binary cross-entropy.
* **MLM Loss:** Weighted cross-entropy.
* The raw output of `mlm_Y_hat` is `(Batch, Max_Len, Vocab)`.
* We reshape it to `(Batch * Max_Len, Vocab)` to treat every token prediction as an independent classification problem.
* Crucially, we multiply by `mlm_weights` to ensure we only count the error for the 15% masked tokens and ignore the padding/unmasked tokens.

---

#### **3. Key Code Lines and Their Roles**

**1. MLM Loss Normalization**

```python
mlm_l = (mlm_l * mlm_weights.reshape(-1)).sum() / (mlm_weights.sum() + 1e-8)

```
* **Role:** This line is mathematically dense.
* `mlm_l`: Contains the loss for *every* token position (masked or not).
* `mlm_weights`: Is 1.0 for masked tokens, 0.0 otherwise.
* **Operation:** We zero out the loss for unmasked tokens. Then, we sum the remaining loss and divide by the *number of masked tokens* (`mlm_weights.sum()`). This gives the average loss per masked prediction, which is the correct metric to minimize.

**2. Unpacking the Batch**

```python
(token_ids, segments, valid_lens, pred_positions, ...) = [x.to(device) for x in batch]

```
* **Role:** The `DataLoader` returns a tuple of tensors on the CPU. This list comprehension efficiently moves every single tensor to the GPU (`cuda`) in one line. Without this, the model (on GPU) would try to read data from CPU RAM, causing a runtime error.

**3. Tracking Metrics**

```python
losses_mlm.append(mlm_l.item())

```
* **Role:** `.item()` detaches the loss value from the PyTorch computation graph and converts it to a standard Python float. If we just appended `mlm_l`, we would store the entire graph history for every step, leading to a massive memory leak that would crash the script.

In [7]:
BATCH_SIZE = 64
MAX_LEN = 128
NUM_STEPS = 500
DATA_DIR = './wikitext-2' # Ensure this points to where your wikitext.train.tokens is

# Load Data
print("Reading and processing data...")
paragraphs = _read_wiki(DATA_DIR)
if not paragraphs:
    print("No data found. Please check the path.")
else:
    dataset = WikiTextDataset(paragraphs, MAX_LEN)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    print(f"Vocab Size: {len(dataset.vocab)}")
    print(f"Dataset Size: {len(dataset)}")
    
    # Initialize Model
    net = BERTModel(vocab_size=len(dataset.vocab), 
                    num_hiddens=128, 
                    ffn_num_hiddens=256, 
                    num_heads=4, 
                    num_blks=2, 
                    dropout=0.1, 
                    max_len=MAX_LEN)
    
    net.to(device)
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    
    # Training Loop
    net.train()
    step = 0
    losses_mlm = []
    losses_nsp = []
    
    print(f"Starting training on {device}...")
    start_time = time.time()
    
    while step < NUM_STEPS:
        for batch in train_loader:
            if step >= NUM_STEPS: break
            
            # Unpack batch
            (token_ids, segments, valid_lens, pred_positions, 
             mlm_weights, mlm_labels, nsp_labels) = [x.to(device) for x in batch]
            
            optimizer.zero_grad()
            
            # Forward pass
            _, mlm_Y_hat, nsp_Y_hat = net(token_ids, segments, valid_lens.reshape(-1), pred_positions)
            
            # MLM Loss
            mlm_l = loss_fn(mlm_Y_hat.reshape(-1, len(dataset.vocab)), mlm_labels.reshape(-1))
            mlm_l = (mlm_l * mlm_weights.reshape(-1)).sum() / (mlm_weights.sum() + 1e-8)
            
            # NSP Loss
            nsp_l = loss_fn(nsp_Y_hat, nsp_labels)
            nsp_l = nsp_l.mean()
            
            # Total Loss
            total_loss = mlm_l + nsp_l
            total_loss.backward()
            optimizer.step()
            
            losses_mlm.append(mlm_l.item())
            losses_nsp.append(nsp_l.item())
            
            if (step + 1) % 10 == 0:
                print(f"Step {step+1}/{NUM_STEPS} | MLM Loss: {mlm_l.item():.4f} | NSP Loss: {nsp_l.item():.4f}")
            
            step += 1
            
    print(f"Training finished in {time.time() - start_time:.2f} seconds.")

    # Plotting Results
    plt.plot(losses_mlm, label='MLM Loss')
    plt.plot(losses_nsp, label='NSP Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

Reading and processing data...
File not found: ./wikitext-2\wikitext.train.tokens
No data found. Please check the path.
