In [16]:
import torch.nn as nn
import torch.distributions as dist
import torch.utils.data
from torch.utils.data import sampler
from torch.distributions import kl
class embed_align(nn.Module):
    def __init__(self, vocab_size1, vocab_size2, emb_dimension):
        super(embed_align, self).__init__()
        self.emb_dimension = emb_dimension

        self.embedding = nn.Embedding(vocab_size1, emb_dimension, padding_idx = 0)
        self.BiLSTM = nn.LSTM(emb_dimension, emb_dimension, bidirectional=True, batch_first=True)
        
        self.affine1_mu = nn.Linear(emb_dimension, emb_dimension)
        self.affine2_mu = nn.Linear(emb_dimension, emb_dimension)
        
        self.affine1_sig = nn.Linear(emb_dimension, emb_dimension)
        self.affine2_sig = nn.Linear(emb_dimension, emb_dimension)
        
        self.affine1_L1 = nn.Linear(emb_dimension, emb_dimension)
        self.affine2_L1 = nn.Linear(emb_dimension, vocab_size1)
        self.affine1_L2 = nn.Linear(emb_dimension, emb_dimension)
        self.affine2_L2 = nn.Linear(emb_dimension, vocab_size2)
        
        self.relu = nn.ReLU()
        self.softplus = nn.Softplus()
        self.log_softmax = nn.Softmax(dim=0)
        
    def forward(self, sentence1, sentence2, use_cuda=False):
        # sentence1 & sentence2 are (batches of) list of all ints in a sentence
        # encoder
        sen1_emb = self.embedding(sentence1)
        if len(sen1_emb.shape) == 2: # not a batch
            sen1_emb = sen1_emb.unsqueeze(0)
        h, _ = self.BiLSTM(sen1_emb)
        h1, h2 = torch.split(h, split_size_or_sections=self.emb_dimension, dim =2)
        h = h1 + h2
        mu = self.affine2_mu(self.relu(self.affine1_mu(h)))
        sig = self.relu(self.affine2_sig(self.relu(self.affine1_sig(h))))
        
        sample_norm = dist.multivariate_normal.MultivariateNormal(torch.zeros(self.emb_dimension), torch.eye(self.emb_dimension))
        e = sample_norm.sample()
        if use_cuda:
            z = mu + e.cuda() * sig
        else:
            z = mu + e * sig
    
        # likelihood language 1
        dist_1 = self.log_softmax(self.affine2_L1(self.relu(self.affine1_L1(z))))
        # sum over batch
        sum_1 = torch.sum(dist_1, dim=0)
        likelihood_1 = torch.mean(sum_1, dim=1)
        total_likelihood1 = 0
        sen_len = 0
        for i, likelihood in enumerate(likelihood_1):
            # no batches:
            if len(sentence1) == longest1:
                if sentence1[i].item() == 0:
                    continue
                total_likelihood1 += likelihood
                sen_len +=1
            else:
                for j in range(len(sentence1)):
                    if sentence1[j][i].item() == 0:
                        continue
                    total_likelihood1 += likelihood
                sen_len +=1
        likelihood1 = total_likelihood1/sen_len
        
        # likelihood language 2
        dist_2 = self.log_softmax(self.affine2_L2(self.relu(self.affine1_L2(z))))
        sum_2 = torch.sum(dist_2, dim=0)
        likelihood_2 = torch.mean(sum_2, dim=1)
        total_likelihood2 = 0
        sen_len = 0
        for i, likelihood in enumerate(likelihood_2):
            # no batches:
            if len(sentence1) == longest1:
                if sentence1[i].item() == 0:
                    continue
                total_likelihood2 += likelihood
                sen_len +=1
            else:
                for j in range(len(sentence1)):
                    if sentence1[j][i].item() == 0:
                        continue
                    total_likelihood2 += likelihood
                sen_len +=1
        likelihood2 = total_likelihood2/sen_len
        
        # KL
        # to prevent log returning infinity
        sig = sig+1e-8
        KL =  -0.5 * torch.sum(1 + torch.log(sig) - mu.pow(2) - sig)
        return - ((likelihood1 + likelihood2) - KL)

In [17]:
# read in all needed data

def get_test_data(filename):
    test_data = []
    with open(filename) as f:
        for line in f:
            line = line.split('\t')
            test_word = line[0]
            sentence_id = line[1]
            test_data.append([test_word, sentence_id])
    return(test_data)

def get_candidates(filename):
    with open(filename) as f:
        candidate_dict = {}
        for line in f:
            line = line.split('::')
            word = line[0]
            candidates = line[1].split(';')
            candidate_dict[word] = candidates
        return candidate_dict

In [18]:
import dill
import torch

test_data = get_test_data('lst/lst_test.preprocessed')

candidates = get_candidates('lst/lst.gold.candidates')

with open('w2i_en_embedalign.pkl', 'rb') as f:
    w2i_en = dill.load(f)

with open('i2w_en_embedalign.pkl', 'rb') as f:
    i2w_en = dill.load(f)
    
with open('w2i_fr_embedalign.pkl', 'rb') as f:
    w2i_fr = dill.load(f)

with open('i2w_fr_embedalign.pkl', 'rb') as f:
    i2w_fr = dill.load(f)


In [19]:
import string

id_to_sentence = {}

translator = str.maketrans('', '', string.punctuation)

with open('lst/lst_test.preprocessed') as f:
    for line in f:
        line = line.split('\t')
        test_word = line[0]
        sentence_id = line[1]
        sentence = line[3].translate(translator).lower().split()
        #print(sentence)
        id_to_sentence[int(sentence_id)] = sentence

In [20]:
id_to_sentence[1105]

['although',
 'the',
 'mah',
 'look',
 'family',
 'became',
 'a',
 'strong',
 'part',
 'of',
 'community',
 'life',
 'in',
 'the',
 'king',
 'valley',
 'they',
 'had',
 'to',
 'suffer',
 'the',
 'occasional',
 'racist',
 'comments',
 'which',
 'must',
 'have',
 'made',
 'life',
 'unpleasant',
 'at',
 'times']

In [21]:
use_cuda = torch.cuda.is_available()

if use_cuda: # load model with gpu (as it was trained)
    embedalign_model = torch.load('embedalign.pt')
else: # convert to cpu:
    embedalign_model = torch.load('embedalign.pt',  map_location='cpu')

In [22]:
print(len(test_data))

tot_cans = 0

for cans in candidates.values():
    tot_cans += len(cans)
    
print(tot_cans)

1710
4207


In [23]:
orig_test_data = test_data[:]
for [word, sentence_id] in test_data[:]:
    word_nopos = word[:-2]
    if word_nopos not in w2i_en:
        test_data.remove([word, sentence_id])
    for candidate in candidates[word][:]:
        if candidate not in w2i_en:
            candidates[word].remove(candidate)

In [24]:
print(len(test_data))

tot_cans = 0

for cans in candidates.values():
    tot_cans += len(cans)
    
print(tot_cans)

540
1142


In [25]:
word_is = [w2i_en[word[:-2]] for word, _ in test_data]
if use_cuda:
    word_is = torch.cuda.LongTensor(word_is)
else:
    word_is = torch.LongTensor(word_is)


can_is = {}
for [word, _] in test_data[:]:
    can_i = [w2i_en[can] for can in candidates[word]]
    word_i = w2i_en[word[:-2]]
    if use_cuda:
        can_is[word_i] = torch.cuda.LongTensor(can_i)
    else:
        can_is[word_i] = torch.LongTensor(can_i)

In [30]:
sentence_is = {}
for [word, sentence_id] in test_data[:]:
    
    sen_i = [w2i_en[sen] for sen in id_to_sentence[int(sentence_id)]]
    #print(sentence_id,id_to_sentence[int(sentence_id)], sen_i)
    if use_cuda:
        sentence_is[int(sentence_id)] = torch.cuda.LongTensor(sen_i)
    else:
        sentence_is[int(sentence_id)] = torch.LongTensor(sen_i)

In [31]:
def get_ranking(word, candidates, sentence, word_pos):
    mu, sigma = embedalign_step(embedalign_model, sentence)

    can_sims=[]
    for can in candidates:
        candidate_sentence = sentence
        candidate_sentence[word_pos] = can
        
        mu_p, sigma_p = embedalign_step(embedalign_model, candidate_sentence)
        sigma_p = sigma_p + 1e-8
        sigma = sigma + 1e-8
        score =  (torch.log(sigma_p) - torch.log(sigma) + (sigma**2 + (mu - mu_p)**2) / (2*sigma_p**2) - 0.5)
        score = score.sum().data.item()
        can_sims.append([can, score])
            
    return(can_sims)

In [32]:
def embedalign_step(model, sentence):

    # Get embeddings for words in sentence
    sen_emb = model.embedding(sentence)
    # - Encoder -
    if len(sen_emb.shape) == 2: # not a batch
        sen_emb = sen_emb.unsqueeze(0)
    h, _ = model.BiLSTM(sen_emb)
        
    h1, h2 = torch.split(h, split_size_or_sections=model.emb_dimension, dim =2)
    h = h1 + h2
    
    
    mu = model.affine2_mu(model.relu(model.affine1_mu(h)))
    sigma = model.relu(model.affine2_sig(model.relu(model.affine1_sig(h))))

    return mu, sigma

In [33]:
results = {}
for i, [word, sentence_id] in enumerate(test_data):
    word_i = word_is[i]
    cans = can_is[word_i.item()]
    
    sentence_ids = sentence_is[int(sentence_id)]
    
    word_pos = -1
    for e, el in enumerate(sentence_ids):
        if el == word_i:
            word_pos = e
        
    
    if (len(cans)>0):
        rank = get_ranking(word_i, cans, sentence_ids, word_pos)
        words_scores = []
        for [candidate, score] in rank:
            can_word = i2w_en[int(candidate)]
            score = score
            words_scores.append([can_word, score])
        results[word] = words_scores

  


In [34]:
with open('embedalign_predictions', 'w') as f:
    for [word, sentence_id] in orig_test_data: # lst_gap needs all words, also once we do not have data for
        f.write('#RANKED\t')
        f.write(word + ' ')
        f.write(sentence_id)
        if word in results:
            for [candidate, score] in results[word]:
                f.write('\t' + candidate + ' ' + str(score))
        f.write('\n')

In [35]:
%run lst/lst_gap.py lst/lst_test.gold embedalign_predictions embed_align_out no-mwe


MEAN_GAP	0.033760420625893135



In [36]:
%run lst/lst_gap.py lst/lst_test.gold bayesian_skipgram_predictions bayesian_skipgram_out no-mwe


MEAN_GAP	0.06320463650142717



# AER

In [54]:
def get_alignments(model, sentence1, sentence2):
    mu, sigma = embedalign_step(model, sentence1)

    softmax = torch.nn.LogSoftmax(dim=-1)
    sample_norm = dist.multivariate_normal.MultivariateNormal(torch.zeros(model.emb_dimension), torch.eye(model.emb_dimension))
    e = sample_norm.sample()
    if use_cuda:
        z = mu + e.cuda() * sigma
    else:
        z = mu + e * sigma
    g = model.log_softmax(model.affine2_L2(model.relu(model.affine1_L2(z))))
    g = g.squeeze(0)
    x = model.affine2_L2(model.relu(model.affine1_L2(z)))

    translations = []
    for i, word1 in enumerate(sentence1):
        highest_prob = -100000
        highest_prob_word = -1
        for j, word2 in enumerate(sentence2):
            prob = g[i][word2]
            if prob> highest_prob:
                highest_prob = prob
                highest_prob_word = j + 1 # because test.naacl starts counting at 1
        translations.append([i+ 1, highest_prob_word])
        
    return(translations)

In [55]:
def read_data(filename):
    data = []
    with open(filename) as f:
        for s in f:
            data.append(s.split())
    return data

In [56]:
data_en = read_data('wa/test.en')
data_fr = read_data('wa/test.fr')

In [61]:
translations = []
for i, sen_en in enumerate(data_en):
    if use_cuda:
        sen_en_i = torch.cuda.LongTensor([w2i_en[word] for word in sen_en])
        sen_fr_i = torch.cuda.LongTensor([w2i_fr[word] for word in data_fr[i]])
    else:
        sen_en_i = torch.LongTensor([w2i_en[word] for word in sen_en])
        sen_fr_i = torch.LongTensor([w2i_fr[word] for word in data_fr[i]])

    translation = get_alignments(embedalign_model, sen_en_i, sen_fr_i)
    translations.append(translation)

  


In [74]:
with open('aer_predictions.naacl', 'w') as f:
    for i in range(len(data_en)):
        for pair in translations[i]:
            f.write(str(i) + ' ' + str(pair[0]) + ' ' + str(pair[1]) + ' S \n')


In [76]:
%run wa/aer.py aer_predictions.naacl wa/test.naacl

0.9206803499877341
