In [1]:
from preprocess import *

text_preprocessor = TextPreprocessor()

## Model Building

In [2]:
class InterpolationAddK:
    def __init__(
        self,
        k=0.1,  # Smoothing parameter
        lambda1=0.1,
        lambda2=0.2,
        lambda3=0.3,
        lambda4=0.4,
    ) -> None:
        # preprocess
        text_preprocessor = TextPreprocessor()
        
        # initialize necessary fields
        self.freq_uni = text_preprocessor.freq_uni
        self.freq_bi = text_preprocessor.freq_bi
        self.freq_tri = text_preprocessor.freq_tri
        self.freq_four = text_preprocessor.freq_four
        
        # k and lambda
        self.k = k
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.lambda3 = lambda3
        self.lambda4 = lambda4
    
    # ---------- interpolation with Add-k probability of unigram ----------
    def probability(
        self,
        word: str,
        given_tri_gram: tuple,
    ):
        """
        Estimate the probability of a word being the next word after given previous words using linear interpolation with add-k smoothing.

        Args:
            word: The word for which to calculate the next word probability.
            previous_words: A tuple containing the two previous words.
            unigram_counts: A dictionary with counts of unigrams.
            bigram_counts: A dictionary with counts of bigrams.
            trigram_counts: A dictionary with counts of trigrams.
            fourgram_counts: A dictionary with counts of fourgrams.
            k: Smoothing parameter (default is 1).
            lambda1: Weight for unigram model.
            lambda2: Weight for bigram model.
            lambda3: Weight for trigram model.
            lambda4: Weight for fourgram model.

        Returns:
            The estimated probability of 'word' being the next word after 'previous_words' using linear interpolation with add-k smoothing.
        """
        
        # Create the unigram, bigram, trigram, and fourgram tuples
        uni_gram = (word,)
        bi_gram = (given_tri_gram[2], word)
        tri_gram = (given_tri_gram[1], given_tri_gram[2], word)
        four_gram = (given_tri_gram[0], given_tri_gram[1], given_tri_gram[2], word)

        # Calculate probabilities for each n-gram model with add-k smoothing
        unigram_prob = self.unigram_addk_probability(
            current_uni=uni_gram, 
            k=self.k,
        )
        
        bigram_prob = self.n_gram_addk_probability(
            word=bi_gram[1], 
            given_gram=bi_gram[:1], 
            freq_previous=self.freq_uni, 
            freq_current=self.freq_bi,
            k=self.k,
        )
        
        trigram_prob = self.n_gram_addk_probability(
            word=tri_gram[2],
            given_gram=tri_gram[:2],
            freq_previous=self.freq_bi, 
            freq_current=self.freq_tri,
            k=self.k,
        )
        
        fourgram_prob = self.n_gram_addk_probability(
            word=four_gram[3], 
            given_gram=(four_gram[:3]), 
            freq_previous=self.freq_tri, 
            freq_current=self.freq_four,
            k=self.k,
        )

        # Calculate interpolated probability
        probability = (self.lambda1*unigram_prob) + (self.lambda2*bigram_prob) + (self.lambda3*trigram_prob) + (self.lambda4*fourgram_prob)
        
        # print(f'probability of {word}: {probability}')

        return probability
    
    # ---------- Add-k probability of unigram ----------
    def unigram_addk_probability(
        self,
        current_uni: tuple,
        k = 0.1
    ):
        uni_gram_count = self.freq_uni.get(current_uni, 0)
        n_total_words = len(text_preprocessor.training_data)
        n_unique_words = len(self.freq_uni)
        
        probability = (uni_gram_count + k) / (n_total_words + n_unique_words * k) 
        return probability
    
    # ---------- Add-k probability of n-gram, starting from bi-gram ----------
    def n_gram_addk_probability(
        self,
        word: str,
        given_gram: tuple,
        freq_previous: dict, 
        freq_current: dict, 
        k = 0.1
    ):
        # new n-gram
        n_gram = list(given_gram)
        n_gram.append(word)
        n_gram = tuple(n_gram)
        
        current_gram_count = freq_current.get(n_gram, 0)
        previous_gram_count = freq_previous.get(given_gram, 0)
        unique_word_count = len(self.freq_uni)
        
        probability = (current_gram_count + k) / (previous_gram_count + unique_word_count * k)
        
        return probability
    
    # ---------- Predict the next word ----------
    def predict(
        self,
        previous_word: tuple[str, str, str],
    ):
        predictions = []
        for word in self.freq_uni.keys():
            if (word[0] != '<s>' and word[0] != '</s>'):
                probability = self.probability(word[0], previous_word)
                predictions.append((word, probability)) 

        predictions.sort(key=lambda x: x[1], reverse=True)
        print(predictions)
        return predictions[0][0][0]
    

In [3]:
# text_preprocessor.tokenize_words(text_preprocessor.test_data)['sentences']

In [6]:
model = InterpolationAddK(k=pow(10, -5))
model.predict(('computer', 'science', 'is'))

[(('data',), 0.08784459767770207), (('s',), 0.013287831076607352), (('systems',), 0.012995451599740315), (('computer',), 0.010071656831069933), (('used',), 0.009486897877335856), (('business',), 0.009194518400468818), (('big',), 0.009194518400468818), (('intelligence',), 0.008609759446734743), (('machine',), 0.008609759446734743), (('often',), 0.008025000493000665), (('warehouse',), 0.008025000493000665), (('first',), 0.007732621016133627), (('model',), 0.006855482585532513), (('information',), 0.006563103108665475), (('processing',), 0.006563103108665475), (('could',), 0.006563103108665475), (('ai',), 0.006563103108665475), (('artificial',), 0.006270723631798437), (('database',), 0.005978344154931399), (('–',), 0.005978344154931399), (('would',), 0.005685964678064361), (('science',), 0.005393585201197323), (('software',), 0.005393585201197323), (('turing',), 0.005393585201197323), (('machines',), 0.005393585201197323), (('one',), 0.005101205724330284), (('system',), 0.0048088262474632

'data'

## Model Evaluation

In [5]:
def perplexity_interpolation():
    
    return

## Text Generation