## skipgram

In [1]:
import string
# read in data, remove punctuation and make all lower case
def read_words(filename):
    words = []
    translator = str.maketrans('', '', string.punctuation)
    with open(filename) as f:
        for s in f:
            clean_s = s.translate(translator).lower()
            words.append(clean_s.split())
    return words
        

In [2]:
# create dict that returns the index of onehot encoding for a word (and other way around)
# also create a frequency dict + set size, usuable for negative sampling
import numpy as np

def get_onehot_dicts(corpus):
    # create one set of all unique words
    flat_corpus = [w for s in corpus for w in s]
    corpus_set = set(flat_corpus)
    w_to_i = {}
    i_to_w = {}
    w_freq = []
    num_words = len(corpus_set)
    for i, w in enumerate(corpus_set):
        w_to_i[w] = i
        i_to_w[i] = w
        freq = flat_corpus.count(w)**0.75
        w_freq.append([i, freq])
    return w_to_i, i_to_w, np.array(w_freq), num_words

In [3]:
corpus = read_words('wa/test.en')
w_to_i, i_to_w, w_freq, num_words = get_onehot_dicts(corpus)

In [4]:
# create all positive samples
import torch
def get_pos_pairs(corpus, window_size):
    pairs = [] 
    for s in corpus:
        for i, word in enumerate(s):
            word_index = w_to_i[word]
            for j in range(i-window_size, i+window_size+1):
                if j < 0 or j >= len(s) or j == i:
                    continue
                context = s[j]
                context_index = w_to_i[context]
                pairs.append([word_index, context_index])
    return torch.LongTensor(pairs)

In [5]:
pairs = get_pos_pairs(corpus, 5)

In [6]:
def get_neg_pairs(num_samples, w_freq, pos_pairs):
    neg_dict = {} # pos-pair : neg-pairs
    total = w_freq[:,1].sum()
    prob = [freq/total for freq in w_freq[:, 1]]
    for pair in pos_pairs:
        word = pair[0].item()
        # get negative context words based on their frequencies
        neg_contexts = np.random.choice(w_freq[:,0], p=prob, size=num_samples) 
        neg_dict[tuple(pair.numpy())] = torch.LongTensor(neg_contexts)
    return neg_dict

In [7]:
neg_dict = get_neg_pairs(5, w_freq, pairs)

In [90]:
import torch.nn as nn
class skipgram(nn.Module):
    def __init__(self, vocab_size, emb_dimension):
        super(skipgram, self).__init__()
        self.vocab_size = vocab_size
        self.emb_dimension = emb_dimension
        # start with random embeddings
        self.W_embeddings = torch.randn((vocab_size, emb_dimension), requires_grad=True)
        self.C_embeddings = torch.randn((vocab_size, emb_dimension), requires_grad=True)
        
    def forward(self, word, pos_context, neg_contexts):
        # word, pos_context and neg_contexts are integers, so can just pick that row from W and C matrices
        word_embed = self.W_embeddings[word]
        pos_embed = self.C_embeddings[pos_context]
        neg_embeds = self.C_embeddings.index_select(0, torch.LongTensor(neg_contexts))
        
        pos_similarity = torch.matmul(word_embed, pos_embed).squeeze()
        pos_logsig = nn.functional.logsigmoid(pos_similarity)
        pos_score = pos_logsig
        
        neg_similarity = torch.matmul(word_embed, neg_embeds.transpose(0,1)).squeeze()
        neg_logsig = nn.functional.logsigmoid(-1 * neg_similarity)
        neg_score = sum(neg_logsig)
        
        loss = -(pos_score+neg_score)
        
        return loss
        

In [91]:
print(len(pairs))

49652


In [101]:
# train skipgram
import time
start_time = time.time()
iter_time = time.time()
sg_model = skipgram(num_words, 200)
optimizer = torch.optim.SGD([sg_model.W_embeddings, sg_model.C_embeddings], lr=0.01)

sg_model.train()
loss_sum = 0
for epoch in range(5):
    for i, pos_pair in enumerate(pairs):
        neg_pairs = neg_dict[tuple(pos_pair.numpy())]
        optimizer.zero_grad()
        loss = sg_model.forward(pos_pair[0], pos_pair[1], neg_pairs)
        loss.backward()
        optimizer.step()
        loss_sum += loss
        if i % 10000 == 0 and i!=0:
            print("loss at iteration ", i+len(pairs)*epoch, " is: ", loss_sum.item()/10000)
            print("this took ", time.time()-iter_time, " seconds")
            iter_time = time.time()
            loss_sum = 0
            
print("total time taken:", time.time()-start_time)

loss at iteration  10000  is:  26.768859375
this took  30.541802883148193  seconds
loss at iteration  20000  is:  21.4815078125
this took  30.86644697189331  seconds
loss at iteration  30000  is:  19.5975296875
this took  30.806231021881104  seconds
loss at iteration  40000  is:  18.66006875
this took  30.609694957733154  seconds
loss at iteration  59652  is:  27.0302
this took  60.07502889633179  seconds
loss at iteration  69652  is:  10.70293984375
this took  30.82400107383728  seconds
loss at iteration  79652  is:  10.6948125
this took  30.625030994415283  seconds
loss at iteration  89652  is:  10.64741171875
this took  30.868729829788208  seconds
loss at iteration  109304  is:  16.407815625
this took  60.18449592590332  seconds
loss at iteration  119304  is:  6.84880078125
this took  30.83335304260254  seconds
loss at iteration  129304  is:  7.0220234375
this took  30.854877948760986  seconds
loss at iteration  139304  is:  6.966021875
this took  30.714722633361816  seconds
loss at