In [5]:
import os

try:
    from google.colab import drive
    drive.mount('/content/drive')
    DATA_DIR = '/content/drive/MyDrive/assignment3/data'
except ImportError:
    # Not on Colab—fall back to local folder
    DATA_DIR = 'data'

os.makedirs(DATA_DIR, exist_ok=True)
print(f"🏷️ Using DATA_DIR = {DATA_DIR}")


Mounted at /content/drive
🏷️ Using DATA_DIR = /content/drive/MyDrive/assignment3/data


In [6]:
from datasets import load_from_disk
DATA_DIR   = "/content/drive/MyDrive/assignment3/data"
tokenized_ds = load_from_disk(f"{DATA_DIR}/wikitext_tokens")
train_data   = tokenized_ds["train"]
valid_data   = tokenized_ds["validation"]
test_data    = tokenized_ds["test"]

In [13]:
from collections import Counter
from nltk import ngrams
from datasets import load_dataset

# Function to clean tokens: lowercase, remove punctuation, handle special tokens
def clean_tokens(tokens):
    """
    Clean a list of tokens by lowercasing, removing punctuation, and handling special tokens.

    Args:
        tokens (list): List of string tokens (e.g., ['weston', '@-@', 'super'])

    Returns:
        list: Cleaned list of tokens (e.g., ['weston-super'])
    """
    cleaned = []
    i = 0
    while i < len(tokens):
        token = tokens[i].lower()
        if token == '@.@':
            cleaned.append('.')  # Replace '@.@' with decimal point
        elif token == '@-@':
            if i > 0 and i < len(tokens) - 1:
                prev_token = cleaned.pop()  # Get the previous token
                next_token = tokens[i + 1].lower()  # Get the next token
                merged = prev_token + '-' + next_token  # Merge with hyphen
                cleaned.append(merged)
                i += 1  # Skip the next token since it's merged
        elif token == '=' or token in [',', '.', '(', ')', ';', ':', '?', '!', '"', "'"]:
            pass  # Remove section markers ('=') and punctuation
        else:
            cleaned.append(token)
        i += 1
    return cleaned

# Load the dataset (replace with your actual dataset loading code)
train_data   = tokenized_ds["train"]
valid_data   = tokenized_ds["validation"]
test_data    = tokenized_ds["test"]
# Clean the data
cleaned_test = []



for item in test_data:
    tokens = item['tokens']
    if tokens:
        cleaned = clean_tokens(tokens)
        if cleaned:
            cleaned_test.append(cleaned)






In [12]:
from collections import Counter

unigram_freq = load_counter("unigram_freq")
bigram_freq  = load_counter("bigram_freq")
trigram_freq = load_counter("trigram_freq")

print(len(unigram_freq), len(bigram_freq), len(trigram_freq))


380063 13437868 43001383


In [14]:
# Build vocabulary from training data

vocab = set(unigram_freq.keys())
vocab_size = len(vocab)


In [16]:
import math
# Function to compute N-gram probability with Laplace smoothing
def ngram_probability(ngram, n, unigram_freq, bigram_freq, trigram_freq, vocab_size):
    """
    Compute the probability of an N-gram using Laplace smoothing.

    Args:
        ngram (tuple): The N-gram (e.g., ('the', 'sun') for bigram, ('the', 'sun', 'rises') for trigram)
        n (int): Order of N-gram (2 for bigram, 3 for trigram)
        unigram_freq (Counter): Frequency of unigrams
        bigram_freq (Counter): Frequency of bigrams
        trigram_freq (Counter): Frequency of trigrams
        vocab_size (int): Size of vocabulary

    Returns:
        float: Smoothed probability of the N-gram
    """
    if n == 2:
        prefix = ngram[:-1]  # e.g., ('the',)
        word = ngram[-1]    # e.g., 'sun'
        prefix_count = unigram_freq[prefix[0]] if prefix[0] in unigram_freq else 0
        ngram_count = bigram_freq[ngram] if ngram in bigram_freq else 0
        # Laplace smoothing: (count + 1) / (prefix_count + vocab_size)
        return (ngram_count + 1) / (prefix_count + vocab_size)
    elif n == 3:
        prefix = ngram[:-1]  # e.g., ('the', 'sun')
        word = ngram[-1]    # e.g., 'rises'
        prefix_count = bigram_freq[prefix] if prefix in bigram_freq else 0
        ngram_count = trigram_freq[ngram] if ngram in trigram_freq else 0
        return (ngram_count + 1) / (prefix_count + vocab_size)
    else:
        raise ValueError("Only bigrams (n=2) and trigrams (n=3) are supported")

# Function to compute perplexity for N-gram model
def compute_perplexity(test_data, n, unigram_freq, bigram_freq, trigram_freq, vocab_size):
    """
    Compute perplexity of an N-gram model on test data.

    Args:
        test_data (list): List of cleaned token lists
        n (int): Order of N-gram (2 for bigram, 3 for trigram)
        unigram_freq (Counter): Frequency of unigrams
        bigram_freq (Counter): Frequency of bigrams
        trigram_freq (Counter): Frequency of trigrams
        vocab_size (int): Size of vocabulary

    Returns:
        float: Perplexity score
    """
    log_prob_sum = 0.0
    total_words = 0

    for sentence in test_data:
        if len(sentence) < n:
            continue  # Skip sentences too short for N-gram
        # Extract N-grams from the sentence
        sentence_ngrams = list(ngrams(sentence, n))
        for ngram in sentence_ngrams:
            prob = ngram_probability(ngram, n, unigram_freq, bigram_freq, trigram_freq, vocab_size)
            log_prob_sum += math.log2(prob) if prob > 0 else math.log2(1e-10)  # Avoid log(0)
            total_words += 1

    if total_words == 0:
        return float('inf')

    # Perplexity = 2^(-average log probability)
    avg_log_prob = log_prob_sum / total_words
    perplexity = 2 ** (-avg_log_prob)
    return perplexity

# Compute perplexity for bigram and trigram models on test data
bigram_perplexity = compute_perplexity(cleaned_test, 2, unigram_freq, bigram_freq, trigram_freq, vocab_size)
trigram_perplexity = compute_perplexity(cleaned_test, 3, unigram_freq, bigram_freq, trigram_freq, vocab_size)

# Output results

print(f"Number of test sentences: {len(cleaned_test)}")
print(f"Vocabulary size: {vocab_size}")
print(f"Bigram Perplexity: {bigram_perplexity:.2f}")
print(f"Trigram Perplexity: {trigram_perplexity:.2f}")

# Save frequency dictionaries for further use in N-gram modeling
# Example: bigram_freq[('the', 'sun')] gives the count of "the sun"

Number of test sentences: 175450
Vocabulary size: 380063
Bigram Perplexity: 3621.65
Trigram Perplexity: 36298.54
