# Word2Vec (Skipgram )

In [49]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [50]:
from nltk.corpus import brown
brown.categories()
news_corpus = brown.sents(categories=['news'])
  

In [51]:
news_corpus

[['The', 'Fulton', 'County', 'Grand', 'Jury', 'said', 'Friday', 'an', 'investigation', 'of', "Atlanta's", 'recent', 'primary', 'election', 'produced', '``', 'no', 'evidence', "''", 'that', 'any', 'irregularities', 'took', 'place', '.'], ['The', 'jury', 'further', 'said', 'in', 'term-end', 'presentments', 'that', 'the', 'City', 'Executive', 'Committee', ',', 'which', 'had', 'over-all', 'charge', 'of', 'the', 'election', ',', '``', 'deserves', 'the', 'praise', 'and', 'thanks', 'of', 'the', 'City', 'of', 'Atlanta', "''", 'for', 'the', 'manner', 'in', 'which', 'the', 'election', 'was', 'conducted', '.'], ...]

In [52]:
np.__version__, torch.__version__

('1.26.4', '2.4.1')

In [53]:
import matplotlib
matplotlib.__version__

'3.9.1.post1'

## 1. Load data

In [54]:
#2. numeralization
#find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
#assign unique integer
vocabs = list(set(flatten(news_corpus))) #all the words we have in the system - <UNK>

In [55]:
#create handy mapping between integer and word
word2index = {v:idx for idx, v in enumerate(vocabs)}
word2index['dog']

4695

In [56]:
vocabs.append('<UNK>')
word2index['<UNK>'] = 6

In [57]:
index2word = {v:k for k, v in word2index.items()}
index2word[5]

'meritorious'

In [58]:
templist = [i for i in range(10)]
templist[len(templist) - 10]

for i in range(1):
    print(i)


0


## 2. Prepare train data

In [59]:
#create pairs of center word, and outside word

def random_batch(batch_size, news_corpus,windows_size=2):

    skipgrams = []

    #loop each corpus
    for doc in news_corpus:
        #look from the 2nd word until second last word
        for i in range(windows_size, len(doc)-windows_size):
            #center word
            center = word2index[doc[i]]
            #outside words = 2 words
            outside = []
            for j in range(windows_size):
                outside.append(word2index[doc[i-j-1]])
                outside.append(word2index[doc[i+j+1]]) 

            #for each of these two outside words, we gonna append to a list
            for _,each_out in enumerate(outside):
                skipgrams.append([center, each_out])
                    #center, outside1;   center, outside2

                
    random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)
    
    inputs, labels = [], []
    for index in random_index:
        inputs.append([skipgrams[index][0]])
        labels.append([skipgrams[index][1]])
        
    return np.array(inputs), np.array(labels)
            
x, y = random_batch(2, news_corpus,2)

In [60]:
x.shape  #batch_size, 1

(2, 1)

In [61]:
x

array([[9435],
       [8333]])

In [62]:
y.shape  #batch_size 1

(2, 1)

In [63]:
y

array([[1931],
       [6081]])

## Negative Sampling

In [64]:
z = 0.001

#count
from collections import Counter

word_count = Counter(flatten(news_corpus))
word_count

#get the total number of words
num_total_words = sum([c for w, c in word_count.items()])
num_total_words

unigram_table = []

for v in vocabs:
    uw = word_count[v] / num_total_words
    uw_alpha = int((uw ** 0.75) / z)
    unigram_table.extend([v] * uw_alpha)
    
Counter(unigram_table)

Counter({'the': 114,
         ',': 108,
         '.': 89,
         'of': 69,
         'and': 55,
         'to': 55,
         'a': 52,
         'in': 50,
         'for': 30,
         'The': 26,
         'that': 26,
         '``': 24,
         "''": 24,
         'is': 24,
         'was': 24,
         'on': 22,
         'at': 21,
         'be': 19,
         'with': 19,
         'by': 18,
         'as': 18,
         'he': 17,
         'will': 15,
         'his': 15,
         'said': 15,
         'it': 14,
         'from': 14,
         'are': 13,
         ';': 13,
         'has': 12,
         'had': 12,
         'an': 12,
         '--': 12,
         'Mrs.': 11,
         'this': 11,
         'not': 11,
         'were': 11,
         'who': 11,
         'have': 11,
         'their': 10,
         'which': 10,
         'would': 10,
         'been': 9,
         'they': 9,
         'He': 9,
         'or': 8,
         'I': 8,
         'more': 8,
         'but': 8,
         'last': 8,
         'out'

## Skipgram Model


In [65]:
len(vocabs)

14395

In [66]:
embedding = nn.Embedding(7, 2)

In [67]:
# x_tensor = torch.LongTensor(x)
# embedding(x_tensor).shape  #(batch_size, 1, emb_size)

$$P(o|c)=\frac{\exp(\mathbf{u_o^{\top}v_c})}{\sum_{w=1}^V\exp(\mathbf{u_w^{\top}v_c})}$$

In [68]:
class Skipgram(nn.Module):
    
    def __init__(self, voc_size, emb_size):
        super(Skipgram, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
    
    def forward(self, center, outside, all_vocabs):
        center_embedding     = self.embedding_center(center)  #(batch_size, 1, emb_size)
        outside_embedding    = self.embedding_center(outside) #(batch_size, 1, emb_size)
        all_vocabs_embedding = self.embedding_center(all_vocabs) #(batch_size, voc_size, emb_size)
        
        top_term = torch.exp(outside_embedding.bmm(center_embedding.transpose(1, 2)).squeeze(2))
        #batch_size, 1, emb_size) @ (batch_size, emb_size, 1) = (batch_size, 1, 1) = (batch_size, 1) 

        lower_term = all_vocabs_embedding.bmm(center_embedding.transpose(1, 2)).squeeze(2)
        #batch_size, voc_size, emb_size) @ (batch_size, emb_size, 1) = (batch_size, voc_size, 1) = (batch_size, voc_size) 
        
        lower_term_sum = torch.sum(torch.exp(lower_term), 1)  #(batch_size, 1)
        
        loss = -torch.mean(torch.log(top_term / lower_term_sum))  #scalar
        
        return loss
        

In [69]:
#prepare all vocabs

batch_size = 2
voc_size   = len(vocabs)

def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index["<UNK>"], seq))
    return torch.LongTensor(idxs)

all_vocabs = prepare_sequence(list(vocabs), word2index).expand(batch_size, voc_size)
all_vocabs

tensor([[    0,     1,     2,  ..., 14392, 14393,     6],
        [    0,     1,     2,  ..., 14392, 14393,     6]])

In [70]:
model = Skipgram(voc_size, 2)
model

Skipgram(
  (embedding_center): Embedding(14395, 2)
  (embedding_outside): Embedding(14395, 2)
)

In [71]:
input_tensor = torch.LongTensor(x)
label_tensor = torch.LongTensor(y)

In [72]:
loss = model(input_tensor, label_tensor, all_vocabs)

In [73]:
loss

tensor(9.9786, grad_fn=<NegBackward0>)

## Skipgram Model (Neg Sampling)

In [74]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index['<UNK>'], seq))
    return torch.LongTensor(idxs)


import random

def negative_sampling(targets, unigram_table, k):
    batch_size = targets.shape[0]
    neg_samples = []
    for i in range(batch_size):  #(1, k)
        target_index = targets[i].item()
        nsample      = []
        while (len(nsample) < k):
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).reshape(1, -1))
        
    return torch.cat(neg_samples) #batch_size, k

batch_size = 2
x, y = random_batch(batch_size, news_corpus,2)
x_tensor = torch.LongTensor(x)
y_tensor = torch.LongTensor(y)

k = 5
neg_samples = negative_sampling(y_tensor, unigram_table, k)

In [75]:
class SkipgramNeg(nn.Module):
    
    def __init__(self, voc_size, emb_size):
        super(SkipgramNeg, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
        self.logsigmoid        = nn.LogSigmoid()
    
    def forward(self, center, outside, negative):
        #center, outside:  (bs, 1)
        #negative       :  (bs, k)
        
        center_embed   = self.embedding_center(center) #(bs, 1, emb_size)
        outside_embed  = self.embedding_outside(outside) #(bs, 1, emb_size)
        negative_embed = self.embedding_outside(negative) #(bs, k, emb_size)
        
        uovc           = outside_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, 1)
        ukvc           = -negative_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, k)
        ukvc_sum       = torch.sum(ukvc, 1).reshape(-1, 1) #(bs, 1)
        
        loss           = self.logsigmoid(uovc) + self.logsigmoid(ukvc_sum)
        
        return -torch.mean(loss)

In [76]:
#test your model
emb_size = 2
voc_size = len(vocabs)
model_neg = SkipgramNeg(voc_size, emb_size)

In [77]:
loss = model_neg(x_tensor, y_tensor, neg_samples)

## 4. Training

In [78]:
batch_size = 2
emb_size   = 2
model      = Skipgram(voc_size, emb_size)
optimizer1  = optim.Adam(model.parameters(), lr=0.001)

In [79]:
optimizer2 = optim.Adam(model_neg.parameters(), lr=0.001)

In [80]:
num_epochs = 100

for epoch in range(num_epochs):
    
    #get batch
    input_batch, label_batch = random_batch(batch_size, news_corpus)
    input_tensor = torch.LongTensor(input_batch)
    label_tensor = torch.LongTensor(label_batch)
    
    #predict
    neg_samples = negative_sampling(label_tensor, unigram_table, k)
    loss2 = model_neg(x_tensor, y_tensor, neg_samples)
    loss1 = model(input_tensor, label_tensor, all_vocabs)
    
    #backprogate
    optimizer1.zero_grad()
    loss1.backward()
    optimizer2.zero_grad()
    loss2.backward()
    
    #update alpha
    optimizer1.step()
    optimizer2.step()
    
    #print the loss
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:6.0f} | Loss: {loss1:2.6f}")

Epoch     10 | Loss: 11.175952
Epoch     20 | Loss: 9.994502
Epoch     30 | Loss: 12.152360
Epoch     40 | Loss: 8.781167
Epoch     50 | Loss: 9.951160
Epoch     60 | Loss: 9.849049
Epoch     70 | Loss: 9.648080
Epoch     80 | Loss: 9.856827
Epoch     90 | Loss: 10.164693
Epoch    100 | Loss: 9.761404


## 5. Plot the embeddings

Is fruit really near to banana?
Is fruit really far from cat?

In [81]:
vocabs

['chandeliers',
 'Aeronautics',
 'La',
 'Prince',
 'holdings',
 'meritorious',
 '2,700,877',
 'stopovers',
 'carefully',
 'legitimate',
 'Hester',
 'swim',
 'lifetime',
 'comprehensive',
 'Philip',
 'fortune',
 'Music',
 'Ilona',
 'Gardner',
 'inadequacy',
 'Disarmament',
 'seconds',
 'speech',
 '11-3',
 'surprises',
 '2-3/4',
 'Super',
 'caused',
 'summitry',
 'reconstruction',
 'Times-Picayune',
 '67',
 'hampered',
 'exhibit',
 'Republicans',
 'wed',
 'expenditures',
 'Several',
 'enemies',
 'Refuses',
 'ant',
 'Interested',
 'choreography',
 'builtin',
 'tries',
 'embezzling',
 "Hall's",
 'decadence',
 'Turandot',
 'journalist',
 'Natalie',
 'decisive',
 'raising',
 'believing',
 'tube',
 'jury-tampering',
 'Stratton',
 'small',
 'Waveland',
 'alternatives',
 'Theatre',
 'lettered',
 'southpaw',
 'reckonings',
 '29-5',
 'understood',
 '420',
 'Ernie',
 'Life',
 'cornering',
 'Majesties',
 'cigaret',
 'at-bats',
 '$85',
 '11,744',
 'Revolutionary',
 'Winthrop',
 'however',
 'enrichin

In [82]:
tourists = torch.LongTensor([word2index['tourists']])
# banana = torch.LongTensor([word2index['banana']])
tourists

tensor([1965])

In [83]:
tourists_embed_c = model.embedding_center(tourists)
tourists_embed_o = model.embedding_outside(tourists)
tourists_embed   = (tourists_embed_c + tourists_embed_o) / 2
tourists_embed

tensor([[-1.2236, -0.0211]], grad_fn=<DivBackward0>)

In [84]:
tourists_embed_o

tensor([[-1.1424,  0.1548]], grad_fn=<EmbeddingBackward0>)

In [85]:
def get_embed(word):
    try:
        index = word2index[word]
    except:
        index = word2index['<UNK>']
        
    word = torch.LongTensor([word2index[word]])
    
    embed_c = model.embedding_center(word)
    embed_o = model.embedding_outside(word)
    embed   = (embed_c + embed_o) / 2
    
    return embed[0][0].item(), embed[0][1].item()

In [86]:
# get_embed('fruit')
get_embed('tourists')

(-1.2235623598098755, -0.021121598780155182)

In [87]:
get_embed('jury')

(-1.297802448272705, 0.8251262903213501)

In [88]:
get_embed('bedroom')

(-0.320132315158844, 0.9588786363601685)

In [89]:
get_embed('Gin')

(0.42304155230522156, 0.9374563694000244)

In [90]:
# x, y = get_embed(word)

In [91]:
# plt.figure(figsize=(6, 3))
# for i, word in enumerate(vocabs):
#     x, y = get_embed(word)
#     plt.scatter(x, y)
#     plt.annotate(word, xy=(x, y), xytext=(5, 2), textcoords='offset points')
# plt.show()

## 6. Cosine similarity

In [98]:
bedroom = get_embed('bedroom')
bedroom

(-0.320132315158844, 0.9588786363601685)

In [99]:
tourists = get_embed('tourists')
tourists

(-1.2235623598098755, -0.021121598780155182)

In [104]:
jury = get_embed('jury')
jury

(-1.297802448272705, 0.8251262903213501)

In [105]:
np.array(bedroom) @ np.array(jury)

1.2066644744726176

In [106]:
#more formally is to divide by its norm
def cosine_similarity(A, B):
    dot_product = np.dot(A, B)
    norm_a = np.linalg.norm(A)
    norm_b = np.linalg.norm(B)
    similarity = dot_product / (norm_a * norm_b)
    return similarity

print(cosine_similarity(np.array(tourists), np.array(jury)))
print(cosine_similarity(np.array(tourists), np.array(bedroom)))

0.83449561769138
0.30025963746734957
