In [72]:
import numpy as np

import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

import plotly.express as px
from sklearn.manifold import TSNE

from tqdm import tqdm


In [73]:
path = '/content/nn_text.txt'

with open(path) as text:
  text = text.read()

for i in [',', '\n', "’",  "‘", '(', ')', '-', '/']:
  unprepared_corpus = text.replace(i, ' ').split('.')
  
  corpus = []
  for i in unprepared_corpus:
    i = i.lower()
    for j in [',', '\n', "’",  "‘", '(', ')', '-']:
      i = i.replace(j, ' ')

    current_sentece = i.split(' ')
    if '' in current_sentece:
      current_sentece.remove('')
    corpus.append(current_sentece)

    vocabulary = []
    for sent in corpus:
      for word in sent:
        if word.isdigit() == True or word == '':
          sent.remove(word) 
        if word not in vocabulary:
          vocabulary.append(word)

    vocabulary_size = len(vocabulary)

    word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
    idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

In [74]:
window_size = 2
idx_pairs = []

for sentence in corpus:
    idx = [word2idx[word] for word in sentence]
    
    for center_word_pos in range(len(idx)):
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            
            if context_word_pos < 0 or context_word_pos >= len(idx) or center_word_pos == context_word_pos:
                continue
            context_word_idx = idx[context_word_pos]
            idx_pairs.append((idx[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs)

In [75]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size, dtype=torch.float64)
    x[word_idx] = 1.0
    return x

In [76]:
#v_embedding - центральное слово
#u_embedding - контекстное слово

In [77]:
embedding_dim = 100
v_embedding = torch.rand([embedding_dim, vocabulary_size], dtype=torch.float64, requires_grad = True)
u_embedding = torch.rand([embedding_dim, vocabulary_size], dtype=torch.float64, requires_grad = True)
learning_rate = 0.01
epochs_cnt = 100

In [78]:
for epoch in tqdm(range(epochs_cnt)):
  loss_value = 0
  
  for center_word_idx, context_word_idx in idx_pairs:

    input_vec = get_input_layer(center_word_idx)
    output_vec = torch.from_numpy(np.array([context_word_idx]))
 
    y1 = torch.matmul(v_embedding, input_vec)
    y2 = torch.matmul(y1, u_embedding)   

    log_softmax = F.log_softmax(y2, dim=0)

    loss = F.nll_loss(log_softmax.view(1,-1), output_vec)
    loss_value += loss.data
    loss.backward()

    v_embedding.data -= learning_rate * v_embedding.grad.data
    u_embedding.data -= learning_rate * u_embedding.grad.data

    v_embedding.grad.data.zero_()
    u_embedding.grad.data.zero_()
  if epoch % 10 == 0:    
        print(f'Loss at epoch {epoch}: {loss_value/len(idx_pairs)}')

  1%|          | 1/100 [00:04<07:13,  4.38s/it]

Loss at epoch 0: 6.879439665771121


 11%|█         | 11/100 [00:46<06:24,  4.32s/it]

Loss at epoch 10: 5.027844827681459


 21%|██        | 21/100 [01:28<05:32,  4.21s/it]

Loss at epoch 20: 3.807448251248995


 31%|███       | 31/100 [02:10<04:48,  4.18s/it]

Loss at epoch 30: 2.942213525885196


 41%|████      | 41/100 [02:52<04:06,  4.18s/it]

Loss at epoch 40: 2.5435397054302253


 51%|█████     | 51/100 [03:35<03:26,  4.22s/it]

Loss at epoch 50: 2.4026638117414048


 61%|██████    | 61/100 [04:16<02:43,  4.20s/it]

Loss at epoch 60: 2.348320366416317


 71%|███████   | 71/100 [04:59<02:02,  4.24s/it]

Loss at epoch 70: 2.3219544754579537


 81%|████████  | 81/100 [05:41<01:19,  4.19s/it]

Loss at epoch 80: 2.306766351036186


 91%|█████████ | 91/100 [06:23<00:37,  4.20s/it]

Loss at epoch 90: 2.2969075849271388


100%|██████████| 100/100 [07:01<00:00,  4.22s/it]


In [None]:
#тест поиск контекстных слов
cnt_context_words = 4
input_word = 1
 
y1 = torch.matmul(v_embedding, get_input_layer(input_word))
y2 = torch.matmul(y1, u_embedding)   

log_softmax = F.log_softmax(y2, dim=0)

sorted_softmax = log_softmax.sort(descending=True)
print("Target word: " + idx2word[input_word])
for i in range(cnt_context_words):
  print(idx2word[sorted_softmax.indices[i].item()])


In [80]:
numpy_v_emb = v_embedding.detach().numpy()
numpy_v_emb = numpy_v_emb.transpose()

resized_v_emb = TSNE(n_components=2, learning_rate='auto',
                   init='random').fit_transform(numpy_v_emb)

In [81]:
x_embed = []
y_embed = []

for i in range(len(resized_v_emb)):
  x_embed.append(resized_v_emb[i][0])
  y_embed.append(resized_v_emb[i][1])

In [82]:
fig = px.scatter(x=x_embed, y=y_embed, hover_name=vocabulary)
fig.show()