# Word2Vec (Negative Sampling)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import nltk


In [2]:
np.__version__, torch.__version__

('2.4.1', '2.9.1+cpu')

In [3]:
import matplotlib
matplotlib.__version__

'3.10.8'

## Import datasets using nltk

In [5]:
from nltk.corpus import brown
nltk.download('brown')
print(brown.categories())

['adventure', 'belles_lettres', 'editorial', 'fiction', 'government', 'hobbies', 'humor', 'learned', 'lore', 'mystery', 'news', 'religion', 'reviews', 'romance', 'science_fiction']


[nltk_data] Downloading package brown to
[nltk_data]     C:\Users\svrat\AppData\Roaming\nltk_data...
[nltk_data]   Package brown is already up-to-date!


In [6]:
#1. tokenization
corpus = brown.sents(categories='news')
print(corpus[:5])  # Print first 5 sentences from the corpus

[['The', 'Fulton', 'County', 'Grand', 'Jury', 'said', 'Friday', 'an', 'investigation', 'of', "Atlanta's", 'recent', 'primary', 'election', 'produced', '``', 'no', 'evidence', "''", 'that', 'any', 'irregularities', 'took', 'place', '.'], ['The', 'jury', 'further', 'said', 'in', 'term-end', 'presentments', 'that', 'the', 'City', 'Executive', 'Committee', ',', 'which', 'had', 'over-all', 'charge', 'of', 'the', 'election', ',', '``', 'deserves', 'the', 'praise', 'and', 'thanks', 'of', 'the', 'City', 'of', 'Atlanta', "''", 'for', 'the', 'manner', 'in', 'which', 'the', 'election', 'was', 'conducted', '.'], ['The', 'September-October', 'term', 'jury', 'had', 'been', 'charged', 'by', 'Fulton', 'Superior', 'Court', 'Judge', 'Durwood', 'Pye', 'to', 'investigate', 'reports', 'of', 'possible', '``', 'irregularities', "''", 'in', 'the', 'hard-fought', 'primary', 'which', 'was', 'won', 'by', 'Mayor-nominate', 'Ivan', 'Allen', 'Jr.', '.'], ['``', 'Only', 'a', 'relative', 'handful', 'of', 'such', 'rep

In [7]:
#2. numeralization
#find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
#assign unique integer
vocabs = list(set(flatten(corpus))) #all the words we have in the system - <UNK>

In [8]:
#create handy mapping between integer and word
word2index = {v:idx for idx, v in enumerate(vocabs)}
word2index['dog']

6186

In [9]:
vocabs.append('<UNK>')
word2index['<UNK>'] = 6

In [10]:
index2word = {v:k for k, v in word2index.items()}
index2word[5]

'Ratto'

## 2. Prepare train data

In [None]:
# #create pairs of center word, and outside word

# def random_batch(batch_size, corpus, window_size=2):

#     skipgrams = []

#     #loop each corpus
#     for doc in corpus:
#         #look from the 2nd word until second last word
#         for i in range(1, len(doc)-1):
#             #center word
#             center = word2index[doc[i]]
#             #outside words = 2 words
#             outside = (word2index[doc[i-1]], word2index[doc[i+1]])
#             #for each of these two outside words, we gonna append to a list
#             for each_out in outside:
#                 skipgrams.append([center, each_out])
#                 #center, outside1;   center, outside2
                
#     random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)
    
#     inputs, labels = [], []
#     for index in random_index:
#         inputs.append([skipgrams[index][0]])
#         labels.append([skipgrams[index][1]])
        
#     return np.array(inputs), np.array(labels)
            
# x, y = random_batch(2, corpus)

In [34]:
def random_batch(batch_size, corpus, window_size=2):
    """
    Create skip-gram (center, outside) pairs using a dynamic window size.
    Default window_size = 2
    """
    skipgrams = []

    for doc in corpus:
        doc_len = len(doc)
        for i in range(doc_len):
            # center word index (fallback to <UNK> if missing)
            center = word2index.get(doc[i], word2index['<UNK>'])

            # dynamic context range
            start = max(0, i - window_size)
            end   = min(doc_len, i + window_size + 1)

            for j in range(start, end):
                if j == i:
                    continue
                outside = word2index.get(doc[j], word2index['<UNK>'])
                skipgrams.append([center, outside])

    random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)

    inputs, labels = [], []
    for index in random_index:
        inputs.append([skipgrams[index][0]])
        labels.append([skipgrams[index][1]])

    return np.array(inputs), np.array(labels)

x, y = random_batch(2, corpus)


In [35]:
x.shape  #batch_size, 1

(2, 1)

In [36]:
x

array([[10895],
       [   32]])

In [37]:
y.shape  #batch_size 1

(2, 1)

## 3. Negative Sampling

### Unigram distribution

$$P(w)=U(w)^{3/4}/Z$$

In [39]:
z = 0.001

In [40]:
#count
from collections import Counter

word_count = Counter(flatten(corpus))
word_count

#get the total number of words
num_total_words = sum([c for w, c in word_count.items()])
num_total_words

100554

In [41]:
vocabs

['Wilbur',
 'subtitled',
 'violent',
 'trading',
 'Anniston',
 'Ratto',
 'declaring',
 'Sinfonica',
 'ordered',
 'mistakes',
 'mirror',
 'Marella',
 'tougher',
 '1926',
 'global',
 'Hugh',
 'yearbook',
 'Burgess',
 'bicycle-auto',
 'Rolnick',
 'deadline',
 'seven',
 'conductors',
 'lad',
 'Majesties',
 'Kong',
 'detachment',
 'Lawford',
 '23',
 'employ',
 '3-0',
 'action',
 'of',
 'invasion',
 'documentary',
 'during',
 'eventually',
 'sphynxes',
 'Churchill',
 'card',
 'sank',
 'attendants',
 'senator',
 'antiques',
 '108',
 'service',
 'dot',
 'Lawrenceville',
 'Wilkinson',
 'Q.',
 'wed',
 'Regional',
 'Beadle',
 'Amateur',
 '$4',
 'Martin',
 'three-fifths',
 'This',
 'robbing',
 'attempted',
 'lacking',
 'directors',
 "Center's",
 'potatoes',
 'Sixty-seven',
 'chain',
 'Norway',
 'collectors',
 'three-hour',
 'Outstanding',
 'keeping',
 'excitement',
 'SWC',
 'fueled',
 'pry',
 'acquire',
 'front',
 'franker',
 'masterful',
 'aware',
 'Bellows',
 'Ankara',
 'Nunes',
 'stations',
 'm

$$P(w)=U(w)^{3/4}/Z$$

In [42]:
unigram_table = []

for v in vocabs:
    uw = word_count[v] / num_total_words
    uw_alpha = int((uw ** 0.75) / z)
    unigram_table.extend([v] * uw_alpha)
    
Counter(unigram_table)

Counter({'the': 114,
         ',': 108,
         '.': 89,
         'of': 69,
         'to': 55,
         'and': 55,
         'a': 52,
         'in': 50,
         'for': 30,
         'The': 26,
         'that': 26,
         'was': 24,
         '``': 24,
         "''": 24,
         'is': 24,
         'on': 22,
         'at': 21,
         'be': 19,
         'with': 19,
         'as': 18,
         'by': 18,
         'he': 17,
         'will': 15,
         'his': 15,
         'said': 15,
         'it': 14,
         'from': 14,
         'are': 13,
         ';': 13,
         '--': 12,
         'has': 12,
         'an': 12,
         'had': 12,
         'were': 11,
         'this': 11,
         'who': 11,
         'Mrs.': 11,
         'have': 11,
         'not': 11,
         'their': 10,
         'which': 10,
         'would': 10,
         'been': 9,
         'they': 9,
         'He': 9,
         'out': 8,
         'up': 8,
         'last': 8,
         'but': 8,
         '(': 8,
         'its':

## 4. Model

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [43]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index['<UNK>'], seq))
    return torch.LongTensor(idxs)

In [44]:
import random

def negative_sampling(targets, unigram_table, k):
    batch_size = targets.shape[0]
    neg_samples = []
    for i in range(batch_size):  #(1, k)
        target_index = targets[i].item()
        nsample      = []
        while (len(nsample) < k):
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).reshape(1, -1))
        
    return torch.cat(neg_samples) #batch_size, k

In [45]:
batch_size = 2
x, y = random_batch(batch_size, corpus)
x_tensor = torch.LongTensor(x)
y_tensor = torch.LongTensor(y)

In [46]:
k = 5
neg_samples = negative_sampling(y_tensor, unigram_table, k)

In [47]:
y_tensor[1]

tensor([9605])

In [48]:
neg_samples[1]

tensor([14289, 11395,  5494, 11423,   970])

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [49]:
class SkipgramNeg(nn.Module):
    
    def __init__(self, voc_size, emb_size):
        super(SkipgramNeg, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
        self.logsigmoid        = nn.LogSigmoid()
    
    def forward(self, center, outside, negative):
        #center, outside:  (bs, 1)
        #negative       :  (bs, k)
        
        center_embed   = self.embedding_center(center) #(bs, 1, emb_size)
        outside_embed  = self.embedding_outside(outside) #(bs, 1, emb_size)
        negative_embed = self.embedding_outside(negative) #(bs, k, emb_size)
        
        uovc           = outside_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, 1)
        ukvc           = -negative_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, k)
        ukvc_sum       = torch.sum(ukvc, 1).reshape(-1, 1) #(bs, 1)
        
        loss           = self.logsigmoid(uovc) + self.logsigmoid(ukvc_sum)
        
        return -torch.mean(loss)

In [50]:
#test your model
emb_size = 2
voc_size = len(vocabs)
model = SkipgramNeg(voc_size, emb_size)

In [51]:
loss = model(x_tensor, y_tensor, neg_samples)

In [52]:
loss

tensor(4.5837, grad_fn=<NegBackward0>)

## 5. Training

In [53]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [54]:
num_epochs = 10000

for epoch in range(num_epochs):
    
    #get batch
    # Example schedule (you can change this however you like)
    if epoch < 3000:
        window_size = 2
    elif epoch < 7000:
        window_size = 3
    else:
        window_size = 4

    input_batch, label_batch = random_batch(batch_size, corpus, window_size=window_size)
    
    input_tensor = torch.LongTensor(input_batch)
    label_tensor = torch.LongTensor(label_batch)
    
    #predict
    neg_samples = negative_sampling(label_tensor, unigram_table, k)
    loss = model(input_tensor, label_tensor, neg_samples)
    
    #backprogate
    optimizer.zero_grad()
    loss.backward()
    
    #update alpha
    optimizer.step()
    
    #print the loss
    if (epoch + 1) % 1000 == 0:
        print(f"Epoch {epoch+1:6.0f} | Loss: {loss:2.6f} | Window: {window_size}")

Epoch   1000 | Loss: 1.957078 | Window: 2
Epoch   2000 | Loss: 0.594019 | Window: 2
Epoch   3000 | Loss: 1.684340 | Window: 2
Epoch   4000 | Loss: 1.719735 | Window: 3
Epoch   5000 | Loss: 1.875870 | Window: 3
Epoch   6000 | Loss: 0.361528 | Window: 3
Epoch   7000 | Loss: 0.838764 | Window: 3
Epoch   8000 | Loss: 0.883517 | Window: 4
Epoch   9000 | Loss: 1.013314 | Window: 4
Epoch  10000 | Loss: 1.259255 | Window: 4


## 6. Plot the embeddings

In [55]:
vocabs

['Wilbur',
 'subtitled',
 'violent',
 'trading',
 'Anniston',
 'Ratto',
 'declaring',
 'Sinfonica',
 'ordered',
 'mistakes',
 'mirror',
 'Marella',
 'tougher',
 '1926',
 'global',
 'Hugh',
 'yearbook',
 'Burgess',
 'bicycle-auto',
 'Rolnick',
 'deadline',
 'seven',
 'conductors',
 'lad',
 'Majesties',
 'Kong',
 'detachment',
 'Lawford',
 '23',
 'employ',
 '3-0',
 'action',
 'of',
 'invasion',
 'documentary',
 'during',
 'eventually',
 'sphynxes',
 'Churchill',
 'card',
 'sank',
 'attendants',
 'senator',
 'antiques',
 '108',
 'service',
 'dot',
 'Lawrenceville',
 'Wilkinson',
 'Q.',
 'wed',
 'Regional',
 'Beadle',
 'Amateur',
 '$4',
 'Martin',
 'three-fifths',
 'This',
 'robbing',
 'attempted',
 'lacking',
 'directors',
 "Center's",
 'potatoes',
 'Sixty-seven',
 'chain',
 'Norway',
 'collectors',
 'three-hour',
 'Outstanding',
 'keeping',
 'excitement',
 'SWC',
 'fueled',
 'pry',
 'acquire',
 'front',
 'franker',
 'masterful',
 'aware',
 'Bellows',
 'Ankara',
 'Nunes',
 'stations',
 'm

In [56]:
Hugh = torch.LongTensor([word2index['Hugh']])
Hugh

tensor([15])

In [57]:
Hugh_embed_c = model.embedding_center(Hugh)
Hugh_embed_o = model.embedding_outside(Hugh)
Hugh_embed   = (Hugh_embed_c + Hugh_embed_o) / 2
Hugh_embed

tensor([[-0.1737, -0.2575]], grad_fn=<DivBackward0>)

In [58]:
Hugh_embed_o

tensor([[-0.1798,  1.0187]], grad_fn=<EmbeddingBackward0>)

In [59]:
def get_embed(word):
    try:
        index = word2index[word]
    except:
        index = word2index['<UNK>']
        
    word = torch.LongTensor([word2index[word]])
    
    embed_c = model.embedding_center(word)
    embed_o = model.embedding_outside(word)
    embed   = (embed_c + embed_o) / 2
    
    return embed[0][0].item(), embed[0][1].item()

In [50]:
get_embed('fruit')

(-1.6248319149017334, 0.41915708780288696)

In [51]:
get_embed('cat')

(1.3170452117919922, -0.7658721208572388)

In [52]:
get_embed('dog')

(1.7207331657409668, -0.2711341679096222)

In [53]:
get_embed('banana')

(-1.2168848514556885, 0.26804348826408386)

In [None]:
plt.figure(figsize=(6, 3))
for i, word in enumerate(vocabs):
    x, y = get_embed(word)
    plt.scatter(x, y)
    plt.annotate(word, xy=(x, y), xytext=(5, 2), textcoords='offset points')
plt.show()

KeyboardInterrupt: 

## 7. Cosine similarity

In [61]:
banana = get_embed('banana')
banana

(-1.2168848514556885, 0.26804348826408386)

In [62]:
fruit = get_embed('fruit')
fruit

(-1.6248319149017334, 0.41915708780288696)

In [63]:
cat = get_embed('cat')
cat

(1.3170452117919922, -0.7658721208572388)

In [66]:
np.array(banana) @ np.array(cat)

-1.8079794017507105

In [65]:
#more formally is to divide by its norm
def cosine_similarity(A, B):
    dot_product = np.dot(A, B)
    norm_a = np.linalg.norm(A)
    norm_b = np.linalg.norm(B)
    similarity = dot_product / (norm_a * norm_b)
    return similarity

print(cosine_similarity(np.array(banana), np.array(cat)))
print(cosine_similarity(np.array(banana), np.array(fruit)))

-0.9523630610360275
0.9993643506202321
