In [None]:
#@title üéß Download Narration Audio & Play Introduction
import os as _os
if not _os.path.exists("/content/narration"):
    !pip install -q gdown
    import gdown
    gdown.download(id="1adzccVf_oMTz5XoyEIMJ41XEgZuWBzI6", output="/content/narration.zip", quiet=False)
    !unzip -q /content/narration.zip -d /content/narration
    !rm /content/narration.zip
    print(f"Loaded {len(_os.listdir('/content/narration'))} narration segments")
else:
    print("Narration audio already loaded.")

from IPython.display import Audio, display
display(Audio("/content/narration/seg_01_intro.mp3"))

In [None]:
# üîß Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"‚úÖ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("‚ö†Ô∏è No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

print(f"\nüì¶ Python {sys.version.split()[0]}")
print(f"üî• PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"üé≤ Random seed set to {SEED}")

%matplotlib inline

# üöÄ BERT Architecture & Pre-training from Scratch

*Part 3 of the Vizuara series on Understanding BERT from Scratch*
*Estimated time: 75 minutes*

# ü§ñ AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** ‚Äî it has already read this entire notebook and can help with concepts, code, and exercises.

**[üëâ Open AI Teaching Assistant](https://pods.vizuara.ai/courses/understanding-bert-from-scratch/practice/3/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*


In [None]:
#@title üéß Listen: Seg 02 Why It Matters
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_02_why_it_matters.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 1. Why Does This Matter?

In the previous notebook, we built a Transformer encoder from scratch. Now we will turn it into **BERT** ‚Äî the model that changed NLP forever.

BERT's magic comes from two things:
1. A clever **input representation** that combines token, segment, and position information
2. Two simple but powerful **pre-training objectives** ‚Äî Masked Language Modeling (MLM) and Next Sentence Prediction (NSP)

In this notebook, we will:
1. Build BERT's complete **input pipeline** (token + segment + position embeddings)
2. Implement **Masked Language Modeling** with the 80-10-10 masking strategy
3. Implement **Next Sentence Prediction**
4. Train a **mini-BERT** on a small corpus and watch it learn to fill in masked words

In [None]:
# üîß Setup ‚Äî run this cell first
!pip install -q torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import random
from collections import Counter

%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
#@title üéß Listen: Seg 04 Intuition
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_04_intuition.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 2. Building Intuition

BERT needs to understand three things about every token in the input:
1. **What** is this token? (token embedding)
2. **Which sentence** does it belong to? (segment embedding)
3. **Where** does it sit in the sequence? (position embedding)

Think of it like a letter in the mail. It needs:
- The **content** (the letter itself ‚Äî token embedding)
- The **envelope** telling you which batch it belongs to (segment embedding)
- The **address** telling you where it goes (position embedding)

BERT also has two special tokens:
- **[CLS]**: Added at the start. Its final representation is used for classification tasks.
- **[SEP]**: Added between sentences and at the end.

### ü§î Think About This
Why does BERT use *learned* position embeddings instead of the sinusoidal encodings we built in the previous notebook? What are the trade-offs?

In [None]:
#@title üéß Listen: Seg 05 Math
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_05_math.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 3. The Mathematics

### Input Representation

For each token at position $i$, the input is:

$$\text{Input}_i = \text{TokenEmbed}(\text{token}_i) + \text{SegmentEmbed}(\text{segment}_i) + \text{PositionEmbed}(i)$$

Computationally: we look up three separate embedding vectors and add them element-wise. This gives us a single vector that encodes the token's identity, segment membership, and position ‚Äî all in $d_{\text{model}}$ dimensions.

### Masked Language Modeling Loss

For the masked positions, we predict the original token using cross-entropy:

$$\mathcal{L}_{\text{MLM}} = -\sum_{i \in \text{masked}} \log P(w_i \mid \mathbf{w}_{\text{context}}; \theta)$$

Computationally: for each masked position, the model outputs a probability distribution over the entire vocabulary, and we want to maximize the probability of the correct token.

### Next Sentence Prediction Loss

Binary cross-entropy on the [CLS] token's prediction:

$$\mathcal{L}_{\text{NSP}} = -\left[y \log(p) + (1-y) \log(1-p)\right]$$

Computationally: $y=1$ if sentence B actually follows sentence A, $y=0$ otherwise. We want the model to correctly predict whether two sentences are consecutive.

### Total Pre-training Loss

$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{MLM}} + \mathcal{L}_{\text{NSP}}$$

In [None]:
#@title üéß Listen: Seg 03 Setup
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_03_setup.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
#@title üéß Listen: Seg 06 Tokenizer Corpus
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_06_tokenizer_corpus.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 4. Let's Build It ‚Äî Component by Component

### 4.1 A Simple Tokenizer

Real BERT uses WordPiece tokenization, but for our mini-BERT we will use a word-level tokenizer.

In [None]:
# Small corpus for pre-training our mini-BERT
corpus_sentences = [
    "the cat sat on the mat",
    "the cat purred on the rug",
    "the dog chased the cat",
    "the dog barked at the cat",
    "she went to the bank to deposit money",
    "he walked to the bank to withdraw cash",
    "money was deposited at the bank",
    "cash was withdrawn from the bank",
    "the river bank was covered with grass",
    "the river bank had beautiful flowers",
    "he sat on the bank of the river",
    "the bank of the river was muddy",
    "she loves the beautiful flowers",
    "the flowers in the garden are beautiful",
    "he walked to the garden to see flowers",
    "the cat slept on the rug all day",
    "the dog played in the garden",
    "she deposited cash at the bank",
    "he withdrew money from the bank",
    "the mat was on the floor all day",
]

# Build vocabulary with special tokens
special_tokens = ["[PAD]", "[CLS]", "[SEP]", "[MASK]", "[UNK]"]
all_words = [w for s in corpus_sentences for w in s.split()]
word_counts = Counter(all_words)
vocab_words = sorted(word_counts.keys())
vocab = special_tokens + vocab_words

word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for w, i in word_to_id.items()}
vocab_size = len(vocab)

PAD_ID = word_to_id["[PAD]"]
CLS_ID = word_to_id["[CLS]"]
SEP_ID = word_to_id["[SEP]"]
MASK_ID = word_to_id["[MASK]"]

print(f"Vocabulary size: {vocab_size}")
print(f"Special tokens: [PAD]={PAD_ID}, [CLS]={CLS_ID}, [SEP]={SEP_ID}, [MASK]={MASK_ID}")
print(f"\nSample vocabulary: {list(word_to_id.items())[:10]}")

In [None]:
#@title üéß Listen: Seg 07 Bert Embedding
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_07_bert_embedding.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.2 BERT Input Embeddings

In [None]:
class BERTEmbedding(nn.Module):
    """
    BERT Input Embedding = Token + Segment + Position embeddings.

    All three are learned embedding tables. Their outputs are
    summed element-wise to produce the final input representation.
    """
    def __init__(self, vocab_size, d_model, max_len=128, num_segments=2, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.segment_embedding = nn.Embedding(num_segments, d_model)
        self.position_embedding = nn.Embedding(max_len, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, token_ids, segment_ids):
        """
        Args:
            token_ids: (batch, seq_len) ‚Äî token indices
            segment_ids: (batch, seq_len) ‚Äî 0 for sentence A, 1 for sentence B
        Returns:
            (batch, seq_len, d_model) ‚Äî combined embedding
        """
        seq_len = token_ids.size(1)
        position_ids = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)

        # Look up each embedding
        tok_emb = self.token_embedding(token_ids)
        seg_emb = self.segment_embedding(segment_ids)
        pos_emb = self.position_embedding(position_ids)

        # Sum them
        combined = tok_emb + seg_emb + pos_emb

        # Apply layer norm and dropout
        return self.dropout(self.layer_norm(combined))

In [None]:
# Test the embedding layer
bert_emb = BERTEmbedding(vocab_size, d_model=64)

# Encode: [CLS] the cat sat [SEP] the dog barked [SEP]
sample_tokens = torch.tensor([[CLS_ID, word_to_id["the"], word_to_id["cat"],
                                word_to_id["sat"], SEP_ID, word_to_id["the"],
                                word_to_id["dog"], word_to_id["barked"], SEP_ID]])
sample_segments = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1]])

output = bert_emb(sample_tokens, sample_segments)
print(f"Token IDs shape:   {sample_tokens.shape}")
print(f"Segment IDs shape: {sample_segments.shape}")
print(f"Embedding output:  {output.shape}")

In [None]:
#@title üéß Listen: Seg 08 Embedding Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_08_embedding_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Visualize the three embedding components
with torch.no_grad():
    tok = bert_emb.token_embedding(sample_tokens)[0].numpy()
    seg = bert_emb.segment_embedding(sample_segments)[0].numpy()
    pos_ids = torch.arange(sample_tokens.size(1)).unsqueeze(0)
    pos = bert_emb.position_embedding(pos_ids)[0].numpy()

fig, axes = plt.subplots(1, 3, figsize=(18, 4))
words_list = ["[CLS]", "the", "cat", "sat", "[SEP]", "the", "dog", "barked", "[SEP]"]

for ax, data, title, cmap in zip(
    axes,
    [tok, seg, pos],
    ["Token Embeddings", "Segment Embeddings", "Position Embeddings"],
    ["Blues", "Greens", "Oranges"]
):
    im = ax.imshow(data[:, :20].T, cmap=cmap, aspect='auto')
    ax.set_title(title, fontsize=13, fontweight='bold')
    ax.set_xticks(range(len(words_list)))
    ax.set_xticklabels(words_list, rotation=45, fontsize=9)
    ax.set_ylabel("Dimension")
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle("BERT Input = Token + Segment + Position", fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Seg 09 Mini Bert
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_09_mini_bert.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.3 Building Mini-BERT

Now let us assemble the full BERT model using the Transformer encoder from the previous notebook.

In [None]:
# Re-define the building blocks (from Notebook 02)
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, S, _ = x.shape
        Q = self.W_Q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        out, weights = scaled_dot_product_attention(Q, K, V, mask)
        out = out.transpose(1, 2).contiguous().view(B, S, self.d_model)
        return self.W_O(out), weights

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        a, w = self.attn(x, mask)
        x = self.norm1(x + self.drop(a))
        x = self.norm2(x + self.drop(self.ff(x)))
        return x, w

In [None]:
class MiniBERT(nn.Module):
    """
    A mini BERT model for educational purposes.

    Architecture:
    - BERT Embeddings (token + segment + position)
    - N Transformer encoder blocks
    - MLM head (predict masked tokens)
    - NSP head (predict if sentence B follows A)
    """
    def __init__(self, vocab_size, d_model=128, num_heads=4, d_ff=512,
                 num_layers=4, max_len=128, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Input embeddings
        self.embedding = BERTEmbedding(vocab_size, d_model, max_len, dropout=dropout)

        # Transformer encoder stack
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # MLM head: predict masked tokens
        self.mlm_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.LayerNorm(d_model),
            nn.Linear(d_model, vocab_size)
        )

        # NSP head: binary classification from [CLS]
        self.nsp_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 2)
        )

    def forward(self, token_ids, segment_ids, attention_mask=None):
        # Get embeddings
        x = self.embedding(token_ids, segment_ids)

        # Pass through encoder layers
        all_attn_weights = []
        for layer in self.encoder_layers:
            x, attn_w = layer(x, attention_mask)
            all_attn_weights.append(attn_w)

        # MLM predictions (for all positions)
        mlm_logits = self.mlm_head(x)  # (batch, seq_len, vocab_size)

        # NSP prediction (from [CLS] token ‚Äî position 0)
        cls_output = x[:, 0, :]
        nsp_logits = self.nsp_head(cls_output)  # (batch, 2)

        return mlm_logits, nsp_logits, all_attn_weights

# Create our mini-BERT
model = MiniBERT(
    vocab_size=vocab_size,
    d_model=128,
    num_heads=4,
    d_ff=512,
    num_layers=4,
    max_len=64
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Mini-BERT parameters: {total_params:,}")
print(f"(Real BERT-Base has 110,000,000 parameters)")

In [None]:
#@title üéß Listen: Seg 10 Mlm Masking
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_10_mlm_masking.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.4 Data Preparation: MLM Masking

In [None]:
def create_mlm_data(token_ids, mask_id, vocab_size, mask_prob=0.15,
                    special_token_ids=None):
    """
    Apply BERT's masking strategy to a token sequence.

    The 80-10-10 rule:
    - 80% of masked tokens ‚Üí replaced with [MASK]
    - 10% of masked tokens ‚Üí replaced with random word
    - 10% of masked tokens ‚Üí kept unchanged

    Args:
        token_ids: list of token IDs
        mask_prob: probability of masking each token (default: 15%)

    Returns:
        masked_ids: token IDs with masking applied
        mlm_labels: -100 for non-masked positions (ignored in loss),
                    original token ID for masked positions
    """
    if special_token_ids is None:
        special_token_ids = {PAD_ID, CLS_ID, SEP_ID}

    masked_ids = list(token_ids)
    mlm_labels = [-100] * len(token_ids)  # -100 = ignore in loss

    for i, token_id in enumerate(token_ids):
        if token_id in special_token_ids:
            continue  # Never mask special tokens

        if random.random() < mask_prob:
            mlm_labels[i] = token_id  # Store original for loss computation

            r = random.random()
            if r < 0.8:
                masked_ids[i] = mask_id      # 80% ‚Üí [MASK]
            elif r < 0.9:
                masked_ids[i] = random.randint(len(special_token_ids), vocab_size - 1)  # 10% ‚Üí random
            # else: 10% ‚Üí unchanged (already copied)

    return masked_ids, mlm_labels

# Demo the masking
demo_sentence = "the cat sat on the mat"
demo_ids = [CLS_ID] + [word_to_id[w] for w in demo_sentence.split()] + [SEP_ID]
demo_words = ["[CLS]"] + demo_sentence.split() + ["[SEP]"]

print("Original:    ", " ".join(demo_words))
for trial in range(3):
    masked_ids, labels = create_mlm_data(demo_ids, MASK_ID, vocab_size)
    masked_words = [id_to_word[i] for i in masked_ids]
    label_words = [id_to_word[l] if l != -100 else "‚Äî" for l in labels]
    print(f"Masked ({trial+1}):  ", " ".join(masked_words))
    print(f"Labels ({trial+1}):  ", " ".join(label_words))
    print()

In [None]:
#@title üéß Listen: Seg 11 Nsp Data
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_11_nsp_data.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### 4.5 Data Preparation: NSP

In [None]:
def create_nsp_data(corpus_sentences, word_to_id):
    """
    Create Next Sentence Prediction training pairs.

    50% of the time: sentence B actually follows sentence A (IsNext = 1)
    50% of the time: sentence B is random (NotNext = 0)
    """
    pairs = []

    for i in range(len(corpus_sentences) - 1):
        # Positive pair: actual consecutive sentences
        sent_a = corpus_sentences[i].split()
        sent_b = corpus_sentences[i + 1].split()
        pairs.append((sent_a, sent_b, 1))  # IsNext

        # Negative pair: random sentence for B
        random_idx = random.choice([j for j in range(len(corpus_sentences)) if j != i + 1])
        sent_b_random = corpus_sentences[random_idx].split()
        pairs.append((sent_a, sent_b_random, 0))  # NotNext

    return pairs

def prepare_bert_input(sent_a, sent_b, word_to_id, max_len=64):
    """
    Prepare a single BERT input from two sentences.

    Format: [CLS] sent_a [SEP] sent_b [SEP] [PAD]...
    """
    tokens = [CLS_ID] + [word_to_id.get(w, word_to_id["[UNK]"]) for w in sent_a] + [SEP_ID]
    segments = [0] * len(tokens)

    tokens += [word_to_id.get(w, word_to_id["[UNK]"]) for w in sent_b] + [SEP_ID]
    segments += [1] * (len(tokens) - len(segments))

    # Pad to max_len
    pad_len = max_len - len(tokens)
    tokens += [PAD_ID] * pad_len
    segments += [0] * pad_len

    return tokens[:max_len], segments[:max_len]

# Create NSP training data
nsp_pairs = create_nsp_data(corpus_sentences, word_to_id)
print(f"Total NSP training pairs: {len(nsp_pairs)}")
print(f"  IsNext pairs: {sum(1 for _, _, l in nsp_pairs if l == 1)}")
print(f"  NotNext pairs: {sum(1 for _, _, l in nsp_pairs if l == 0)}")
print(f"\nExample pair (IsNext):")
print(f"  A: {' '.join(nsp_pairs[0][0])}")
print(f"  B: {' '.join(nsp_pairs[0][1])}")
print(f"  Label: {'IsNext' if nsp_pairs[0][2] else 'NotNext'}")

In [None]:
#@title üéß Listen: Seg 12 Todo Training Step
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_12_todo_training_step.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 5. üîß Your Turn

### TODO: Implement the Training Step

Implement the combined MLM + NSP training step.

In [None]:
def train_step(model, optimizer, token_ids, segment_ids, mlm_labels, nsp_labels):
    """
    Perform one BERT pre-training step.

    Args:
        token_ids: (batch, seq_len) ‚Äî masked token IDs
        segment_ids: (batch, seq_len) ‚Äî segment IDs
        mlm_labels: (batch, seq_len) ‚Äî -100 for non-masked, original ID for masked
        nsp_labels: (batch,) ‚Äî 1 for IsNext, 0 for NotNext

    Returns:
        total_loss, mlm_loss, nsp_loss
    """
    model.train()
    optimizer.zero_grad()

    # Forward pass
    mlm_logits, nsp_logits, _ = model(token_ids, segment_ids)

    # ============ TODO ============
    # Step 1: Compute MLM loss using F.cross_entropy
    #         Hint: reshape mlm_logits to (batch*seq_len, vocab_size)
    #         and mlm_labels to (batch*seq_len)
    #         F.cross_entropy ignores labels == -100 automatically
    #
    # Step 2: Compute NSP loss using F.cross_entropy
    #         nsp_logits shape: (batch, 2), nsp_labels shape: (batch,)
    #
    # Step 3: Total loss = mlm_loss + nsp_loss
    # ==============================

    mlm_loss = ???  # YOUR CODE HERE
    nsp_loss = ???  # YOUR CODE HERE
    total_loss = ???  # YOUR CODE HERE

    total_loss.backward()
    optimizer.step()

    return total_loss.item(), mlm_loss.item(), nsp_loss.item()

In [None]:
#@title üéß Listen: Seg 13 Todo Training Verify
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_13_todo_training_verify.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# ‚úÖ Verification ‚Äî test with a single batch
model_test = MiniBERT(vocab_size=vocab_size, d_model=64, num_heads=2, d_ff=128, num_layers=2).to(device)
opt_test = torch.optim.Adam(model_test.parameters(), lr=1e-3)

# Create a small batch
test_tokens = torch.randint(5, vocab_size, (2, 16)).to(device)
test_segments = torch.zeros(2, 16, dtype=torch.long).to(device)
test_mlm_labels = torch.full((2, 16), -100, dtype=torch.long).to(device)
test_mlm_labels[0, 3] = 10  # One masked position
test_mlm_labels[1, 5] = 15  # One masked position
test_nsp_labels = torch.tensor([1, 0]).to(device)

total, mlm, nsp = train_step(model_test, opt_test, test_tokens, test_segments,
                              test_mlm_labels, test_nsp_labels)
assert total > 0, "‚ùå Loss should be positive"
assert mlm > 0, "‚ùå MLM loss should be positive"
assert nsp > 0, "‚ùå NSP loss should be positive"
print(f"‚úÖ Training step works! Total: {total:.3f}, MLM: {mlm:.3f}, NSP: {nsp:.3f}")

In [None]:
#@title üéß Listen: Seg 14 Todo Nsp
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_14_todo_nsp.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

### TODO: Implement the NSP Prediction Check

Implement a function that uses our trained model to check whether sentence B follows sentence A.

In [None]:
def predict_nsp(model, sent_a, sent_b, word_to_id, device, max_len=32):
    """
    Predict whether sent_b follows sent_a (Next Sentence Prediction).

    Args:
        sent_a: list of words (sentence A)
        sent_b: list of words (sentence B)

    Returns:
        is_next_prob: probability that B follows A
        prediction: "IsNext" or "NotNext"
    """
    model.eval()

    # ============ TODO ============
    # Step 1: Build token IDs: [CLS] sent_a [SEP] sent_b [SEP] [PAD]...
    # Step 2: Build segment IDs: 0 for sent_a tokens, 1 for sent_b tokens
    # Step 3: Forward pass through model to get nsp_logits
    # Step 4: Apply softmax to nsp_logits to get probabilities
    # Step 5: Return probability and prediction label
    # ==============================

    tokens, segments = ???, ???  # YOUR CODE HERE

    token_tensor = torch.tensor([tokens], dtype=torch.long).to(device)
    segment_tensor = torch.tensor([segments], dtype=torch.long).to(device)

    with torch.no_grad():
        _, nsp_logits, _ = model(token_tensor, segment_tensor)
        probs = ???  # YOUR CODE HERE: softmax over nsp_logits

    is_next_prob = ???  # YOUR CODE HERE: probability of IsNext (index 1)
    prediction = ???  # YOUR CODE HERE: "IsNext" if prob > 0.5 else "NotNext"

    return is_next_prob, prediction

In [None]:
#@title üéß Listen: Seg 15 Todo Nsp Verify
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_15_todo_nsp_verify.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# ‚úÖ Verification (run AFTER training)
prob, pred = predict_nsp(
    model,
    "the cat sat on the mat".split(),
    "the cat purred on the rug".split(),
    word_to_id, device
)
assert isinstance(prob, float), "‚ùå Should return a float probability"
assert pred in ["IsNext", "NotNext"], f"‚ùå Prediction should be 'IsNext' or 'NotNext', got '{pred}'"
print(f"‚úÖ NSP prediction works! Prob(IsNext) = {prob:.3f}, Prediction = {pred}")

In [None]:
#@title üéß Listen: Seg 16 Training
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_16_training.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 6. Training Mini-BERT

In [None]:
def create_training_batch(nsp_pairs, word_to_id, batch_size=8, max_len=32):
    """Create a batch of training examples with MLM masking and NSP labels."""
    batch = random.sample(nsp_pairs, min(batch_size, len(nsp_pairs)))

    all_tokens = []
    all_segments = []
    all_mlm_labels = []
    all_nsp_labels = []

    for sent_a, sent_b, nsp_label in batch:
        tokens, segments = prepare_bert_input(sent_a, sent_b, word_to_id, max_len)
        masked_tokens, mlm_labels = create_mlm_data(tokens, MASK_ID, vocab_size)

        all_tokens.append(masked_tokens)
        all_segments.append(segments)
        all_mlm_labels.append(mlm_labels)
        all_nsp_labels.append(nsp_label)

    return (
        torch.tensor(all_tokens, dtype=torch.long).to(device),
        torch.tensor(all_segments, dtype=torch.long).to(device),
        torch.tensor(all_mlm_labels, dtype=torch.long).to(device),
        torch.tensor(all_nsp_labels, dtype=torch.long).to(device),
    )

# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
EPOCHS = 100
BATCH_SIZE = 16

mlm_losses = []
nsp_losses = []
total_losses = []

for epoch in range(EPOCHS):
    tokens, segments, mlm_labels, nsp_labels = create_training_batch(
        nsp_pairs, word_to_id, BATCH_SIZE, max_len=32
    )
    total, mlm, nsp = train_step(model, optimizer, tokens, segments, mlm_labels, nsp_labels)

    total_losses.append(total)
    mlm_losses.append(mlm)
    nsp_losses.append(nsp)

    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}/{EPOCHS} | Total: {total:.3f} | MLM: {mlm:.3f} | NSP: {nsp:.3f}")

In [None]:
#@title üéß Listen: Seg 17 Training Curves
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_17_training_curves.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Training curves
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

axes[0].plot(total_losses, color='steelblue', alpha=0.7)
axes[0].set_title("Total Loss", fontsize=13, fontweight='bold')
axes[0].set_xlabel("Epoch")
axes[0].grid(alpha=0.3)

axes[1].plot(mlm_losses, color='coral', alpha=0.7)
axes[1].set_title("MLM Loss", fontsize=13, fontweight='bold')
axes[1].set_xlabel("Epoch")
axes[1].grid(alpha=0.3)

axes[2].plot(nsp_losses, color='forestgreen', alpha=0.7)
axes[2].set_title("NSP Loss", fontsize=13, fontweight='bold')
axes[2].set_xlabel("Epoch")
axes[2].grid(alpha=0.3)

plt.suptitle("Mini-BERT Pre-training Progress", fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
#@title üéß Listen: Seg 18 Final Predictions
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_18_final_predictions.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 7. üéØ Final Output: Mini-BERT Fills in the Blanks

In [None]:
def predict_masked_word(model, sentence, mask_position, word_to_id, id_to_word, top_k=5):
    """
    Use our trained mini-BERT to predict a masked word.

    Args:
        sentence: string like "the cat sat on the mat"
        mask_position: index of the word to mask (0-indexed, in original sentence)
        top_k: number of predictions to show
    """
    model.eval()
    words = sentence.split()
    original_word = words[mask_position]

    # Build input: [CLS] word1 word2 ... [SEP]
    token_ids = [CLS_ID] + [word_to_id.get(w, word_to_id["[UNK]"]) for w in words] + [SEP_ID]
    # Mask the target position (+1 for [CLS])
    token_ids[mask_position + 1] = MASK_ID
    segment_ids = [0] * len(token_ids)

    # Pad
    max_len = 32
    pad_len = max_len - len(token_ids)
    token_ids += [PAD_ID] * pad_len
    segment_ids += [0] * pad_len

    token_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
    segment_tensor = torch.tensor([segment_ids], dtype=torch.long).to(device)

    with torch.no_grad():
        mlm_logits, _, _ = model(token_tensor, segment_tensor)

    # Get predictions for the masked position
    masked_logits = mlm_logits[0, mask_position + 1]  # +1 for [CLS]
    probs = F.softmax(masked_logits, dim=-1)
    top_probs, top_ids = probs.topk(top_k)

    return original_word, [(id_to_word[idx.item()], prob.item()) for idx, prob in zip(top_ids, top_probs)]

# Test sentences
test_sentences = [
    ("the cat sat on the mat", 0),      # Mask "the"
    ("the cat sat on the mat", 1),      # Mask "cat"
    ("she went to the bank to deposit money", 4),  # Mask "bank"
    ("the river bank was covered with grass", 2),   # Mask "bank"
    ("the dog barked at the cat", 1),   # Mask "dog"
]

print("=" * 60)
print("üéØ Mini-BERT Masked Word Predictions")
print("=" * 60)

for sentence, mask_pos in test_sentences:
    words = sentence.split()
    masked_display = words.copy()
    masked_display[mask_pos] = "[MASK]"

    original, predictions = predict_masked_word(
        model, sentence, mask_pos, word_to_id, id_to_word
    )

    print(f"\nInput:    {' '.join(masked_display)}")
    print(f"Original: {original}")
    print(f"Top predictions:")
    for word, prob in predictions:
        marker = "‚úÖ" if word == original else "  "
        print(f"  {marker} {word:12s} ({prob:.3f})")

In [None]:
#@title üéß Listen: Seg 19 Final Viz
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_19_final_viz.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

In [None]:
# üìä Visualize predictions as bar charts
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

test_cases = [
    ("the cat sat on the mat", 1, "Predict 'cat'"),
    ("she went to the bank to deposit money", 4, "Predict 'bank' (financial)"),
    ("the river bank was covered with grass", 2, "Predict 'bank' (river)"),
]

for ax, (sentence, mask_pos, title) in zip(axes, test_cases):
    original, predictions = predict_masked_word(model, sentence, mask_pos, word_to_id, id_to_word)
    words_pred = [w for w, p in predictions[:5]]
    probs_pred = [p for w, p in predictions[:5]]
    colors = ['forestgreen' if w == original else 'steelblue' for w in words_pred]

    ax.barh(range(len(words_pred)), probs_pred, color=colors)
    ax.set_yticks(range(len(words_pred)))
    ax.set_yticklabels(words_pred, fontsize=11)
    ax.set_xlabel("Probability")
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.invert_yaxis()

plt.suptitle("Mini-BERT Predictions (green = correct)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nüéâ Congratulations! You've built and trained BERT from scratch!")
print("   Our mini-BERT learned to predict masked words using bidirectional context.")
print("   Next up: fine-tuning a real BERT model for downstream tasks.")

In [None]:
#@title üéß Listen: Seg 20 Closing
from IPython.display import Audio, display
import os as _os
_f = "/content/narration/seg_20_closing.mp3"
if _os.path.exists(_f):
    display(Audio(_f))
else:
    print("Run the first cell to download narration audio.")

## 8. Reflection and Next Steps

### ü§î Reflection Questions
1. Why does BERT mask only 15% of tokens, not 50% or 100%? What would happen if we masked too many tokens?
2. The 80-10-10 rule replaces some masked tokens with random words. Why does this help? (Hint: think about the mismatch between pre-training and fine-tuning.)
3. Later research (RoBERTa) showed that NSP does not actually help much. Why do you think that might be?

### üèÜ Optional Challenges
1. **Dynamic Masking**: Instead of masking the same positions every epoch, re-mask randomly each time (this is what RoBERTa does). Does it improve performance?
2. **Whole Word Masking**: Instead of masking individual subword tokens, mask entire words at once. Implement this variant.
3. **Scale Up**: Increase d_model to 256 and num_layers to 6. Train for longer. Does the model get noticeably better at predicting masked words?