In [6]:
# ============================================================
# Quadgram LM + Katz Backoff + Kneser–Ney + Sentence Generation
# (Greedy + Beam search) with progress bar and file saving
# ============================================================

import pickle
from collections import Counter
import math
from tqdm import tqdm   # <-- for progress bar

# -------------------------------
# 1️⃣ Load dataset & build counts
# -------------------------------
with open("/content/tokenized_data_complete.pkl", "rb") as f:   # <-- change to your file
    data = pickle.load(f)

# Fix: Iterate directly through the list and access the 'tokens' key
sentences = [s["sentences"] for s in data]

uni_counts  = Counter()
bi_counts   = Counter()
tri_counts  = Counter()
quad_counts = Counter()

for sentence_list in sentences:
    for token_info in sentence_list:
        seq = ["<s>", "<s>", "<s>"] + token_info["tokens"] + ["</s>"]
        for i in range(len(seq)):
            uni_counts[seq[i]] += 1
            if i >= 1:
                bi_counts[(seq[i-1], seq[i])] += 1
            if i >= 2:
                tri_counts[(seq[i-2], seq[i-1], seq[i])] += 1
            if i >= 3:
                quad_counts[(seq[i-3], seq[i-2], seq[i-1], seq[i])] += 1

counts_all = [uni_counts, bi_counts, tri_counts, quad_counts]
vocab = list(uni_counts.keys())

# ----------------------------
# 2️⃣ Katz Backoff (Quadrigram)
# ----------------------------
class KatzBackoff:
    def __init__(self, counts, discount=0.5):
        self.uni, self.bi, self.tri, self.quad = counts
        self.d = discount
        self.total = sum(self.uni.values())

    def prob(self, ngram):
        w1, w2, w3, w4 = ngram
        q = (w1, w2, w3, w4)
        t = (w1, w2, w3)
        c4 = self.quad.get(q, 0)
        c3 = self.tri.get(t, 0)
        if c3 > 0:
            return max(c4 - self.d, 0) / c3
        return self._backoff_tri((w2, w3, w4))

    def _backoff_tri(self, tri):
        w2, w3, w4 = tri
        t = (w2, w3, w4)
        b = (w2, w3)
        c3 = self.tri.get(t, 0)
        c2 = self.bi.get(b, 0)
        if c3 > 0:
            return max(c3 - self.d, 0) / c2
        return self._backoff_bi((w3, w4))

    def _backoff_bi(self, bi):
        w3, w4 = bi
        c2 = self.bi.get((w3, w4), 0)
        c1 = self.uni.get(w3, 0)
        if c2 > 0:
            return max(c2 - self.d, 0) / c1
        return self.uni.get(w4, 0) / self.total

# -----------------------------------
# 3️⃣ Kneser–Ney Smoothing (Quadrigram)
# -----------------------------------
class KneserNey:
    def __init__(self, counts, discount=0.75):
        self.uni, self.bi, self.tri, self.quad = counts
        self.D = discount

    def continuation_prob(self, word):
        return sum(1 for bg in self.bi if bg[1] == word) / len(self.bi)

    def prob(self, ngram):
        w1, w2, w3, w4 = ngram
        quad = (w1, w2, w3, w4)
        tri  = (w1, w2, w3)
        c4 = self.quad.get(quad, 0)
        c3 = self.tri.get(tri, 0)
        if c3 > 0:
            return max(c4 - self.D, 0)/c3 + (self.D/c3)*self._lower(tri[1:], w4)
        return self._lower(tri[1:], w4)

    def _lower(self, bi, w):
        c3 = self.tri.get((bi[0], bi[1], w), 0)
        c2 = self.bi.get(bi, 0)
        if c2 > 0:
            return max(c3 - self.D, 0)/c2 + (self.D/c2)*self.continuation_prob(w)
        return self.continuation_prob(w)

# ---------------------------
# 4️⃣ Sentence Generation
# ---------------------------
class SentenceGenerator:
    def __init__(self, model, vocab):
        self.model = model
        self.vocab = vocab

    def greedy(self, max_len=30):
        sent = ["<s>", "<s>", "<s>"]
        while len(sent) < max_len:
            # Prevent generating "<s>" if the previous token is also "<s>"
            if sent[-1] == "<s>":
                probs = {w: self.model.prob((sent[-3], sent[-2], sent[-1], w))
                         for w in self.vocab if w != "<s>"}
            else:
                probs = {w: self.model.prob((sent[-3], sent[-2], sent[-1], w))
                         for w in self.vocab}

            # Handle the case where all remaining probabilities are zero
            if not probs:
                break

            nxt = max(probs, key=probs.get)
            sent.append(nxt)
            if nxt == "</s>":
                break
        return sent

    def beam_search(self, beam_size=20, max_len=30):
        beams = [(["<s>", "<s>", "<s>"], 0.0)]
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                if seq[-1] == "</s>":
                    new_beams.append((seq, score))
                    continue
                for w in self.vocab:
                    p = self.model.prob((seq[-3], seq[-2], seq[-1], w))
                    new_beams.append((seq + [w], score + math.log(p + 1e-12)))
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
        return beams[0][0]

# --------------------------
# 5️⃣ Run & save sentences
# --------------------------
import os
os.makedirs("/content/generated_sentences", exist_ok=True)

def generate_and_save(generator, method, n=100, out_file="out.txt"):
    sentences = []
    with tqdm(total=n, desc=f"{method}") as pbar:
        for _ in range(n):
            if method == "Greedy":
                s = generator.greedy()
            else:
                s = generator.beam_search()
            sentences.append(" ".join(s))
            pbar.update(1)
    with open(out_file, "w", encoding="utf-8") as f:
        f.write("\n".join(sentences))
    print(f"✅ Saved {n} sentences to {out_file}")

# Katz Backoff
katz = KatzBackoff(counts_all)
gen_katz = SentenceGenerator(katz, vocab)
generate_and_save(gen_katz, "Greedy", 100,
                  "/content/generated_sentences/katz_greedy.txt")
generate_and_save(gen_katz, "Beam", 100,
                  "/content/generated_sentences/katz_beam.txt")

# Kneser–Ney
kn = KneserNey(counts_all)
gen_kn = SentenceGenerator(kn, vocab)
generate_and_save(gen_kn, "Greedy", 100,
                  "/content/generated_sentences/kn_greedy.txt")
generate_and_save(gen_kn, "Beam", 100,
                  "/content/generated_sentences/kn_beam.txt")

Greedy: 100%|██████████| 100/100 [04:53<00:00,  2.94s/it]


✅ Saved 100 sentences to /content/generated_sentences/katz_greedy.txt


Beam:   0%|          | 0/100 [01:39<?, ?it/s]


KeyboardInterrupt: 

In [3]:
# Inspect the structure of the loaded data
print(type(data))
if isinstance(data, list) and len(data) > 0:
    print(type(data[0]))
    if isinstance(data[0], dict):
        print(data[0].keys())

<class 'list'>
<class 'dict'>
dict_keys(['document_id', 'original_text', 'sentences', 'document_stats'])


In [None]:
# ============================================================
# Quadgram LM + Katz Backoff + Kneser–Ney + Sentence Generation
# (Greedy + Beam search) with progress bar and file saving
# ============================================================

import pickle
from collections import Counter
import math
from tqdm import tqdm   # <-- for progress bar

# -------------------------------
# 1️⃣ Load dataset & build counts
# -------------------------------
with open("/content/tokenized_data_complete.pkl", "rb") as f:   # <-- change to your file
    data = pickle.load(f)

# Fix: Iterate directly through the list and access the 'tokens' key
sentences = [s["sentences"] for s in data]

uni_counts  = Counter()
bi_counts   = Counter()
tri_counts  = Counter()
quad_counts = Counter()

for sentence_list in sentences:
    for token_info in sentence_list:
        seq = ["<s>", "<s>", "<s>"] + token_info["tokens"] + ["</s>"]
        for i in range(len(seq)):
            uni_counts[seq[i]] += 1
            if i >= 1:
                bi_counts[(seq[i-1], seq[i])] += 1
            if i >= 2:
                tri_counts[(seq[i-2], seq[i-1], seq[i])] += 1
            if i >= 3:
                quad_counts[(seq[i-3], seq[i-2], seq[i-1], seq[i])] += 1

counts_all = [uni_counts, bi_counts, tri_counts, quad_counts]
vocab = list(uni_counts.keys())

# ----------------------------
# 2️⃣ Katz Backoff (Quadrigram)
# ----------------------------
class KatzBackoff:
    def __init__(self, counts, discount=0.5):
        self.uni, self.bi, self.tri, self.quad = counts
        self.d = discount
        self.total = sum(self.uni.values())

    def prob(self, ngram):
        w1, w2, w3, w4 = ngram
        q = (w1, w2, w3, w4)
        t = (w1, w2, w3)
        c4 = self.quad.get(q, 0)
        c3 = self.tri.get(t, 0)
        if c3 > 0:
            return max(c4 - self.d, 0) / c3
        return self._backoff_tri((w2, w3, w4))

    def _backoff_tri(self, tri):
        w2, w3, w4 = tri
        t = (w2, w3, w4)
        b = (w2, w3)
        c3 = self.tri.get(t, 0)
        c2 = self.bi.get(b, 0)
        if c3 > 0:
            return max(c3 - self.d, 0) / c2
        return self._backoff_bi((w3, w4))

    def _backoff_bi(self, bi):
        w3, w4 = bi
        c2 = self.bi.get((w3, w4), 0)
        c1 = self.uni.get(w3, 0)
        if c2 > 0:
            return max(c2 - self.d, 0) / c1
        return self.uni.get(w4, 0) / self.total

# -----------------------------------
# 3️⃣ Kneser–Ney Smoothing (Quadrigram)
# -----------------------------------
class KneserNey:
    def __init__(self, counts, discount=0.75):
        self.uni, self.bi, self.tri, self.quad = counts
        self.D = discount

    def continuation_prob(self, word):
        return sum(1 for bg in self.bi if bg[1] == word) / len(self.bi)

    def prob(self, ngram):
        w1, w2, w3, w4 = ngram
        quad = (w1, w2, w3, w4)
        tri  = (w1, w2, w3)
        c4 = self.quad.get(quad, 0)
        c3 = self.tri.get(tri, 0)
        if c3 > 0:
            return max(c4 - self.D, 0)/c3 + (self.D/c3)*self._lower(tri[1:], w4)
        return self._lower(tri[1:], w4)

    def _lower(self, bi, w):
        c3 = self.tri.get((bi[0], bi[1], w), 0)
        c2 = self.bi.get(bi, 0)
        if c2 > 0:
            return max(c3 - self.D, 0)/c2 + (self.D/c2)*self.continuation_prob(w)
        return self.continuation_prob(w)

# ---------------------------
# 4️⃣ Sentence Generation
# ---------------------------
class SentenceGenerator:
    def __init__(self, model, vocab):
        self.model = model
        self.vocab = vocab

    def greedy(self, max_len=30):
        sent = ["<s>", "<s>", "<s>"]
        while len(sent) < max_len:
            # Prevent generating "<s>" if the previous token is also "<s>"
            if sent[-1] == "<s>":
                probs = {w: self.model.prob((sent[-3], sent[-2], sent[-1], w))
                         for w in self.vocab if w != "<s>"}
            else:
                probs = {w: self.model.prob((sent[-3], sent[-2], sent[-1], w))
                         for w in self.vocab}

            # Handle the case where all remaining probabilities are zero
            if not probs:
                break

            nxt = max(probs, key=probs.get)
            sent.append(nxt)
            if nxt == "</s>":
                break
        return sent

    def beam_search(self, beam_size=20, max_len=30):
        beams = [(["<s>", "<s>", "<s>"], 0.0)]
        for _ in range(max_len):
            new_beams = []
            for seq, score in beams:
                if seq[-1] == "</s>":
                    new_beams.append((seq, score))
                    continue
                for w in self.vocab:
                    p = self.model.prob((seq[-3], seq[-2], seq[-1], w))
                    new_beams.append((seq + [w], score + math.log(p + 1e-12)))
            beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
        return beams[0][0]

# --------------------------
# 5️⃣ Run & save sentences
# --------------------------
import os
os.makedirs("/content/generated_sentences", exist_ok=True)

def generate_and_save(generator, method, n=100, out_file="out.txt"):
    sentences = []
    with tqdm(total=n, desc=f"{method}") as pbar:
        for _ in range(n):
            if method == "Greedy":
                s = generator.greedy()
            else:
                s = generator.beam_search()
            sentences.append(" ".join(s))
            pbar.update(1)
    with open(out_file, "w", encoding="utf-8") as f:
        f.write("\n".join(sentences))
    print(f"✅ Saved {n} sentences to {out_file}")

# Katz Backoff
katz = KatzBackoff(counts_all)
gen_katz = SentenceGenerator(katz, vocab)
generate_and_save(gen_katz, "Greedy", 100,
                  "/content/generated_sentences/katz_greedy.txt")
generate_and_save(gen_katz, "Beam", 100,
                  "/content/generated_sentences/katz_beam.txt")

# Kneser–Ney
kn = KneserNey(counts_all)
gen_kn = SentenceGenerator(kn, vocab)
generate_and_save(gen_kn, "Greedy", 100,
                  "/content/generated_sentences/kn_greedy.txt")
generate_and_save(gen_kn, "Beam", 100,
                  "/content/generated_sentences/kn_beam.txt")

Greedy: 100%|██████████| 100/100 [04:55<00:00,  2.95s/it]


✅ Saved 100 sentences to /content/generated_sentences/katz_greedy.txt


Beam: 100%|██████████| 100/100 [5:28:02<00:00, 196.82s/it]


✅ Saved 100 sentences to /content/generated_sentences/katz_beam.txt


Greedy:   0%|          | 0/100 [00:00<?, ?it/s]