<a href="https://colab.research.google.com/github/ZvoneST/pytorch-labs/blob/master/lovro_w2v.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F

In [None]:
corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
    'algebra is in zagreb.',
    'zagreb is in croatia.' 
]

In [None]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)
print(tokenized_corpus)

[['he', 'is', 'a', 'king'], ['she', 'is', 'a', 'queen'], ['he', 'is', 'a', 'man'], ['she', 'is', 'a', 'woman'], ['warsaw', 'is', 'poland', 'capital'], ['berlin', 'is', 'germany', 'capital'], ['paris', 'is', 'france', 'capital'], ['algebra', 'is', 'in', 'zagreb.'], ['zagreb', 'is', 'in', 'croatia.']]


In [None]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

vocabulary_size = len(vocabulary)

In [None]:
window_size = 2
idx_pairs = []
# for each sentence
for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    # for each word, threated as center word
    for center_word_pos in range(len(indices)):
        # for each window position
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array

In [None]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x

In [None]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 201
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.item()
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    if epo % 10 == 0:    
        print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')

Loss at epo 0: 5.040851533909639
Loss at epo 10: 4.2932516568236885
Loss at epo 20: 4.0051409562428795
Loss at epo 30: 3.8186696502897473
Loss at epo 40: 3.6770777874522738
Loss at epo 50: 3.564348969194624
Loss at epo 60: 3.472456865840488
Loss at epo 70: 3.396070056491428
Loss at epo 80: 3.331357987721761
Loss at epo 90: 3.2755567769209546
Loss at epo 100: 3.226670405599806
Loss at epo 110: 3.183243382639355
Loss at epo 120: 3.1441958142651454
Loss at epo 130: 3.1087102830410003
Loss at epo 140: 3.076154669125875
Loss at epo 150: 3.0460299293200177
Loss at epo 160: 3.0179349369472925
Loss at epo 170: 2.9915415949291653
Loss at epo 180: 2.966576886177063
Loss at epo 190: 2.9428104552957746
Loss at epo 200: 2.9200454857614306


In [None]:
W1

tensor([[-7.6717e-01,  3.1300e-01, -7.8656e-02, -1.0394e+00,  1.6313e-01,
          5.7106e-01,  5.2958e-01, -6.4902e-01, -2.8531e-01, -3.9432e-01,
          4.4901e-01,  1.0763e+00,  1.8754e+00, -7.1248e-01,  1.9051e+00,
         -1.4991e-01, -4.5832e-01,  7.8424e-01,  2.1628e+00, -9.0415e-01],
        [-5.8844e-01, -4.2341e-03,  6.3293e-01,  1.2036e+00, -1.1293e+00,
         -5.5735e-01, -3.8308e-01, -1.1101e+00, -1.2589e-01,  8.8074e-01,
         -6.2040e-01,  9.6201e-01, -1.0301e+00,  2.5325e+00,  3.0698e-01,
          6.4078e-01, -7.5463e-01, -5.7654e-01,  1.0104e+00, -1.8471e-01],
        [ 1.8943e+00,  4.1081e-01, -6.7337e-01, -7.8644e-01,  8.8529e-01,
         -4.5183e-01,  1.1362e+00,  5.3620e-01, -3.3223e-01,  3.1357e-01,
         -9.2043e-01,  6.5207e-01, -2.6594e-01, -3.0699e-01, -5.4055e-04,
         -1.2905e+00,  8.1839e-01, -1.4937e+00, -4.3169e-02, -8.9843e-01],
        [-2.6818e-01,  1.6009e-02, -1.7224e-01,  2.3696e-02, -1.1585e+00,
         -3.4978e-02,  1.6333e-01, 