In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F

In [2]:

corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
]

In [3]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)
print(tokenized_corpus)

[['he', 'is', 'a', 'king'], ['she', 'is', 'a', 'queen'], ['he', 'is', 'a', 'man'], ['she', 'is', 'a', 'woman'], ['warsaw', 'is', 'poland', 'capital'], ['berlin', 'is', 'germany', 'capital'], ['paris', 'is', 'france', 'capital']]


In [4]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

vocabulary_size = len(vocabulary)

In [5]:
word2idx

{'he': 0,
 'is': 1,
 'a': 2,
 'king': 3,
 'she': 4,
 'queen': 5,
 'man': 6,
 'woman': 7,
 'warsaw': 8,
 'poland': 9,
 'capital': 10,
 'berlin': 11,
 'germany': 12,
 'paris': 13,
 'france': 14}

In [6]:
window_size = 2
idx_pairs = []
# for each sentence
for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    # for each word, threated as center word
    for center_word_pos in range(len(indices)):
        # for each window position
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array

In [7]:
idx_pairs[:10]

array([[0, 1],
       [0, 2],
       [1, 0],
       [1, 2],
       [1, 3],
       [2, 0],
       [2, 1],
       [2, 3],
       [3, 1],
       [3, 2]])

![alt text](https://miro.medium.com/max/377/1*uYiqfNrUIzkdMrmkBWGMPw.png)

In [8]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x
  
  #Input layer is just the center word encoded in one-hot manner. It dimensions are [1, vocabulary_size]
  
  

In [9]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 1010
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data.item()
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    if epo % 10 == 0:    
        print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')

Loss at epo 0: 4.53072595340865
Loss at epo 10: 4.0549723642213005
Loss at epo 20: 3.7252400296075003
Loss at epo 30: 3.4815810373851233
Loss at epo 40: 3.292556229659489
Loss at epo 50: 3.1397426298686435
Loss at epo 60: 3.012433751991817
Loss at epo 70: 2.9044359173093524
Loss at epo 80: 2.81195342029844
Loss at epo 90: 2.7323426961898805
Loss at epo 100: 2.663464447430202
Loss at epo 110: 2.603437714917319
Loss at epo 120: 2.5506201131003245
Loss at epo 130: 2.50364739213671
Loss at epo 140: 2.4614409719194685
Loss at epo 150: 2.4231791981628965
Loss at epo 160: 2.3882445079939707
Loss at epo 170: 2.356171919618334
Loss at epo 180: 2.326606409038816
Loss at epo 190: 2.2992689217839923
Loss at epo 200: 2.2739326545170377
Loss at epo 210: 2.25040830884661
Loss at epo 220: 2.2285300740173883
Loss at epo 230: 2.2081513336726597
Loss at epo 240: 2.1891385316848755
Loss at epo 250: 2.1713689046246665
Loss at epo 260: 2.154729962348938
Loss at epo 270: 2.139118781260082
Loss at epo 280: 2.