## Word2Vec Skip Gram Model using Pytorch

* Checkout this repo for more details: https://github.com/OlgaChernytska/word2vec-pytorch
* Well explained article: https://towardsdatascience.com/word2vec-with-pytorch-implementing-original-paper-2cd7040120b0

In [162]:
# Import
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.data import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator
import torch
from torch.utils.data import DataLoader
from functools import partial
import torch.nn as nn

In [163]:
# Constants
SKIPGRAM_N_WORDS = 4
MAX_SEQUENCE_LENGTH = 258
BATCH_SIZE = 32
EMBED_DIMENSION = 128
EMBED_NORM = 1
EPOCHS = 5

In [164]:
# Preprocessing Function
def get_eng_tokenizer():
    tokenizer = get_tokenizer('basic_english', language = "en")
    return tokenizer

In [165]:
# Processing the data
#data_iter = WikiText2(root='/Users/vijayravichander/Code/CS224N/data', split = 'train')
data_iter = WikiText2(root='data', split = 'train')
data_iter = to_map_style_dataset(data_iter)

In [166]:
tokenizer = get_eng_tokenizer()
# To create the vocab from the dataset 
vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials = ["<unk>"], min_freq= 50)
# All news words will be have "unk" token
vocab.set_default_index(vocab["<unk>"])

In [167]:
#Example of tokenizer
example_text = "Machine Learning is quite good"
print(tokenizer(example_text))
print(vocab(tokenizer(example_text)))

['machine', 'learning', 'is', 'quite', 'good']
[1016, 2849, 23, 2189, 423]


In [168]:
# To mapping processed words into
text_pipeline = lambda x: vocab(tokenizer(x))

In [169]:
# Run down of the pipeline
for data in data_iter[:5]:
    print(f"Current data: {data}")
    print(f" Length: {len(data)}")
    print(f" Tokenized: {tokenizer(data)}")
    print(f"Formatted Text: {text_pipeline(data)}")
    print(f"Length of tokens: {len(text_pipeline(data))}")
    print("-----------------")

Current data:  

 Length: 2
 Tokenized: []
Formatted Text: []
Length of tokens: 0
-----------------
Current data:  = Valkyria Chronicles III = 

 Length: 30
 Tokenized: ['=', 'valkyria', 'chronicles', 'iii', '=']
Formatted Text: [9, 3849, 3869, 881, 9]
Length of tokens: 5
-----------------
Current data:  

 Length: 2
 Tokenized: []
Formatted Text: []
Length of tokens: 0
-----------------
Current data:  Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan W

In [170]:
# Run down of th collate function
import numpy as np
SNW = 4
text_tokens_ids = list(np.random.randint(0, 400, (14)))
batch_input, batch_output = [], []
print(f"Input Tokens: {text_tokens_ids}")
for idx in range(len(text_tokens_ids) - SNW * 2):
    token_id_sequence = text_tokens_ids[idx : (idx + SNW * 2 + 1)]
    print(f"SNW tokens: {token_id_sequence}")
    input_ = token_id_sequence.pop(SNW)
    outputs = token_id_sequence
    print(f"Input Tokens: {input_}")
    print(f"Output Tokens: {outputs}")

    for output in outputs: 
        batch_input.append(input_)
        batch_output.append(output)

    print("Batched Input")
    print(batch_input)
    print("Batched Output")
    print(batch_output)
    break

Input Tokens: [281, 144, 88, 181, 245, 172, 30, 48, 89, 231, 298, 365, 17, 320]
SNW tokens: [281, 144, 88, 181, 245, 172, 30, 48, 89]
Input Tokens: 245
Output Tokens: [281, 144, 88, 181, 172, 30, 48, 89]
Batched Input
[245, 245, 245, 245, 245, 245, 245, 245]
Batched Output
[281, 144, 88, 181, 172, 30, 48, 89]


In [171]:
# Formatting the inputs and outputs for training
def collate_skipgram(batch, text_pipeline):
    batch_input, batch_output = [], []
    for text in batch:
        text_tokens_ids = text_pipeline(text)

        if len(text_tokens_ids) < SKIPGRAM_N_WORDS * 2 + 1:
            continue

        if MAX_SEQUENCE_LENGTH:
            text_tokens_ids = text_tokens_ids[:MAX_SEQUENCE_LENGTH]

        for idx in range(len(text_tokens_ids) - SKIPGRAM_N_WORDS * 2):
            token_id_sequence = text_tokens_ids[idx : (idx + SKIPGRAM_N_WORDS * 2 + 1)]
            input_ = token_id_sequence.pop(SKIPGRAM_N_WORDS)
            outputs = token_id_sequence

            for output in outputs:
                batch_input.append(input_)
                batch_output.append(output)
    
    batch_input = torch.tensor(batch_input, dtype=torch.long)
    batch_output = torch.tensor(batch_output, dtype=torch.long)
    return batch_input, batch_output

In [172]:
# The data loading part
dataloader = DataLoader(data_iter, batch_size = BATCH_SIZE, shuffle = True, collate_fn = partial(collate_skipgram, text_pipeline = text_pipeline))

In [173]:
# Skip Gram Model
class SkipGram_Model(nn.Module):
    """
    Implementation of Skip-Gram model described in paper:
    https://arxiv.org/abs/1301.3781
    """
    def __init__(self, vocab_size: int):
        super(SkipGram_Model, self).__init__()
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBED_DIMENSION,
            max_norm=EMBED_DIMENSION
        )
        self.linear = nn.Linear(
            in_features=EMBED_DIMENSION,
            out_features=vocab_size,
        )

    def forward(self, inputs_):
        x = self.embeddings(inputs_)
        x = self.linear(x)
        return x

In [174]:
# Init
model = SkipGram_Model(len(vocab))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

In [175]:
# Training
running_loss = []
for i, batch_data in enumerate(dataloader, 1):
    
    inputs = batch_data[0]
    labels = batch_data[1]

    optimizer.zero_grad()
    outputs = model(inputs)

    loss = loss_fn(outputs,labels)
    loss.backward()
    optimizer.step()

    running_loss.append(loss.item())
    if i % 10 == 0:
        print(f"Training Step: {i} Loss: {loss.item()}")
    
    if i == 1000:
        break


Training Step: 10 Loss: 7.26474142074585
Training Step: 20 Loss: 6.620956897735596
Training Step: 30 Loss: 6.205772399902344
Training Step: 40 Loss: 6.165273189544678
Training Step: 50 Loss: 6.0563435554504395
Training Step: 60 Loss: 5.704689025878906
Training Step: 70 Loss: 5.651385307312012
Training Step: 80 Loss: 5.83539342880249
Training Step: 90 Loss: 5.865797996520996
Training Step: 100 Loss: 5.726323127746582
Training Step: 110 Loss: 5.774597644805908
Training Step: 120 Loss: 5.651203632354736
Training Step: 130 Loss: 5.747607707977295
Training Step: 140 Loss: 5.60974645614624
Training Step: 150 Loss: 5.746978759765625
Training Step: 160 Loss: 5.562025547027588
Training Step: 170 Loss: 5.397178649902344
Training Step: 180 Loss: 5.892734050750732
Training Step: 190 Loss: 5.66327428817749
Training Step: 200 Loss: 5.532326698303223
Training Step: 210 Loss: 5.579980373382568
Training Step: 220 Loss: 5.450323581695557
Training Step: 230 Loss: 5.710562229156494
Training Step: 240 Loss

In [176]:
# Inference NOTE: The model was not trained much. More training leads to better performance
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()
tokens = vocab.get_itos()

In [177]:
# To get similar words
def get_top_similar(word: str, topN: int = 10):
    word_id = vocab[word]
    if word_id == 0:
        print("Out of vocabulary word")
        return

    word_vec = embeddings[word_id]
    word_vec = np.reshape(word_vec, (len(word_vec), 1))
    dists = np.matmul(embeddings, word_vec).flatten()
    topN_ids = np.argsort(-dists)[1 : topN + 1]

    topN_dict = {}
    for sim_word_id in topN_ids:
        sim_word = vocab.lookup_token(sim_word_id)
        dist = dists[sim_word_id]
        topN_dict[sim_word] = np.round(dist, 3)
    return topN_dict

In [178]:
get_top_similar("uk")

{'zealand': 37.545,
 'vietnam': 35.851,
 '130': 35.637,
 'peach': 35.626,
 'onto': 32.696,
 'bay': 32.297,
 'siege': 32.133,
 '1940': 32.068,
 'focus': 32.03,
 'operating': 31.58}

In [179]:
emb1 = embeddings[vocab["king"]]
emb2 = embeddings[vocab["man"]]
emb3 = embeddings[vocab["woman"]]

emb4 = emb1 - emb2 + emb3

emb4 = np.reshape(emb4, (len(emb4), 1))
dists = np.matmul(embeddings, emb4).flatten()

top5 = np.argsort(-dists)[:5]

for word_id in top5:
    print("{}: {:.3f}".format(vocab.lookup_token(word_id), dists[word_id]))

king: 103.393
woman: 100.410
egypt: 53.854
contributed: 51.014
2015: 49.980
