https://stats.stackexchange.com/questions/253244/gradients-for-skipgram-word2vec                                      
https://rguigoures.github.io/word2vec_pytorch/

In [1]:
import re
import nltk
nltk.download('brown')
from nltk.corpus import brown
import itertools
from collections import Counter
import random, math
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

[nltk_data] Error loading brown: <urlopen error [Errno 11001]
[nltk_data]     getaddrinfo failed>


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [3]:
corpus = []

for cat in ['news']:
    for text_id in brown.fileids(cat):
        raw_text = list(itertools.chain.from_iterable(brown.sents(text_id)))
        text = ' '.join(raw_text)
        text = text.lower()
        text.replace('\n', ' ')
        text = re.sub('[^a-z ]+', '', text)
        corpus.append([w for w in text.split() if w != ''])

In [4]:
def subsample_frequent_words(corpus):
    filtered_corpus = []
    word_counts = dict(Counter(list(itertools.chain.from_iterable(corpus))))
    sum_word_counts = sum(list(word_counts.values()))
    word_counts = {word: word_counts[word]/float(sum_word_counts) for word in word_counts}
    for text in corpus:
        filtered_corpus.append([])
        for word in text:
            if random.random() < (1+math.sqrt(word_counts[word] * 1e3)) * 1e-3 / float(word_counts[word]):
                filtered_corpus[-1].append(word)
    return filtered_corpus

In [5]:
corpus = subsample_frequent_words(corpus)
vocabulary = set(itertools.chain.from_iterable(corpus))

word_to_index = {w: idx for (idx, w) in enumerate(vocabulary)}
index_to_word = {idx: w for (idx, w) in enumerate(vocabulary)}

In [6]:
context_tuple_list = []
w = 4

for text in corpus:
    for i, word in enumerate(text):
        first_context_word_index = max(0, i-w)
        last_context_word_index = min(i+w, len(text))
        for j in range(first_context_word_index, last_context_word_index):
            if i!=j:
                context_tuple_list.append((word, text[j]))
print("There are {} pairs of target and context words".format(len(context_tuple_list)))

There are 473231 pairs of target and context words


In [7]:
context_tuple_list[:10]

[('fulton', 'county'),
 ('fulton', 'grand'),
 ('fulton', 'jury'),
 ('county', 'fulton'),
 ('county', 'grand'),
 ('county', 'jury'),
 ('county', 'friday'),
 ('grand', 'fulton'),
 ('grand', 'county'),
 ('grand', 'jury')]

In [8]:
class Word2Vec(nn.Module):
    
    def __init__(self, embedding_size, vocab_size):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, context_word):
        emb = self.embeddings(context_word)
        hidden = self.linear(emb)
        out = F.log_softmax(hidden)
        return out

In [9]:
vocabulary_size = len(vocabulary)

net = Word2Vec(embedding_size=100, vocab_size=vocabulary_size).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())
context_tensor_list = []

In [10]:
net

Word2Vec(
  (embeddings): Embedding(12132, 100)
  (linear): Linear(in_features=100, out_features=12132, bias=True)
)

In [11]:
for target, context in context_tuple_list:
    target_tensor = torch.LongTensor([word_to_index[target]])
    context_tensor = torch.LongTensor([word_to_index[context]])
    context_tensor_list.append((target_tensor.to(device), context_tensor.to(device)))

In [12]:
def get_batches(context_tuple_list, batch_size=64):
    random.shuffle(context_tuple_list)
    batches = []
    batch_target, batch_context = [], []
    for i in range(len(context_tuple_list)):
        batch_target.append(word_to_index[context_tuple_list[i][0]])
        batch_context.append(word_to_index[context_tuple_list[i][1]])
        if (i+1) % batch_size == 0 or i == len(context_tuple_list)-1:
            tensor_target = torch.LongTensor(torch.from_numpy(np.array(batch_target)).long()).to(device)
            tensor_context = torch.LongTensor(torch.from_numpy(np.array(batch_context)).long()).to(device)
            batches.append((tensor_target, tensor_context))
            batch_target, batch_context = [], []
    return batches

In [13]:
context_tuple_batches = get_batches(context_tuple_list, batch_size=2000)

In [14]:
len(context_tuple_batches), len(context_tuple_batches[0][0])

(237, 2000)

In [15]:
context_tuple_batches[:5]

[(tensor([3346,  378, 8288,  ..., 6968, 9338,  818], device='cuda:0'),
  tensor([3637, 6559, 2353,  ..., 3960, 2525,  472], device='cuda:0')),
 (tensor([ 8412, 10035,  3384,  ...,  8804,  5699,  4840], device='cuda:0'),
  tensor([ 6544, 11515,  1585,  ...,  2360,  6553,  4006], device='cuda:0')),
 (tensor([12098,  6655, 10820,  ...,  6296, 11056,  5454], device='cuda:0'),
  tensor([7646, 7468, 9525,  ..., 8095, 5006, 2711], device='cuda:0')),
 (tensor([9213, 5680, 9117,  ..., 3325, 9214, 3264], device='cuda:0'),
  tensor([ 7131, 10644,  7708,  ...,  3481,  8864,  4894], device='cuda:0')),
 (tensor([ 5495, 10505,  1724,  ...,  9960,  1520,  7436], device='cuda:0'),
  tensor([10606,  3865,  9404,  ...,  4503,  7975,  5211], device='cuda:0'))]

In [16]:
del context_tensor_list

In [17]:
epochs = 1000
for j in range(epochs):
    losses = []
    for i in range(len(context_tuple_batches)):
        net.zero_grad()
        target_tensor, context_tensor = context_tuple_batches[i]
        
        log_probs = net(context_tensor)
        loss = loss_function(log_probs, target_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if i%100 == 0:
            print(j, i, "Loss: ", np.mean(losses))

  # This is added back by InteractiveShellApp.init_path()


0 0 Loss:  9.573467254638672
0 100 Loss:  9.453773498535156
0 200 Loss:  9.38239717246288
1 0 Loss:  8.908522605895996
1 100 Loss:  8.913830067851755
1 200 Loss:  8.806369027095055
2 0 Loss:  8.284957885742188
2 100 Loss:  8.302826702004612
2 200 Loss:  8.23763742019881
3 0 Loss:  7.883403778076172
3 100 Loss:  7.931308614145411
3 200 Loss:  7.898256188008323
4 0 Loss:  7.642251014709473
4 100 Loss:  7.699016221679083
4 200 Loss:  7.67992416543154
5 0 Loss:  7.47418737411499
5 100 Loss:  7.533843474813027
5 200 Loss:  7.520986851175033
6 0 Loss:  7.343972682952881
6 100 Loss:  7.40339448192332
6 200 Loss:  7.393574560459573
7 0 Loss:  7.23533296585083
7 100 Loss:  7.292978522801163
7 200 Loss:  7.2848779645132185
8 0 Loss:  7.140255451202393
8 100 Loss:  7.195584084727977
8 200 Loss:  7.188632455038194
9 0 Loss:  7.054688453674316
9 100 Loss:  7.107595037705828
9 200 Loss:  7.101524865449364
10 0 Loss:  6.976510524749756
10 100 Loss:  7.027001258170251
10 200 Loss:  7.021672286797519
1

In [18]:
def get_closest_word(word, topn=5):
    word_distance = []
    emb = net.embeddings
    pdist = nn.PairwiseDistance()
    i = word_to_index[word]
    lookup_tensor_i = torch.tensor([i], dtype=torch.long).to(device)
    v_i = emb(lookup_tensor_i)
    for j in range(len(vocabulary)):
        if j != i:
            lookup_tensor_j = torch.tensor([j], dtype=torch.long).to(device)
            v_j = emb(lookup_tensor_j)
            word_distance.append((index_to_word[j], float(pdist(v_i, v_j))))
    word_distance.sort(key=lambda x: x[1])
    return word_distance[:topn]

In [20]:
get_closest_word('dog')

[('be', 13.522345542907715),
 ('expected', 13.62894344329834),
 ('and', 13.876252174377441),
 ('females', 13.88402271270752),
 ('wagons', 13.906343460083008)]