# Chapter 10 — Natural Language Processing with TensorFlow: Language Modeling

In the previous chapter, I treated text as an input for a *classification* problem (sentiment analysis).  
This chapter changes the objective: instead of predicting a label for a whole text, I train a model to **predict the next piece of text**.

The core task is **language modeling**:
> Given a sequence of previous tokens, predict the next token.

Language models are important because they force the network to learn structure in text:
- frequent patterns (common phrases),
- grammar-like constraints (what tends to follow what),
- and longer dependencies (topics, entities, references across sentences).

In the book, the workflow is built around a practical dataset of children’s stories (CBTest) and a GRU-based model that learns to generate the next token.  
A key design choice in this chapter is to represent text using **character n-grams** (in particular, bigrams), which keeps vocabulary size manageable.

This notebook reproduces the end-to-end process:
1) download and inspect the dataset,
2) read stories into Python,
3) convert stories → n-grams → token IDs,
4) define a `tf.data` pipeline that creates fixed-length training windows,
5) implement and train a GRU language model,
6) evaluate using **perplexity**,
7) generate new text using greedy decoding and beam search.


## 1) Summary

### 1.1 What is a language model in practice?
A language model is trained to approximate:

\[
P(w_{t} \mid w_{1}, w_{2}, \dots, w_{t-1})
\]

In a neural setting, this means:
- take a sequence of tokens,
- compress context into a hidden state (RNN/GRU/LSTM),
- output a probability distribution over the vocabulary for the next token.

### 1.2 Why character n-grams help here
If I tokenize at the word level, vocabulary can explode quickly:
- unusual spellings,
- rare names,
- punctuation and formatting artifacts.

The chapter uses **n-grams** as subword tokens. With small *n*, the token set is limited and OOV (out-of-vocabulary) becomes less frequent, because unseen words can often be constructed from known n-grams.

In this notebook I use the book’s `get_ngrams(text, n)` function that splits the string into chunks with stride `n`:
- for `n=2`, tokens are character bigrams like `"ch"`, `"ap"`, `"te"`, `"r "`, etc.

This reduces vocabulary size substantially compared with word tokens.

### 1.3 How the dataset becomes training examples
Stories are variable-length. A model needs fixed-length sequences, so I create training windows:

- take a long token ID sequence,
- create windows of length `n_seq + 1`,
- use first `n_seq` IDs as inputs,
- use the same sequence shifted by 1 as targets.

This matches the “next-token prediction” objective.

### 1.4 How I judge model quality (perplexity)
Accuracy can be misleading in language modeling because:
- predicting common tokens can inflate accuracy,
- but the probability assigned to correct tokens still matters a lot.

Perplexity is the standard metric:

\[
\text{perplexity} = \exp(\text{cross-entropy})
\]

Lower perplexity indicates the model is assigning higher probability to the correct next tokens on average.

### 1.5 How text is generated
After training, I convert the model into an inference setup that supports:
- taking an initial prompt,
- predicting the next token,
- feeding the prediction back recursively.

I implement two decoding methods from the chapter:
- **greedy decoding**: always pick the highest-probability next token,
- **beam search**: keep the top-*k* candidate sequences to reduce the chance of early greedy mistakes.


## 2) Setup

Imports, seeds, and a few helper utilities.

In [1]:
import os
import tarfile
import random
import pickle
from pathlib import Path
from collections import Counter

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

SEED = 4321
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

print("TensorFlow:", tf.__version__)


TensorFlow: 2.19.0


## 3) Download and extract the CBTest dataset

The chapter uses the CBTest dataset (a set of children’s book stories).

Download URL used in the book (tgz archive):
- `http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz`

This section:
- downloads the archive if it does not exist,
- extracts it to a local directory,
- and then lists the available files so I can pick train/valid/test files.


In [3]:
import requests

DATA_DIR = Path("data") / "lm"
DATA_DIR.mkdir(parents=True, exist_ok=True)

tgz_path = DATA_DIR / "CBTest.tgz"
cbtest_dir = DATA_DIR / "CBTest"

url = "http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz"

if not tgz_path.exists():
    print("Downloading:", url)
    r = requests.get(url, stream=True)
    r.raise_for_status()
    with open(tgz_path, "wb") as f:
        for chunk in r.iter_content(chunk_size=1024 * 1024):
            if chunk:
                f.write(chunk)
    print("Saved:", tgz_path)
else:
    print("Archive already exists:", tgz_path)

if not cbtest_dir.exists():
    print("Extracting archive...")
    with tarfile.open(tgz_path) as tarf:
        tarf.extractall(DATA_DIR)
    print("Extracted to:", cbtest_dir)
else:
    print("Extracted folder already exists:", cbtest_dir)

# List available text files
all_txt = sorted([p for p in cbtest_dir.rglob("*.txt")])
print("Number of .txt files:", len(all_txt))
for p in all_txt[:20]:
    print(" -", p.relative_to(cbtest_dir))


Downloading: http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz


HTTPError: 404 Client Error: Not Found for url: http://www.thespermwhale.com/jaseweston/babi/CBTest.tgz

## 4) Reading stories into Python

The CBTest files contain multiple stories.  
In the chapter, stories are separated by a line that begins with `_BOOK_TITLE_`.

I follow the same idea:
- maintain a list `s` for the current story lines,
- when `_BOOK_TITLE_` appears, close out the previous story (if any) and start a new one,
- join story lines into a single long string.

I also include a fallback: if `_BOOK_TITLE_` markers are not present (dataset variants can differ), the code treats the entire file as a single “story”.


In [None]:
def read_stories(path):
    stories = []
    current = []

    with open(path, "r", encoding="utf-8", errors="ignore") as f:
        for row in f:
            row = row.rstrip("\n")

            if row.startswith("_BOOK_TITLE_"):
                if len(current) > 0:
                    stories.append(" ".join(current).strip())
                current = []
                continue

            # Remove leading line numbers if present (common in bAbI/CBTest formats)
            # Example: "1 Mary went to the bathroom."
            parts = row.split(" ", 1)
            if len(parts) == 2 and parts[0].isdigit():
                row = parts[1]

            if row.strip():
                current.append(row)

    if len(current) > 0:
        stories.append(" ".join(current).strip())

    # Fallback: if no markers and nothing collected, return empty list
    return stories

def pick_split_files(cbtest_dir):
    txts = sorted([p for p in cbtest_dir.rglob("*.txt")])
    # Heuristic: pick files with train/valid/test in name if present
    train = next((p for p in txts if "train" in p.name.lower()), None)
    valid = next((p for p in txts if "valid" in p.name.lower() or "val" in p.name.lower()), None)
    test  = next((p for p in txts if "test" in p.name.lower()), None)

    # If not found, fallback to first three text files
    if train is None or valid is None or test is None:
        if len(txts) >= 3:
            train = train or txts[0]
            valid = valid or txts[1]
            test  = test  or txts[2]

    return train, valid, test

train_file, valid_file, test_file = pick_split_files(cbtest_dir)
print("Train file:", train_file)
print("Valid file:", valid_file)
print("Test file :", test_file)

train_stories = read_stories(train_file)
valid_stories = read_stories(valid_file)
test_stories  = read_stories(test_file)

print(f"Collected {len(train_stories)} stories (train)")
print(f"Collected {len(valid_stories)} stories (valid)")
print(f"Collected {len(test_stories)} stories (test)")

# Peek at one story
if len(train_stories) > 0:
    print("\nSample story snippet:")
    print(train_stories[min(10, len(train_stories)-1)][:500])


## 5) N-grams (character chunks)

The chapter defines a very simple n-gram function:

```python
def get_ngrams(text, n):
    return [text[i:i+n] for i in range(0, len(text), n)]
```

This is not a sliding window; it is a chunking operation with stride `n`.  
For `n=2`, it splits the string into non-overlapping bigrams.

I reproduce the function exactly and test it on a small example.


In [None]:
def get_ngrams(text, n):
    return [text[i:i+n] for i in range(0, len(text), n)]

test_string = "I like chocolates"
print("Original:", test_string)
for n in [1, 2, 3]:
    print(f"{n}-grams:", get_ngrams(test_string, n)[:30], "...")


### 5.1 Vocabulary comparison: words vs n-grams (quick sanity check)

This is a small measurement step to confirm the motivation:
- word vocabulary size can be very large,
- character n-gram vocabulary size is much smaller for small *n*.


In [None]:
def word_vocab_size(texts):
    vocab = set()
    for t in texts:
        for w in t.lower().split():
            vocab.add(w)
    return len(vocab)

def ngram_vocab_size(texts, n):
    vocab = set()
    for t in texts:
        for g in get_ngrams(t.lower(), n):
            vocab.add(g)
    return len(vocab)

ngrams = 2
wv = word_vocab_size(train_stories[:50])  # sample to keep this quick
ngv = ngram_vocab_size(train_stories[:50], ngrams)

print("Word vocab size (sample):", wv)
print(f"{ngrams}-gram vocab size (sample):", ngv)


## 6) Tokenization: n-grams → token IDs

The chapter uses `Tokenizer` from Keras.

A practical issue is that n-grams can contain spaces and punctuation, so default splitting (by space) can break tokens.  
To avoid that, I join tokens using a delimiter (tab) and configure the tokenizer with `split="\t"` and `filters=""`.

I also apply a minimum frequency cutoff (`MIN_FREQ=10`) similar to the chapter’s goal:
- n-grams that appear fewer than `MIN_FREQ` times are mapped to an OOV token.


In [None]:
from tensorflow.keras.preprocessing.text import Tokenizer

MIN_FREQ = 10
OOV_TOKEN = "[UNK]"
DELIM = "\t"

def stories_to_ngram_docs(stories, n):
    # Lowercase for consistency; keep punctuation as characters.
    docs = []
    for s in stories:
        grams = get_ngrams(s.lower(), n)
        docs.append(DELIM.join(grams))
    return docs

train_docs = stories_to_ngram_docs(train_stories, ngrams)
valid_docs = stories_to_ngram_docs(valid_stories, ngrams)
test_docs  = stories_to_ngram_docs(test_stories,  ngrams)

tokenizer = Tokenizer(filters="", lower=False, split=DELIM, oov_token=OOV_TOKEN)
tokenizer.fit_on_texts(train_docs)

# Build a frequency map
word_counts = tokenizer.word_counts  # token -> count

# Identify tokens that meet the minimum frequency (excluding OOV token)
kept_tokens = [tok for tok, cnt in word_counts.items() if cnt >= MIN_FREQ and tok != OOV_TOKEN]

print("Total unique n-grams:", len(word_counts))
print(f"Kept tokens (freq >= {MIN_FREQ}):", len(kept_tokens))

# Define a max vocab size based on kept tokens (plus OOV + padding)
# Tokenizer indices start at 1; index 0 is reserved for padding.
oov_id = tokenizer.word_index[OOV_TOKEN]
print("OOV token id:", oov_id)

# Convert documents to sequences, then remap rare tokens to OOV id
def docs_to_sequences_with_minfreq(docs, tokenizer, word_counts, min_freq, oov_id):
    seqs = []
    for d in docs:
        ids = tokenizer.texts_to_sequences([d])[0]
        toks = d.split(DELIM)
        # toks and ids should align
        out = []
        for tok, idx in zip(toks, ids):
            if word_counts.get(tok, 0) < min_freq:
                out.append(oov_id)
            else:
                out.append(idx)
        seqs.append(out)
    return seqs

train_seqs = docs_to_sequences_with_minfreq(train_docs, tokenizer, word_counts, MIN_FREQ, oov_id)
valid_seqs = docs_to_sequences_with_minfreq(valid_docs, tokenizer, word_counts, MIN_FREQ, oov_id)
test_seqs  = docs_to_sequences_with_minfreq(test_docs,  tokenizer, word_counts, MIN_FREQ, oov_id)

# Vocabulary size for the model (max id + 1 to include padding=0)
vocab_size = max(max(s) for s in train_seqs if len(s) > 0) + 1
print("Model vocab_size (including padding):", vocab_size)

# Show a quick example
example = train_stories[min(10, len(train_stories)-1)][:120]
eg_grams = get_ngrams(example.lower(), ngrams)
eg_doc = DELIM.join(eg_grams)
eg_ids = docs_to_sequences_with_minfreq([eg_doc], tokenizer, word_counts, MIN_FREQ, oov_id)[0]

print("\nOriginal snippet:", example)
print("n-grams:", eg_grams[:30], "...")
print("IDs   :", eg_ids[:30], "...")


## 7) `tf.data` pipeline: fixed-length windows for next-token prediction

Stories have different lengths, so I create a dataset of fixed-length training windows.

For each story sequence:
- create sliding windows of length `n_seq + 1` (where the extra 1 is the next-token target),
- split window into `(inputs, targets)` where targets are inputs shifted by 1.

This reproduces the idea shown in the chapter listing.

I keep the shift at 1 (classic next-token prediction), and I use `drop_remainder=True` so every example has the same length.


In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def get_tf_pipeline(data_seq, n_seq, batch_size=64, shift=1, shuffle=True):
    """Create a tf.data pipeline from a list of variable-length integer sequences."""
    ragged = tf.ragged.constant(data_seq, dtype=tf.int32)
    ds = tf.data.Dataset.from_tensor_slices(ragged)

    if shuffle:
        ds = ds.shuffle(buffer_size=min(len(data_seq), 1024), seed=SEED, reshuffle_each_iteration=True)

    # Convert each ragged sequence into many fixed-length windows
    ds = ds.flat_map(
        lambda x: tf.data.Dataset.from_tensor_slices(x)
            .window(n_seq + 1, shift=shift, drop_remainder=True)
            .flat_map(lambda w: w.batch(n_seq + 1, drop_remainder=True))
    )

    # Split into inputs and targets
    ds = ds.map(lambda w: (w[:-1], w[1:]), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

n_seq = 100
BATCH_SIZE = 64

train_ds = get_tf_pipeline(train_seqs, n_seq=n_seq, batch_size=BATCH_SIZE, shift=1, shuffle=True)
valid_ds = get_tf_pipeline(valid_seqs, n_seq=n_seq, batch_size=BATCH_SIZE, shift=1, shuffle=False)
test_ds  = get_tf_pipeline(test_seqs,  n_seq=n_seq, batch_size=BATCH_SIZE, shift=1, shuffle=False)

# Sanity check one batch
x0, y0 = next(iter(train_ds))
print("inputs :", x0.shape, x0.dtype)
print("targets:", y0.shape, y0.dtype)
print("First input sequence (first 20 ids):", x0[0, :20].numpy().tolist())
print("First target sequence (first 20 ids):", y0[0, :20].numpy().tolist())


### 7.1 Save hyperparameters (for reproducibility)

The chapter explicitly saves key preprocessing hyperparameters:
- n in n-grams,
- vocabulary size,
- sequence length.

I do the same so that generation notebooks or later reuse can stay consistent.


In [None]:
MODEL_DIR = Path("models")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

hyperparams = {
    "ngrams": ngrams,
    "vocab_size": vocab_size,
    "n_seq": n_seq,
    "min_freq": MIN_FREQ,
    "oov_token": OOV_TOKEN,
}

print("n_grams uses n={}".format(hyperparams["ngrams"]))
print("Vocabulary size: {}".format(hyperparams["vocab_size"]))
print("Sequence length for model: {}".format(hyperparams["n_seq"]))

with open(MODEL_DIR / "text_hyperparams.pkl", "wb") as f:
    pickle.dump(hyperparams, f)


## 8) Model: GRU-based language model

The chapter builds a GRU language model using:
- an embedding layer (learn token vectors),
- a GRU with `return_sequences=True` (predict at every time step),
- a dense head that produces vocabulary-sized logits/probabilities.

The output has shape `(batch, n_seq, vocab_size)`.

Because this is a multi-class next-token problem, I use:
- `SparseCategoricalCrossentropy`,
- and track perplexity as an additional metric.


In [None]:
import tensorflow.keras.backend as K

class PerplexityMetric(tf.keras.metrics.Mean):
    def __init__(self, name="perplexity", **kwargs):
        super().__init__(name=name, **kwargs)
        self.cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=False, reduction="none"
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        # y_true: (B, T), y_pred: (B, T, V)
        ce = self.cross_entropy(y_true, y_pred)  # (B, T) -> reduced per element by loss impl
        # Ensure we reduce over time steps
        ce = tf.reduce_mean(ce, axis=-1)  # (B,)
        return super().update_state(ce, sample_weight=sample_weight)

    def result(self):
        return tf.exp(super().result())

def build_language_model(vocab_size):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=512, input_shape=(None,)),
        tf.keras.layers.GRU(1024, return_state=False, return_sequences=True),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(vocab_size, name="final_out"),
        tf.keras.layers.Activation("softmax"),
    ])
    return model

lm = build_language_model(vocab_size)
lm.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    metrics=["accuracy", PerplexityMetric()],
)

lm.summary()


## 9) Training and evaluation

This is a compute-heavy model (embedding 512 + GRU 1024), and the dataset creates many windows.

To keep this notebook runnable in Colab:
- I compute an approximate number of training windows,
- then set `steps_per_epoch` and `validation_steps`.

If you want a stronger result, increasing epochs and steps will usually help.


In [None]:
def count_windows(seqs, n_seq, shift=1):
    total = 0
    for s in seqs:
        L = len(s)
        if L >= n_seq + 1:
            total += 1 + (L - (n_seq + 1)) // shift
    return total

train_windows = count_windows(train_seqs, n_seq, shift=1)
valid_windows = count_windows(valid_seqs, n_seq, shift=1)
test_windows  = count_windows(test_seqs,  n_seq, shift=1)

print("Approx window counts")
print("train:", train_windows)
print("valid:", valid_windows)
print("test :", test_windows)

steps_per_epoch = max(1, train_windows // BATCH_SIZE)
validation_steps = max(1, valid_windows // BATCH_SIZE)
test_steps = max(1, test_windows // BATCH_SIZE)

print("steps_per_epoch:", steps_per_epoch)
print("validation_steps:", validation_steps)
print("test_steps:", test_steps)


In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

EPOCHS = 5
ckpt_path = MODEL_DIR / "ch10_language_model_best.keras"

callbacks = [
    ModelCheckpoint(str(ckpt_path), monitor="val_perplexity", mode="min", save_best_only=True),
    EarlyStopping(monitor="val_perplexity", mode="min", patience=2, restore_best_weights=True),
]

history = lm.fit(
    train_ds,
    validation_data=valid_ds,
    epochs=EPOCHS,
    steps_per_epoch=min(steps_per_epoch, 2000),      # cap for runtime
    validation_steps=min(validation_steps, 500),
    callbacks=callbacks,
    verbose=1,
)


### 9.1 Plot learning curves

In [None]:
def plot_history(hist, keys):
    plt.figure(figsize=(10, 4))
    for k in keys:
        if k in hist.history:
            plt.plot(hist.history[k], label=k)
    plt.xlabel("Epoch")
    plt.legend()
    plt.grid(True)
    plt.show()

plot_history(history, ["loss", "val_loss"])
plot_history(history, ["perplexity", "val_perplexity"])
plot_history(history, ["accuracy", "val_accuracy"])


### 9.2 Evaluate on the test split

The test split is built from a separate CBTest file (selected automatically by filename heuristics).  
I report loss, accuracy, and perplexity.


In [None]:
test_metrics = lm.evaluate(test_ds, steps=min(test_steps, 500), verbose=1)
for name, value in zip(lm.metrics_names, test_metrics):
    print(f"{name:12s}: {value:.4f}")


## 10) Inference model (stateful decoding)

During training, the GRU processes a whole sequence and returns predictions for every time step.

For generation, I want a model that supports:
- feeding the previous hidden state,
- predicting the next token,
- returning the updated state.

So I build a small inference graph that reuses the learned weights:
- `Embedding` weights,
- `GRU` weights,
- output head weights.

This is the model I will call recursively for greedy decoding and beam search.


In [None]:
def build_inference_model_from_trained(lm_model, vocab_size):
    # Extract layers by order (as built in build_language_model)
    emb_layer = lm_model.layers[0]
    gru_layer = lm_model.layers[1]
    dense1 = lm_model.layers[2]
    dense_out = lm_model.layers[3]
    act = lm_model.layers[4]

    inp_tokens = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="inp_tokens")
    inp_state = tf.keras.layers.Input(shape=(gru_layer.units,), dtype=tf.float32, name="inp_state")

    x = emb_layer(inp_tokens)
    x, out_state = gru_layer(x, initial_state=inp_state, training=False, return_state=True)
    x = dense1(x)
    x = dense_out(x)
    probs = act(x)

    inf_model = tf.keras.Model(inputs=[inp_tokens, inp_state], outputs=[probs, out_state], name="lm_inference")
    return inf_model

lm_infer = build_inference_model_from_trained(lm, vocab_size)
lm_infer.summary()


## 11) Text generation utilities

Because tokens are character n-grams, decoding back to text means:
- convert token IDs → token strings,
- join the token strings directly.

I keep a mapping from ID → token using the tokenizer’s `index_word`.  
Padding ID 0 is ignored in decoding.


In [None]:
index_word = tokenizer.index_word  # id -> token (includes OOV token)
word_index = tokenizer.word_index  # token -> id

def ids_to_text(ids):
    parts = []
    for i in ids:
        if i == 0:
            continue
        tok = index_word.get(int(i), OOV_TOKEN)
        parts.append(tok)
    return "".join(parts)

def text_to_ids(text, n=ngrams):
    grams = get_ngrams(text.lower(), n)
    doc = DELIM.join(grams)
    ids = docs_to_sequences_with_minfreq([doc], tokenizer, word_counts, MIN_FREQ, oov_id)[0]
    return ids

# Quick sanity check: encode -> decode
prompt = "chapter i. "
ids = text_to_ids(prompt)
print("Prompt:", prompt)
print("IDs:", ids[:20])
print("Decoded (approx):", ids_to_text(ids[:20]))


## 12) Greedy decoding (next-token recursion)

Greedy decoding is straightforward:
- start with a prompt,
- run the prompt through the inference model to get an initial state,
- then repeatedly:
  - predict next token distribution,
  - pick the highest probability token,
  - feed it back as the next input.

This is fast, but it can get stuck in repetitive loops if early choices are suboptimal.


In [None]:
def greedy_generate(lm_infer, prompt_text, n_steps=200):
    ids = text_to_ids(prompt_text)
    if len(ids) == 0:
        ids = [oov_id]

    state = tf.zeros((1, 1024), dtype=tf.float32)

    # Prime the model with the whole prompt
    inp = tf.constant([ids], dtype=tf.int32)
    probs, state = lm_infer([inp, state])
    next_id = int(tf.argmax(probs[:, -1, :], axis=-1).numpy()[0])
    out_ids = ids + [next_id]

    for _ in range(n_steps - 1):
        inp = tf.constant([[out_ids[-1]]], dtype=tf.int32)
        probs, state = lm_infer([inp, state])
        next_id = int(tf.argmax(probs[:, -1, :], axis=-1).numpy()[0])
        out_ids.append(next_id)

    return ids_to_text(out_ids)

generated = greedy_generate(lm_infer, "chapter i. ", n_steps=250)
print(generated[:1500])


## 13) Beam search decoding

Beam search keeps several candidate continuations instead of only one.  
At each step, it expands each candidate by the top-*k* tokens, then keeps the best beams by log-probability.

This typically produces more coherent text than greedy decoding, especially when early steps are ambiguous.


In [None]:
def top_k_from_probs(prob_vec, k):
    # prob_vec: (V,) numpy
    idx = np.argpartition(-prob_vec, k)[:k]
    idx = idx[np.argsort(-prob_vec[idx])]
    return idx, prob_vec[idx]

def beam_search_generate(lm_infer, prompt_text, n_steps=200, beam_width=3):
    prompt_ids = text_to_ids(prompt_text)
    if len(prompt_ids) == 0:
        prompt_ids = [oov_id]

    init_state = tf.zeros((1, 1024), dtype=tf.float32)

    # Prime the model with the prompt
    probs, state = lm_infer([tf.constant([prompt_ids], dtype=tf.int32), init_state])
    last_probs = probs[:, -1, :].numpy()[0]

    top_ids, top_ps = top_k_from_probs(last_probs, beam_width)

    beams = []
    for tid, tp in zip(top_ids, top_ps):
        beams.append({
            "ids": prompt_ids + [int(tid)],
            "state": state,  # same primed state for first step
            "logp": float(np.log(tp + 1e-12)),
        })

    for _ in range(n_steps - 1):
        candidates = []
        for b in beams:
            last_id = b["ids"][-1]
            probs, new_state = lm_infer([tf.constant([[last_id]], dtype=tf.int32), b["state"]])
            p = probs[:, -1, :].numpy()[0]

            top_ids, top_ps = top_k_from_probs(p, beam_width)
            for tid, tp in zip(top_ids, top_ps):
                candidates.append({
                    "ids": b["ids"] + [int(tid)],
                    "state": new_state,
                    "logp": b["logp"] + float(np.log(tp + 1e-12)),
                })

        # Keep best beams by average log prob (length-normalized)
        candidates.sort(key=lambda d: d["logp"] / len(d["ids"]), reverse=True)
        beams = candidates[:beam_width]

    best = max(beams, key=lambda d: d["logp"] / len(d["ids"]))
    return ids_to_text(best["ids"])

beam_text = beam_search_generate(lm_infer, "chapter i. ", n_steps=250, beam_width=3)
print(beam_text[:1500])


## 14) Takeaways

- Language modeling reframes NLP as a next-token prediction task, which forces the network to learn patterns in sequences.
- Character n-grams keep vocabulary manageable and reduce OOV issues, at the cost of expressivity compared with word-level modeling.
- A `tf.data` pipeline that creates fixed-length windows is the practical bridge between raw stories and training tensors.
- Perplexity is a more informative metric than accuracy for language models because it measures probability quality.
- Greedy decoding is simple and fast but can be brittle; beam search often produces more coherent continuations.


## 15) References

- Thushan Ganegedara, *TensorFlow in Action* (Chapter 10).
- CBTest dataset archive used in the chapter.
- Keras/TensorFlow: `Tokenizer`, `tf.data`, `Embedding`, `GRU`.
