In [1]:
import numpy as np

In [2]:
def softmax(array):
    array = np.exp(array)
    return array / np.sum(array)

class Word2Vec():
    def __init__(self, vocab, embedding_size, learning_rate=1.):
        self.vocab = vocab
        self.embedding_size = embedding_size
        self.learning_rate = learning_rate
        
        self.vocab_size = len(self.vocab)
        self.word_to_index = { word: i for i, word in enumerate(self.vocab)}
        self.index_to_word = { i: word for word, i in self.word_to_index.items()}
        
        self.embedding = np.random.randn(self.vocab_size, self.embedding_size)
        self.linear = np.random.randn(self.embedding_size, self.vocab_size)
        
    def one_hot_encod(self, index):
        one_hot = np.zeros(self.vocab_size)
        one_hot[index] = 1
        return one_hot
    
    def forward(self, index):
        emb = self.embedding[index,:]
        return emb, softmax(emb @ self.linear)
    
    def criterion(self, y, output):
        return -1/self.vocab_size*np.inner(y, np.log(output))
    
    def step(self, input_index, output_index):
        y = self.one_hot_encod(output_index)
        emb, output = self.forward(input_index)

        loss = self.criterion(y, output)
        print(f"Loss - {loss:.3f}")

        dL_dx = 1/self.vocab_size*(output-y)

        grad_linear = np.outer(emb, dL_dx)
        self.linear -= self.learning_rate*grad_linear

        grad_embedding = self.linear @ dL_dx
        self.embedding[input_index,:] -= self.learning_rate*grad_embedding

In [3]:
words = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']
vocab = set(words)
embedding_size = 10

word2vec = Word2Vec(vocab, embedding_size)

## Backward
This is the forward graph
```
embedding --(@ linear)--> x --(softmax)--> o --(criterion)--> loss
```

Gradient descent step

In [4]:
input_index, output_index = 4, 1

for _ in range(20):
    word2vec.step(input_index, output_index)
_, output = word2vec.forward(input_index)
output

Loss - 0.882
Loss - 0.518
Loss - 0.237
Loss - 0.078
Loss - 0.035
Loss - 0.022
Loss - 0.016
Loss - 0.012
Loss - 0.010
Loss - 0.008
Loss - 0.007
Loss - 0.006
Loss - 0.006
Loss - 0.005
Loss - 0.005
Loss - 0.004
Loss - 0.004
Loss - 0.004
Loss - 0.003
Loss - 0.003


array([4.00139550e-03, 9.75727517e-01, 1.08158542e-03, 3.72763238e-03,
       6.02102315e-03, 7.69673623e-03, 1.68083013e-03, 6.32796903e-05])

It learns very well!