In [30]:
import numpy as np
import string
from nltk.corpus import stopwords

In [31]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x/e_x.sum()

In [38]:
class word2vec(object): 
    def __init__(self): 
        self.N = 10
        self.X_train = [] 
        self.y_train = [] 
        self.window_size = 2
        self.alpha = 0.001
        self.words = [] 
        self.word_index = {} 
    
    def initialize(self,V,data): 
        self.V = V 
        self.W = np.random.uniform(-0.8, 0.8, (self.V, self.N)) 
        self.W1 = np.random.uniform(-0.8, 0.8, (self.N, self.V)) 
           
        self.words = data 
        for i in range(len(data)): 
            self.word_index[data[i]] = i 
    
    def feed_forward(self,X): 
        self.h = np.dot(self.W.T,X).reshape(self.N,1) 
        self.u = np.dot(self.W1.T,self.h) 
        #print(self.u) 
        self.y = softmax(self.u)   
        return self.y 
    
    def backpropagate(self,x,t): 
        e = self.y - np.asarray(t).reshape(self.V,1) 
        # e.shape is V x 1 
        dLdW1 = np.dot(self.h,e.T) 
        X = np.array(x).reshape(self.V,1) 
        dLdW = np.dot(X, np.dot(self.W1,e).T) 
        self.W1 = self.W1 - self.alpha*dLdW1 
        self.W = self.W - self.alpha*dLdW 
        
    def train(self,epochs): 
        for x in range(1,epochs):         
            self.loss = 0
            for j in range(len(self.X_train)): 
                self.feed_forward(self.X_train[j]) 
                self.backpropagate(self.X_train[j],self.y_train[j]) 
                C = 0
                for m in range(self.V): 
                    if(self.y_train[j][m]): 
                        self.loss += -1*self.u[m][0] 
                        C += 1
                self.loss += C*np.log(np.sum(np.exp(self.u))) 
            print("epoch ",x, " loss = ",self.loss) 
            self.alpha *= 1/( (1+self.alpha*x) )
    
    def predict(self,word,number_of_predictions): 
        if word in self.words: 
            index = self.word_index[word] 
            X = [0 for i in range(self.V)] 
            X[index] = 1
            prediction = self.feed_forward(X) 
            output = {} 
            for i in range(self.V): 
                output[prediction[i][0]] = i 
               
            top_context_words = [] 
            for k in sorted(output,reverse=True): 
                top_context_words.append(self.words[output[k]]) 
                if(len(top_context_words)>=number_of_predictions): 
                    break
       
            return top_context_words 
        else: 
            print("Word not found in dicitonary") 

In [43]:
def preprocessing(corpus):
    stop_words = set(stopwords.words('english'))     
    training_data = [] 
    sentences = corpus.split(".") 
    for i in range(len(sentences)): 
        sentences[i] = sentences[i].strip() 
        sentence = sentences[i].split() 
        x = [word.strip(string.punctuation) for word in sentence if word not in stop_words] 
        x = [word.lower() for word in x] 
        training_data.append(x) 
    return training_data 
    
def prepare_data_for_training(sentences,w2v): 
    data = {} 
    for sentence in sentences: 
        for word in sentence: 
            if word not in data: 
                data[word] = 1
            else: 
                data[word] += 1
    V = len(data) 
    data = sorted(list(data.keys())) 
    vocab = {} 
    for i in range(len(data)): 
        vocab[data[i]] = i 

      #for i in range(len(words)): 
    for sentence in sentences: 
        for i in range(len(sentence)): 
            center_word = [0 for x in range(V)] 
            center_word[vocab[sentence[i]]] = 1
            context = [0 for x in range(V)] 

            for j in range(i-w2v.window_size,i+w2v.window_size): 
                if i!=j and j>=0 and j<len(sentence): 
                    context[vocab[sentence[j]]] += 1
            w2v.X_train.append(center_word) 
            w2v.y_train.append(context) 
    w2v.initialize(V,data) 

    return w2v.X_train,w2v.y_train 
    

In [50]:
corpus = "" 
corpus += "The earth revolves around the sun. The moon revolves around the earth"
epochs = 1000
  
training_data = preprocessing(corpus) 
w2v = word2vec() 
  
prepare_data_for_training(training_data,w2v) 
w2v.train(epochs)  
  
print(w2v.predict("revolves",3))   

epoch  1  loss =  41.021115496969564
epoch  2  loss =  40.96636333270881
epoch  3  loss =  40.911935789836946
epoch  4  loss =  40.857883185885164
epoch  5  loss =  40.80425420845885
epoch  6  loss =  40.751095655252215
epoch  7  loss =  40.698452197352516
epoch  8  loss =  40.64636616867452
epoch  9  loss =  40.5948773837509
epoch  10  loss =  40.54402298545418
epoch  11  loss =  40.49383732356343
epoch  12  loss =  40.44435186443909
epoch  13  loss =  40.39559513144959
epoch  14  loss =  40.34759267522644
epoch  15  loss =  40.30036707232042
epoch  16  loss =  40.253937950404186
epoch  17  loss =  40.20832203782116
epoch  18  loss =  40.16353323502035
epoch  19  loss =  40.11958270523995
epoch  20  loss =  40.07647898170714
epoch  21  loss =  40.03422808859878
epoch  22  loss =  39.99283367305098
epoch  23  loss =  39.95229714560437
epoch  24  loss =  39.912617826616724
epoch  25  loss =  39.87379309635502
epoch  26  loss =  39.83581854668418
epoch  27  loss =  39.79868813249204
epoc

epoch  282  loss =  38.00926323736726
epoch  283  loss =  38.00824169056581
epoch  284  loss =  38.00722728877843
epoch  285  loss =  38.0062199579078
epoch  286  loss =  38.0052196248693
epoch  287  loss =  38.00422621757401
epoch  288  loss =  38.003239664911916
epoch  289  loss =  38.00225989673551
epoch  290  loss =  38.00128684384363
epoch  291  loss =  38.000320437965726
epoch  292  loss =  37.999360611746425
epoch  293  loss =  37.99840729873028
epoch  294  loss =  37.99746043334697
epoch  295  loss =  37.996519950896726
epoch  296  loss =  37.995585787535994
epoch  297  loss =  37.99465788026352
epoch  298  loss =  37.99373616690646
epoch  299  loss =  37.99282058610707
epoch  300  loss =  37.9919110773094
epoch  301  loss =  37.99100758074636
epoch  302  loss =  37.99011003742702
epoch  303  loss =  37.989218389124105
epoch  304  loss =  37.9883325783618
epoch  305  loss =  37.98745254840377
epoch  306  loss =  37.98657824324137
epoch  307  loss =  37.98570960758207
epoch  308

epoch  531  loss =  37.873503813157186
epoch  532  loss =  37.87321491315903
epoch  533  loss =  37.87292709991303
epoch  534  loss =  37.87264036731292
epoch  535  loss =  37.87235470929795
epoch  536  loss =  37.87207011985251
epoch  537  loss =  37.87178659300571
epoch  538  loss =  37.871504122830935
epoch  539  loss =  37.87122270344548
epoch  540  loss =  37.87094232901017
epoch  541  loss =  37.87066299372886
epoch  542  loss =  37.870384691848095
epoch  543  loss =  37.87010741765674
epoch  544  loss =  37.86983116548559
epoch  545  loss =  37.869555929706934
epoch  546  loss =  37.86928170473423
epoch  547  loss =  37.86900848502171
epoch  548  loss =  37.868736265064
epoch  549  loss =  37.868465039395815
epoch  550  loss =  37.86819480259145
epoch  551  loss =  37.86792554926461
epoch  552  loss =  37.86765727406794
epoch  553  loss =  37.86738997169267
epoch  554  loss =  37.867123636868286
epoch  555  loss =  37.866858264362236
epoch  556  loss =  37.86659384897958
epoch  

epoch  781  loss =  37.824356535951104
epoch  782  loss =  37.824223217805866
epoch  783  loss =  37.82409024129237
epoch  784  loss =  37.82395760510058
epoch  785  loss =  37.823825307927116
epoch  786  loss =  37.823693348475224
epoch  787  loss =  37.823561725454766
epoch  788  loss =  37.82343043758216
epoch  789  loss =  37.82329948358035
epoch  790  loss =  37.823168862178704
epoch  791  loss =  37.82303857211306
epoch  792  loss =  37.82290861212569
epoch  793  loss =  37.822778980965126
epoch  794  loss =  37.82264967738625
epoch  795  loss =  37.82252070015027
epoch  796  loss =  37.82239204802453
epoch  797  loss =  37.822263719782654
epoch  798  loss =  37.82213571420434
epoch  799  loss =  37.82200803007547
epoch  800  loss =  37.821880666187994
epoch  801  loss =  37.82175362133985
epoch  802  loss =  37.82162689433503
epoch  803  loss =  37.821500483983485
epoch  804  loss =  37.821374389101045
epoch  805  loss =  37.82124860850952
epoch  806  loss =  37.821123141036466
