# Chapter 13 — Transformers (BERT + Question Answering)

**Primary reference:** *TensorFlow in Action* (Thushan Ganegedara), Chapter 13  
This notebook reproduces the main workflows from the chapter and adds a structured explanation of the ideas used in the code.

---

## 1) Summary

This chapter focuses on *Transformer-based* NLP workflows, especially when we do not train a model from scratch:

- A short recap of why **self-attention** is the core computation behind Transformers, and how **positional encoding** compensates for the lack of recurrence.
- **Fine-tuning a pretrained BERT encoder** for a downstream classification task (spam vs ham SMS messages).
- **Question answering (extractive QA)** using a pretrained Transformer (DistilBERT) with the SQuAD v1 dataset: predicting the *start* and *end* token positions of the answer span.

In practice, the chapter highlights a modern pattern:

1. Download a pretrained Transformer + matching tokenizer
2. Format inputs the way the model expects
3. Add a small task-specific head (classification head or span heads)
4. Fine-tune with a relatively small labeled dataset

---

## 2) Transformers recap (core concepts used in this chapter)

### 2.1 Self-attention in one paragraph

Given token representations \(X\in\mathbb{R}^{T\times d}\), self-attention learns three projections:

- \(Q=XW_Q\) (queries)
- \(K=XW_K\) (keys)
- \(V=XW_V\) (values)

Then each token attends to every other token by

\[
\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V.
\]

The softmax weights decide “which tokens matter” when building a new representation for each position.

### 2.2 Why positional encoding exists

Self-attention alone does **not** know token order. Transformers inject order information by adding a position-dependent vector to each token embedding (either sinusoidal or learned). The book discusses the classic sinusoidal version.

### 2.3 BERT vs the original Transformer

BERT is essentially the **encoder stack** of the original Transformer architecture, pretrained on large text corpora. During fine-tuning, we reuse the encoder weights and only learn a small head (plus optionally update the encoder).

### 2.4 Extractive question answering

In extractive QA (like SQuAD), the answer is a span inside the context paragraph. A Transformer model produces token-level representations, then we add two token-wise classifiers:

- one predicts the **start token index**
- one predicts the **end token index**

The final predicted answer is the substring between those token indices.


In [1]:
# Environment check (Colab-friendly)
import os, sys, platform
import tensorflow as tf

print("python:", sys.version.split()[0])
print("platform:", platform.platform())
print("tf:", tf.__version__)
print("GPU available:", bool(tf.config.list_physical_devices("GPU")))


python: 3.12.12
platform: Linux-6.6.105+-x86_64-with-glibc2.35
tf: 2.19.0
GPU available: True


## 3) Quick sanity-check: scaled dot-product attention (toy example)

This section is not meant to be a full Transformer implementation. The goal is to verify the *shape logic* behind attention:

- Inputs: a sequence of token vectors \([batch, T, d]\)
- Outputs: same shape \([batch, T, d_v]\) after attention


In [2]:
import tensorflow as tf
import numpy as np

def scaled_dot_product_attention(q, k, v, mask=None):
    """Compute Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V"""
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scores = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(dk)  # [B, Tq, Tk]
    if mask is not None:
        scores += (mask * -1e9)
    weights = tf.nn.softmax(scores, axis=-1)
    output = tf.matmul(weights, v)  # [B, Tq, dv]
    return output, weights

# toy batch: 2 sequences, length 5, model dim 8
B, T, d = 2, 5, 8
x = tf.random.normal([B, T, d])

# simple linear projections
Wq = tf.keras.layers.Dense(d, use_bias=False)
Wk = tf.keras.layers.Dense(d, use_bias=False)
Wv = tf.keras.layers.Dense(d, use_bias=False)

q, k, v = Wq(x), Wk(x), Wv(x)
out, w = scaled_dot_product_attention(q, k, v)

print("x:", x.shape)
print("attention output:", out.shape)
print("attention weights:", w.shape)
print("weights row sums (should be 1):", tf.reduce_sum(w, axis=-1)[0].numpy())


x: (2, 5, 8)
attention output: (2, 5, 8)
attention weights: (2, 5, 5)
weights row sums (should be 1): [1.         1.         0.99999994 0.99999994 0.99999994]


## 4) Project 1 — Spam classification with a pretrained BERT encoder (TensorFlow Hub)

### 4.1 Goal

Build a spam classifier for SMS messages with minimal feature engineering by fine-tuning a pretrained BERT encoder.

### 4.2 Dataset

The book uses the **SMS Spam Collection** dataset (ham/spam labeled SMS). Each line contains a label and the message text.

### 4.3 Approach

1. Download + parse the dataset  
2. Create **balanced** validation and test sets (same number of ham and spam)  
3. Build a BERT-based model:
   - BERT preprocessing layer (tokenization + packing)
   - BERT encoder layer (produces pooled output)
   - A small classification head  
4. Train for a few epochs and evaluate on the test set


In [3]:
# If you run this on Colab, install the required libraries.
# (If you already have them, this cell is safe to re-run.)
!pip -q install -U tensorflow_hub tensorflow_text

# Aggressively uninstall existing scikit-learn and imbalanced-learn to avoid conflicts
!pip uninstall -y scikit-learn imbalanced-learn

# Specify compatible versions for scikit-learn and imbalanced-learn
# imbalanced-learn 0.13.0 is compatible with scikit-learn >= 1.6
!pip -q install scikit-learn==1.6.0 imbalanced-learn==0.13.0

Found existing installation: scikit-learn 1.6.0
Uninstalling scikit-learn-1.6.0:
  Successfully uninstalled scikit-learn-1.6.0
Found existing installation: imbalanced-learn 0.13.0
Uninstalling imbalanced-learn-0.13.0:
  Successfully uninstalled imbalanced-learn-0.13.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

In [4]:
import re
import numpy as np
import pandas as pd
import tensorflow as tf

# Download the dataset (UCI mirror). If it fails, try changing the URL to another mirror.
DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
zip_path = tf.keras.utils.get_file("smsspamcollection.zip", DATA_URL, extract=True)
# The file `SMSSpamCollection` is inside the extracted directory, which `zip_path` now points to.
# So, we should join `zip_path` directly with the filename.
txt_path = os.path.join(zip_path, "SMSSpamCollection")

texts, labels = [], []
with open(txt_path, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        parts = line.split("\t", 1)
        if len(parts) != 2:
            continue
        lab, msg = parts[0], parts[1]
        if lab == "ham":
            labels.append(0)
        elif lab == "spam":
            labels.append(1)
        else:
            continue
        texts.append(msg)

texts = np.array(texts, dtype=object)
labels = np.array(labels, dtype=np.int32)

df = pd.DataFrame({"text": texts, "label": labels})
display(df.head())
print("dataset size:", len(df))
print("class counts:\n", df["label"].value_counts())

Unnamed: 0,text,label
0,"Go until jurong point, crazy.. Available only ...",0
1,Ok lar... Joking wif u oni...,0
2,Free entry in 2 a wkly comp to win FA Cup fina...,1
3,U dun say so early hor... U c already then say...,0
4,"Nah I don't think he goes to usf, he lives aro...",0


dataset size: 5574
class counts:
 label
0    4827
1     747
Name: count, dtype: int64


### 4.4 Split data (balanced validation + test)

The chapter emphasizes evaluating on balanced subsets (same number of examples per class).  
Here, I create:

- test: `n_per_class` ham + `n_per_class` spam  
- validation: `n_per_class` ham + `n_per_class` spam (from remaining data)  
- training: everything else (can be imbalanced, then balanced for training)

For training, the book discusses using undersampling (including NearMiss).  
To keep this notebook stable across Colab environments, I implement:

- a **default random undersampling** for training (simple and reliable)
- an **optional NearMiss** path if imbalanced-learn supports extracting selected indices

You can switch between them with a flag.


In [5]:
rng = np.random.default_rng(4321)
n_per_class = 100  # same idea as the book; you can increase if you have more compute

idx_all = np.arange(len(labels))
idx_ham = idx_all[labels == 0]
idx_spam = idx_all[labels == 1]

test_idx = np.concatenate([
    rng.choice(idx_ham, n_per_class, replace=False),
    rng.choice(idx_spam, n_per_class, replace=False),
])
remaining = np.setdiff1d(idx_all, test_idx)

# validation from remaining
rem_ham = remaining[labels[remaining] == 0]
rem_spam = remaining[labels[remaining] == 1]
valid_idx = np.concatenate([
    rng.choice(rem_ham, n_per_class, replace=False),
    rng.choice(rem_spam, n_per_class, replace=False),
])
train_idx = np.setdiff1d(remaining, valid_idx)

train_texts, train_y = texts[train_idx], labels[train_idx]
valid_texts, valid_y = texts[valid_idx], labels[valid_idx]
test_texts, test_y = texts[test_idx], labels[test_idx]

print("train size:", len(train_texts), " | class counts:", np.bincount(train_y))
print("valid size:", len(valid_texts), " | class counts:", np.bincount(valid_y))
print("test  size:", len(test_texts),  " | class counts:", np.bincount(test_y))

# Optional: balance training set with undersampling
USE_NEARMISS = False

if USE_NEARMISS:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from imblearn.under_sampling import NearMiss

    vec = TfidfVectorizer(max_features=5000, ngram_range=(1,2))
    X_tfidf = vec.fit_transform(train_texts)

    nm = NearMiss()
    X_res, y_res = nm.fit_resample(X_tfidf, train_y)

    if hasattr(nm, "sample_indices_"):
        sel = nm.sample_indices_
        train_texts_bal = train_texts[sel]
        train_y_bal = train_y[sel]
        print("NearMiss selected:", len(sel), " | class counts:", np.bincount(train_y_bal))
    else:
        print("NearMiss does not expose sample indices in this environment. Falling back to random undersampling.")
        USE_NEARMISS = False

if not USE_NEARMISS:
    from imblearn.under_sampling import RandomUnderSampler
    rus = RandomUnderSampler(random_state=4321)
    # Use indices as dummy features so we can retrieve selected texts by index
    dummy_X = np.arange(len(train_texts)).reshape(-1, 1)
    X_res, y_res = rus.fit_resample(dummy_X, train_y)
    sel = X_res.flatten()
    train_texts_bal = train_texts[sel]
    train_y_bal = train_y[sel]
    print("Random undersampling:", len(train_texts_bal), " | class counts:", np.bincount(train_y_bal))

train size: 5174  | class counts: [4627  547]
valid size: 200  | class counts: [100 100]
test  size: 200  | class counts: [100 100]


Exception ignored on calling ctypes callback function: <function ThreadpoolController._find_libraries_with_dl_iterate_phdr.<locals>.match_library_callback at 0x7a0958273ce0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/threadpoolctl.py", line 1005, in match_library_callback
    self._make_controller_from_path(filepath)
  File "/usr/local/lib/python3.12/dist-packages/threadpoolctl.py", line 1187, in _make_controller_from_path
    lib_controller = controller_class(
                     ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/threadpoolctl.py", line 114, in __init__
    self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/ctypes/__init__.py", line 379, in __init__
    self._handle = _dlopen(self._name, mode)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
OSError: dlopen() error


Random undersampling: 1094  | class counts: [547 547]


### 4.5 Build tf.data pipelines

For BERT fine-tuning, it is convenient to feed raw strings and let the preprocessing layer handle tokenization and packing.

- Input: a batch of strings
- Output: a batch of labels (0/1)


In [6]:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 16

def make_text_ds(x, y, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if shuffle:
        ds = ds.shuffle(buffer_size=min(len(x), 5000), seed=4321, reshuffle_each_iteration=True)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

tr_ds = make_text_ds(train_texts_bal, train_y_bal, shuffle=True)
va_ds = make_text_ds(valid_texts, valid_y, shuffle=False)
te_ds = make_text_ds(test_texts, test_y, shuffle=False)

for batch_x, batch_y in tr_ds.take(1):
    print("batch text dtype:", batch_x.dtype, "shape:", batch_x.shape)
    print("batch labels:", batch_y[:10].numpy())


batch text dtype: <dtype: 'string'> shape: (16,)
batch labels: [1 0 1 1 0 0 1 1 1 1]


### 4.6 Define the BERT model (preprocess + encoder + classification head)

To avoid handling `vocab.txt` manually, I use the official TF Hub preprocessing module that matches the encoder:

- preprocess: `bert_en_uncased_preprocess`
- encoder: `bert_en_uncased_L-12_H-768_A-12`

The model produces a pooled representation for the whole sequence, then a small dense layer predicts spam/ham.


In [7]:
import tensorflow_hub as hub

preprocess_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
encoder_url = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4"

bert_preprocess = hub.KerasLayer(preprocess_url, name="bert_preprocess")
bert_encoder = hub.KerasLayer(encoder_url, trainable=True, name="bert_encoder")

text_inp = tf.keras.layers.Input(shape=(), dtype=tf.string, name="text")
enc_inputs = bert_preprocess(text_inp)
enc_outputs = bert_encoder(enc_inputs)

# TF Hub BERT returns a dict-like output; pooled_output corresponds to [CLS]
pooled = enc_outputs["pooled_output"]
x = tf.keras.layers.Dropout(0.1)(pooled)
logits = tf.keras.layers.Dense(1, name="classifier")(x)

spam_model = tf.keras.Model(inputs=text_inp, outputs=logits)
spam_model.summary()


RuntimeError: Op type not registered 'CaseFoldUTF8' in binary running on 800ffd28f4d9. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib (e.g. `tf.contrib.resampler`), accessing should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.

### 4.7 Train and evaluate

The chapter uses a small learning rate and trains for a few epochs.  
Fine-tuning can be slow, so I keep the default batch size moderate.

If you need a quicker run, reduce `EPOCHS` or set `bert_encoder.trainable = False` (feature extraction instead of fine-tuning).


In [None]:
EPOCHS = 3
lr = 3e-5  # common fine-tuning LR; the book uses an even smaller LR

loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
metrics = [
    tf.keras.metrics.BinaryAccuracy(name="accuracy", threshold=0.0),
    tf.keras.metrics.AUC(name="auc")
]

spam_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
    loss=loss_fn,
    metrics=metrics
)

history = spam_model.fit(tr_ds, validation_data=va_ds, epochs=EPOCHS)

print("\nTest evaluation:")
spam_model.evaluate(te_ds, verbose=1)


### 4.8 Quick inference check

I test a few hand-written messages to see if the model behaves reasonably.


In [None]:
samples = [
    "Congratulations! You have won a free ticket. Call now to claim.",
    "Are we still meeting at 7pm tonight?",
    "URGENT! Your account has been selected for a prize. Reply YES.",
    "Ok, I will be there in 10 minutes."
]

probs = tf.sigmoid(spam_model.predict(samples, verbose=0)).numpy().reshape(-1)
pred = (probs >= 0.5).astype(int)

for s, p, yhat in zip(samples, probs, pred):
    label = "spam" if yhat == 1 else "ham"
    print(f"{label:4s} | p(spam)={p:.3f} | {s}")


## 5) Project 2 — Extractive question answering with Hugging Face Transformers (SQuAD v1)

### 5.1 Goal

Given a **question** and a **context paragraph**, predict the answer span inside the context.

### 5.2 Main steps (mirroring the chapter)

1. Load SQuAD v1 with `datasets`
2. Fix known alignment issues in answer character indices
3. Tokenize (question + context) with a fast DistilBERT tokenizer
4. Convert character indices → token indices
5. Build a tf.data pipeline
6. Fine-tune `TFDistilBertForQuestionAnswering`
7. Run a small qualitative test: ask a question and decode the predicted span


In [None]:
# Hugging Face libraries used in the chapter
!pip -q install -U datasets transformers


In [None]:
from datasets import load_dataset
import numpy as np
import tensorflow as tf

dataset = load_dataset("squad")
print(dataset)

# For Colab runtime, it is usually better to start with a subset.
TRAIN_SAMPLES = 4000
TEST_SAMPLES = 800

train_subset = dataset["train"].select(range(TRAIN_SAMPLES))
test_subset = dataset["validation"].select(range(TEST_SAMPLES))

print("train subset:", len(train_subset), " | test subset:", len(test_subset))

# Inspect a sample
i = 0
print("Question:", train_subset[i]["question"])
print("Answer:", train_subset[i]["answers"]["text"][0])
print("Answer start (char):", train_subset[i]["answers"]["answer_start"][0])
print("Context snippet:", train_subset[i]["context"][:200], "...")


### 5.3 Fix alignment issues and compute answer end index

SQuAD provides `answer_start` (character index), but some records have small offsets.
The chapter fixes this by checking the substring in the context and shifting the index if needed.

I implement a correction that tries offsets of 0, -1, -2, +1, +2 (small local search).  
Then I store:

- `answer_start` (corrected)
- `answer_end` (exclusive end position)


In [None]:
def correct_indices_add_end_idx(answers, contexts):
    """Fix answer_start when it is slightly misaligned, and add answer_end."""
    n_ok, n_fix, n_fail = 0, 0, 0
    fixed = []
    for ans, ctx in zip(answers, contexts):
        gold_text = ans["text"][0]
        start = ans["answer_start"][0]
        # try small shifts around the given start
        candidates = [0, -1, -2, 1, 2]
        found = None
        for off in candidates:
            s = start + off
            e = s + len(gold_text)
            if s >= 0 and e <= len(ctx) and ctx[s:e] == gold_text:
                found = s
                break
        if found is None:
            # keep original start, but mark as failed (rare)
            found = start
            n_fail += 1
        elif found == start:
            n_ok += 1
        else:
            n_fix += 1

        fixed.append({
            "text": [gold_text],
            "answer_start": found,
            "answer_end": found + len(gold_text),
        })
    print(f"alignment ok: {n_ok} | fixed: {n_fix} | failed-to-fix: {n_fail}")
    return fixed

train_questions = list(train_subset["question"])
train_contexts  = list(train_subset["context"])
train_answers   = correct_indices_add_end_idx(list(train_subset["answers"]), train_contexts)

test_questions = list(test_subset["question"])
test_contexts  = list(test_subset["context"])
test_answers   = correct_indices_add_end_idx(list(test_subset["answers"]), test_contexts)


### 5.4 Tokenize (DistilBERT)

DistilBERT uses the same WordPiece-style tokenization family as BERT (sub-word tokens).  
The *fast* tokenizer provides the `char_to_token` helper, which is important for mapping answer character positions to token indices.


In [None]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Sanity check: how [CLS] + [SEP] are inserted
context_demo = "This is the context"
question_demo = "This is the question"
tok_demo = tokenizer(context_demo, question_demo, return_tensors="tf")
print("input_ids:", tok_demo["input_ids"].shape)
print("tokens:", tokenizer.convert_ids_to_tokens(tok_demo["input_ids"][0].numpy()))


### 5.5 Encode the subset and map character indices to token indices

Tokenization returns padded/truncated sequences.  
Next, we compute token-level start/end positions using `char_to_token`.


In [None]:
MAX_LEN = 384  # 512 is standard, but 384 is often enough and faster for Colab

train_enc = tokenizer(
    train_contexts,
    train_questions,
    truncation=True,
    padding="max_length",
    max_length=MAX_LEN,
    return_tensors="tf"
)
test_enc = tokenizer(
    test_contexts,
    test_questions,
    truncation=True,
    padding="max_length",
    max_length=MAX_LEN,
    return_tensors="tf"
)

def add_token_positions(encodings, answers, tokenizer, max_len):
    start_positions = []
    end_positions = []
    n_truncated = 0

    for i, ans in enumerate(answers):
        start_char = ans["answer_start"]
        end_char = ans["answer_end"] - 1  # inclusive end char for char_to_token

        start_tok = encodings.char_to_token(i, start_char)
        end_tok = encodings.char_to_token(i, end_char)

        # If the answer is truncated away, char_to_token returns None.
        # A stable fallback is to use 0 (the [CLS] position).
        if start_tok is None or end_tok is None:
            n_truncated += 1
            start_tok = 0
            end_tok = 0

        start_positions.append(start_tok)
        end_positions.append(end_tok)

    encodings.update({
        "start_positions": tf.convert_to_tensor(start_positions, dtype=tf.int32),
        "end_positions": tf.convert_to_tensor(end_positions, dtype=tf.int32),
    })
    print("answers truncated (mapped to [CLS]):", n_truncated)

add_token_positions(train_enc, train_answers, tokenizer, MAX_LEN)
add_token_positions(test_enc, test_answers, tokenizer, MAX_LEN)

print("train input_ids:", train_enc["input_ids"].shape)
print("train start_positions:", train_enc["start_positions"].shape)


### 5.6 tf.data pipeline

The model expects:

- inputs: `(input_ids, attention_mask)`
- outputs: `(start_positions, end_positions)`

I keep the pipeline simple with `from_tensor_slices`.


In [None]:
BATCH_SIZE_QA = 8
AUTOTUNE = tf.data.AUTOTUNE

def make_qa_ds(enc):
    x = (enc["input_ids"], enc["attention_mask"])
    y = (enc["start_positions"], enc["end_positions"])
    return tf.data.Dataset.from_tensor_slices((x, y))

full_train_ds = make_qa_ds(train_enc).shuffle(2048, seed=4321)

# split: last 10% as validation
n_train = int(0.9 * TRAIN_SAMPLES)
train_ds = full_train_ds.take(n_train).batch(BATCH_SIZE_QA).prefetch(AUTOTUNE)
valid_ds = full_train_ds.skip(n_train).batch(BATCH_SIZE_QA).prefetch(AUTOTUNE)

test_ds = make_qa_ds(test_enc).batch(BATCH_SIZE_QA).prefetch(AUTOTUNE)

for (x_ids, x_mask), (y_s, y_e) in train_ds.take(1):
    print("input_ids:", x_ids.shape, "attention_mask:", x_mask.shape)
    print("start/end:", y_s.shape, y_e.shape)


### 5.7 Define and train the QA model

The chapter uses `TFDistilBertForQuestionAnswering`.  
In TensorFlow/Keras training, it is convenient to make the model output a tuple of tensors (start_logits, end_logits).  
I wrap the Hugging Face model inside a small Keras Functional model so the outputs are plain tensors.


In [None]:
from transformers import TFDistilBertForQuestionAnswering

hf_qa = TFDistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
hf_qa.config.return_dict = False  # prefer tuple outputs where possible

def tf_wrap_model(model, max_len):
    input_ids = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="input_ids")
    attention_mask = tf.keras.layers.Input(shape=(max_len,), dtype=tf.int32, name="attention_mask")
    out = model([input_ids, attention_mask])  # (start_logits, end_logits, ...)
    start_logits, end_logits = out[0], out[1]
    return tf.keras.Model(inputs=[input_ids, attention_mask], outputs=[start_logits, end_logits])

qa_model = tf_wrap_model(hf_qa, MAX_LEN)
qa_model.summary()

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
opt = tf.keras.optimizers.Adam(learning_rate=3e-5)

qa_model.compile(
    optimizer=opt,
    loss=[loss_fn, loss_fn],
    metrics=[
        [tf.keras.metrics.SparseCategoricalAccuracy(name="start_acc")],
        [tf.keras.metrics.SparseCategoricalAccuracy(name="end_acc")]
    ]
)

EPOCHS_QA = 2
history_qa = qa_model.fit(train_ds, validation_data=valid_ds, epochs=EPOCHS_QA)


### 5.8 Evaluate and ask the model a question

Evaluation metrics here are token-index accuracies (start/end).  
A more realistic QA metric is Exact Match / F1, but the chapter focuses on the start/end heads directly.

Then I decode one sample prediction into text.


In [None]:
print("\nTest evaluation:")
qa_model.evaluate(test_ds, verbose=1)

def ask_bert(sample_input, tokenizer):
    """Decode predicted start/end token indices into a text answer."""
    start_logits, end_logits = qa_model.predict(sample_input, verbose=0)
    start_idx = int(np.argmax(start_logits[0]))
    end_idx = int(np.argmax(end_logits[0]))

    # ensure valid ordering
    if end_idx < start_idx:
        end_idx = start_idx

    input_ids = sample_input[0].numpy()[0]
    # decode span (inclusive)
    span_ids = input_ids[start_idx:end_idx+1]
    answer = tokenizer.decode(span_ids, skip_special_tokens=True).strip()
    return start_idx, end_idx, answer

i = 5
sample_q = test_questions[i]
sample_c = test_contexts[i]
gold = test_answers[i]["text"][0]

sample_input = (
    test_enc["input_ids"][i:i+1],
    test_enc["attention_mask"][i:i+1],
)

s_idx, e_idx, pred_answer = ask_bert(sample_input, tokenizer)

print("Question:", sample_q)
print("\nGold answer:", gold)
print("\nPredicted:", pred_answer)
print("\nContext snippet:", sample_c[:400], "...")


### 5.9 Saving the tokenizer and the pretrained QA model

The wrapper `qa_model` is a standard Keras model, but it contains a Hugging Face model as a layer.
To keep saving simple, I save the underlying Hugging Face model and tokenizer using `save_pretrained()`.


In [None]:
import os

save_dir = os.path.join("models", "distilbert_qa")
os.makedirs(save_dir, exist_ok=True)

tokenizer.save_pretrained(save_dir)
hf_qa.save_pretrained(save_dir)

print("Saved tokenizer + model to:", save_dir)


## 6) Takeaways

- Transformers replace recurrence with self-attention, so they can model long-range dependencies while staying highly parallelizable.
- Pretrained Transformer encoders (BERT / DistilBERT) drastically reduce the amount of task-specific training needed.
- For classification, a pooled sequence representation + a small dense head is usually enough.
- For extractive QA, the model predicts start/end positions. Even basic token-index accuracy can be informative, but qualitative checks are still necessary.
- Libraries matter: TF Hub makes BERT fine-tuning approachable in Keras; Hugging Face makes advanced Transformer workflows (like QA) accessible with minimal code.
