In [131]:
import string
import numpy as np
import random
import torch
import time
start_time = time.time()

import nltk
nltk.download("stopwords")
from nltk.corpus import stopwords

stop_words = set(stopwords.words('english'))


# read in data, remove punctuation and make all lower case
def read_corpus(filename):
    corpus = []
    translator = str.maketrans('', '', string.punctuation)
    with open(filename) as f:
        for s in f:
            clean_s = s.translate(translator).lower()
            clean_words = clean_s.split()
            corpus.append(clean_words)
    return corpus

#Make data dictionaries and indexing 
def make_dictionaries(corpus):
    # create one set of all unique words
    flat_corpus = [w for s in corpus for w in s]
    
    corpus_words = set(flat_corpus)
    print("Num unique words: ", len(corpus_words))
    vocabulary = [v for v in corpus_words if v not in stop_words]
    w_to_i = {}
    i_to_w = {}
    w_freq = {}
    w_freq_all={}
    #w_to_i['<unk>'] = 0
    #i_to_w[0] = '<unk>'
    
    #V = len(vocabulary)
    for i, w in enumerate(vocabulary):
        w_to_i[w] = i + 1
        i_to_w[i + 1] = w  
        w_freq[w] = 0
        
    for i, w in enumerate(flat_corpus):
        w_freq_all[w] = 0
    
    for i,w in enumerate(flat_corpus):
        w_freq_all[w] += 1
        
    for i, w in enumerate(vocabulary):
        w_freq[w] = w_freq_all[w]
        
        
    w_to_i['<unk'] = 0
    i_to_w[0] = '<unk>'
        
    print("Num filtered vocabulary: ", len(vocabulary))
    
    return w_to_i, i_to_w, w_freq, len(vocabulary)

def make_word_windows(corpus, window_size):
    windows = []
    start = time.time()
    for sentence in corpus:
        for i, center_word in enumerate(sentence):           
            if center_word in vocabulary:
                curr_window = np.zeros(shape=[WINDOW_SIZE*2+1])
                curr_window[0] = w_to_i[center_word]  #store center word first
                
                n = 1
                #iterate over context words
                for w in range(i - WINDOW_SIZE ,i+ WINDOW_SIZE ):
                    #print(w, i)
                    if w != i:        #avoid center word

                        #check out of range
                        if w > 0 and w < len(sentence):
                            #print(w) 
                            #print(sentence[w])
                            if sentence[w] in vocabulary:
                                curr_window[n] = w_to_i[sentence[w]]
                        n+=1

                windows.append(torch.LongTensor(curr_window))
    print('Context windows took %d seconds' % int(time.time()-start))
    return windows
    

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Sindi\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [135]:
WINDOW_SIZE = 4
EMBEDDING_DIM = 128
BATCH_SIZE = 64
NUM_EPOCHS = 10

In [None]:
corpus = read_corpus('hansards/training.en')
w_to_i, i_to_w, w_freq, V = make_dictionaries(corpus)
vocabulary = w_to_i.keys()

Num unique words:  31602
Num filtered vocabulary:  31453


In [None]:

data = make_word_windows(corpus, WINDOW_SIZE)

In [None]:
import pickle

with open('w2i_bayesian.pkl', 'wb') as f:
    pickle.dump(w_to_i, f)
    
with open('i2w_bayesian.pkl', 'wb') as f:
    pickle.dump(i_to_w, f)

In [None]:
import torch.nn as nn
from torch import distributions

class bayesian_skipgram(nn.Module):
    def __init__(self, num_words, emb_dim):
        super(bayesian_skipgram, self).__init__()

        self.num_words = num_words
        self.emb_dim = emb_dim

        self.R = nn.Embedding(num_words, emb_dim)
        self.mu_prior = nn.Embedding(num_words, emb_dim)
        self.sigma_prior  = nn.Embedding(num_words, emb_dim)
        
        self.M = nn.Linear(2*emb_dim, 2*emb_dim)
        self.affine_lambda_mu = nn.Linear(2*emb_dim, emb_dim)
        self.affine_lambda_sigma = nn.Linear(2*emb_dim, emb_dim)
        self.affine_theta = nn.Linear(emb_dim, num_words)
        

    def forward(self, word_idx, context_idx):
        
        batch_size = word_idx.shape[0]
        n_context = len(context_idx[0])
        
        # ********** Encoder ************
        R_w = self.R(word_idx) 
        R_w = R_w.view(batch_size, 1, self.emb_dim)
        R_w = R_w.repeat(1, n_context, 1) 
        

        R_cj = self.R(context_idx)
        
        
        RcRw = torch.cat((R_w, R_cj), dim=2)
    
        h = nn.ReLU()(self.M(RcRw)) 
        h = torch.sum(h, dim=1)   

        mu = self.affine_lambda_mu(h)
        sigma = nn.functional.softplus(self.affine_lambda_sigma(h))

        # reparametrization trick
        eps = distributions.MultivariateNormal(torch.zeros(self.emb_dim), torch.eye(self.emb_dim)).sample()
        z = mu + sigma * eps    
               
            
        # ********* Decoder ***********
    
        affine_categ = self.affine_theta(z)
        f_i = nn.functional.softmax(affine_categ, dim=1)    
                            
        mu_prior = self.mu_prior(word_idx)
        sigma_prior = nn.functional.softplus(self.sigma_prior(word_idx))
            
        # ********** Loss ************
        
        likelihood_terms = torch.zeros(batch_size)
        KL_div_terms = torch.zeros(batch_size)
        
        for i, contexts in enumerate(context_idx):  
            likelihood = 0
            for idx in contexts:
                likelihood += torch.log(f_i[i, idx] +1e-8)
            likelihood_terms[i] = likelihood
            
            KL =  self.KL_div(mu_prior[i], sigma_prior[i],  mu[i],  sigma[i] )
            KL_div_terms[i] = KL
            
          
        total_loss = torch.mean(KL_div_terms) - torch.mean(likelihood_terms)
             
        return total_loss
    
    def KL_div(self,  mu_p, sigma_p, mu, sigma):
        div = torch.log(sigma_p + 1e-8) - torch.log(sigma+1e-8) + (sigma**2 + (mu - mu_p)**2) / (2*sigma_p**2) - 0.5
        return div.sum()

In [None]:
# TRAIN NETWORK
%matplotlib inline
import time
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

start_time = time.time()

use_cuda = torch.cuda.is_available()

BSG_model = bayesian_skipgram(V, EMBEDDING_DIM)

if use_cuda:
    BSG_model.cuda()
optimizer = torch.optim.Adam(BSG_model.parameters(), lr=0.005)
BSG_model.train()
loss_progress = []
iter_time = time.time()

dataloader = DataLoader(data, batch_size=16, shuffle=True)

print("N_batches", len(dataloader))

for epoch in range(1):
    for i, batch in enumerate(dataloader):
        batch_start = time.time()
        main_word = batch[:,0]
        context_word = batch[:,1:]

        #print("Main word:,", main_word.shape, context_word.shape)


        optimizer.zero_grad()

        if use_cuda:
            loss = BSG_model.forward(main_word.cuda(), context_word.cuda(), use_cuda=True)
        else:
            loss = BSG_model.forward(main_word, context_word)

        loss.backward()
        optimizer.step()
        batch_end = time.time() - batch_start
        if i % 10 == 0:
            print("epoch, batch ", epoch, i)
            loss_progress.append(loss.item())  
            print(time.time()-iter_time)
            print(loss)
            iter_time = time.time()
print("total time taken:", time.time()-start_time)
plt.plot(loss_progress)
plt.savefig('plot.png')
plt.show()

In [None]:
torch.save(BSG_model, 'bayesian_skipgram.pt')