nomenclature:
- target - word in context window
- context - given word (input)

In [19]:
import torch
from torch import nn, optim
from torch.nn import functional as F

from tqdm import tqdm
from nltk.tokenize import RegexpTokenizer

In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

#### Step 1. Data Preparation

In [21]:
tokenizer = RegexpTokenizer(r"\w+")
with open("data.txt", "r") as f:
  corpus = f.read().lower()
  corpus = tokenizer.tokenize(corpus)
  vocab = sorted(set(corpus))
  vocab_size = len(vocab)
  encoded_vocab = torch.tensor(range(vocab_size))  # used for one hot encoding

In [22]:
# Embedding Preparation
EMB_DIM = 150
context_emb = torch.randn(vocab_size, EMB_DIM).to(device)
target_emb = torch.randn(vocab_size, EMB_DIM).to(device)
vocab_ohe = F.one_hot(encoded_vocab, vocab_size).to(device).float()

context_emb.shape, target_emb.shape, vocab_ohe.shape

(torch.Size([11456, 150]),
 torch.Size([11456, 150]),
 torch.Size([11456, 11456]))

In [23]:
context_emb.requires_grad_(True)
target_emb.requires_grad_(True)

tensor([[ 3.7779e-01, -2.6384e-01, -5.7726e-01,  ...,  1.3826e+00,
         -6.3111e-01, -1.2829e+00],
        [-1.3845e+00,  1.0692e+00, -3.8071e-01,  ..., -1.6524e+00,
         -1.2421e+00, -6.8716e-01],
        [ 7.9428e-02, -1.8727e-03, -1.2963e+00,  ...,  6.6184e-01,
          1.2117e+00, -1.8842e-01],
        ...,
        [ 7.5691e-01,  7.6872e-01,  1.2065e-01,  ...,  1.0301e+00,
          1.2501e+00,  1.8313e+00],
        [-2.9517e-02,  2.0976e-01, -4.2028e-01,  ...,  4.1318e-01,
          7.7671e-01, -5.0845e-01],
        [ 2.1846e+00, -2.0764e+00, -1.4999e-01,  ..., -7.8781e-02,
          4.4944e-01, -5.5593e-01]], device='cuda:0', requires_grad=True)

#### Step 2. Skip-gram

In [24]:
loss_fn = nn.NLLLoss()
optimizer = optim.Adam([context_emb, target_emb], lr=0.001)

In [25]:
WINDOW_SIZE = 2

for epoch in range(1):
  for context_i in tqdm(range(WINDOW_SIZE, len(corpus)-WINDOW_SIZE-1)):
    prev_targets = [vocab.index(w) for w in corpus[context_i-WINDOW_SIZE:context_i]]
    past_targets = [vocab.index(w) for w in corpus[context_i+1:context_i+WINDOW_SIZE+1]]

    context = torch.tensor(vocab.index(corpus[context_i]))
    targets = torch.tensor(prev_targets + past_targets).to(device)

    ohe_context = vocab_ohe[context].unsqueeze(0)
    context_word_emb = ohe_context @ context_emb
    probability = torch.softmax(context_word_emb @ target_emb.T, dim=1)  # Over all vocab
    loss = 0
    for target in targets:
      loss += loss_fn(probability, target.unsqueeze(0))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

100%|██████████| 208525/208525 [08:34<00:00, 405.17it/s]
