https://stats.stackexchange.com/questions/253244/gradients-for-skipgram-word2vec                                      
https://rguigoures.github.io/word2vec_pytorch/

In [4]:
import re
import nltk
nltk.download('brown')
from nltk.corpus import brown
import itertools

[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\davidt\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [5]:
corpus = []

for cat in ['news']:
    for text_id in brown.fileids(cat):
        raw_text = list(itertools.chain.from_iterable(brown.sents(text_id)))
        text = ' '.join(raw_text)
        text = text.lower()
        text.replace('\n', ' ')
        text = re.sub('[^a-z ]+', '', text)
        corpus.append([w for w in text.split() if w != ''])

In [9]:
from collections import Counter
import random, math

def subsample_frequent_words(corpus):
    filtered_corpus = []
    word_counts = dict(Counter(list(itertools.chain.from_iterable(corpus))))
    sum_word_counts = sum(list(word_counts.values()))
    word_counts = {word: word_counts[word]/float(sum_word_counts) for word in word_counts}
    for text in corpus:
        filtered_corpus.append([])
        for word in text:
            if random.random() < (1+math.sqrt(word_counts[word] * 1e3)) * 1e-3 / float(word_counts[word]):
                filtered_corpus[-1].append(word)
    return filtered_corpus

In [10]:
corpus = subsample_frequent_words(corpus)
vocabulary = set(itertools.chain.from_iterable(corpus))

word_to_index = {w: idx for (idx, w) in enumerate(vocabulary)}
index_to_word = {idx: w for (idx, w) in enumerate(vocabulary)}

In [14]:
import numpy as np

context_tuple_list = []
w = 4

for text in corpus:
    for i, word in enumerate(text):
        first_context_word_index = max(0, i-w)
        last_context_word_index = min(i+w, len(text))
        for j in range(first_context_word_index, last_context_word_index):
            if i!=j:
                context_tuple_list.append((word, text[j]))
print("There are {} pairs of target and context words".format(len(context_tuple_list)))

There are 474008 pairs of target and context words


In [15]:
context_tuple_list

[('fulton', 'county'),
 ('fulton', 'grand'),
 ('fulton', 'jury'),
 ('county', 'fulton'),
 ('county', 'grand'),
 ('county', 'jury'),
 ('county', 'said'),
 ('grand', 'fulton'),
 ('grand', 'county'),
 ('grand', 'jury'),
 ('grand', 'said'),
 ('grand', 'friday'),
 ('jury', 'fulton'),
 ('jury', 'county'),
 ('jury', 'grand'),
 ('jury', 'said'),
 ('jury', 'friday'),
 ('jury', 'an'),
 ('said', 'fulton'),
 ('said', 'county'),
 ('said', 'grand'),
 ('said', 'jury'),
 ('said', 'friday'),
 ('said', 'an'),
 ('said', 'investigation'),
 ('friday', 'county'),
 ('friday', 'grand'),
 ('friday', 'jury'),
 ('friday', 'said'),
 ('friday', 'an'),
 ('friday', 'investigation'),
 ('friday', 'atlantas'),
 ('an', 'grand'),
 ('an', 'jury'),
 ('an', 'said'),
 ('an', 'friday'),
 ('an', 'investigation'),
 ('an', 'atlantas'),
 ('an', 'recent'),
 ('investigation', 'jury'),
 ('investigation', 'said'),
 ('investigation', 'friday'),
 ('investigation', 'an'),
 ('investigation', 'atlantas'),
 ('investigation', 'recent'),
 ('

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [31]:
class Word2Vec(nn.Module):
    
    def __init__(self, embedding_size, vocab_size):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)
    
    def forward(self, context_word):
        emb = self.embeddings(context_word)
        hidden = self.linear(emb)
        out = F.log_softmax(hidden)
        return out

In [27]:
class EarlyStopping():
    def __init__(self, patience=5, min_percent_gain=0.1):
        self.patience = patience
        self.loss_list = []
        self.min_percent_gain = min_percent_gain/100.
    
    def update_loss(self, loss):
        self.loss_list.append(loss)
        if len(self.loss_list) > self.patience:
            del self.loss_list[0]
    
    def stop_training(self):
        if len(self.loss_list) == 1:
            return False
        gain = (max(self.loss_list)-min(self.loss_list))/max(self.loss_list)
        print("Loss gain: {}%".format(round(100*gain,2)))
        if gain < self.min_percent_gain:
            return True
        else:
            return False

In [36]:
vocabulary_size = len(vocabulary)

net = Word2Vec(embedding_size=2, vocab_size=vocabulary_size)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())
early_stopping = EarlyStopping()
context_tensor_list = []

for target, context in context_tuple_list:
    target_tensor = torch.LongTensor([word_to_index[target]])
    context_tensor = torch.LongTensor([word_to_index[context]])
    context_tensor_list.append((target_tensor, context_tensor))

In [37]:
len(context_tensor_list)

474008

In [38]:
while True:
    losses = []
    for target_tensor, context_tensor in context_tensor_list:
        net.zero_grad()
        log_probs = net(context_tensor)
        loss = loss_function(log_probs, target_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.data)ji
    print("Loss: ", np.mean(losses))
    early_stopping.update_loss(np.mean(losses))
    if early_stopping.stop_training():
        break

  # This is added back by InteractiveShellApp.init_path()


KeyboardInterrupt: 