# Chapter 13: Transformers

## 1️⃣ Chapter Overview

In Chapter 5, we introduced the Transformer architecture. In this chapter, we dive deeper into the ecosystem that Transformers have created. We move beyond building them from scratch to leveraging **Pretrained Models**, specifically **BERT** (Bidirectional Encoder Representations from Transformers).

We will cover two major practical applications:
1.  **Spam Classification:** Using a pretrained **BERT** model from TensorFlow Hub to classify SMS messages.
2.  **Question Answering (QA):** Using **DistilBERT** (a lighter version of BERT) via the **Hugging Face Transformers** library to answer questions based on a context paragraph (SQuAD dataset).

**Key Machine Learning Concepts:**
* **Transfer Learning in NLP:** Using models pretrained on massive corpora (like Wikipedia) for downstream tasks.
* **BERT Architecture:** Masked Language Modeling (MLM) and Next Sentence Prediction (NSP).
* **Tokenization:** WordPiece tokenization and special tokens (`[CLS]`, `[SEP]`).
* **Span Prediction:** How QA models predict start and end indices of an answer.

**Practical Skills:**
* Using **TensorFlow Hub** to load pretrained models.
* Using **Hugging Face Transformers** and **Datasets** libraries.
* Handling class imbalance with undersampling techniques.
* Preprocessing data for QA tasks (aligning character indices to token indices).

## 2️⃣ Theoretical Explanation

### 2.1 BERT (Bidirectional Encoder Representations from Transformers)
BERT is essentially the **Encoder** stack of the original Transformer architecture, but deeper and trained on massive amounts of text. 

**Pretraining Tasks:**
Instead of predicting the next word (like standard Language Modeling), BERT uses two tasks:
1.  **Masked Language Modeling (MLM):** Randomly mask 15% of tokens in the input and ask the model to predict them based on the context from *both* left and right directions.
2.  **Next Sentence Prediction (NSP):** Given two sentences A and B, predict if B naturally follows A.

**Embeddings:**
BERT combines three types of embeddings:
1.  **Token Embeddings:** The ID of the word/sub-word.
2.  **Segment Embeddings:** Distinguishes between Sentence A and Sentence B.
3.  **Position Embeddings:** Learned vectors indicating the position of tokens.

### 2.2 DistilBERT
DistilBERT is a smaller, faster, cheaper, and lighter version of BERT. It is trained using **Knowledge Distillation**, where a small student model (DistilBERT) is trained to reproduce the behavior of a large teacher model (BERT). It retains 97% of BERT's performance but is 40% smaller and 60% faster.

### 2.3 Question Answering (Span Prediction)
In Extractive QA, the model does not generate text. Instead, given a Context (Paragraph) and a Question, it predicts:
* **Start Logits:** The probability of each token being the *start* of the answer.
* **End Logits:** The probability of each token being the *end* of the answer.


## 3️⃣ Setup

This chapter requires specific libraries. We need `tensorflow-text` for BERT preprocessing and `transformers` + `datasets` for the Hugging Face section.

In [None]:
!pip install -q tensorflow-text tensorflow-hub transformers datasets imbalanced-learn

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text  # Required for BERT preprocessing
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Ensure reproducibility
np.random.seed(42)
tf.random.set_seed(42)

## 4️⃣ Part 1: Spam Classification with BERT

We will build a binary classifier to detect Spam SMS messages. We will use a pretrained BERT encoder from TensorFlow Hub and attach a simple classification head.

In [None]:
# 1. Download Data
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
zip_path = tf.keras.utils.get_file("smsspamcollection.zip", origin=url, extract=True)
data_path = os.path.join(os.path.dirname(zip_path), "SMSSpamCollection")

# 2. Load Data
inputs = []
labels = []

with open(data_path, 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            label_str, text = parts
            inputs.append(text)
            labels.append(1 if label_str == 'spam' else 0)

inputs = np.array(inputs)
labels = np.array(labels)

print(f"Total samples: {len(inputs)}")
print(f"Spam count: {sum(labels)}")
print(f"Ham count: {len(labels) - sum(labels)}")

### 4.1 Handling Class Imbalance
Spam datasets are heavily imbalanced. We will use `RandomUnderSampler` to balance the dataset by reducing the number of Ham (non-spam) examples.

In [None]:
from imblearn.under_sampling import RandomUnderSampler

# Reshape inputs for sampling
inputs_reshaped = inputs.reshape(-1, 1)

# Undersample majority class
rus = RandomUnderSampler(random_state=42)
inputs_res, labels_res = rus.fit_resample(inputs_reshaped, labels)

inputs_res = inputs_res.flatten()

# Split Data
X_train, X_test, y_train, y_test = train_test_split(
    inputs_res, labels_res, test_size=0.2, random_state=42, stratify=labels_res
)

print(f"Balanced Train Size: {len(X_train)}")
print(f"Balanced Test Size: {len(X_test)}")

### 4.2 Building the BERT Model
We use TF Hub to load:
1.  **BERT Preprocessor:** Handles tokenization and packing inputs (creating `input_word_ids`, `input_mask`, `input_type_ids`).
2.  **BERT Encoder:** The actual model that outputs pooled and sequence representations.

We select the "Small BERT" to keep training fast.

In [None]:
# TF Hub URLs for BERT Preprocessor and Encoder
# Using Small BERT for speed
tfhub_handle_preprocess = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
tfhub_handle_encoder = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1"

def build_classifier_model():
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
    
    # 1. Preprocessing Layer
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    
    # 2. Encoder Layer (BERT)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    
    # 3. Classification Head
    # We use the 'pooled_output' which represents the [CLS] token embedding
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(1, activation=None, name='classifier')(net)
    
    return tf.keras.Model(text_input, net)

classifier_model = build_classifier_model()
classifier_model.summary()

### 4.3 Training
We use `BinaryCrossentropy(from_logits=True)` because our final layer has no activation.

In [None]:
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
metrics = tf.metrics.BinaryAccuracy()

# Using AdamW optimizer is standard for BERT, but standard Adam works for this demo
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)

classifier_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

history = classifier_model.fit(
    x=X_train,
    y=y_train,
    validation_data=(X_test, y_test),
    epochs=3
)

### 4.4 Inference
Let's test the model on some examples.

In [None]:
examples = [
    "Reply to this message to win a free vacation!",
    "Hey man, are we still meeting for lunch?",
    "Urgent! Your bank account has been compromised. Click here."
]

results = classifier_model.predict(examples)
results = tf.sigmoid(results).numpy()

for text, score in zip(examples, results):
    label = "SPAM" if score > 0.5 else "HAM"
    print(f"'{text}' -> {label} ({score[0]:.4f})")

## 5️⃣ Part 2: Question Answering with Hugging Face

We will now use the **Hugging Face** ecosystem to build a Question Answering system using **DistilBERT**.

In [None]:
from transformers import DistilBertTokenizerFast, TFDistilBertForQuestionAnswering
from datasets import load_dataset

# 1. Load SQuAD Dataset
# SQuAD (Stanford Question Answering Dataset) contains Context-Question-Answer triplets
dataset = load_dataset("squad")

# Inspect a sample
sample = dataset['train'][0]
print("Context:", sample['context'])
print("Question:", sample['question'])
print("Answer:", sample['answers'])

### 5.1 Preprocessing (The Tricky Part)
The dataset provides the **character start index** of the answer. However, the model works with **token indices**. We need to map character indices to token indices.

We use `DistilBertTokenizerFast` because it provides methods to map between chars and tokens.

In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def preprocess_function(examples):
    # Tokenize questions and contexts
    # truncation="only_second" truncates the context (second seq) if it's too long, not the question
    encodings = tokenizer(examples["question"], examples["context"], truncation="only_second", max_length=384, stride=128, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length")
    
    # Note: Complex mapping logic is usually required here to handle answers that span across split windows.
    # For simplicity in this demo, we will use a simplified preprocessing strategy focusing on the first answer.
    
    # Simplified for educational clarity:
    # We need to calculate start_positions and end_positions (token indices)
    
    start_positions = []
    end_positions = []

    # Iterate over batch
    for i, offsets in enumerate(encodings.pop("offset_mapping")):
        # Find the index of the original example this feature belongs to
        sample_index = encodings["overflow_to_sample_mapping"][i]
        answers = examples["answers"][sample_index]
        
        # If no answer, set indices to 0 (CLS token)
        if len(answers["answer_start"]) == 0:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Start/End character index
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Find sequence tokens (where context tokens are)
            sequence_ids = encodings.sequence_ids(i)
            
            # Find the start and end of the context in tokens
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # Check if answer is fully inside the context span
            if offsets[context_start][0] > start_char or offsets[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Map char to token
                idx = context_start
                while idx <= context_end and offsets[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offsets[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

    encodings["start_positions"] = start_positions
    encodings["end_positions"] = end_positions
    return encodings

# Apply preprocessing (Subset for speed)
tokenized_squad = dataset['train'].select(range(1000)).map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names)
tokenized_val = dataset['validation'].select(range(100)).map(preprocess_function, batched=True, remove_columns=dataset['validation'].column_names)

### 5.2 Convert to TensorFlow Dataset
Hugging Face datasets can be easily converted to `tf.data.Dataset`.

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator(return_tensors="tf")

tf_train_set = tokenized_squad.to_tf_dataset(
    columns=["attention_mask", "input_ids", "start_positions", "end_positions"],
    shuffle=True,
    batch_size=8,
    collate_fn=data_collator,
)

tf_val_set = tokenized_val.to_tf_dataset(
    columns=["attention_mask", "input_ids", "start_positions", "end_positions"],
    shuffle=False,
    batch_size=8,
    collate_fn=data_collator,
)

### 5.3 Model Training
We load `TFDistilBertForQuestionAnswering`. This model has the DistilBERT backbone plus a span classification head (two outputs: start logits and end logits).

In [None]:
model = TFDistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss, metrics=["accuracy"])

model.fit(tf_train_set, epochs=1)

### 5.4 Asking BERT a Question
Let's do inference manually to see the result.

In [None]:
def ask_question(question, context):
    inputs = tokenizer(question, context, return_tensors="tf")
    outputs = model(**inputs)

    # Get the token index with the highest score for start and end
    answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
    answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])

    # Convert tokens back to string
    predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
    return tokenizer.decode(predict_answer_tokens)

context = """
Transformers are a type of deep learning model introduced in 2017. 
They are primarily used in natural language processing tasks. 
BERT and GPT are famous examples of Transformers.
"""

q1 = "When were Transformers introduced?"
q2 = "What fields are they used in?"

print(f"Q: {q1}\nA: {ask_question(q1, context)}\n")
print(f"Q: {q2}\nA: {ask_question(q2, context)}")

## 6️⃣ Chapter Summary

* **BERT** revolutionized NLP by allowing Transfer Learning. We can take a generic pretrained model and fine-tune it on small datasets (like our 5,000 spam SMS dataset) to get high performance.
* **Architecture:** BERT uses the Encoder stack. It requires specific inputs: Token IDs, Mask, and Segment IDs. It outputs a sequence of vectors and a pooled vector (used for classification).
* **Hugging Face:** The `transformers` library abstracts away much of the complexity of loading models and tokenizers. The `datasets` library handles downloading and metrics.
* **Question Answering:** This is a span prediction task. The model predicts the start and end tokens of the answer within the context text.