In [1]:
import numpy as np
from collections import defaultdict, Counter

In [2]:
corpus = [
    "apple", "banana", "grapes", "watermelon", "orange",
    "the cat", "the dog", "a cat", "a dog", "my dog",
    "the cat sleeps", "the dog barks", "a cat eats", "a dog runs",
    "my dog sleeps", "my cat meows", "a bird flies", "the bird sings",
    "the sun shines", "the wind blows"
]

In [3]:
sentences = [["<bos>"] + s.split() + ["<eos>"] for s in corpus]

In [4]:
# Collect n-gram counts
def get_ngram_counts(sentences, n):
    counts = defaultdict(Counter)
    for sent in sentences:
        for i in range(n - 1, len(sent)):
            context = tuple(sent[i - n + 1:i])
            word = sent[i]
            counts[context][word] += 1
    return counts


# Calculate experimental probability p*
def get_prob(counts):
    prob = {}
    for context in counts:
        total = sum(counts[context].values())
        prob[context] = {w: c / total for w, c in counts[context].items()}
    return prob

In [5]:
# EM algorithm to learn interpolation weights
def em_interpolation(sentences, probs, n_max, iterations=10):
    # Number of actual word
    M = sum(
        1 for sent in sentences for tok in sent
        if tok not in ("<bos>", "<eos>")
    )
    print(f"Number of words: {M}")

    # Initialization lambda with uniform distribution
    lambdas = np.array([1 / n_max] * n_max)

    for it in range(iterations):
        # responsibility matrix for posterior probability
        q = np.zeros((M, n_max))
        token_index = 0

        # E-step: estimate posterior probability q
        for sent in sentences:
            for i in range(0, len(sent)):
                word = sent[i]
                if word in ("<bos>", "<eos>"):
                    continue
                context_probs = []
                for z in range(1, n_max + 1):
                    context = tuple(sent[max(i - z, 0):i])
                    prob = probs[z - 1].get(context, {}).get(word, 1e-10)   # avoid zero prob
                    context_probs.append(prob * lambdas[z - 1])

                # Normalize(q)
                norm = sum(context_probs)
                if norm == 0:
                    norm = 1e-10
                q[token_index] = [cp / norm for cp in context_probs]
                token_index += 1

        # M-step: update lambda
        lambdas = np.mean(q, axis=0)
        lambdas /= np.sum(lambdas)
        print(f"Iteration {it + 1}: lambdas = {lambdas}")

    return lambdas

In [7]:
# Compute counts and probabilities for 1-gram, 2-gram, and 3-gram
counts = [get_ngram_counts(sentences, n) for n in [1, 2, 3]]
probs = [get_prob(c) for c in counts]
n_max = 3

# Run EM
lambdas = em_interpolation(sentences, probs, n_max)
print(f"\nFinal interpolated weights: {lambdas}")

Number of words: 45
Iteration 1: lambdas = [0.07407407 0.51851852 0.40740741]
Iteration 2: lambdas = [0.01646091 0.55967078 0.42386831]
Iteration 3: lambdas = [0.00365798 0.56881573 0.42752629]
Iteration 4: lambdas = [0.00081288 0.57084794 0.42833918]
Iteration 5: lambdas = [1.80640941e-04 5.71299542e-01 4.28519817e-01]
Iteration 6: lambdas = [4.01424315e-05 5.71399898e-01 4.28559959e-01]
Iteration 7: lambdas = [8.92054036e-06 5.71422199e-01 4.28568880e-01]
Iteration 8: lambdas = [1.98234231e-06 5.71427155e-01 4.28570862e-01]
Iteration 9: lambdas = [4.40520515e-07 5.71428257e-01 4.28571303e-01]
Iteration 10: lambdas = [9.78934482e-08 5.71428501e-01 4.28571401e-01]

Final interpolated weights: [9.78934482e-08 5.71428501e-01 4.28571401e-01]


In [27]:
# Compute interpolated probability of a word given context
def interpolated_prob(word, context_tokens, probs, lambdas):
    total_prob = 0.0
    return total_prob

# Examples
examples = [
    {'context': ["my", "cat"], 'word': "sleeps"},
    {'context': ["my", "dog"], 'word': "sleeps"}
]

for ex in examples:
    prob_on = interpolated_prob(ex['word'], ex['context'], probs, lambdas)
    print(f"Interpolated probability: P('{ex['word']}'|'{' '.join(ex['context'])}') = {prob_on:.4f}")

Interpolated probability: P('sleeps'|'my cat') = 0.0000
Interpolated probability: P('sleeps'|'my dog') = 0.0000
