In [35]:
import numpy as np
from scipy.sparse import dok_matrix
from Data import Dicts, Data
import itertools
import time

In [36]:
dicts = Dicts('dicts.pkl')
data = Data(('training/hansards.36.2.e',
             'training/hansards.36.2.f',
             'validation/dev.e',
             'validation/dev.f'), dicts)

In [37]:
e_len = len(dicts.e_idx2word) + 1
f_len = len(dicts.f_idx2word)
N = 100000

In [38]:
theta = {we: {} for we in range(e_len)}
for i, (se, sf) in enumerate(zip(data.file['training/hansards.36.2.e'][:N], data.file['training/hansards.36.2.f'][:N])):
    
    # NULL word
    se.append(len(dicts.e_idx2word))
    
    for i, j in list(itertools.product(se, sf)):
        theta[i][j] = 1

In [41]:
def epoch(epoch):
    count = {we: {wf: 0 for wf in theta[we].keys()} for we in range(e_len)}
    total = {we: 0  for we in range(e_len)}
    for i, (se, sf) in enumerate(zip(data.file['training/hansards.36.2.e'][:N], data.file['training/hansards.36.2.f'][:N])):
        
        # NULL word
        se.append(len(dicts.e_idx2word))
        
        # compute normalization
        s_total = {}
        for wf in sf:
            s_total[wf] = 0
            
            for we in se:
                s_total[wf] += theta[we][wf]
    
        # collect counts
        for we in se:
            for wf in sf:
                v = theta[we][wf] / s_total[wf]
                count[we][wf] += v
                total[we] += v
    
        print('\repoch: {}\t{:.0%}'.format(epoch, (i + 1)/ N), end='')
    
    # estimate probabilities
    for we in range(e_len):
        for wf in theta[we].keys():
            theta[we][wf] = count[we][wf] / total[we]

In [None]:
epochs = 10

for i in range(epochs):
    s = time.time()
    epoch(i)
    print('\ttook: {:.2f} seconds'.format(time.time() - s))

epoch: 0	100%	took: 77.11 seconds
epoch: 1	99%

In [31]:
d = sorted(theta[dicts.e_word2idx['near']].items(),key=lambda x: x[1], reverse=True)

In [32]:
sum([e[1] for e in d])

0.9999999999999993

In [33]:
for i in range(10):
    print(dicts.f_idx2word[d[i][0]], d[i][1])

de 0.8208696573231715
. 0.08095938911361239
le 0.07164508003314833
un 0.013276723204235787
, 0.008998306017725562
pour 0.002638284208114501
bientôt 0.0007365693941441231
se 0.0005469806169386398
? 0.00012684295769575178
annonce 5.4295854224516265e-05


In [47]:
theta[dicts.e_word2idx['the']][dicts.f_word2idx['CONSTITUTION']]

0.46182291718585683