# Word2Vec (Negative Sampling)

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
import nltk
from nltk.corpus import semcor, stopwords

nltk.download('semcor')
nltk.download('stopwords')

[nltk_data] Downloading package semcor to
[nltk_data]     C:\Users\surya\AppData\Roaming\nltk_data...
[nltk_data]   Package semcor is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\surya\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [3]:
torch.cuda.is_available(), torch.cuda.get_device_name(0)

(True, 'NVIDIA GeForce RTX 2080 with Max-Q Design')

In [4]:
semcor.words()

['The', 'Fulton', 'County', 'Grand', 'Jury', 'said', ...]

## 1. Load data

In [5]:
#1. tokenization
corpus = semcor.sents()

In [6]:
#2. numeralization
#find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
#assign unique integer
vocabs = list(set(flatten(corpus))) #all the words we have in the system - <UNK>

In [7]:
#create handy mapping between integer and word
word2index = {v:idx for idx, v in enumerate(vocabs)}

In [8]:
vocabs.append('<UNK>')
word2index['<UNK>'] = len(vocabs) - 1  

In [9]:
index2word = {v:k for k, v in word2index.items()}

## 2. Prepare train data

In [10]:
#create pairs of center word, and outside word

def random_batch(batch_size, corpus, window_size=2):

    skipgrams = []

    #loop each corpus
    for doc in corpus:
        #look from the 2nd word until second last word
        for i in range(window_size, len(doc)-window_size):
            #center word
            center = word2index[doc[i]]
            
            #outside words (window size on both sides of the center word)
            outside_start =  i - window_size
            outside_end =  i + window_size + 1 # +1 because the end index is exclusive

            # outside = []
            # loop through the outside words, append to the list 'outside'
            for j in range(outside_start, outside_end):
                if i != j:  # Skip the center word
                    outside= word2index[doc[j]]
                    skipgrams.append([center, outside])

                
    random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)
    
    inputs, labels = [], []
    for index in random_index:
        inputs.append([skipgrams[index][0]])
        labels.append([skipgrams[index][1]])
        
    return np.array(inputs), np.array(labels)
            
x, y = random_batch(2, corpus)

In [11]:
x.shape  #batch_size, 1

(2, 1)

In [12]:
x

array([[27139],
       [31084]])

In [13]:
y.shape  #batch_size 1

(2, 1)

## 3. Negative Sampling

### Unigram distribution

$$P(w)=U(w)^{3/4}/Z$$

In [14]:
z = 0.001

In [15]:
#count
from collections import Counter

word_count = Counter(flatten(corpus))
word_count

#get the total number of words
num_total_words = sum([c for w, c in word_count.items()])
num_total_words

820411

In [16]:
vocabs

['servant',
 '',
 'collaborator',
 'Polls',
 'keyhole',
 'pellets',
 'Ahm',
 'Longhorn',
 'worries',
 'attentive',
 'Bavaria',
 'expressionless',
 'too-large',
 'Interpretation',
 '14.7',
 'leather-hard',
 'beneath',
 'keening',
 'Ships',
 'DePaul',
 'Casualty',
 'Clean',
 'brash',
 'banner',
 'That-a-way',
 'suddenly',
 'Scores',
 'brave',
 'assertion',
 'osteoporosis',
 'disapprobation',
 'Eight',
 'anesthetically',
 'galleys',
 'Cal.',
 'novelists',
 'Arcade',
 'McSorley',
 'stultifying',
 'epiphyseal',
 'narcotic',
 'recount',
 'vaginal',
 'irreparable',
 'Andrena',
 'mediating',
 'preconceptions',
 'barnsful',
 'spots',
 'wakes',
 'beginning',
 'barrier',
 'Christiana',
 'uncomforatble',
 'inhabit',
 'living-room',
 'woo',
 'rekindling',
 'celiac',
 'Sixteenth',
 'imposes',
 'turn',
 'Yang',
 'electrostatic',
 'two-part',
 '1.4',
 'syllables',
 'stewardess',
 'wavers',
 'steelmakers',
 'elation',
 'glib',
 'monographs',
 'ship',
 'Le',
 'Chico',
 'driver',
 'taxi',
 'grandsons',
 

$$P(w)=U(w)^{3/4}/Z$$

In [17]:
unigram_table = []

for v in vocabs:
    uw = word_count[v] / num_total_words
    uw_alpha = int((uw ** 0.75) / z)
    unigram_table.extend([v] * uw_alpha)
    
Counter(unigram_table)

Counter({'the': 111,
         ',': 104,
         '.': 93,
         'of': 72,
         'and': 60,
         'to': 56,
         'a': 50,
         'in': 46,
         'that': 28,
         'is': 28,
         'was': 27,
         "''": 25,
         '``': 25,
         'for': 25,
         'it': 21,
         'with': 21,
         'The': 21,
         'he': 20,
         'be': 20,
         'as': 20,
         'on': 20,
         'his': 19,
         "'s": 18,
         'had': 16,
         'I': 16,
         'by': 16,
         'at': 16,
         'are': 15,
         'not': 15,
         'or': 14,
         'this': 14,
         'from': 14,
         '-': 14,
         'have': 13,
         'an': 13,
         'which': 12,
         'were': 12,
         'He': 11,
         'but': 11,
         'they': 11,
         'you': 11,
         'one': 11,
         'would': 11,
         'all': 10,
         'their': 10,
         'has': 10,
         'her': 10,
         ';': 10,
         'It': 9,
         'him': 9,
         '?': 9,


## 4. Model

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [18]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index['<UNK>'], seq))
    return torch.LongTensor(idxs)

In [19]:
import random

def negative_sampling(targets, unigram_table, k):
    batch_size = targets.shape[0]
    neg_samples = []
    for i in range(batch_size):  #(1, k)
        target_index = targets[i].item()
        nsample      = []
        while (len(nsample) < k):
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).reshape(1, -1))
        
    return torch.cat(neg_samples) #batch_size, k

In [20]:
batch_size = 2
x, y = random_batch(batch_size, corpus)
x_tensor = torch.LongTensor(x)
y_tensor = torch.LongTensor(y)

In [21]:
k = 5
neg_samples = negative_sampling(y_tensor, unigram_table, k)

In [22]:
y_tensor[1]

tensor([1980])

In [23]:
neg_samples[1]

tensor([ 3404,  3390, 24163, 38623, 12435])

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [24]:
class SkipgramNeg(nn.Module):
    
    def __init__(self, voc_size, emb_size):
        super(SkipgramNeg, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
        self.logsigmoid        = nn.LogSigmoid()
    
    def forward(self, center, outside, negative):
        #center, outside:  (bs, 1)
        #negative       :  (bs, k)
        
        center_embed   = self.embedding_center(center) #(bs, 1, emb_size)
        outside_embed  = self.embedding_outside(outside) #(bs, 1, emb_size)
        negative_embed = self.embedding_outside(negative) #(bs, k, emb_size)
        
        uovc           = outside_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, 1)
        ukvc           = -negative_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, k)
        ukvc_sum       = torch.sum(ukvc, 1).reshape(-1, 1) #(bs, 1)
        
        loss           = self.logsigmoid(uovc) + self.logsigmoid(ukvc_sum)
        
        return -torch.mean(loss)
        

In [25]:
#test your model
emb_size = 2
voc_size = len(vocabs)
model = SkipgramNeg(voc_size, emb_size)

In [26]:
loss = model(x_tensor, y_tensor, neg_samples)

In [27]:
loss

tensor(1.2834, grad_fn=<NegBackward0>)

## 5. Training

In [28]:
import time

def log_epoch(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [29]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [30]:
start = time.time()
num_epochs = 100
window_size = 5

for epoch in range(num_epochs):
    
    #get batch
    input_batch, label_batch = random_batch(batch_size, corpus, window_size)
    input_tensor = torch.LongTensor(input_batch)
    label_tensor = torch.LongTensor(label_batch)
    
    #predict
    neg_samples = negative_sampling(label_tensor, unigram_table, k)
    loss = model(input_tensor, label_tensor, neg_samples)
    
    #backprogate
    optimizer.zero_grad()
    loss.backward()
    
    #update alpha
    optimizer.step()
    # Log epoch time
    epoch_mins, epoch_secs = log_epoch(start, time.time())
    #print the loss
    if (epoch + 1) % 10 == 0:
        print(f"Epoch: {epoch + 1} | Loss: {loss:.6f} | Time: {epoch_mins}m {epoch_secs}s")

Epoch: 10 | Loss: 1.673143 | Time: 2m 48s
Epoch: 20 | Loss: 1.891698 | Time: 5m 37s
Epoch: 30 | Loss: 2.914092 | Time: 8m 27s
Epoch: 40 | Loss: 1.774965 | Time: 11m 16s
Epoch: 50 | Loss: 1.323492 | Time: 14m 4s
Epoch: 60 | Loss: 1.392781 | Time: 16m 57s
Epoch: 70 | Loss: 2.789881 | Time: 19m 48s
Epoch: 80 | Loss: 3.034571 | Time: 22m 38s
Epoch: 90 | Loss: 1.278655 | Time: 25m 31s
Epoch: 100 | Loss: 3.414258 | Time: 28m 27s


## 6. Plot the embeddings

In [31]:
vocabs

['servant',
 '',
 'collaborator',
 'Polls',
 'keyhole',
 'pellets',
 'Ahm',
 'Longhorn',
 'worries',
 'attentive',
 'Bavaria',
 'expressionless',
 'too-large',
 'Interpretation',
 '14.7',
 'leather-hard',
 'beneath',
 'keening',
 'Ships',
 'DePaul',
 'Casualty',
 'Clean',
 'brash',
 'banner',
 'That-a-way',
 'suddenly',
 'Scores',
 'brave',
 'assertion',
 'osteoporosis',
 'disapprobation',
 'Eight',
 'anesthetically',
 'galleys',
 'Cal.',
 'novelists',
 'Arcade',
 'McSorley',
 'stultifying',
 'epiphyseal',
 'narcotic',
 'recount',
 'vaginal',
 'irreparable',
 'Andrena',
 'mediating',
 'preconceptions',
 'barnsful',
 'spots',
 'wakes',
 'beginning',
 'barrier',
 'Christiana',
 'uncomforatble',
 'inhabit',
 'living-room',
 'woo',
 'rekindling',
 'celiac',
 'Sixteenth',
 'imposes',
 'turn',
 'Yang',
 'electrostatic',
 'two-part',
 '1.4',
 'syllables',
 'stewardess',
 'wavers',
 'steelmakers',
 'elation',
 'glib',
 'monographs',
 'ship',
 'Le',
 'Chico',
 'driver',
 'taxi',
 'grandsons',
 

In [32]:
banana = torch.LongTensor([word2index['banana']])
banana

tensor([20139])

In [33]:
banana_embed_c = model.embedding_center(banana)
banana_embed_o = model.embedding_outside(banana)
banana_embed   = (banana_embed_c + banana_embed_o) / 2
banana_embed

tensor([[0.5464, 0.3359]], grad_fn=<DivBackward0>)

In [34]:
banana_embed_o

tensor([[1.5425, 0.5178]], grad_fn=<EmbeddingBackward0>)

In [35]:
def get_embed(word):
    try:
        index = word2index[word]
    except:
        index = word2index['<UNK>']
        
    word = torch.LongTensor([word2index[word]])
    
    embed_c = model.embedding_center(word)
    embed_o = model.embedding_outside(word)
    embed   = (embed_c + embed_o) / 2
    
    return embed[0][0].item(), embed[0][1].item()

In [36]:
get_embed('fruit')

(-0.21022072434425354, -0.05544298142194748)

In [37]:
get_embed('cat')

(-0.07699589431285858, 0.5209966897964478)

In [38]:
get_embed('dog')

(-0.25635766983032227, -0.24425120651721954)

In [39]:
get_embed('banana')

(0.5463566780090332, 0.33587732911109924)

## 7. Cosine similarity

In [40]:
banana = get_embed('banana')
banana

(0.5463566780090332, 0.33587732911109924)

In [41]:
fruit = get_embed('fruit')
fruit

(-0.21022072434425354, -0.05544298142194748)

In [42]:
cat = get_embed('cat')
cat

(-0.07699589431285858, 0.5209966897964478)

In [43]:
np.array(banana) @ np.array(cat)

0.13292375560744674

In [44]:
#more formally is to divide by its norm
def cosine_similarity(A, B):
    dot_product = np.dot(A, B)
    norm_a = np.linalg.norm(A)
    norm_b = np.linalg.norm(B)
    similarity = dot_product / (norm_a * norm_b)
    return similarity

print(cosine_similarity(np.array(banana), np.array(cat)))
print(cosine_similarity(np.array(banana), np.array(fruit)))

0.3935379003701135
-0.9572847872156872


## Export Model

In [45]:
import pickle

In [46]:
filename = 'skipgram_neg_sampling_model.pkl'
pickle.dump(model, open(filename, 'wb'))