In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import time
import numpy as np
import math
'''
import sys
sys.path.append('Classes')
from arpy import *
'''
import arpa
from tqdm import tqdm

import torch
import torch.nn as nn



# prep cuda

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())


GeForce GTX 1060 6GB
True


# Prep dict

In [3]:
file = '../../rsc/train_counts.txt'
first_read = open(file ,'r')

num_lines = sum(1 for line in open(file,'r'))

ngram_dict = {}
for x in tqdm(first_read, total=num_lines, position=0, leave=True):
    line = x.split('\t')
    r = int(line[-1])
    ngram_dict[line[0]] = r


100%|██████████| 45918515/45918515 [00:47<00:00, 971822.91it/s] 



# Prep NN

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 4)
        #self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(4, 3)
        self.fc4 = nn.Linear(3, 1)

        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        #x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x)) 
        x = torch.sigmoid(self.fc4(x)) 
        return x
'''
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        #self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(3, 3)
        self.fc4 = nn.Linear(3, 1)

        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        #x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x)) 
        x = torch.sigmoid(self.fc4(x)) 
        return x
'''
net = Net().cuda()
net.load_state_dict(torch.load('NN saves/temp_50-50'))
net.eval()

Net(
  (fc1): Linear(in_features=4, out_features=4, bias=True)
  (fc3): Linear(in_features=4, out_features=3, bias=True)
  (fc4): Linear(in_features=3, out_features=1, bias=True)
)

# Load and Write ARPA
## NN with 4 inputs

In [None]:
file = '../../rsc/unsmoothedLM.arpa'
first_read = open(file ,'r')
new_file = open("../../rsc/output_LM.arpa","w+")
num_lines = sum(1 for line in open(file,'r'))
current_ngram_len = 0
TH_exceeded = 0
count = 0
nn_input = torch.zeros(1, 4, dtype = torch.float, device = device)

for x in tqdm(first_read, total=num_lines, position=0, leave=True):
    if x == '\\end\\\n':
        current_ngram_len = -1
        new_file.write(x)
    elif x == '\n':
        new_file.write(x)
    elif current_ngram_len < 3:  
        new_file.write(x)
    elif current_ngram_len == 3: #trigrams
        line = x.split('\t')
        r = ngram_dict[line[1][:-1]] #get ngram occurence

        if r > 1 and r < 8: #only smooth applicable trigrams
            prob = 10**float(line[0])  #retrieve ngram prob from ARPA
            ngram = line[1].split(' ') #retrieve ngram
            ngram[2] = ngram[2][:-1]   
            count += 1
            
            ######setup nn input#######
            nn_input[0][0] = ngram_dict[ngram[0] + ' ' + ngram[1]]  #prefix count
            nn_input[0][1] = ngram_dict[line[1][:-1]]               #trigram count
            nn_input[0][2] = ngram_dict[ngram[1] + ' ' + ngram[2]]  #backoff bigram count
            nn_input[0][3] = ngram_dict[ngram[2]]                   #unigram count
            nn_input = 1/nn_input #normalise

            MLE = nn_input[0][0]/nn_input[0][1]  #get Threshold
            smoothed_prob  = net(nn_input)       #get NN value

            if smoothed_prob > (MLE): #evaluate threshold
                TH_exceeded += 1
                new_file.write(x) #write GT value
                
                #write MLE value
                smoothed_prob = MLE
                logbase = math.log(smoothed_prob, 10)
                new_file.write('{:.7f}\t{}\n'.format(logbase, line[1][:-1]))
                
            else:
                logbase = math.log(smoothed_prob, 10)
                new_file.write('{:.7f}\t{}\n'.format(logbase, line[1][:-1]))
        else:
            new_file.write(x)
            
    if x == '\\1-grams:\n':
        current_ngram_len = 1
    if x == '\\2-grams:\n':
        current_ngram_len = 2
    if x == '\\3-grams:\n':
        current_ngram_len = 3
        
new_file.close()
print('MLE estimates exceeded: {:.2f}%'.format((TH_exceeded/count)*100))

# inject GT smoothed bigrams into NN LM

In [None]:
file1, file2 = '../../rsc/nnLM.arpa', '../../rsc/smoothedLM.arpa'
first_read, second_read = open(file1 ,'r'), open(file2 , 'r')
new_file = open('../../rsc/output_LM.arpa',"w+")
num_lines = sum(1 for line in open(file1,'r'))

for x in tqdm(range(0,num_lines), position=0, leave=True):
    line1, line2 = first_read.readline().split('\t') , second_read.readline().split('\t')
    
    if line1[0][0] != '\\' and line1[0] !='\n' and line1[0][0] != 'n':        
        ngram = line1[1].split(' ')
        r = len(ngram)
        
        if r == 2: #read until bigrams are reached in ARPA file
            for y in line2: #write smoothed bigram value
                new_file.write(y)
                if y[-1:] != '\n':
                    new_file.write('\t')
        else:
            for y in line1: #write values for unigrams and trigrams
                new_file.write(y)
                if y[-1:] != '\n':
                    new_file.write('\t')
    else:
        for y in line1:
            new_file.write(y)
        
new_file.close()

# debug

In [9]:
#start = time.time()
#file = '../../temp/test/12-14_bigrams_discounted.arpa'
file = '../../rsc/unsmoothedLM.arpa'
first_read = open(file ,'r')

num_lines = sum(1 for line in open(file,'r'))
prob = 0
tot_prob = 0
prob2 = 0
tot_prob2 = 0
prob3 = 0
tot_prob3 = 0


for x in tqdm(first_read, total=num_lines, position=0, leave=True):

    line = x.split('\t')
    if len(line) >1:
        if line[1] == '<s> AND BENNI\n':
            print(x)

 52%|█████▏    | 10125136/19630106 [00:05<00:04, 2002186.90it/s]

-4.727726	<s> AND BENNI



100%|██████████| 19630106/19630106 [00:10<00:00, 1962939.51it/s]


In [5]:
hos, cheerss = 1 , 2
