In [1]:
# Katz-path-simulator.py
from typing import Dict, Tuple, List

class KatzPathSimulator:
    def __init__(self, k:int=1):
        """
        k : cutoff for Katz. If count > k we use discounted MLE,
            otherwise we back off to lower order.
        """
        self.k = k

    def _get_count(self, counts:Dict, key):
        # safe getter: missing = 0
        return counts.get(key, 0)

    def explain_sentence(self, sentence: str,
                         trigram_counts: Dict[Tuple[str,str,str], int] = None,
                         bigram_counts: Dict[Tuple[str,str], int] = None,
                         unigram_counts: Dict[str,int] = None):
        """
        Prints a step-by-step call trace of Katz backoff decisions.
        If any counts dict is None it will be treated as empty (all zeros).
        """
        if trigram_counts is None: trigram_counts = {}
        if bigram_counts is None: bigram_counts = {}
        if unigram_counts is None: unigram_counts = {}

        tokens = sentence.split()
        n = len(tokens)
        if n == 0:
            print("Empty sentence.")
            return

        print(f"Sentence tokens: {tokens}\nCutoff k = {self.k}\n")

        # For first token, only unigram
        print(f"Predict P({tokens[0]}) -> unigram")
        self._explain_unigram(tokens[0], unigram_counts)

        if n >= 2:
            print("\nPredict P({} | {}) -> bigram".format(tokens[1], tokens[0]))
            self._explain_bigram(tokens[1], tokens[0], bigram_counts, unigram_counts)

        # from token index 2 onwards, use trigram history (h1,h2)
        for i in range(2, n):
            w = tokens[i]
            h1, h2 = tokens[i-2], tokens[i-1]
            print(f"\nPredict P({w} | {h1}, {h2}) -> trigram")
            self._explain_trigram(w, h1, h2, trigram_counts, bigram_counts, unigram_counts)

    def _explain_unigram(self, w, unigram_counts):
        c = self._get_count(unigram_counts, w)
        print(f"  Unigram count C({w}) = {c}")
        if c > self.k:
            print(f"  -> count > k ({self.k}) : use discounted unigram MLE (stop)")
        else:
            print(f"  -> count <= k ({self.k}) : unigram too rare/unseen -> use UNIFORM fallback (stop)")

    def _explain_bigram(self, w, h, bigram_counts, unigram_counts):
        c = self._get_count(bigram_counts, (h,w))
        c_h = self._get_count(unigram_counts, h)
        print(f"  Bigram count C({h},{w}) = {c}; history count C({h}) = {c_h}")
        if c > self.k and c_h > 0:
            print(f"  -> bigram count > k : use discounted bigram MLE (stop)")
        else:
            print(f"  -> bigram count <= k or history unseen -> back off to unigram:")
            self._explain_unigram(w, unigram_counts)

    def _explain_trigram(self, w, h1, h2, trigram_counts, bigram_counts, unigram_counts):
        c = self._get_count(trigram_counts, (h1,h2,w))
        c_h = self._get_count(bigram_counts, (h1,h2))
        print(f"  Trigram count C({h1},{h2},{w}) = {c}; trigram-history count C({h1},{h2}) = {c_h}")
        if c > self.k and c_h > 0:
            print(f"  -> trigram count > k : use discounted trigram MLE (stop)")
            return
        else:
            print(f"  -> trigram count <= k or trigram-history unseen -> back off to bigram:")
            self._explain_bigram(w, h2, bigram_counts, unigram_counts)


# ---------------------------
# Example runs (you can edit counts to test different paths)
# ---------------------------
if __name__ == "__main__":
    sim = KatzPathSimulator(k=1)

    sentence = "Natural Language Processing is Love"

    # Example A: all counts unknown (empty dicts) -> everything unseen
    print("=== Example A: All counts missing (simulate unseen) ===")
    sim.explain_sentence(sentence)

    # Example B: partial counts so some trigram/bigram exist
    print("\n\n=== Example B: Some counts provided ===")
    trigram_counts = {
        # Suppose only ("Language","Processing","is") seen 2 times
        ("Language","Processing","is"): 2,
        # ("Natural","Language","Processing") absent (0)
    }
    bigram_counts = {
        ("Natural","Language"): 1,     # small
        ("Language","Processing"): 3,  # bigger
        ("Processing","is"): 2,
        ("is","Love"): 0
    }
    unigram_counts = {
        "Natural": 1,
        "Language": 5,
        "Processing": 2,
        "is": 3,
        "Love": 0   # unseen unigram
    }
    sim.explain_sentence(sentence,
                         trigram_counts=trigram_counts,
                         bigram_counts=bigram_counts,
                         unigram_counts=unigram_counts)


=== Example A: All counts missing (simulate unseen) ===
Sentence tokens: ['Natural', 'Language', 'Processing', 'is', 'Love']
Cutoff k = 1

Predict P(Natural) -> unigram
  Unigram count C(Natural) = 0
  -> count <= k (1) : unigram too rare/unseen -> use UNIFORM fallback (stop)

Predict P(Language | Natural) -> bigram
  Bigram count C(Natural,Language) = 0; history count C(Natural) = 0
  -> bigram count <= k or history unseen -> back off to unigram:
  Unigram count C(Language) = 0
  -> count <= k (1) : unigram too rare/unseen -> use UNIFORM fallback (stop)

Predict P(Processing | Natural, Language) -> trigram
  Trigram count C(Natural,Language,Processing) = 0; trigram-history count C(Natural,Language) = 0
  -> trigram count <= k or trigram-history unseen -> back off to bigram:
  Bigram count C(Language,Processing) = 0; history count C(Language) = 0
  -> bigram count <= k or history unseen -> back off to unigram:
  Unigram count C(Processing) = 0
  -> count <= k (1) : unigram too rare/uns

In [1]:
from collections import Counter, defaultdict

def get_stats(vocab):
    """Count frequency of symbol pairs in the vocabulary."""
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[(symbols[i], symbols[i+1])] += freq
    return pairs

def merge_vocab(pair, vocab):
    """Merge the most frequent pair in vocabulary."""
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word in vocab:
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = vocab[word]
    return new_vocab

# Example corpus
corpus = ["low", "lower", "newest", "widest"]

# Step 1: Build initial vocabulary with character-level tokens
vocab = {}
for word in corpus:
    chars = " ".join(list(word)) + " </w>"   # </w> marks word-end
    vocab[chars] = vocab.get(chars, 0) + 1

print("Initial Vocabulary:")
for k, v in vocab.items():
    print(k, ":", v)

# Step 2: Run BPE merges
num_merges = 10
for i in range(num_merges):
    pairs = get_stats(vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)   # most frequent pair
    vocab = merge_vocab(best, vocab)
    print(f"\nMerge {i+1}: {best}")
    for k, v in vocab.items():
        print(k, ":", v)


Initial Vocabulary:
l o w </w> : 1
l o w e r </w> : 1
n e w e s t </w> : 1
w i d e s t </w> : 1

Merge 1: ('l', 'o')
lo w </w> : 1
lo w e r </w> : 1
n e w e s t </w> : 1
w i d e s t </w> : 1

Merge 2: ('lo', 'w')
low </w> : 1
low e r </w> : 1
n e w e s t </w> : 1
w i d e s t </w> : 1

Merge 3: ('e', 's')
low </w> : 1
low e r </w> : 1
n e w es t </w> : 1
w i d es t </w> : 1

Merge 4: ('es', 't')
low </w> : 1
low e r </w> : 1
n e w est </w> : 1
w i d est </w> : 1

Merge 5: ('est', '</w>')
low </w> : 1
low e r </w> : 1
n e w est</w> : 1
w i d est</w> : 1

Merge 6: ('low', '</w>')
low</w> : 1
low e r </w> : 1
n e w est</w> : 1
w i d est</w> : 1

Merge 7: ('low', 'e')
low</w> : 1
lowe r </w> : 1
n e w est</w> : 1
w i d est</w> : 1

Merge 8: ('lowe', 'r')
low</w> : 1
lower </w> : 1
n e w est</w> : 1
w i d est</w> : 1

Merge 9: ('lower', '</w>')
low</w> : 1
lower</w> : 1
n e w est</w> : 1
w i d est</w> : 1

Merge 10: ('n', 'e')
low</w> : 1
lower</w> : 1
ne w est</w> : 1
w i d est</w> : 1


In [2]:
from collections import defaultdict

# -------------------------
# Step 1: Build initial vocab from corpus
# -------------------------
def get_vocab(corpus):
    vocab = {}
    for word in corpus:
        chars = " ".join(list(word)) + " </w>"
        vocab[chars] = vocab.get(chars, 0) + 1
    return vocab

# -------------------------
# Step 2: Compute pair scores (WordPiece style)
# -------------------------
def compute_scores(vocab):
    pair_counts = defaultdict(int)
    total_freq = sum(vocab.values())

    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pair_counts[pair] += freq

    scores = {}
    for pair, count in pair_counts.items():
        # simplified score = probability of this pair in corpus
        scores[pair] = count / total_freq
    return scores

# -------------------------
# Step 3: Merge vocab with best pair
# -------------------------
def merge_vocab(pair, vocab):
    bigram = " ".join(pair)
    replacement = "".join(pair)
    new_vocab = {}
    for word in vocab:
        new_word = word.replace(bigram, replacement)
        new_vocab[new_word] = vocab[word]
    return new_vocab

# -------------------------
# Step 4: Train WordPiece merges
# -------------------------
def train_wordpiece(corpus, num_merges=20):
    vocab = get_vocab(corpus)
    merges = []

    for i in range(num_merges):
        scores = compute_scores(vocab)
        if not scores:
            break
        best = max(scores, key=scores.get)
        vocab = merge_vocab(best, vocab)
        merges.append(best)
    return merges

# -------------------------
# Step 5: Build final subword vocab
# -------------------------
def build_subword_vocab(merges, corpus):
    # Start with characters
    subword_vocab = set(list("".join(corpus)))
    for (a, b) in merges:
        token = a + b
        subword_vocab.add(token)
    return subword_vocab

# -------------------------
# Step 6: WordPiece Tokenizer
# -------------------------
def tokenize_wordpiece(word, vocab):
    tokens = []
    start = 0
    while start < len(word):
        end = len(word)
        subword = None
        while start < end:
            piece = word[start:end]
            if start > 0:  # continuation → add ##
                piece = "##" + piece
            if piece in vocab:
                subword = piece
                break
            end -= 1
        if subword is None:
            tokens.append("[UNK]")
            break
        tokens.append(subword)
        start = end if subword.startswith("##") else len(subword)
    return tokens

# -------------------------
# Example Run
# -------------------------
corpus = ["low", "lower", "newest", "widest"]

# Train
merges = train_wordpiece(corpus, num_merges=10)
print("Learned merges:", merges)

# Build vocab
subword_vocab = build_subword_vocab(merges, corpus)
# Add common prefixes with ##
extended_vocab = set(subword_vocab)
for token in list(subword_vocab):
    if len(token) > 1:
        extended_vocab.add("##" + token)

print("\nFinal Subword Vocabulary:")
print(extended_vocab)

# Tokenize words
print("\nTokenization examples:")
print("lower ->", tokenize_wordpiece("lower", extended_vocab))
print("newest ->", tokenize_wordpiece("newest", extended_vocab))
print("widest ->", tokenize_wordpiece("widest", extended_vocab))


Learned merges: [('l', 'o'), ('lo', 'w'), ('e', 's'), ('es', 't'), ('est', '</w>'), ('low', '</w>'), ('low', 'e'), ('lowe', 'r'), ('lower', '</w>'), ('n', 'e')]

Final Subword Vocabulary:
{'##low</w>', 'l', 'est', '##low', 'est</w>', '##ne', 'n', '##es', '##est</w>', 't', '##lo', 'i', 'es', 'low</w>', 'r', '##lower', 'lower</w>', 'low', 'w', 'lo', 'lower', '##lowe', 's', 'ne', '##est', 'o', '##lower</w>', 'e', 'd', 'lowe'}

Tokenization examples:
lower -> ['lower']
newest -> ['ne', '[UNK]']
widest -> ['w', '[UNK]']
