In [None]:
import math
import re
import operator

from collections import Counter

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def tokenize(text):
    #"List all the word tokens (consecutive letters) in a text. Normalize to lowercase."
    return re.findall('[a-z]+', text.lower())

def find_bigrams(text):
    bi_dict = {}
    for i in range(0, len(text) - 1):
        (first, second) = (text[i], text[i+1])
        if not (first, second) in bi_dict:
            bi_dict[(first, second)] = 1
        else:
            bi_dict[(first, second)] += 1
    return bi_dict

def find_unigrams(text):
    return Counter(text)

In [None]:
def get_params(d, bigrams, unigrams, vocab):
    
    #params to return
    N1_plus_wi1 = dict() #eqn 10
    N_bigrams = dict() #denominator of 7, 9
    lambda_wi = dict() #eqn 9
    P_abs_wi = dict() #eqn 8
    
    P_unif = 1/len(vocab)
    N_unis = sum(unigrams.values())
    N1_plus = len(unigrams)
    
    lambdadot = (d/N_unis) * N1_plus
    
    for word in vocab:
        count = 0 #unique bigrams starting with word
        N = 0 #total number of bigrams starting with word
        
        for key, val in bigrams.items():
            if(key[0] == word):
                if(val > 0):
                    count += 1
                    N += val
        
        N1_plus_wi1[word] = count 
        N_bigrams[word] = N
        
        if N > 0:
            lambdad = (d/N) * N1_plus_wi1[word]
        else:
            lambdad = 0
            
        lambda_wi[word] = lambdad
        P_abs_wi[word] = (max((unigrams[word] - d), 0)/N_unis) + (lambdadot * P_unif)
          
    return P_unif, lambdadot, N1_plus_wi1, N_bigrams, lambda_wi, P_abs_wi


def get_P_abs_wi1(wi, wi_1, d, unigrams, bigrams, N_bigrams, lambda_wi, P_abs_wi, lambdadot, P_unif):
    try:
        if((wi_1, wi) not in bigrams.keys()):
            if(wi_1 not in lambda_wi.keys()):
                if(wi not in P_abs_wi.keys()):
                    P_abs_wi1 = lambdadot * P_unif
                else:
                    P_abs_wi1 = P_abs_wi[wi]
            else:
                P_abs_wi1 = (lambda_wi[wi_1] * P_abs_wi[wi])
        else:
            N = N_bigrams[wi_1]
            P_abs_wi1 = (max((bigrams[(wi_1, wi)] - d), 0)/N) + (lambda_wi[wi_1] * P_abs_wi[wi])
    except:
        P_abs_wi1 = lambdadot * P_unif
    return P_abs_wi1

def calculate_probabilities(unigrams, bigrams, bgrams, vocab):
    bi_dict = {}
    P_unif, lambdadot, N1_plus_wi1, N_bigrams, lambda_wi, P_abs_wi = get_params(0.7, bigrams, unigrams, vocab)
    for key, val in bgrams.items():
        bi_dict[key] = get_P_abs_wi1(key[1], key[0], 0.7, unigrams, bigrams, N_bigrams, lambda_wi, P_abs_wi, lambdadot, P_unif)
    
    return P_abs_wi, bi_dict

In [None]:
def prune(uni_probs, bi_probs, epsilon, ugrams, bi_grams, bgrams, vocab):
    
    unigrams = ugrams.copy()
    bigrams = bi_grams.copy()
    
    for key in uni_probs.keys():
        if(uni_probs[key] < epsilon):
            del unigrams[key]
            uni_probs
    for key in bi_probs.keys():
        if((bi_probs[key] < epsilon) and (key in bigrams)):
            del bigrams[key]
    #recalculate probabilities based on pruned ngrams            
    uni_dict, bi_dict = calculate_probabilities(unigrams, bigrams, bgrams, vocab)
    
    return uni_dict, bi_dict, unigrams, bigrams

In [None]:
def find_perplexity(bgrams, bi_dict):
    tsum = 0
    s = sum(bgrams.values())
    
    for k,v in bgrams.items():
        rel_freq = v / s
        cond_prob = bi_dict[k]
        tsum -= rel_freq * math.log(cond_prob)
    
    perplexity = math.exp(tsum)
    return perplexity

In [None]:
trainfile = './ex7_materials/English_train.txt'
with open(trainfile, encoding="utf-8") as f:
    traintext = f.read()
f.close()

testfile = './ex7_materials/English_test.txt'
with open(testfile, encoding="utf-8") as f:
    testtext = f.read()
f.close()

train_text = tokenize(traintext)
test_text = tokenize(testtext)

uni_train = find_unigrams(train_text)
bi_train = find_bigrams(train_text)

vocab = list(uni_train.keys())
V = len(vocab)

uni_test = find_unigrams(test_text)
bi_test = find_bigrams(test_text)

In [None]:
uni_dict, bi_dict = calculate_probabilities(uni_train, bi_train, bi_test, vocab)
print(f"Prepruning lengths: unigrams - {len(uni_train)}, bigrams - {len(bi_train)}")
pp = find_perplexity(bi_test, bi_dict)
print(f"Perplexity without pruning = {pp}")

Prepruning lengths: unigrams - 5765, bigrams - 45533
Perplexity without pruning = 225.61367810627084


In [None]:
for i in range(2,7):
    epsilon = math.pow(10, i*-1)
    udict, bdict, unis, bis = prune(uni_dict, bi_dict, epsilon, uni_train, bi_train, bi_test, vocab)
    pp = find_perplexity(bi_test, bdict)
    print(f"epsilon = {epsilon}, unigrams length = {len(unis)}, bigrams length = {len(bis)}, perplexity = {pp}")

epsilon = 0.01, unigrams length = 15, bigrams length = 41884, perplexity = 28087.90931123331
epsilon = 0.001, unigrams length = 132, bigrams length = 44465, perplexity = 887.5127072455467
epsilon = 0.0001, unigrams length = 1014, bigrams length = 45479, perplexity = 285.6268150247238
epsilon = 1e-05, unigrams length = 5765, bigrams length = 45533, perplexity = 225.61367810627084
epsilon = 1e-06, unigrams length = 5765, bigrams length = 45533, perplexity = 225.61367810627084
