# Word2Vec (Negative Sampling)

In [16]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [17]:
np.__version__, torch.__version__

('1.26.4', '2.5.1+cu121')

In [18]:
import matplotlib
matplotlib.__version__

'3.10.0'

In [19]:
# Select the GPU with the most free memory
def get_free_gpu():
    # Check if CUDA is available before getting device count
    if torch.cuda.is_available():
        free_mem = [torch.cuda.memory_reserved(i) for i in range(torch.cuda.device_count())]
        return free_mem.index(min(free_mem))
    # If CUDA is not available, return -1 (or another appropriate value)
    else:
        return -1

best_gpu = get_free_gpu()

# Use the best GPU if available, otherwise use CPU
if best_gpu != -1:
    torch.cuda.set_device(best_gpu)
    print(f"Using GPU: {best_gpu}")
else:
    print("No CUDA-enabled GPUs found. Using CPU.")

Using GPU: 0


In [20]:
# Set the device variable
device = torch.device(f'cuda:{best_gpu}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

## 1. Load data

In [21]:
import nltk
nltk.download('reuters')
nltk.download('punkt_tab')

[nltk_data] Downloading package reuters to /root/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [22]:
from nltk.corpus import reuters

In [23]:
reuters.words()

['ASIAN', 'EXPORTERS', 'FEAR', 'DAMAGE', 'FROM', 'U', ...]

In [24]:
# corpus = ["apple banana fruit", "banana apple fruit", "banana fruit apple", "grape apple apple",
#         "dog cat animal", "cat dog animal", "cat animal dog", "fish dog dog"]
corpus = reuters.sents()
corpus = corpus[:10000]
len(corpus)

10000

In [25]:
# #1. tokenization
# corpus = [sent.split(" ") for sent in corpus]
# corpus

In [26]:
#2. numeralization
#find unique words
flatten = lambda l: [item for sublist in l for item in sublist]
#assign unique integer
vocabs = list(set(flatten(corpus))) #all the words we have in the system - <UNK>

In [27]:
#create handy mapping between integer and word
word2index = {v:idx for idx, v in enumerate(vocabs)}
word2index['pattern']

13185

In [28]:
last_vocab_index = len(vocabs)
last_vocab_index

18045

In [29]:
vocabs.append('<UNK>')
word2index['<UNK>'] = last_vocab_index

In [30]:
index2word = {v:k for k, v in word2index.items()}
index2word[5]

'McGill'

## 2. Prepare train data

In [31]:
#create pairs of center word, and outside word

def random_batch(batch_size, corpus, window_size=2):

    skipgrams = []

    #loop each corpus
    for doc in corpus:
        #look from the 2nd word until second last word
        for i in range(window_size, len(doc)-window_size):
            #center word
            center = word2index[doc[i]]
            #outside words = 2 words
            outside = tuple(word2index[doc[j]] for j in range(i - window_size, i + window_size + 1) if j != i)
            # outside = (word2index[doc[i-2]], word2index[doc[i-1]], word2index[doc[i+1]], word2index[doc[i+2]])
            #for each of these two outside words, we gonna append to a list
            for each_out in outside:
                skipgrams.append([center, each_out])
                #center, outside1;   center, outside2

    random_index = np.random.choice(range(len(skipgrams)), batch_size, replace=False)

    inputs, labels = [], []
    for index in random_index:
        inputs.append([skipgrams[index][0]])
        labels.append([skipgrams[index][1]])

    return np.array(inputs), np.array(labels)

x, y = random_batch(2, corpus)

In [32]:
x.shape  #batch_size, 1

(2, 1)

In [33]:
x

array([[ 1555],
       [12522]])

In [34]:
y.shape  #batch_size 1

(2, 1)

## 3. Negative Sampling

### Unigram distribution

$$P(w)=U(w)^{3/4}/Z$$

In [35]:
z = 0.001

In [36]:
#count
from collections import Counter

word_count = Counter(flatten(corpus))
word_count

#get the total number of words
num_total_words = sum([c for w, c in word_count.items()])
num_total_words

315187

In [37]:
vocabs[:100]

['TLX',
 'Bond',
 'POSITIVE',
 'Greyhound',
 'Following',
 'McGill',
 'Stock',
 'proposals',
 'noteholders',
 'Contact',
 'Mailers',
 'EXPORTERS',
 'Dreyer',
 'confidence',
 'MIDEAST',
 'keen',
 'recession',
 'FEARS',
 'stern',
 'PACKAGE',
 'Drew',
 'Du',
 'speculators',
 'UNPROFITABLE',
 'an',
 'gestures',
 'uptrend',
 'companyies',
 'page',
 'Wilf',
 'Espinosa',
 '810',
 'According',
 'Interface',
 'measurement',
 'powers',
 'MOSCOW',
 'finally',
 '409p',
 'purposes',
 'GHW',
 'initiatives',
 'Sarney',
 'exercisable',
 'FUNDING',
 'processes',
 'withdrawing',
 'eliminate',
 'stockpiled',
 'Chit',
 'CAJAMARQUILLA',
 'creditors',
 'MANAGE',
 'revoked',
 'deposited',
 'CRITICAL',
 'plannned',
 'Ishihara',
 'lumber',
 'Bentsen',
 'happier',
 'pressuring',
 'Industrielle',
 'Lyle',
 '06',
 '869',
 'Deposits',
 '6p',
 'preferential',
 'redeem',
 'Discount',
 'stimulatory',
 'Boakes',
 'restocking',
 'KO',
 'INFO',
 'REPLACES',
 'Copperbelt',
 'Gary',
 'Minority',
 'suspend',
 'Buffer',
 'E

$$P(w)=U(w)^{3/4}/Z$$

In [38]:
unigram_table = []

for v in vocabs:
    uw = word_count[v] / num_total_words
    uw_alpha = int((uw ** 0.75) / z)
    unigram_table.extend([v] * uw_alpha)

Counter(unigram_table)

Counter({'an': 11,
         'Deposits': 1,
         '24': 3,
         'cash': 3,
         'If': 1,
         'TRADE': 1,
         'exports': 3,
         'heavy': 1,
         'pressure': 1,
         'want': 1,
         'Africa': 1,
         'Harcourt': 1,
         '03': 3,
         'francs': 1,
         'PRICE': 1,
         'shareholders': 2,
         'quarter': 6,
         'LDP': 1,
         'selling': 1,
         'Western': 1,
         'sharply': 1,
         '37': 2,
         'rate': 5,
         'STAKE': 1,
         'will': 12,
         'so': 2,
         'tomorrow': 1,
         'seen': 1,
         'early': 1,
         'due': 3,
         'growth': 4,
         'added': 4,
         'up': 6,
         'although': 1,
         'DLRS': 2,
         'CANADA': 1,
         'out': 3,
         'currently': 1,
         'manager': 1,
         'predicted': 1,
         'not': 11,
         'SETS': 1,
         'inflation': 2,
         'accord': 1,
         'approved': 1,
         'part': 2,
         '--':

## 4. Model

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [39]:
def prepare_sequence(seq, word2index):
    idxs = list(map(lambda w: word2index[w] if word2index.get(w) is not None else word2index['<UNK>'], seq))
    return torch.LongTensor(idxs)

In [40]:
import random

def negative_sampling(targets, unigram_table, k):
    batch_size = targets.shape[0]
    neg_samples = []
    for i in range(batch_size):  #(1, k)
        target_index = targets[i].item()
        nsample      = []
        while (len(nsample) < k):
            neg = random.choice(unigram_table)
            if word2index[neg] == target_index:
                continue
            nsample.append(neg)
        neg_samples.append(prepare_sequence(nsample, word2index).reshape(1, -1))

    return torch.cat(neg_samples) #batch_size, k

In [41]:
batch_size = 2
x, y = random_batch(batch_size, corpus)
x_tensor = torch.LongTensor(x)
y_tensor = torch.LongTensor(y)

In [42]:
k = 5
neg_samples = negative_sampling(y_tensor, unigram_table, k)

In [43]:
y_tensor[1]

tensor([11394])

In [44]:
neg_samples[1]

tensor([ 2802, 10719,  6500,  8090,  5713])

$$\mathbf{J}_{\text{neg-sample}}(\mathbf{v}_c,o,\mathbf{U})=-\log(\sigma(\mathbf{u}_o^T\mathbf{v}_c))-\sum_{k=1}^K\log(\sigma(-\mathbf{u}_k^T\mathbf{v}_c))$$

In [45]:
class SkipgramNeg(nn.Module):

    def __init__(self, voc_size, emb_size, word2index):
        super(SkipgramNeg, self).__init__()
        self.embedding_center  = nn.Embedding(voc_size, emb_size)
        self.embedding_outside = nn.Embedding(voc_size, emb_size)
        self.logsigmoid        = nn.LogSigmoid()
        self.word2index        = word2index

    def forward(self, center, outside, negative):
        #center, outside:  (bs, 1)
        #negative       :  (bs, k)

        center_embed   = self.embedding_center(center) #(bs, 1, emb_size)
        outside_embed  = self.embedding_outside(outside) #(bs, 1, emb_size)
        negative_embed = self.embedding_outside(negative) #(bs, k, emb_size)

        uovc           = outside_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, 1)
        ukvc           = -negative_embed.bmm(center_embed.transpose(1, 2)).squeeze(2) #(bs, k)
        ukvc_sum       = torch.sum(ukvc, 1).reshape(-1, 1) #(bs, 1)

        loss           = self.logsigmoid(uovc) + self.logsigmoid(ukvc_sum)

        return -torch.mean(loss)

    def get_embed(self, word):
      word2index = self.word2index
      try:
        index = word2index[word]
      except:
        index = word2index['<UNK>']

      word = torch.LongTensor([word2index[word]]).to(device)

      embed_c = self.embedding_center(word)
      embed_o = self.embedding_outside(word)
      embed   = (embed_c + embed_o) / 2

      return embed[0][0].item(), embed[0][1].item()

In [46]:
#test your model
emb_size = 2
voc_size = len(vocabs)
model = SkipgramNeg(voc_size, emb_size, word2index)

In [47]:
loss = model(x_tensor, y_tensor, neg_samples)

In [48]:
loss

tensor(1.6778, grad_fn=<NegBackward0>)

## 5. Training

In [49]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [50]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_time, elapsed_mins, elapsed_secs

In [51]:
import time

num_epochs = 100
total_time = 0
start_time = time.time()

for epoch in range(num_epochs):

    #get batch
    input_batch, label_batch = random_batch(batch_size, corpus)
    input_tensor = torch.LongTensor(input_batch)
    label_tensor = torch.LongTensor(label_batch)

    #predict
    neg_samples = negative_sampling(label_tensor, unigram_table, k)
    loss = model(input_tensor, label_tensor, neg_samples)

    #backprogate
    optimizer.zero_grad()
    loss.backward()

    #update alpha
    optimizer.step()

    #print the loss
    if (epoch + 1) % 10 == 0:
        end_time = time.time()
        total, epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        total_time += total
        print(f"Epoch {epoch+1:6.0f} | Loss: {loss:2.6f} | time: {epoch_mins}m {epoch_secs}s")
        start_time = time.time()

Epoch     10 | Loss: 4.542861 | time: 0m 24s
Epoch     20 | Loss: 2.501101 | time: 0m 24s
Epoch     30 | Loss: 5.943749 | time: 0m 24s
Epoch     40 | Loss: 0.549180 | time: 0m 25s
Epoch     50 | Loss: 0.997851 | time: 0m 24s
Epoch     60 | Loss: 1.585015 | time: 0m 25s
Epoch     70 | Loss: 2.395042 | time: 0m 26s
Epoch     80 | Loss: 1.166544 | time: 0m 25s
Epoch     90 | Loss: 0.924157 | time: 0m 25s
Epoch    100 | Loss: 1.296758 | time: 0m 25s


In [52]:
# print total train loss and total training time
print(f"Total train loss: {loss:.6f}")
print(f"Total training time: {total_time:.2f} seconds")

Total train loss: 1.296758
Total training time: 253.03 seconds


## 6. Plot the embeddings

In [53]:
vocabs[:10]

['TLX',
 'Bond',
 'POSITIVE',
 'Greyhound',
 'Following',
 'McGill',
 'Stock',
 'proposals',
 'noteholders',
 'Contact']

In [54]:
frame = torch.LongTensor([word2index['frame']])
frame

tensor([16169])

In [55]:
frame_embed_c = model.embedding_center(frame)
frame_embed_o = model.embedding_outside(frame)
frame_embed   = (frame_embed_c + frame_embed_o) / 2
frame_embed

tensor([[0.5181, 1.2027]], grad_fn=<DivBackward0>)

In [56]:
frame_embed_o

tensor([[ 0.4644, -0.0138]], grad_fn=<EmbeddingBackward0>)

In [57]:
# def get_embed(word):
#     try:
#         index = word2index[word]
#     except:
#         index = word2index['<UNK>']

#     word = torch.LongTensor([word2index[word]])

#     embed_c = model.embedding_center(word)
#     embed_o = model.embedding_outside(word)
#     embed   = (embed_c + embed_o) / 2

#     return embed[0][0].item(), embed[0][1].item()

In [58]:
# get_embed('fruit')

In [59]:
# get_embed('cat')

In [60]:
# get_embed('dog')

In [61]:
# get_embed('banana')

In [62]:
# plt.figure(figsize=(6, 3))
# for i, word in enumerate(vocabs):
#     x, y = get_embed(word)
#     plt.scatter(x, y)
#     plt.annotate(word, xy=(x, y), xytext=(5, 2), textcoords='offset points')
# plt.show()

## 7. Save Model

In [63]:
torch.save(model.state_dict(), 'app/code/models/skipgram-neg.pt')

In [64]:
import pickle

skipgram_neg_args = {
    'word2index': word2index,
    'voc_size': voc_size,
    'emb_size': emb_size
}

pickle.dump(skipgram_neg_args, open('app/code/models/skipgrams-neg.pkl', 'wb'))

In [65]:
load_skipgram_neg_args = pickle.load(open('app/code/models/skipgrams-neg.pkl', 'rb'))
load_model = SkipgramNeg(**load_skipgram_neg_args).to(device)
load_model.load_state_dict(torch.load('app/code/models/skipgram-neg.pt'))

  load_model.load_state_dict(torch.load('app/code/models/skipgram-neg.pt'))


<All keys matched successfully>

In [66]:
load_model.get_embed('frame')

(0.5180826783180237, 1.202650547027588)

## 7. Cosine similarity

In [67]:
# banana = get_embed('banana')
# banana

In [68]:
# fruit = get_embed('fruit')
# fruit

In [69]:
# cat = get_embed('cat')
# cat

In [70]:
# np.array(banana) @ np.array(cat)

In [71]:
# #more formally is to divide by its norm
# def cosine_similarity(A, B):
#     dot_product = np.dot(A, B)
#     norm_a = np.linalg.norm(A)
#     norm_b = np.linalg.norm(B)
#     similarity = dot_product / (norm_a * norm_b)
#     return similarity

# print(cosine_similarity(np.array(banana), np.array(cat)))
# print(cosine_similarity(np.array(banana), np.array(fruit)))