In [1]:
import torch
import numpy as np
import pandas as pd

In [2]:
corpus = ['king is a strong man', 
          'queen is a wise woman', 
          'boy is a young man',
          'girl is a young woman',
          'prince is a young king',
          'princess is a young queen',
          'man is strong', 
          'woman is pretty',
          'prince is a boy will be king',
          'princess is a girl will be queen']

In [3]:
def remove_stop_words(corpus):
    stop_words = ["is", "a", "will", "be"]
    results = []
    for text in corpus:
        tmp = text.split(' ')
        for stop_word in stop_words:
            if stop_word in tmp:
                tmp.remove(stop_word)
        results.append(" ".join(tmp))
    return results

In [4]:
corpus = remove_stop_words(corpus)

In [5]:
def get_set(corpus):
    words = []
    for line in corpus:
        for word in line.split(' '):
            words.append(word)
    return set(words)

In [6]:
words = get_set(corpus)
words

{'boy',
 'girl',
 'king',
 'man',
 'pretty',
 'prince',
 'princess',
 'queen',
 'strong',
 'wise',
 'woman',
 'young'}

## Skip-gram part

In [7]:
word2int = {}
for i, word in enumerate(words):
    word2int[word] = i
word2int

{'prince': 0,
 'girl': 1,
 'queen': 2,
 'man': 3,
 'boy': 4,
 'young': 5,
 'pretty': 6,
 'woman': 7,
 'princess': 8,
 'strong': 9,
 'king': 10,
 'wise': 11}

In [8]:
sentences = []
for sentence in corpus:
    sentences.append(sentence.split(' '))
    
sentences

[['king', 'strong', 'man'],
 ['queen', 'wise', 'woman'],
 ['boy', 'young', 'man'],
 ['girl', 'young', 'woman'],
 ['prince', 'young', 'king'],
 ['princess', 'young', 'queen'],
 ['man', 'strong'],
 ['woman', 'pretty'],
 ['prince', 'boy', 'king'],
 ['princess', 'girl', 'queen']]

In [9]:
data = []
window_size = 2
for sentence in sentences:
    for i, target in enumerate(sentence):
        for word in sentence[max(0, i-window_size):min(len(sentence), i + window_size)+1]:
            if target != word:
                data.append([target, word])

In [10]:
df = pd.DataFrame(data, columns=['input', 'label'])
df.head(10)

Unnamed: 0,input,label
0,king,strong
1,king,man
2,strong,king
3,strong,man
4,man,king
5,man,strong
6,queen,wise
7,queen,woman
8,wise,queen
9,wise,woman


In [11]:
df.shape

(52, 2)

In [12]:
word2int

{'prince': 0,
 'girl': 1,
 'queen': 2,
 'man': 3,
 'boy': 4,
 'young': 5,
 'pretty': 6,
 'woman': 7,
 'princess': 8,
 'strong': 9,
 'king': 10,
 'wise': 11}

## PyTorch things

In [13]:
class Word2Vec(torch.nn.Module):
    def __init__(self, ONE_HOT_DIM, EMBEDDING_DIM):
        super(Word2Vec, self).__init__()
        self.linear1 = torch.nn.Linear(ONE_HOT_DIM, EMBEDDING_DIM)
        self.linear2 = torch.nn.Linear(EMBEDDING_DIM, ONE_HOT_DIM)
        self.softmax = torch.nn.Softmax()
        
    def forward(self, x):
        z1 = self.linear1(x)
        z2 = self.linear2(z1)
        y_pred = self.softmax(z2)
        return y_pred

In [14]:
ONE_HOT_DIM = len(words)

# convert numbers to one hot vectors
def to_one_hot_encoding(index):
    one_hot_encoding = np.zeros(ONE_HOT_DIM)
    one_hot_encoding[index] = 1
    return one_hot_encoding

X = [] # input word
Y = [] # target word

for x, y in zip(df['input'], df['label']):
    X.append(to_one_hot_encoding(word2int[x]))
    Y.append(to_one_hot_encoding(word2int[y]))
    
X_train = np.float32(X)
Y_train = np.float32(Y)

X_train = torch.from_numpy(X_train)
Y_train = torch.from_numpy(Y_train)

print(X_train.shape)
print(Y_train.shape)

print(X_train[0])
print(Y_train[0])

EMBEDDING_DIM = 2

# model = torch.nn.Sequential(
#     torch.nn.Linear(ONE_HOT_DIM, EMBEDDING_DIM),
#     torch.nn.Linear(EMBEDDING_DIM, ONE_HOT_DIM),
#     torch.nn.Softmax()
# )

model = Word2Vec(ONE_HOT_DIM, EMBEDDING_DIM)

loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 0.05

torch.Size([52, 12])
torch.Size([52, 12])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])


In [15]:
for t in range(500):
    
    y_pred = model(X_train)
    loss = loss_fn(y_pred, Y_train)
    print(t, loss.item())
    
    model.zero_grad()
    
    loss.backward()
    
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

0 48.39484405517578
1 48.118446350097656
2 47.92127990722656
3 47.7667350769043
4 47.64070129394531
5 47.53601837158203
6 47.44791793823242
7 47.372676849365234
8 47.30727005004883
9 47.24929428100586
10 47.19692611694336
11 47.148826599121094
12 47.10403823852539
13 47.061885833740234
14 47.02190017700195
15 46.98372268676758
16 46.94709777832031
17 46.91181182861328
18 46.877685546875
19 46.84457015991211
20 46.812320709228516
21 46.78081130981445
22 46.749916076660156
23 46.71951675415039
24 46.68950271606445
25 46.659759521484375
26 46.63018035888672
27 46.60065841674805
28 46.57109069824219
29 46.54137420654297
30 46.51140594482422
31 46.481082916259766
32 46.45030975341797
33 46.418983459472656
34 46.38700866699219
35 46.354286193847656
36 46.320716857910156
37 46.28620529174805
38 46.25065612792969
39 46.21398162841797
40 46.17608642578125
41 46.13688659667969
42 46.09630584716797
43 46.054264068603516
44 46.01070785522461
45 45.965576171875
46 45.91883850097656
47 45.8704719543

  # This is added back by InteractiveShellApp.init_path()


354 39.292633056640625
355 39.28978729248047
356 39.28697204589844
357 39.28419494628906
358 39.28145217895508
359 39.27873992919922
360 39.27606201171875
361 39.273414611816406
362 39.27079772949219
363 39.268211364746094
364 39.265655517578125
365 39.263126373291016
366 39.26062774658203
367 39.25815200805664
368 39.255706787109375
369 39.25328826904297
370 39.250892639160156
371 39.2485237121582
372 39.24618148803711
373 39.24386215209961
374 39.2415657043457
375 39.23929214477539
376 39.23704147338867
377 39.23481750488281
378 39.232608795166016
379 39.23042678833008
380 39.22826385498047
381 39.22612380981445
382 39.224002838134766
383 39.221900939941406
384 39.219818115234375
385 39.21775817871094
386 39.21571350097656
387 39.21369171142578
388 39.21168518066406
389 39.20969772338867
390 39.20772933959961
391 39.205780029296875
392 39.2038459777832
393 39.201927185058594
394 39.20002746582031
395 39.198143005371094
396 39.1962776184082
397 39.19442367553711
398 39.19259262084961


In [16]:
print(model.linear1.weight.shape)

torch.Size([2, 12])


In [17]:
print(model.linear1.weight.detach().numpy())

[[ 1.022875   -0.8754933  -0.11785867  0.21894585  1.100892    0.53299946
  -1.715776   -0.6139958  -1.2047893   3.0685062   0.9126718  -2.7486134 ]
 [-1.2745206   0.33084053  0.5941604  -2.7677171   0.33171302  0.6297283
   2.1866744  -0.3355597  -0.9603005  -0.43434876  0.68866915  0.44242463]]
