In [None]:
import torch.nn as nn
from torch import distributions

class bayesian_skipgram(nn.Module):
    def __init__(self, num_words, emb_dim):
        super(bayesian_skipgram, self).__init__()

        self.num_words = num_words
        self.emb_dim = emb_dim

        self.R = nn.Embedding(num_words, emb_dim)
        self.mu_prior = nn.Embedding(num_words, emb_dim)
        self.sigma_prior  = nn.Embedding(num_words, emb_dim)
        
        self.M = nn.Linear(2*emb_dim, 2*emb_dim)
        self.affine_lambda_mu = nn.Linear(2*emb_dim, emb_dim)
        self.affine_lambda_sigma = nn.Linear(2*emb_dim, emb_dim)
        self.affine_theta = nn.Linear(emb_dim, num_words)
        

    def forward(self, word_idx, context_idx):
        
        batch_size = word_idx.shape[0]
        n_context = len(context_idx[0])
        
        # ********** Encoder ************
        
        R_w = self.R(word_idx) 
        R_w = R_w.view(batch_size, 1, self.emb_dim) 
        R_w = R_w.repeat(1, n_context, 1) 
        
        R_cj = self.R(context_idx)
        
        
        RcRw = torch.cat((R_w, R_cj), dim=2)
    
        h = nn.ReLU()(self.M(RcRw)) 
        h = torch.sum(h, dim=1)   

        mu = self.affine_lambda_mu(h)
        sigma = nn.functional.softplus(self.affine_lambda_sigma(h))

        # reparametrization trick
        eps = distributions.MultivariateNormal(torch.zeros(self.emb_dim), torch.eye(self.emb_dim)).sample()
        z = mu + sigma * eps    
               
            
        # ********* Decoder ***********
    
        affine_categ = self.affine_theta(z)
        f_i = nn.functional.softmax(affine_categ, dim=1)    
                            
        mu_prior = self.mu_prior(word_idx)
        sigma_prior = nn.functional.softplus(self.sigma_prior(word_idx))
            
        # ********** Loss ************
        
        likelihood_terms = torch.zeros(batch_size)
        KL_div_terms = torch.zeros(batch_size)
        
        for i, contexts in enumerate(context_idx):  
            likelihood = 0
            for idx in contexts:
                likelihood += torch.log(f_i[i, idx] +1e-8)
            likelihood_terms[i] = likelihood
            
            KL =  self.KL_div(mu_prior[i], sigma_prior[i],  mu[i],  sigma[i] )
            KL_div_terms[i] = KL
            
          
        total_loss = torch.mean(KL_div_terms) - torch.mean(likelihood_terms)
             
        return total_loss
    
    def KL_div(self,  mu_p, sigma_p, mu, sigma):
        div = torch.log(sigma_p + 1e-8) - torch.log(sigma+1e-8) + (sigma**2 + (mu - mu_p)**2) / (2*sigma_p**2) - 0.5
        return div.sum()

In [2]:
# read in all needed data

def get_test_data(filename):
    test_data = []
    with open(filename) as f:
        for line in f:
            line = line.split('\t')
            test_word = line[0]
            sentence_id = line[1]
            test_data.append([test_word, sentence_id])
    return(test_data)

def get_candidates(filename):
    with open(filename) as f:
        candidate_dict = {}
        for line in f:
            line = line.split('::')
            word = line[0]
            candidates = line[1].split(';')
            candidate_dict[word] = candidates
        return candidate_dict

In [3]:
#candidates = get_candidates('lst/lst.gold.candidates')

In [4]:
import pickle
import torch

test_data = get_test_data('lst/lst_test.preprocessed')

candidates = get_candidates('lst/lst.gold.candidates')

with open('w2i_bayesian.pkl', 'rb') as f:
    w2i = pickle.load(f)

with open('i2w_bayesian.pkl', 'rb') as f:
    i2w = pickle.load(f)

In [5]:
use_cuda = torch.cuda.is_available()

if use_cuda: # load model with gpu (as it was trained)
    bayesian_skipgram_model = torch.load('bayesian_skipgram.pt')
else: # convert to cpu:
    bayesian_skipgram_model = torch.load('bayesian_skipgram.pt',  map_location='cpu')

In [6]:
# preprocessing - take out any words/candidates aren't present in original corpus (w2i check)  
orig_test_data = test_data[:]
for [word, sentence_id] in test_data[:]:
    word_nopos = word[:-2]
    if word_nopos not in w2i:
        test_data.remove([word, sentence_id])
    for candidate in candidates[word][:]:
        if candidate not in w2i:
            candidates[word].remove(candidate)

In [7]:
# convert words to indexes

word_is = [w2i[word[:-2]] for word, _ in test_data]
word_is = torch.LongTensor(word_is)



can_is = {}
for [word, _] in test_data[:]:
    can_i = [w2i[can] for can in candidates[word]]
    word_i = w2i[word[:-2]]
    can_is[word_i] = torch.LongTensor(can_i)

In [36]:
def bsg_step(model, word, context):
    word_emb = model.R(word)
    context_embs = model.R(context)
    emb_dim = model.emb_dim
    n_context = context_embs.shape[0]
        
    
    word_emb = word_emb.view(1, emb_dim) 
    word_emb = word_emb.repeat(n_context, 1) 
                
    cat_embs = torch.cat((word_emb, context_embs), dim=-1)
    
    h = nn.ReLU()(model.M(cat_embs)) 
    h = torch.sum(h, dim=0)   
    
    mu = model.affine_lambda_mu(h)
    sigma = nn.functional.softplus(model.affine_lambda_sigma(h))

    return mu, sigma

In [42]:
def get_ranking(word, candidates):
    
    mu, sigma = bsg_step(bayesian_skipgram_model, word, candidates)
    #print('mu, sig', mu.shape, sigma.shape, mu, sigma)
    can_sims=[]
    for can in candidates:
        
        mu_p = bayesian_skipgram_model.mu_prior(can)
        sigma_p = nn.functional.softplus(bayesian_skipgram_model.sigma_prior(can))
        score = bayesian_skipgram_model.KL_div(mu_p, sigma_p, mu, sigma)#torch.log(sigma_x) - torch.log(sigma) + 0.5 * (sigma**2 + (mu - mu_x)**2)/sigma_x**2 - 0.5
        score = score.sum().data.item()
        can_sims.append([can, score])
            
    #can_sims = sorted(can_sims, key = lambda x: x[1])
    return(can_sims)

In [43]:
results = {}
for i, [word, sentence_id] in enumerate(test_data):
    word_i = word_is[i]
    cans = can_is[word_i.item()]
    if (len(cans)>0):
        rank = get_ranking(word_i, cans)
        words_scores = []
        for [candidate, score] in rank:
            can_word = i2w[int(candidate)]
            score = score
            words_scores.append([can_word, score])
        results[word] = words_scores

In [44]:
results

{'about.r': [['around', 810.9678344726562],
  ['of', 517.5604858398438],
  ['concerning', 296.2214050292969],
  ['arise', 416.9548034667969],
  ['approximately', 1132.5484619140625],
  ['roughly', 254.0074920654297],
  ['nearly', 377.0080871582031],
  ['consider', 482.1925048828125]],
 'account.n': [['balance', 277.5555114746094],
  ['description', 291.7359313964844],
  ['explanation', 158.94105529785156],
  ['finance', 272.7778625488281],
  ['statement', 684.1188354492188],
  ['report', 437.89532470703125]],
 'around.r': [['about', 1695.53271484375],
  ['approximately', 821.4795532226562],
  ['over', 602.875244140625],
  ['there', 687.3256225585938],
  ['roughly', 403.4389953613281],
  ['here', 403.3437805175781]],
 'away.r': [['depart', 612.0407104492188],
  ['go', 516.4744262695312],
  ['out', 1308.1768798828125],
  ['along', 1763.1768798828125],
  ['ahead', 2364.591552734375]],
 'board.n': [['executive', 271.61334228515625],
  ['committee', 678.3038940429688],
  ['management', 377.

In [45]:
with open('bayesian_skipgram_predictions', 'w') as f:
    for [word, sentence_id] in orig_test_data: # lst_gap needs all words, also once we do not have data for
        f.write('#RANKED\t')
        f.write(word + ' ')
        f.write(sentence_id)
        if word in results:
            for [candidate, score] in results[word]:
                f.write('\t' + candidate + ' ' + str(score))
        f.write('\n')

In [46]:
%run lst/lst_gap.py lst/lst_test.gold bayesian_skipgram_predictions bayesian_skipgram_out no-mwe


MEAN_GAP	0.06320463650142717

