# SkipGram Implementation

- It doesn't support batch input.
- That means, 1 batch = 1 sample
- It was made to fully understand the mechanism of `SkipGram`
- It needs to split into batch if using large corpus

In [1]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
]

tok_corpus = tokenize_corpus(corpus)

In [2]:
def make_vocab(tok_corpus):
    vocab = []
    for sent in tok_corpus:
        for token in sent:
            if token not in vocab:
                vocab.append(token)

    stoi = {w: idx for (idx, w) in enumerate(vocab)}
    itos = {idx: w for (idx, w) in enumerate(vocab)}

    return stoi, itos

stoi, itos = make_vocab(tok_corpus)

In [3]:
stoi

{'he': 0,
 'is': 1,
 'a': 2,
 'king': 3,
 'she': 4,
 'queen': 5,
 'man': 6,
 'woman': 7,
 'warsaw': 8,
 'poland': 9,
 'capital': 10,
 'berlin': 11,
 'germany': 12,
 'paris': 13,
 'france': 14}

In [4]:
import numpy as np


def make_idx_pairs(tok_corpus, window_size):
    idx_pairs = []
    # for each sentence
    for sentence in tok_corpus:
        indices = [stoi[word] for word in sentence]
        # for each word, threated as center word
        for center_word_pos in range(len(indices)):
            # for each window position
            for w in range(-window_size, window_size + 1):
                context_word_pos = center_word_pos + w
                # make soure not jump out sentence
                if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                    continue
                context_word_idx = indices[context_word_pos]
                idx_pairs.append((indices[center_word_pos], context_word_idx))

    idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array
    return idx_pairs

idx_pairs = make_idx_pairs(tok_corpus, 2)

In [6]:
len(idx_pairs)

70

In [27]:
idx_pairs[:5]

array([[0, 1],
       [0, 2],
       [1, 0],
       [1, 2],
       [1, 3]])

In [15]:
import torch

# one-hot encoding
def get_input_layer(word_idx, vocab_size):
    x = torch.zeros(vocab_size).float()
    x[word_idx] = 1.0
    return x

In [30]:
import torch
import torch.nn as nn

word_vec_size = 5
vocab_size = len(stoi)
lr = .001
n_epochs = 100

## Initialize model

# embedding layer
W1=torch.randn(word_vec_size, vocab_size, requires_grad=True).float()
# output layer
W2=torch.randn(vocab_size, word_vec_size, requires_grad=True).float()

# criterion
crit = nn.NLLLoss()
log_softmax = nn.LogSoftmax(dim=0)

for epoch in range(n_epochs):
    total_loss = 0
    for x, y in idx_pairs:
        x = get_input_layer(x, vocab_size).float()
        y = torch.from_numpy(np.array([y])).long()
        
        # feed forward
        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
        y_hat = log_softmax(z2)
        
        # calculate loss and back-prop
        loss = crit(y_hat.view(1, -1), y)
        loss.backward()
        
        # update parameters
        W1.data -= lr * W1.grad.data
        W2.data -= lr * W2.grad.data
        
        # zero gradient
        W1.grad.data.zero_()
        W2.grad.data.zero_()

        total_loss += loss.item()
    
    if (epoch+1) % 10 == 0:
        print(f'Loss at epo {epoch+1}: {total_loss/len(idx_pairs)}')

Loss at epo 10: 4.245465894256319
Loss at epo 20: 3.8014162080628533
Loss at epo 30: 3.512472382613591
Loss at epo 40: 3.2980784041540963
Loss at epo 50: 3.1274685910769873
Loss at epo 60: 2.9866258433886936
Loss at epo 70: 2.8684525012969972
Loss at epo 80: 2.7687465531485422
Loss at epo 90: 2.6844719426972525
Loss at epo 100: 2.613019234793527


# Refactor above implementation
- use `Adam` optimizer
- initialize `Trainer`

In [34]:
class SkipGram(nn.Module):
    def __init__(
        self,
        vocab_size,
        word_vec_size
    ):
        super().__init__()
        
        self.emb = nn.Linear(vocab_size, word_vec_size)
        self.generator = nn.Linear(word_vec_size, vocab_size)
        # here use LogSoftmax + NLLLoss
        # instead of Softmax + CrossEntropy
        self.activation = nn.LogSoftmax(dim=0)
        
    def forward(self, x):
        z1 = self.emb(x)
        z2 = self.generator(z1)
        y_hat = self.activation(z2)
        return y_hat

class Trainer(object):
    def __init__(
        self,
        model,
        optimizer,
        crit
    ):
        self.model = model
        self.optimizer = optimizer
        self.crit = crit
        
        self.device = next(model.parameters()).device
        
    def _train(self, pairs):
        self.model.train()
        
        total_loss = 0
        for x, y in pairs:
            x = get_input_layer(x, vocab_size).float()
            y = torch.from_numpy(np.array([y])).long()
            
            # zero_grad optimizer
            self.optimizer.zero_grad()
            
            # feed forward
            y_hat = self.model(x)
            
            # calculate loss and back-prop
            loss = self.crit(y_hat.view(1, -1), y)
            loss.backward()
            
            # update parameters
            self.optimizer.step()

            total_loss += loss.item()
        
        return total_loss / len(pairs)
    
    def train(self, pairs, n_epochs):
        
        for epoch in range(n_epochs):
            train_loss = self._train(pairs)
            
            if (epoch+1) % 10 == 0:
                print(f'Loss at epoch {epoch+1}: {train_loss}')
        
        return self.model
    
vocab_size=len(stoi)
word_vec_size=5
n_epochs = 100

# initialize model
model = SkipGram(
    vocab_size=len(stoi),
    word_vec_size=5
)
optimizer = torch.optim.Adam(model.parameters())
crit = nn.NLLLoss()
trainer = Trainer(model, optimizer, crit)
model = trainer.train(idx_pairs, n_epochs)

Loss at epoch 10: 2.392548942565918
Loss at epoch 20: 2.2273186266422274
Loss at epoch 30: 2.1401266276836397
Loss at epoch 40: 2.073872801235744
Loss at epoch 50: 2.0164452391011376
Loss at epoch 60: 1.9619420647621155
Loss at epoch 70: 1.9087958778653826
Loss at epoch 80: 1.8587679088115692
Loss at epoch 90: 1.8144734740257262
Loss at epoch 100: 1.7769894080502646
