We try to train word2vec (SkipGramModel) with negative sampling. We will use huggingface tokenizer and datasets, just to 
make data processing simpler. 

Almost copied from https://github.com/Andras7/word2vec-pytorch/blob/master/word2vec/model.py

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn import init

# huggingface
# import datasets
# from transformers import AutoTokenizer
# from tokenizers import 

# from fastai.text.data import Numericalize
import pytorch_lightning as pl
from nltk.tokenize.treebank import TreebankWordTokenizer
from typing import List
import numpy as np
import random
from argparse import ArgumentParser

Create a dataloader from the dataset

In [2]:
# !wget https://s3.amazonaws.com/fast-ai-nlp/wikitext-2.tgz
# !tar -zxvf wikitext-2.tgz
# !head wikitext-2/train.csv

In [3]:
class Vocab():
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
    
    def update_vocab(self, tokens: List):
        for token in tokens:
            if not token in self.word2idx.keys():
                next_idx = len(self.idx2word)
                self.word2idx[token] = next_idx
                self.idx2word[next_idx] = token
                
    def __len__(self):
        return len(self.word2idx)
            

In [4]:
class NGramDataSet(Dataset):
    def __init__(self, filepath, n_half=2, neg_size=5):
        self.n_half = n_half
        self.neg_size = neg_size
        
        # read the text line by line
        tokenizer = TreebankWordTokenizer()
        with open(filepath,'r') as f:
            lines = f.readlines()
        print(f'num_lines is {len(lines)}')
        
        # tokenize and combine the text
        self.tokenized_text = []
        for line in lines:
            self.tokenized_text.extend(tokenizer.tokenize(line.lower()))
            
        # build vocab
        self.vocab = Vocab()
        self.vocab.update_vocab(self.tokenized_text)
            
        self.len = len(self.tokenized_text) - 2*self.n_half
        
    def __getitem__(self, center_idx):
        '''
        return the index of context words and center word
        '''
        idx = random.randrange(self.n_half, self.len - self.n_half)
        item_txt = (self.tokenized_text[center_idx:center_idx+self.n_half] + 
                    self.tokenized_text[center_idx+self.n_half+1:center_idx+self.n_half*2+1])
        context_idxs = [self.vocab.word2idx[word] for word in item_txt]
        center_idx = self.vocab.word2idx[self.tokenized_text[center_idx+self.n_half]]
        neg_idxs = random.choices(list(self.vocab.idx2word.keys()), k=self.neg_size)
        return torch.tensor(center_idx), torch.tensor(context_idxs), torch.tensor(neg_idxs)

    def __len__(self):
        return self.len

### Create the neural network

In [35]:
class SkipGramModel(pl.LightningModule):
    def __init__(self, vocab_size, emb_size, hparams):
        super(SkipGramModel, self).__init__()
        self.hparams = hparams
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        
        # center word embeddings
        self.u_embeddings = nn.Embedding(vocab_size, emb_size, sparse=True)
        # context word embeddings
        self.v_embeddings = nn.Embedding(vocab_size, emb_size, sparse=True)
        self.loss_func = nn.NLLLoss()
        
        # initialization
        initrange = 1.0/emb_size
        init.uniform_(self.u_embeddings.weight.data, -initrange, initrange)
        init.constant_(self.v_embeddings.weight.data, 0)
        
    def train_dataloader(self):
        return train_dl
        
    def forward(self, pos_u, pos_v, neg_v):
        # similarity score with positive samples
        emb_v = self.v_embeddings(pos_v)
        emb_u = self.u_embeddings(pos_u.unsqueeze(1).repeat(1,pos_v.shape[1]))
        pos_similarity = torch.sum(torch.mul(emb_u, emb_v), dim=[1,2])
        pos_score = -F.logsigmoid(pos_similarity)

        # similarity score with negative samples
        emb_neg_v = v_embeddings(neg_v)
        emb_neg_u = v_embeddings(pos_u.unsqueeze(1).repeat(1,neg_v.shape[1]))
        neg_similarity = torch.sum(torch.mul(emb_neg_v, emb_neg_u), dim=[1,2])
        neg_score = -F.logsigmoid(-neg_similarity)
        
        return torch.mean(pos_score + neg_score)
        
    def training_step(self, batch, batch_idx):
        pos_u, pos_v, neg_v = batch
        score = self.forward(pos_u, pos_v, neg_v)
        return pl.TrainResult(minimize=score)  
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), self.hparams.lr)

In [26]:
parser = ArgumentParser()
parser.add_argument('--bsz', type=int)
parser.add_argument('--max-epochs', type=int)
parser.add_argument('--lr', type=float)
hparams = parser.parse_args('--bsz 512 --max-epochs 10 --lr 0.0001'.split())
hparams

Namespace(bsz=512, lr=0.0001, max_epochs=10)

In [30]:
filepath = 'wikitext-2/train.csv'
ds = NGramDataSet(filepath)
train_dl = DataLoader(ds, batch_size=hparams.bsz)

num_lines is 37333


In [36]:
vocab_size = len(ds.vocab.word2idx)
emb_size = 300

model = SkipGramModel(vocab_size, emb_size, hparams)
trainer = pl.Trainer(max_epochs=hparams.max_epochs)
trainer.fit(model)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name         | Type      | Params
-------------------------------------------
0 | u_embeddings | Embedding | 8 M   
1 | v_embeddings | Embedding | 8 M   
2 | loss_func    | NLLLoss   | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..





1

### Playground