In [1]:
import string
import numpy as np
import random
import torch

# read in data, remove punctuation and make all lower case
def read_corpus(filename):
    corpus = []
    translator = str.maketrans('', '', string.punctuation)
    with open(filename) as f:
        for s in f:
            clean_s = s.translate(translator).lower()
            clean_words = clean_s.split()
            corpus.append(clean_words)
    return corpus


#Make data dictionaries and indexing 
def make_dictionaries(corpus):
    # create one set of all unique words
    flat_corpus = [w for s in corpus for w in s]
    corpus_set = set(flat_corpus)
    w_to_i = {}
    i_to_w = {}
    w_freq = []
    num_words = len(corpus_set)
    for i, w in enumerate(corpus_set):
        w_to_i[w] = i
        i_to_w[i] = w
        freq = flat_corpus.count(w)   #**0.75
        w_freq.append([i, freq])
    return w_to_i, i_to_w, np.array(w_freq), num_words



#SOME DECISIONS TO MAKE HERE: 
### 1st: Word window
### 2nd: What to do when sentence is smaller than window size? Disregard? Fill with 0s? 
### If the above, fix vocab index for -UNK-
### For now only keeping windows if sentences are large enough

def make_word_windows(corpus, window_size):
    windows = []
    
    for sentence in corpus:
        for i, word in enumerate(sentence):
            curr_window = np.zeros(shape=[5])
            
            #Check if index allows large enough window sampling
            if (i - WINDOW_SIZE > 0) and (i+ WINDOW_SIZE < len(sentence)):
                word_index = w_to_i[word]
                curr_window[0] = word_index
                n=1
                for j in range(i-WINDOW_SIZE, i+WINDOW_SIZE):
                    context_index=w_to_i[sentence[j]]
                    curr_window[n:] = context_index
                    n+=1
                windows.append(torch.LongTensor(curr_window))
    return windows

In [2]:
WINDOW_SIZE = 2
EMBEDDING_DIM = 128
BATCH_SIZE = 64
NUM_EPOCHS = 10

In [3]:
corpus = read_corpus('wa/test.en')
w_to_i, i_to_w, w_freq, num_words = make_dictionaries(corpus)
data = make_word_windows(corpus, WINDOW_SIZE)

In [19]:
import torch.nn as nn
from torch.distributions import MultivariateNormal

class bayesian_skipgram(nn.Module):
    def __init__(self, num_words, emb_dim):
        super(bayesian_skipgram, self).__init__()

        self.num_words = num_words
        self.emb_dim = emb_dim

        
        self.R = nn.Embedding(num_words, emb_dim)
        self.mu_prior = nn.Embedding(num_words, emb_dim)
        self.sigma_prior  = nn.Embedding(num_words, emb_dim)
        
        self.M = nn.Linear(2*emb_dim, 2*emb_dim)
        self.affine_lambda_mu = nn.Linear(2*emb_dim, emb_dim)
        self.affine_lambda_sigma = nn.Linear(2*emb_dim, emb_dim)
        self.affine_theta = nn.Linear(emb_dim, num_words)


    def forward(self, word_idx, context_idx):
        
        R_w = self.R(word_idx) 
        R_w = R_w.unsqueeze(1)  
        R_w = R_w.repeat(1, len(context_idx[1]), 1) 
        R_cj = self.R(context_idx)  
        
        RcRw = torch.cat((R_w, R_cj), dim=2)

        h = nn.ReLU()(self.M(RcRw)) 
        h = torch.sum(h, dim=1)  


        mu = self.affine_lambda_mu(h)
        sigma = nn.Softplus()(self.affine_lambda_sigma(h))

        # reparametrization trick
        eps = MultivariateNormal(torch.zeros(self.emb_dim), torch.eye(self.emb_dim)).sample()
        z = mu + torch.exp(sigma / 2.) * eps        
       
        #########################################
    
        affine_z = self.affine_theta(z)
        
        
        f_i = nn.Softmax(dim=1)(affine_z)  
        
        mu_prior = self.mu_prior(word_idx)
        sigma_prior = nn.Softplus()(self.sigma_prior(word_idx))
            

        ###################################################
        
        likelihood_term = []
        KL_div_term = []
        loss_term = []
        total_loss = 0.0
        
        for i, contexts in enumerate(context_idx):
            
            #likelihood term
            likelihood = 0
            for idx in contexts:
                likelihood += torch.log(f_i[i, idx])
            likelihood_term.append(likelihood)
            
            #KL divergence term             
            cov = torch.diagflat(sigma[i])
            cov_prior = torch.diagflat(sigma_prior[i])
            
            KL = 0.5 * ( torch.log(torch.det(cov_prior) / torch.det(cov)) - self.emb_dim 
                                   + torch.trace(torch.matmul(torch.inverse(cov_prior), cov)) 
                       # + torch.matmul(torch.matmul((mu_prior[i] - mu[i]), torch.inverse(sigma_prior[i])), (mu_prior[i] - mu[i])))
                        +  torch.matmul(torch.matmul((mu_prior[i] - mu[i]), torch.inverse(cov_prior)), (mu_prior[i] - mu[i])))
            
            KL_div_term.append(KL)
            
            
            print("Losses: ", likelihood, KL, likelihood-KL)
            
            #total loss term
            loss_term.append(likelihood - KL)
            total_loss += (likelihood - KL)
            
        return total_loss

In [20]:
# TRAIN NETWORK
%matplotlib inline
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

dataloader = DataLoader(data, batch_size=64, shuffle=True)
use_cuda = torch.cuda.is_available()

BSG_model = bayesian_skipgram(num_words, EMBEDDING_DIM)
optimizer = torch.optim.Adam(BSG_model.parameters(), lr=0.01)


if use_cuda:
    BSG_model.cuda()

BSG_model.train()
loss_progress = []
iter_time = time.time()

for epoch in range(NUM_EPOCHS):
    for i, data_batch in enumerate(dataloader, 0):
        main_word = data_batch[:,0]
        context_word = data_batch[:,1:]
        
        
        optimizer.zero_grad()
        
        if use_cuda:
            loss = BSG_model.forward(main_word.cuda(), context_word.cuda())
        else:
            loss = BSG_model.forward(main_word, context_word)
        
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            loss_progress.append(loss.item())  
        if i % 25000 == 0:
            print(loss_progress[-1])
            print(time.time()-iter_time)
            iter_time = time.time()


print("total time taken:", time.time()-start_time)
plt.plot(loss_progress)
plt.show()

z torch.Size([64, 128])
Losses:  tensor(-29.1962) tensor(180.5776) tensor(-209.7738)
Losses:  tensor(-28.9059) tensor(230.4201) tensor(-259.3260)
Losses:  tensor(-30.4217) tensor(229.9865) tensor(-260.4082)
Losses:  tensor(-31.7219) tensor(210.1728) tensor(-241.8947)
Losses:  tensor(-29.6738) tensor(321.0730) tensor(-350.7467)
Losses:  tensor(-32.7720) tensor(228.1464) tensor(-260.9185)
Losses:  tensor(-32.5044) tensor(304.7249) tensor(-337.2293)
Losses:  tensor(-29.8450) tensor(329.7832) tensor(-359.6282)
Losses:  tensor(-27.9475) tensor(230.4641) tensor(-258.4116)
Losses:  tensor(-29.3218) tensor(354.1818) tensor(-383.5035)
Losses:  tensor(-31.3301) tensor(175.7688) tensor(-207.0989)
Losses:  tensor(-32.2393) tensor(333.5015) tensor(-365.7408)
Losses:  tensor(-31.5288) tensor(177.7138) tensor(-209.2426)
Losses:  tensor(-31.4446) tensor(204.5961) tensor(-236.0407)
Losses:  tensor(-31.1939) tensor(233.0972) tensor(-264.2910)
Losses:  tensor(-31.9311) tensor(173.2609) tensor(-205.1921)


RuntimeError: Lapack Error gesvd : 127 superdiagonals failed to converge. at c:\programdata\miniconda3\conda-bld\pytorch-cpu_1524541161962\work\aten\src\th\generic/THTensorLapack.c:470