### Pytorch SkipGram

In [None]:
!pip install torchdata

In [None]:
!pip install -U torchtext

In [None]:
!pip install portalocker

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, vocab
from torchtext.datasets import WikiText2

import numpy as np

In [None]:
WINDOW_SIZE = 5
BATCH_SIZE = 64
EMB_DIM = 128
EPOCHS = 4

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
  device = torch.device("cpu") 

In [None]:
!wget https://s3.amazonaws.com/fast-ai-nlp/wikitext-2.tgz -O wikitext-2.tar.gz
!tar -xvzf wikitext-2.tar.gz

--2023-04-14 21:14:32--  https://s3.amazonaws.com/fast-ai-nlp/wikitext-2.tgz
Resolving s3.amazonaws.com (s3.amazonaws.com)... 54.231.132.80, 3.5.17.145, 52.217.91.166, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|54.231.132.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4070055 (3.9M) [application/x-tar]
Saving to: ‘wikitext-2.tar.gz’


2023-04-14 21:14:34 (2.20 MB/s) - ‘wikitext-2.tar.gz’ saved [4070055/4070055]

wikitext-2/
wikitext-2/train.csv
wikitext-2/test.csv


In [None]:
def load_data(filepath):
    with open(filepath) as f:
      return f.readlines()

In [None]:
train = load_data("wikitext-2/train.csv")
test = load_data("wikitext-2/test.csv")
data = train + test

In [None]:
tokenizer = get_tokenizer("basic_english", language="en")

In [None]:
def yield_tokens(data_obj):
    for text in data_obj:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(data), specials=["<unk>"], min_freq=20)
vocab.set_default_index(vocab["<unk>"])

In [None]:
len(vocab)

8627

In [None]:
vocab['asdasdasd']

0

In [None]:
text_pipeline = lambda x: vocab(tokenizer(x))

In [None]:
def build_contexts(row, window_size=3):
    contexts = []
    for i in range(len(row)):
      central_word = row[i]
      context = [
          row[i + delta] for delta in range(-window_size, window_size + 1) 
          if delta != 0 and i + delta >= 0 and i + delta < len(row)
      ]

      for c_w in context:
        contexts.append((central_word, c_w)) 
    return contexts

In [None]:
class Word2VecDataset(Dataset):
    def __init__(self, data, vocab, wsize=3):
        self.vocab_size = len(vocab)
        self.data = [text_pipeline(item) for item in data]
        self.data = [item for text in self.data for item in text]
        self.data = build_contexts(self.data, window_size=wsize)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
dataset = Word2VecDataset(data, vocab)
train_dataloader =  DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
central_word, context = dataset[0]
central_word, context 

(9, 435)

In [None]:
class SkipGram_Model(nn.Module):
    def __init__(self, vocab_size: int):
        super(SkipGram_Model, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMB_DIM,
            max_norm=1,
        )
        self.linear = nn.Linear(
            in_features=EMB_DIM,
            out_features=vocab_size,
        )

    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x

In [None]:
vocab_size = len(vocab)
model = SkipGram_Model(vocab_size).to(device)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.03)

In [None]:
def train_model(dataloader, model, optimizer, criterion):
    model.train()
    
    losses = []
    for i, batch_data in enumerate(dataloader, 1):
      inputs = batch_data[0].to(device)
      labels = batch_data[1].to(device)

      optimizer.zero_grad()
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      losses.append(loss.item())

    epoch_loss = np.mean(losses)
    return epoch_loss

In [None]:
for epoch in range(EPOCHS):
    loss = train_model(train_dataloader, model, optimizer, criterion)
    if epoch % 2 == 0:
      print(f'Epoch {epoch}: train loss {loss}')


Epoch 0: train loss 6.262152531981089
Epoch 2: train loss 6.070154088096149


In [None]:
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()

# normalization
norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
norms = np.reshape(norms, (len(norms), 1))
embeddings_norm = embeddings / norms
embeddings_norm.shape

(8627, 128)

In [None]:
def get_top_similar(word, n=10):
    word_id = vocab[word]
    if word_id == 0:
        print("Out of vocabulary word")
        return

    word_vec = embeddings_norm[word_id]
    word_vec = np.reshape(word_vec, (len(word_vec), 1))
    dists = np.matmul(embeddings_norm, word_vec).flatten()
    top_ids = np.argsort(-dists)[1 : n + 1]

    top_dict = {}
    for sim_word_id in top_ids:
        sim_word = vocab.lookup_token(sim_word_id)
        top_dict[sim_word] = dists[sim_word_id]
    return top_dict

In [None]:
get_top_similar('hero')

{'video': 0.40826625,
 'seven': 0.40180534,
 'world': 0.3799724,
 'during': 0.3766271,
 'despite': 0.37006557,
 'along': 0.35462686,
 'the': 0.3502586,
 'fifth': 0.3493992,
 'consecutive': 0.34836656,
 'recorded': 0.34149763}