# Train a customized word2vec

Includes negative sampling rate.</br>
Does not include subsampling, this is todo.</br>

### Imports

In [None]:
from pathlib import Path
import sys

import tabulate
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torch

import pandas as pd
import numpy as np

sys.path.insert(0, '../')

import pathvecs
from pathvecs.pytorch import WordContextDataset

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

### Config

In [None]:
# Place where the pipeline artifacts are going. default: {project}/data
data_path = Path('../').resolve().joinpath('data')

# Name of the input triples folder in data/triples
dataset_name = 'wikipedia_20220101'

# Batch size to run in the forward pass between weight updates
batch_size = 2048

# Total passes through the dataset
num_epochs = 1

# # Words with frequency > ssr are downsampled. 1e-5 used in original paper, 0 for none
# subsample_rate = 1e-5

# Number of negative examples to pair with each training sample
negative_samples = 10

### Model

In [None]:
class SkipGramModel(nn.Module):
    """ Skip gram with negative sampling """

    def __init__(self, wvocab, cvocab, emb_dim):
        
        super().__init__()
        
        # Vocabulary maps
        self.w2i = wvocab
        self.i2w = {i: w for w, i in wvocab.items()}
        
        self.c2i = cvocab
        self.i2c = {i: c for c, i in cvocab.items()}
        
        # Model parameters
        self.emb_dim = emb_dim

        self.device = torch.device('cpu')
        
        self.w_embeddings = nn.Embedding(len(wvocab), emb_dim, sparse=True)
        self.c_embeddings = nn.Embedding(len(cvocab), emb_dim, sparse=True)
    
        nn.init.uniform_(self.w_embeddings.weight, -1.0, 1.0)
        nn.init.uniform_(self.c_embeddings.weight, -1.0, 1.0)

    def forward(self, w_pos, c_pos, c_neg):
        """
        With B = batch_size, N = negative_samples
        w_pos: 1 x B
        c_pos: 1 x B
        c_neg: B x N
        """

        w_emb = self.w_embeddings(w_pos)
        c_emb = self.c_embeddings(c_pos)
        c_neg_emb = self.c_embeddings(c_neg)

        score = torch.sum(torch.mul(w_emb, c_emb), dim=1)
        score = F.logsigmoid(score)
        
        neg_score = torch.bmm(c_neg_emb, w_emb.unsqueeze(2)).squeeze()
        neg_score = F.logsigmoid(-neg_score) # 
        neg_score = torch.sum(neg_score, dim=1)
        
        return torch.sum(score), torch.sum(neg_score)
    
    def top_w_sims(self, word, k=5):

        topk_sims = F.cosine_similarity(
            self.w_embeddings.weight[self.w2i[word]],
            self.w_embeddings.weight
        ).topk(k)
        
        for wi, sim in zip(topk_sims.indices.data.tolist(), topk_sims.values.data.tolist()):
            yield self.i2w[wi], sim


def log_sample_neighbors(model, words, k=5):
    
    data = {}
    for i, word in enumerate(words):
        si = 's{}'.format(i)
        data[word] = []
        data[si] =[]
        
        for other_word, sim in model.top_w_sims(word, k=k):
            data[word].append(other_word)
            data[si].append('{:2.3f}'.format(sim))

    print(tabulate.tabulate(data, headers='keys'))

### Load

In [None]:
wvocab = {}
with open(data_path.joinpath('vocab', dataset_name, 'wvocab.txt')) as infile:
    for i, line in enumerate(infile.readlines()):
        wvocab[line.strip()] = i

In [None]:
cvocab = {}
with open(data_path.joinpath('vocab', dataset_name, 'cvocab.txt')) as infile:
    for i, line in enumerate(infile.readlines()):
        cvocab[line.strip()] = i

In [None]:
# Pre-shuffle once up front
dataset_fp = data_path.joinpath('pairs', dataset_name, 'pairs.pt')
word_context_pairs = torch.load(dataset_fp)
word_context_pairs = word_context_pairs[torch.randperm(len(word_context_pairs))]

In [None]:
dataset = WordContextDataset(
    pairs_data=word_context_pairs,
    negative_samples=negative_samples
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    num_workers=2
)

### Train

In [None]:
model = SkipGramModel(
    wvocab=wvocab,
    cvocab=cvocab,
    emb_dim=128
)

optimizer = torch.optim.SparseAdam(model.parameters(), lr=1e-2)

for epoch in range(num_epochs):
    
    batches = enumerate(iter(dataloader))
    for i, (pos_u, pos_v, neg_v) in tqdm(batches, total=len(dataloader)):

        pos_u = Variable(pos_u)
        pos_v = Variable(pos_v)
        neg_v = Variable(neg_v)

        optimizer.zero_grad()

        pos_score, neg_score = model(pos_u, pos_v, neg_v)

        loss = -1 * (pos_score + neg_score).sum() / batch_size

        loss.backward()

        optimizer.step()

        if i % 2500 == 0:
            print('\nloss:', loss.data.tolist())
            print('-'*50)
            log_sample_neighbors(model, ['book', 'be_leader_of', 'be_author_of', 'poss_brother_appos', 'lead', 'write', 'move-to'], k=5)