In [1]:
%matplotlib inline

import matplotlib.pyplot as plt
import time
import numpy as np
import math
'''
import sys
sys.path.append('Classes')
from arpy import *
'''
import arpa
from tqdm import tqdm

import torch
import torch.nn as nn



# prep cuda

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_name(0))
print(torch.cuda.is_available())


GeForce GTX 1060 6GB
True


# Prep dict

In [6]:
file = '../../rsc/train_counts.txt'
first_read = open(file ,'r')

num_lines = sum(1 for line in open(file,'r'))

ngram_dict = {}
for x in tqdm(first_read, total=num_lines, position=0, leave=True):
    line = x.split('\t')
    r = int(line[-1])
    ngram_dict[line[0]] = r


100%|██████████| 45918515/45918515 [00:47<00:00, 958725.32it/s] 



# Prep NN

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 4)
        #self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(4, 3)
        self.fc4 = nn.Linear(3, 1)

        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        #x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x)) 
        x = torch.sigmoid(self.fc4(x)) 
        return x
'''
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(3, 3)
        #self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(3, 3)
        self.fc4 = nn.Linear(3, 1)

        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        #x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x)) 
        x = torch.sigmoid(self.fc4(x)) 
        return x
'''
net = Net().cuda()
net.load_state_dict(torch.load('NN saves/4431_lr=0.03_MAE'))
net.eval()

Net(
  (fc1): Linear(in_features=4, out_features=4, bias=True)
  (fc3): Linear(in_features=4, out_features=3, bias=True)
  (fc4): Linear(in_features=3, out_features=1, bias=True)
)

# Load and Write ARPA
## NN with 4 inputs

In [8]:
#start = time.time()
file = '../../rsc/unsmoothedLM.arpa'
#file = '../../temp/small.arpa'
first_read = open(file ,'r')
new_file = open("../../rsc/nnLM.arpa","w+")
#new_file = open('../../temp/nn_small.arpa' , 'w+')

num_lines = sum(1 for line in open(file,'r'))
current_ngram_len = 0
error = 0
count = 0
nn_input = torch.zeros(1, 4, dtype = torch.float, device = device)


for x in tqdm(first_read, total=num_lines, position=0, leave=True):
    if x == '\\end\\\n':
        current_ngram_len = -1
        new_file.write(x)
    elif x == '\n':
        new_file.write(x)
    elif current_ngram_len < 3:
        new_file.write(x)
    elif current_ngram_len == 3:
        #evaluate count
        line = x.split('\t')
        r = ngram_dict[line[1][:-1]]
        if r ==1 :
            print('oops')

        if r > 1 and r < 8: #only smooth values for r < 8
            prob = 10**float(line[0])
            ngram = line[1].split(' ')
            ngram[2] = ngram[2][:-1]
            count += 1
            
            ######setup nn input#######
            nn_input[0][0] = ngram_dict[ngram[0] + ' ' + ngram[1]]  #prefix count
            nn_input[0][1] = ngram_dict[line[1][:-1]]               #trigram count
            nn_input[0][2] = ngram_dict[ngram[1] + ' ' + ngram[2]]  #backoff bigram count
            nn_input[0][3] = ngram_dict[ngram[2]]                   #unigram count
            nn_input = 1/nn_input

            MLE = nn_input[0][0]/nn_input[0][1]
            smoothed_prob  = net(nn_input)
            '''
            print(MLE)
            print(prob)
            print(smoothed_prob)
            print(' ')
            '''
            if smoothed_prob > MLE:
                #print('error ' + str(smoothed_prob[0]) + ' vs ' + str(prob) )
                #print(r)
                smoothed_prob = MLE
                error += 1

            logbase = math.log(smoothed_prob, 10)

            #write
            new_file.write('{:.7f}\t{}\n'.format(logbase, line[1][:-1]))
        else:
            new_file.write(x)
            
    if x == '\\1-grams:\n':
        current_ngram_len = 1
    if x == '\\2-grams:\n':
        current_ngram_len = 2
    if x == '\\3-grams:\n':
        current_ngram_len = 3
        
new_file.close()
print('{:.2f}%'.format((error/count)*100))

100%|██████████| 19630106/19630106 [1:11:26<00:00, 4579.13it/s]

20.56%





## NN with 3 inputs

In [None]:
#start = time.time()
file = '../../rsc/old/nn_12-14.arpa'
#file = '../../temp/small.arpa'
first_read = open(file ,'r')
new_file = open("../../rsc/nn_LM_r!=1.arpa","w+")
#new_file = open('../../temp/nn_small.arpa' , 'w+')

num_lines = sum(1 for line in open(file,'r'))
current_ngram_len = 0
error = 0
count = 0
fuck = 0
nn_input = torch.zeros(1, 3, dtype = torch.float, device = device)


for x in tqdm(first_read, total=num_lines, position=0, leave=True):
    if x == '\\end\\\n':
        current_ngram_len = -1
        new_file.write(x)
    elif x == '\n':
        new_file.write(x)
    elif current_ngram_len < 2: #responsible for writing \data\ as well as \1-grams:
        new_file.write(x)
    elif x == '\\3-grams:\n':
        new_file.write(x)
    elif current_ngram_len == 2:#bigrams
        #evaluate count
        line = x.split('\t')
        
        if len(line) == 2:
            r = ngram_dict[line[1][:-1]]
        elif len(line) == 3:
            r = ngram_dict[line[1]]
        
            if r < 8 :
                prob = 10**float(line[0])
                logbase = math.log(prob*0.95, 10)

                new_file.write('{:.7f}\t{}\t{:.8f}\n'.format(logbase, line[1][:-1], -99))
            else:
                new_file.write('{:.7f}\t{}\t{:.8f}\n'.format(float(line[0]), line[1][:-1], -99))

    elif current_ngram_len == 3:#trigrams
        #evaluate count
        line = x.split('\t')
        r = ngram_dict[line[1][:-1]]

        if r < 8 and r > 1: #only smooth values for r < 8, r = 1 does not exist for optim reasons
            prob = 10**float(line[0])
            ngram = line[1].split(' ')
            ngram[2] = ngram[2][:-1]
            count += 1
            
            ######setup nn input#######
            MLE = ngram_dict[ngram[0] + ' ' + ngram[1]]/ngram_dict[line[1][:-1]]
            MLE = 1/MLE
            nn_input[0][0] = MLE                                    #prob
            nn_input[0][1] = ngram_dict[ngram[1] + ' ' + ngram[2]]  #backoff bigram count
            nn_input[0][2] = ngram_dict[ngram[2]]                   #unigram count
            nn_input = 1/nn_input #corrects MLE
            MLE = 1/MLE

            #p = pnn - bias value
            smoothed_prob  = net(nn_input) -0.00195
            
            if smoothed_prob > prob:
                #print('error ' + str(smoothed_prob[0]) + ' vs ' + str(prob) )
                #print(r)
                error += 1
                
            #check for valid log(since bias can send it neg)
            if smoothed_prob > MLE:
                fuck +=1
            else:

                logbase = math.log(smoothed_prob, 10)

            #write
            new_file.write('{:.7f}\t{}\n'.format(logbase, line[1][:-1]))
            
    if x == '\\1-grams:\n':
        current_ngram_len = 1
        new_file.write(x)
    if x == '\\2-grams:\n':
        current_ngram_len = 2
        new_file.write(x)
    if x == '\\3-grams:\n':
        current_ngram_len = 3
        
new_file.close()
print(error/count)
print('fucks: ' + str(fuck))

# debug

In [9]:
#start = time.time()
#file = '../../temp/test/12-14_bigrams_discounted.arpa'
file = '../../rsc/unsmoothedLM.arpa'
first_read = open(file ,'r')

num_lines = sum(1 for line in open(file,'r'))
prob = 0
tot_prob = 0
prob2 = 0
tot_prob2 = 0
prob3 = 0
tot_prob3 = 0


for x in tqdm(first_read, total=num_lines, position=0, leave=True):

    line = x.split('\t')
    if len(line) >1:
        if line[1] == '<s> AND BENNI\n':
            print(x)

 52%|█████▏    | 10125136/19630106 [00:05<00:04, 2002186.90it/s]

-4.727726	<s> AND BENNI



100%|██████████| 19630106/19630106 [00:10<00:00, 1962939.51it/s]
