In [2]:
import numpy as np
import re
from collections import Counter

In [4]:
with open('TheBlackCat.txt', 'r') as f:
    text = f.read()

### Skipgram on numpy

In [475]:
class skip_gram():
    def preprocess(self, text):
        """
        Preprocessing of the text
        """
        text = text.lower() 
        text = text.strip()
        text = sub("[^A-Za-z]+", ' ', text)
        text = sub('\s+', ' ', text)
        return text
    
    def make_vocab(self, text):
        """
        Returns one-hot-encodded vectors for words from the given text
        """
        text_data = self.preprocess(text)
        vocab = dict(Counter(text_data.split()))
        V = len(vocab) 
        one_hot_matrix = np.eye(V, dtype=float)
        for i, key in enumerate(sorted(vocab.keys())):
            vocab[key] = one_hot_matrix[i][:, np.newaxis]
        return vocab
    
    def __init__(self, text, embedding_size=25):
        self.corpus = self.preprocess(text).split()
        self.vocab = self.make_vocab(text)
        self.V = len(self.vocab)
        self.embedding_size = embedding_size
        
        self.W_in = 0.01*np.random.randn(self.embedding_size, self.V)
        self.b_in = np.zeros(self.embedding_size)[:, np.newaxis]
        self.W_out = 0.01*np.random.randn(self.V, self.embedding_size)
        self.b_out = np.zeros(self.V)[:, np.newaxis]
    
    def softmax(self, x):
        """
        Computes softmax of given tensor
        """
        return np.exp(x)/np.exp(x).sum()
    
    def predict(self, word):
        """
        Predicts context words by given center word
        """
        projection = self.W_in @ self.vocab[word] + self.b_in
        softmax_result = self.softmax(self.W_out @ projection + self.b_out)
        return softmax_result
    
    def get_loss(self, pred, target):
        """
        Returns negative log(P(context_word|center_word))
        """
        #print(target @ pred)
        return -np.log((target.T @ pred)/np.sum(pred))
        
    def fit(self, lr=0.02, window_size=3, epochs=10, return_history=False):
        """
        Fits the model
        """
        losses = []
        for i in range(epochs):
            print("Epoch: ", i)
            epoch_loss = 0
            for i, center_word in enumerate(self.corpus):
                #print('center:', center_word)
                
                context = []
                if i < window_size:
                    context += self.corpus[:i]
                    context += self.corpus[i+1:i+window_size+1]
                elif len(self.corpus) - i <= window_size:
                    context += self.corpus[i-window_size:i]
                    context += self.corpus[i+1:]
                else:
                    context += self.corpus[i-window_size:i]
                    context += self.corpus[i+1:i+window_size+1]
                
                # prediction
                projection = self.W_in @ self.vocab[center_word] + self.b_in
                softmax_result = self.softmax(self.W_out @ projection + self.b_out)
                
                #print("Context ", str(context))
                context_error = 0
                
                for context_word in context:
                    target = self.vocab[context_word]
                    context_error += softmax_result - target
                    epoch_loss += self.get_loss(softmax_result, target)
                
                self.W_in -= lr * (self.W_out.T @ context_error @ self.vocab[center_word].T)
                self.b_in -= lr * (self.W_out.T @ context_error)
                self.W_out -= lr * (context_error @ projection.T)
                self.b_out -= lr * (context_error)
            losses += [epoch_loss]    
            #print('Losses: ', epoch_loss)
        if return_history:
            return losses
            
    def word2index(self, word):
        """
        Returns index of the given word
        """
        return int(np.where(self.vocab[word]==1)[0])
    
    def index2word(self, idx):
        """
        Returns word that corresponds to the given index
        """
        return sorted(self.vocab)[idx]
    
    
    def get_embedding_dict(self):
        """
        Returns dict with words in keys and their embeddings in values
        """
        w2v_dict = {}
        for i, key in enumerate(sorted(self.vocab)):
            w2v_dict[key] = self.W_in[:,i]
        return w2v_dict

In [476]:
sg = skip_gram(text)

In [491]:
sg.fit(epochs=50)

Epoch:  0
Losses:  [[83819.81111759]]
Epoch:  1
Losses:  [[83772.44676828]]
Epoch:  2
Losses:  [[83726.84701367]]
Epoch:  3
Losses:  [[83683.06443065]]
Epoch:  4
Losses:  [[83641.1451391]]
Epoch:  5
Losses:  [[83601.12304729]]
Epoch:  6
Losses:  [[83563.01615982]]
Epoch:  7
Losses:  [[83526.8246292]]
Epoch:  8
Losses:  [[83492.52940001]]
Epoch:  9
Losses:  [[83460.09058463]]
Epoch:  10
Losses:  [[83429.44573484]]
Epoch:  11
Losses:  [[83400.50902427]]
Epoch:  12
Losses:  [[83373.17233965]]
Epoch:  13
Losses:  [[83347.30850347]]
Epoch:  14
Losses:  [[83322.77603316]]
Epoch:  15
Losses:  [[83299.42454207]]
Epoch:  16
Losses:  [[83277.1000792]]
Epoch:  17
Losses:  [[83255.65002297]]
Epoch:  18
Losses:  [[83234.9273172]]
Epoch:  19
Losses:  [[83214.79386868]]
Epoch:  20
Losses:  [[83195.12297507]]
Epoch:  21
Losses:  [[83175.80082337]]
Epoch:  22
Losses:  [[83156.72733324]]
Epoch:  23
Losses:  [[83137.816781]]
Epoch:  24
Losses:  [[83118.99865407]]
Epoch:  25
Losses:  [[83100.21904613]]
Ep

KeyboardInterrupt: 