<a href="https://colab.research.google.com/github/aburkov/theLMbook/blob/main/count_language_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Import required libraries
import re  # For regular expressions (text tokenization)
import requests  # For downloading the corpus
import gzip  # For decompressing the downloaded corpus
import io  # For handling byte streams
import math  # For mathematical operations (log, exp)
import random  # For random number generation
from collections import defaultdict  # For efficient dictionary operations

def set_seed(seed):
    """
    Sets random seeds for reproducibility.

    Args:
        seed (int): Seed value for the random number generator
    """
    random.seed(seed)

def download_corpus(url):
    """
    Downloads and decompresses a gzipped corpus file from the given URL.

    Args:
        url (str): URL of the gzipped corpus file

    Returns:
        str: Decoded text content of the corpus

    Raises:
        HTTPError: If the download fails
    """
    print(f"Downloading corpus from {url}...")
    response = requests.get(url)
    response.raise_for_status()  # Raises an exception for bad HTTP responses

    print("Decompressing and reading the corpus...")
    with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as f:
        corpus = f.read().decode('utf-8')

    print(f"Corpus size: {len(corpus)} characters")
    return corpus

class CountLanguageModel:
    """
    Implements an n-gram language model using count-based probability estimation.
    Supports variable context lengths up to n-grams.
    """
    def __init__(self, n):
        """
        Initialize the model with maximum n-gram length.

        Args:
            n (int): Maximum length of n-grams to use
        """
        self.n = n  # Maximum n-gram length
        self.ngram_counts = [{} for _ in range(n)]  # List of dictionaries for each n-gram length
        self.total_unigrams = 0  # Total number of tokens in training data

    def predict_next_token(self, context):
        """
        Predicts the most likely next token given a context.
        Uses backoff strategy: tries largest n-gram first, then backs off to smaller n-grams.

        Args:
            context (list): List of tokens providing context for prediction

        Returns:
            str: Most likely next token, or None if no prediction can be made
        """
        for n in range(self.n, 1, -1):  # Start with largest n-gram, back off to smaller ones
            if len(context) >= n - 1:
                context_n = tuple(context[-(n - 1):])  # Get the relevant context for this n-gram
                counts = self.ngram_counts[n - 1].get(context_n)
                if counts:
                    return max(counts.items(), key=lambda x: x[1])[0]  # Return most frequent token
        # Backoff to unigram if no larger context matches
        unigram_counts = self.ngram_counts[0].get(())
        if unigram_counts:
            return max(unigram_counts.items(), key=lambda x: x[1])[0]
        return None

    def get_probability(self, token, context):
        """
        Calculates probability of a token given a context.
        Uses backoff strategy and returns small probability for unseen events.

        Args:
            token (str): Token to calculate probability for
            context (tuple): Tuple of tokens providing context

        Returns:
            float: Probability of the token given the context
        """
        # Try each n-gram size, starting from the largest
        for n in range(self.n, 1, -1):
            # Check if we have enough context for this n-gram size
            if len(context) >= n - 1:
                # Get the most recent n-1 tokens as context
                context_n = tuple(context[-(n - 1):])
                # Look up the counts for this context
                counts = self.ngram_counts[n - 1].get(context_n)
                if counts:
                    # Calculate total occurrences of this context
                    total = sum(counts.values())
                    # Get count of the specific token following this context
                    count = counts.get(token, 0)
                    if count > 0:
                        # Return maximum likelihood estimate P(token|context) = count(context,token)/count(context)
                        return count / total

        # If no larger context matches, back off to unigram probability
        unigram_counts = self.ngram_counts[0].get(())
        if unigram_counts:
            count = unigram_counts.get(token, 0)
            if count > 0:
                # Return unigram probability P(token) = count(token)/total_tokens
                return count / self.total_unigrams

        # Return small probability for unseen events (prevents zero probabilities)
        return 1e-6

def train(model, tokens):
    """
    Trains the language model by counting n-grams in the training data.

    Args:
        model (CountLanguageModel): Model to train
        tokens (list): List of tokens from the training corpus
    """
    # Train models for each n-gram size from 1 to n
    for n in range(1, model.n + 1):
        counts = model.ngram_counts[n - 1]
        # Slide a window of size n over the corpus
        for i in range(len(tokens) - n + 1):
            # Split into context (n-1 tokens) and next token
            context = tuple(tokens[i:i + n - 1])
            next_token = tokens[i + n - 1]

            # Initialize counts dictionary for this context if needed
            if context not in counts:
                counts[context] = defaultdict(int)

            # Increment count for this context-token pair
            counts[context][next_token] = counts[context][next_token] + 1

    # Store total number of tokens for unigram probability calculations
    model.total_unigrams = len(tokens)

def generate_text(model, context, num_tokens):
    """
    Generates text by repeatedly sampling from the model.

    Args:
        model (CountLanguageModel): Trained language model
        context (list): Initial context tokens
        num_tokens (int): Number of tokens to generate

    Returns:
        str: Generated text including initial context
    """
    # Start with the provided context
    generated = list(context)

    # Generate new tokens until we reach the desired length
    while len(generated) - len(context) < num_tokens:
        # Use the last n-1 tokens as context for prediction
        next_token = model.predict_next_token(generated[-(model.n-1):])
        generated.append(next_token)

        # Stop if we've generated enough tokens AND found a period
        # This helps ensure complete sentences
        if len(generated) - len(context) >= num_tokens and next_token == '.':
            break

    # Join tokens with spaces to create readable text
    return ' '.join(generated)

def compute_perplexity(model, tokens, context_size):
    """
    Computes perplexity of the model on given tokens.

    Args:
        model (CountLanguageModel): Trained language model
        tokens (list): List of tokens to evaluate on
        context_size (int): Maximum context size to consider

    Returns:
        float: Perplexity score (lower is better)
    """
    # Handle empty token list
    if not tokens:
        return float('inf')

    # Initialize log likelihood accumulator
    total_log_likelihood = 0
    num_tokens = len(tokens)

    # Calculate probability for each token given its context
    for i in range(num_tokens):
        # Get appropriate context window, handling start of sequence
        context_start = max(0, i - context_size)
        context = tuple(tokens[context_start:i])
        token = tokens[i]

        # Get probability of this token given its context
        probability = model.get_probability(token, context)

        # Add log probability to total (using log for numerical stability)
        total_log_likelihood += math.log(probability)

    # Calculate average log likelihood
    average_log_likelihood = total_log_likelihood / num_tokens

    # Convert to perplexity: exp(-average_log_likelihood)
    # Lower perplexity indicates better model performance
    perplexity = math.exp(-average_log_likelihood)
    return perplexity

def tokenize(text):
    """
    Tokenizes text into words and periods.

    Args:
        text (str): Input text to tokenize

    Returns:
        list: List of lowercase tokens matching words or periods
    """
    return re.findall(r"\b[a-zA-Z0-9]+\b|[.]", text.lower())

def download_and_prepare_data(data_url):
    """
    Downloads and prepares training and test data.

    Args:
        data_url (str): URL of the corpus to download

    Returns:
        tuple: (training_tokens, test_tokens) split 90/10
    """
    # Download and extract the corpus
    corpus = download_corpus(data_url)

    # Convert text to tokens
    tokens = tokenize(corpus)

    # Split into training (90%) and test (10%) sets
    split_index = int(len(tokens) * 0.9)
    train_corpus = tokens[:split_index]
    test_corpus = tokens[split_index:]

    return train_corpus, test_corpus

def get_hyperparameters():
    """
    Returns model hyperparameters.

    Returns:
        int: Size of n-grams to use in the model
    """
    n = 5
    return n

# Main execution block
if __name__ == "__main__":
    # Initialize random seeds for reproducibility
    set_seed(42)
    n = get_hyperparameters()

    # Download and prepare the Brown corpus
    data_url = "https://www.thelmbook.com/data/brown"
    train_corpus, test_corpus = download_and_prepare_data(data_url)

    # Train the model and evaluate its performance
    print("Training the model...")
    model = CountLanguageModel(n)
    train(model, train_corpus)

    # Calculate and display test perplexity
    perplexity = compute_perplexity(model, test_corpus, n)
    print(f"\nPerplexity on test corpus: {perplexity:.2f}")

    # Test the model with some example contexts
    contexts = [
        "i will build a",
        "the best place to",
        "she was riding a"
    ]

    # Generate completions for each context
    for context in contexts:
        tokens = tokenize(context)
        next_token = model.predict_next_token(tokens)
        print(f"\nContext: {context}")
        print(f"Next token: {next_token}")
        print(f"Generated text: {generate_text(model, tokens, 10)}")

Downloading corpus from https://www.thelmbook.com/data/brown...
Decompressing and reading the corpus...
Corpus size: 6185606 characters
Training the model...

Perplexity on test corpus: 302.08

Context: i will build a
Next token: wall
Generated text: i will build a wall to keep the people in and added so long

Context: the best place to
Next token: live
Generated text: the best place to live in 30 per cent to get happiness for yourself

Context: she was riding a
Next token: horse
Generated text: she was riding a horse and showing a dog are very similar your aids
