<a href="https://colab.research.google.com/github/KavuriGreeshma2326/tm/blob/main/Modified_Kneser_Ney_Smoothing_of_n_gram_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Cell 0 - imports
from collections import Counter, defaultdict
import math
import itertools
import random

In [None]:
# Cell 1 - toy corpus (replace with your own file in Colab)
# If you have a large corpus, upload it to Colab (Files -> upload) or mount Google Drive.

toy_corpus = [
    "the quick brown fox jumps over the lazy dog",
    "the quick brown fox",
    "the lazy dog sleeps",
    "the dog barks at the fox",
    "a quick dog outpaces the fox",
    "dogs and foxes are different",
    "the quick dog and the quick fox"
]

# Example: how to load from uploaded text file (one sentence per line)
# from google.colab import files
# uploaded = files.upload()   # then read the uploaded filename into `lines`

print("Toy corpus sentences:", len(toy_corpus))
for s in toy_corpus[:5]:
    print("  ", s)

Toy corpus sentences: 7
   the quick brown fox jumps over the lazy dog
   the quick brown fox
   the lazy dog sleeps
   the dog barks at the fox
   a quick dog outpaces the fox


In [None]:
# Cell 2 - tokenize, build vocabulary with <UNK> for rare words
def simple_tokenize(s):
    return s.lower().strip().split()

def replace_rare_with_unk(sentences, min_freq=1):
    # Build raw word frequencies
    freq = Counter()
    tokenized = []
    for s in sentences:
        toks = simple_tokenize(s)
        tokenized.append(toks)
        freq.update(toks)
    vocab = set(w for w,c in freq.items() if c > min_freq)  # > min_freq keep more words
    # Always include <UNK>, <s>, </s>
    vocab.add("<UNK>")
    vocab.add("<s>")
    vocab.add("</s>")
    # replace rare words in tokenized sentences
    replaced = []
    for toks in tokenized:
        replaced.append([w if w in vocab else "<UNK>" for w in toks])
    return replaced, vocab, freq

# use the toy corpus:
tokenized, vocab, freq = replace_rare_with_unk(toy_corpus, min_freq=0)  # min_freq=0 -> keep all
print("Vocab size:", len(vocab))
print("Example tokenized:", tokenized[0])

Vocab size: 21
Example tokenized: ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']


In [None]:
# Cell 3 - build ngram counts
def build_ngram_counts(tokenized_sentences, N):
    """Return dict: counts[k] is Counter of k-grams (k=1..N) as tuples."""
    counts = {k: Counter() for k in range(1, N+1)}
    contexts = defaultdict(set)  # for continuation counts if needed
    for toks in tokenized_sentences:
        # pad with (N-1) start symbols
        padded = ["<s>"]*(N-1) + toks + ["</s>"]
        L = len(padded)
        for k in range(1, N+1):
            for i in range(L - k + 1):
                ngram = tuple(padded[i:i+k])
                counts[k][ngram] += 1
    return counts

N = 3   # example: trigram model
counts = build_ngram_counts(tokenized, N)
print("Unigram types:", len(counts[1]), "Bigram types:", len(counts[2]), "Trigram types:", len(counts[3]))
# show a few counts
print("Example bigrams:", list(counts[2].items())[:10])

Unigram types: 20 Bigram types: 33 Trigram types: 38
Example bigrams: [(('<s>', '<s>'), 7), (('<s>', 'the'), 5), (('the', 'quick'), 4), (('quick', 'brown'), 2), (('brown', 'fox'), 2), (('fox', 'jumps'), 1), (('jumps', 'over'), 1), (('over', 'the'), 1), (('the', 'lazy'), 2), (('lazy', 'dog'), 2)]


In [None]:
# Cell 4 - compute D1,D2,D3+ for a given order using Chen & Goodman formulas
def compute_D_discounts(ngram_counter):
    # n1..n4 are the counts of ngram types that occur exactly 1,2,3,4 times
    counts = list(ngram_counter.values())
    n1 = sum(1 for c in counts if c == 1)
    n2 = sum(1 for c in counts if c == 2)
    n3 = sum(1 for c in counts if c == 3)
    n4 = sum(1 for c in counts if c == 4)
    # Y = n1 / (n1 + 2*n2)
    denom = (n1 + 2*n2)
    Y = (n1 / denom) if denom > 0 else 0.0
    # fallback handling: if we don't have enough counts, use a typical small discount
    if n1 == 0 or n2 == 0:
        # safe fallback (common practice)
        D1 = D2 = D3 = 0.75
    else:
        D1 = 1.0 - 2.0 * Y * (n2 / n1) if n1 > 0 else 0.0
        D2 = 2.0 - 3.0 * Y * (n3 / n2) if n2 > 0 else D1
        D3 = 3.0 - 4.0 * Y * (n4 / n3) if n3 > 0 else D2
        # ensure non-negative and reasonable
        D1 = max(0.0, min(D1, 2.0))
        D2 = max(0.0, min(D2, 3.0))
        D3 = max(0.0, min(D3, 4.0))
    return (D1, D2, D3)

discounts = {}
for k in range(1, N+1):
    discounts[k] = compute_D_discounts(counts[k])
    print(f"Order {k}: D1,D2,D3+ = {discounts[k]}")

Order 1: D1,D2,D3+ = (0.6470588235294118, 2.0, 2.0)
Order 2: D1,D2,D3+ = (0.6571428571428571, 2.0, 2.0)
Order 3: D1,D2,D3+ = (0.8, 1.4, 3.0)


In [None]:
# Cell 5 - prepare context-level N1,N2,N3+ precomputation and a model class
def precompute_context_stats(counts, N):
    """
    For each context (length L = k-1) in order k,
    compute N1,N2,N3plus for the continuations.
    Returns: contexts_stats[k][context] = (N1,N2,N3plus, total_count)
    where k ranges 2..N (because contexts relate to n-grams of length k)
    """
    contexts_stats = {k: {} for k in range(2, N+1)}
    for k in range(2, N+1):
        ctr = counts[k]
        # group by prefix (context)
        by_context = defaultdict(list)
        for ngram, c in ctr.items():
            prefix = ngram[:-1]
            by_context[prefix].append(c)
        # compute stats
        for prefix, cont_counts in by_context.items():
            N1 = sum(1 for x in cont_counts if x == 1)
            N2 = sum(1 for x in cont_counts if x == 2)
            N3 = sum(1 for x in cont_counts if x >= 3)
            total = sum(cont_counts)
            contexts_stats[k][prefix] = (N1, N2, N3, total)
    return contexts_stats

contexts_stats = precompute_context_stats(counts, N)
print("Example context stats (some bigram contexts):")
for k in [2,3]:
    for i, (ctx, stats) in enumerate(contexts_stats[k].items()):
        print(f" order {k} context {ctx}: N1,N2,N3+,total = {stats}")
        if i>4: break

Example context stats (some bigram contexts):
 order 2 context ('<s>',): N1,N2,N3+,total = (2, 0, 2, 14)
 order 2 context ('the',): N1,N2,N3+,total = (1, 2, 1, 9)
 order 2 context ('quick',): N1,N2,N3+,total = (1, 2, 0, 5)
 order 2 context ('brown',): N1,N2,N3+,total = (0, 1, 0, 2)
 order 2 context ('fox',): N1,N2,N3+,total = (1, 0, 1, 5)
 order 2 context ('jumps',): N1,N2,N3+,total = (1, 0, 0, 1)
 order 3 context ('<s>', '<s>'): N1,N2,N3+,total = (2, 0, 1, 7)
 order 3 context ('<s>', 'the'): N1,N2,N3+,total = (2, 0, 1, 5)
 order 3 context ('the', 'quick'): N1,N2,N3+,total = (2, 1, 0, 4)
 order 3 context ('quick', 'brown'): N1,N2,N3+,total = (0, 1, 0, 2)
 order 3 context ('brown', 'fox'): N1,N2,N3+,total = (2, 0, 0, 2)
 order 3 context ('fox', 'jumps'): N1,N2,N3+,total = (1, 0, 0, 1)


In [None]:
# Cell 6 - continuation counts used for the unigram base distribution
def compute_unigram_continuation(counts):
    # For each unigram token w, continuation_count[w] = # distinct left contexts that precede w in bigrams
    bigrams = counts[2]
    left_contexts = defaultdict(set)
    for (w_prev, w), c in bigrams.items():
        left_contexts[w].add(w_prev)
    cont_count = {w: len(left_contexts[w]) for w in left_contexts}
    total_types = len(bigrams)  # total distinct bigram types
    # If some word never appears as second element, it will not be in cont_count -> 0
    return cont_count, total_types

cont_count, total_bigram_types = compute_unigram_continuation(counts)
print("Unique bigram-types (for continuation):", total_bigram_types)
print("Example continuation counts:", list(cont_count.items())[:10])

Unique bigram-types (for continuation): 33
Example continuation counts: [('<s>', 1), ('the', 5), ('quick', 2), ('brown', 1), ('fox', 3), ('jumps', 1), ('over', 1), ('lazy', 1), ('dog', 3), ('</s>', 4)]


In [None]:
# Cell 7 - probability function using recursion (top-down)
class ModKneserNey:
    def __init__(self, counts, discounts, contexts_stats, cont_count, total_bigram_types, N):
        self.counts = counts
        self.discounts = discounts
        self.contexts_stats = contexts_stats
        self.cont_count = cont_count
        self.total_bigram_types = total_bigram_types
        self.N = N

    def discount_for_count(self, order, c):
        # use D1 for c==1, D2 for c==2, D3+ for c>=3
        if c == 0:
            return 0.0
        D1, D2, D3 = self.discounts[order]
        if c == 1:
            return D1
        elif c == 2:
            return D2
        else:
            return D3

    def p_continuation_unigram(self, w):
        # base Kneser-Ney unigram = continuation_count(w) / total number of bigram types
        return (self.cont_count.get(w, 0) / self.total_bigram_types) if self.total_bigram_types>0 else 0.0

    def prob(self, w, context):
        """
        w: token (string)
        context: tuple of tokens length L where L <= N-1. If length L < N-1,
                 we still apply same recursion on smaller contexts.
        Returns: smoothed probability p(w | context)
        """
        L = len(context)
        if L == 0:
            return self.p_continuation_unigram(w)
        order = L + 1  # we are looking up ngrams of length L+1
        c_context = 0
        # total count of the context (sum of counts of full ngrams starting with context)
        # contexts_stats[order] stores this as 4th element if the context exists
        c_context = self.contexts_stats.get(order, {}).get(context, (0,0,0,0))[3]

        # count of the exact ngram
        ngram = tuple(context) + (w,)
        c_ngram = self.counts[order].get(ngram, 0)

        if c_context == 0:
            # If context never seen, backoff to shorter context
            return self.prob(w, context[1:])  # shorten left-most token
        # discounted mass numerator
        D_used = self.discount_for_count(order, c_ngram if c_ngram>0 else 0)
        numerator = max(c_ngram - D_used, 0.0) / c_context

        # compute gamma(context)
        N1, N2, N3, _ = self.contexts_stats.get(order, {}).get(context, (0,0,0,0))
        D1, D2, D3 = self.discounts[order]
        gamma = (D1 * N1 + D2 * N2 + D3 * N3) / c_context

        # recursively get lower-order probability
        lower_prob = self.prob(w, context[1:])  # drop leftmost token
        return numerator + gamma * lower_prob

# instantiate model
model = ModKneserNey(counts, discounts, contexts_stats, cont_count, total_bigram_types, N)

# quick test: get probability of some tokens
test_ctx = ("the","quick")   # trigram context (N-1 length)
print("P('fox' | 'the quick') =", model.prob("fox", test_ctx))
print("P('dog' | 'quick dog') =", model.prob("dog", ("quick","dog")) )
print("P('the' | empty) unigram continuation:", model.prob("the", ()))

P('fox' | 'the quick') = 0.16493506493506493
P('dog' | 'quick dog') = 0.047792207792207796
P('the' | empty) unigram continuation: 0.15151515151515152


In [None]:
# Cell 8 - sentence logprob and perplexity
def sentence_logprob(model, tokens, N):
    padded = ["<s>"]*(N-1) + tokens + ["</s>"]
    L = len(padded)
    total_logp = 0.0
    total_words = 0
    for i in range(N-1, L):
        context = tuple(padded[i-(N-1):i])
        w = padded[i]
        p = model.prob(w, context)
        if p <= 0:
            # extremely small probability to avoid -inf; indicates OOV/unseen pattern
            # in practice, ensure <UNK> is present in vocab and trained on
            p = 1e-12
        total_logp += math.log(p)
        total_words += 1
    return total_logp, total_words

def perplexity(model, tokenized_sentences, N):
    tot_logp = 0.0
    tot_words = 0
    for toks in tokenized_sentences:
        logp, nw = sentence_logprob(model, toks, N)
        tot_logp += logp
        tot_words += nw
    return math.exp(-tot_logp / tot_words) if tot_words>0 else float('inf')

pp = perplexity(model, tokenized, N)
print("Perplexity on toy corpus (training set):", pp)

Perplexity on toy corpus (training set): 3.346631994544345


In [None]:
# Cell 9 - next word distribution for a given context
def top_k_predictions(model, context, k=10):
    # consider vocabulary from unigram counts
    vocab = set([w[0] if isinstance(w, tuple) else w for w in [ng[0:] for ng in model.counts[1].keys()]])
    # but counts[1].keys() are tuples, so extract the unigram strings:
    vocab = set([t[0] for t in model.counts[1].keys()])
    scores = []
    for w in vocab:
        p = model.prob(w, context)
        scores.append((p,w))
    scores.sort(reverse=True)
    return scores[:k]

ctx = ("the","quick")
print("Top predictions for context", ctx)
for p,w in top_k_predictions(model, ctx, k=10):
    print(f"  {w:12}  p={p:.6f}")

Top predictions for context ('the', 'quick')
  brown         p=0.171169
  fox           p=0.164935
  dog           p=0.113506
  the           p=0.105844
  </s>          p=0.084675
  quick         p=0.042338
  and           p=0.042338
  sleeps        p=0.021169
  over          p=0.021169
  outpaces      p=0.021169


In [None]:
# Cell 10 - 10-fold cross validation splits
def k_fold_split(data, k=10, seed=42):
    random.Random(seed).shuffle(data)
    fold_size = len(data) // k
    folds = []
    for i in range(k):
        test = data[i*fold_size : (i+1)*fold_size]
        train = data[:i*fold_size] + data[(i+1)*fold_size:]
        folds.append((train, test))
    return folds

In [None]:
# Cell 11 - different discount selection strategies

def compute_discounts_count(ngram_counter):
    # base modkn-count (already implemented)
    return compute_D_discounts(ngram_counter)

def compute_discounts_extend(ngram_counter, contexts_stats):
    # like count, but discounts depend on extended contexts
    # for simplicity, use same D formula, but based on continuation counts
    counts = []
    for ctx, (N1,N2,N3,Ntotal) in contexts_stats.items():
        counts.extend([1]*N1 + [2]*N2 + [3]*N3)
    fake_counter = Counter(counts)
    return compute_D_discounts(fake_counter)

def compute_discounts_diffd(ngram_counter, contexts_stats):
    # like extend, but based on number of extended contexts with counts 1,2,3,4
    counts = []
    for ctx, (N1,N2,N3,Ntotal) in contexts_stats.items():
        counts.extend([1]*N1 + [2]*N2 + [3]*N3)
    fake_counter = Counter(counts)
    return compute_D_discounts(fake_counter)  # placeholder (diffd same as extend here)

In [None]:
# helper: tokenize test data using train vocab
def tokenize_with_vocab(sentences, vocab):
    processed = []
    for s in sentences:
        toks = simple_tokenize(s)
        toks = [w if w in vocab else "<UNK>" for w in toks]
        processed.append(toks)
    return processed

In [None]:
# Cell 12 - train and evaluate one fold
def train_and_eval(train_corpus, test_corpus, N, variant="count"):
    # tokenize with UNK handling from training set
    tokenized_train, vocab, freq = replace_rare_with_unk(train_corpus, min_freq=1)
    counts = build_ngram_counts(tokenized_train, N)
    contexts_stats = precompute_context_stats(counts, N)
    cont_count, total_bigram_types = compute_unigram_continuation(counts)

    # choose discounts
    discounts = {}
    for k in range(1, N+1):
        if variant == "count":
            discounts[k] = compute_discounts_count(counts[k])
        elif variant in ("extend","diffd"):
            discounts[k] = compute_discounts_extend(counts[k], contexts_stats.get(k,{}))
        else:
            discounts[k] = compute_discounts_count(counts[k])  # fallback

    model = ModKneserNey(counts, discounts, contexts_stats, cont_count, total_bigram_types, N)

    # tokenize test using train vocab
    test_tokenized = tokenize_with_vocab(test_corpus, vocab)
    return perplexity(model, test_tokenized, N)

In [None]:
# Cell 13 - run experiments
def run_experiments(corpus, N_values=[3,4,5], k=10):
    folds = k_fold_split(corpus, k=k)
    results = {variant: {N: [] for N in N_values} for variant in ["count","extend","diffd"]}

    for train, test in folds:
        for variant in results.keys():
            for N in N_values:
                pp = train_and_eval(train, test, N, variant=variant)
                results[variant][N].append(pp)
    return results

# Example run on toy corpus (small, just demo)
exp_results = run_experiments(toy_corpus, N_values=[3], k=3)
print(exp_results)

{'count': {3: [5.712232000720976, 8.724294061653397, 4.685855580346316]}, 'extend': {3: [5.674508613012864, 7.252088517311338, 5.7408763207223705]}, 'diffd': {3: [5.674508613012864, 7.252088517311338, 5.7408763207223705]}}


In [None]:
# Cell 14 - pretty table output
import numpy as np
from tabulate import tabulate

def summarize_results(results):
    rows = []
    for variant, N_dict in results.items():
        row = [variant]
        for N, vals in N_dict.items():
            arr = np.array(vals)
            mean, std = arr.mean(), arr.std()
            row.append(f"{mean:.3f} ({std:.3f})")
        rows.append(row)
    headers = ["Model"] + [f"N={N}" for N in list(next(iter(results.values())).keys())]
    print(tabulate(rows, headers=headers, tablefmt="grid"))

# Example summary
summarize_results(exp_results)

+---------+---------------+
| Model   | N=3           |
| count   | 6.374 (1.714) |
+---------+---------------+
| extend  | 6.222 (0.729) |
+---------+---------------+
| diffd   | 6.222 (0.729) |
+---------+---------------+
