In [2]:
import pandas as pd
import re
import ast
from collections import Counter #count hashable objects

In [3]:
df = pd.read_csv("../ass5/train.csv")

In [4]:
def extract_texts(cell):
    if isinstance(cell, str):
        try:
            cell = ast.literal_eval(cell)  # parse string repr of list/dict
        except:
            return []
    if isinstance(cell, list):
        return [d.get("text", "") for d in cell if isinstance(d, dict)]
    return []

all_sentences = []
for row in df["sentence"]:
    all_sentences.extend(extract_texts(row))

In [5]:
def simple_tokenize(text):
    if isinstance(text, bytes):
        text = text.decode("utf-8", errors="ignore")
    elif not isinstance(text, str):
        text = str(text)
    text = re.sub(r"[^\w\s]", "", text)  # remove punctuation
    return text.strip().split()

corpus = [simple_tokenize(s) for s in all_sentences]

In [6]:
def ngrams(tokens, n):
    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]

unigrams, bigrams, trigrams, quadgrams = Counter(), Counter(), Counter(), Counter()

for sentence in corpus:
    unigrams.update(ngrams(sentence, 1))
    bigrams.update(ngrams(sentence, 2))
    trigrams.update(ngrams(sentence, 3))
    quadgrams.update(ngrams(sentence, 4))

In [7]:
class KatzBackoffQuadgram:
    def __init__(self, unig, big, trig, quad, discount=0.75):
        self.unigrams = unig
        self.bigrams = big
        self.trigrams = trig
        self.quadgrams = quad
        self.discount = discount
        self.total_unigrams = sum(unig.values())

    def prob(self, word, history):
        if len(history) < 3:
            history = (["<s>"] * (3 - len(history))) + history
        hist3 = tuple(history[-3:])
        quad_count = self.quadgrams[(hist3 + (word,))]

        if quad_count > 0:
            return (quad_count - self.discount) / self.trigrams[hist3]
        else:
            return self.alpha(hist3, self.quadgrams, self.trigrams) * self.prob_trigram(word, history[-2:])

    def prob_trigram(self, word, history):
        hist2 = tuple(history[-2:])
        tri_count = self.trigrams[(hist2 + (word,))]

        if tri_count > 0:
            return (tri_count - self.discount) / self.bigrams[hist2]
        else:
            return self.alpha(hist2, self.trigrams, self.bigrams) * self.prob_bigram(word, history[-1:])

    def prob_bigram(self, word, history):
        hist1 = tuple(history[-1:])
        bi_count = self.bigrams[(hist1 + (word,))]

        if bi_count > 0:
            return (bi_count - self.discount) / self.unigrams[hist1]
        else:
            return self.alpha(hist1, self.bigrams, self.unigrams) * self.prob_unigram(word)

    def prob_unigram(self, word):
        return self.unigrams[(word,)] / self.total_unigrams if self.total_unigrams > 0 else 1e-6

    def alpha(self, history, higher_order, lower_order):
        hist_count = lower_order[history]
        if hist_count == 0:
            return 1.0
        discounted_mass = 0
        for (ngram, count) in higher_order.items():
            if ngram[:-1] == history and count > 0:
                discounted_mass += (count - self.discount) / hist_count
        return max(1.0 - discounted_mass, 1e-6)

In [8]:
katz_quad_model = KatzBackoffQuadgram(unigrams, bigrams, trigrams, quadgrams)

print("P(ગાંધીનગર | મ્યુનિસિપલ કોર્પોરેશન) =", katz_quad_model.prob("ગાંધીનગર", ["મ્યુનિસિપલ", "કોર્પોરેશન",]))

P(ગાંધીનગર | મ્યુનિસિપલ કોર્પોરેશન) = 1e-06


In [9]:
import pickle

with open("katz_quad_model.pkl", "wb") as f:
    pickle.dump(katz_quad_model, f)