In [1]:
## https://www.geeksforgeeks.org/implement-your-own-word2vecskip-gram-model-in-python/
import numpy as np 
import string 
from nltk.corpus import stopwords

In [2]:
class word2vec(object): 
	def __init__(self): 
		self.N = 10
		self.X_train = [] 
		self.y_train = [] 
		self.window_size = 2
		self.alpha = 0.001
		self.words = [] 
		self.word_index = {} 

	def initialize(self,V,data): 
		self.V = V 
		self.W = np.random.uniform(-0.8, 0.8, (self.V, self.N)) 
		self.W1 = np.random.uniform(-0.8, 0.8, (self.N, self.V)) 
		
		self.words = data 
		for i in range(len(data)): 
			self.word_index[data[i]] = i 

	
	def feed_forward(self,X): 
		self.h = np.dot(self.W.T,X).reshape(self.N,1) 
		self.u = np.dot(self.W1.T,self.h) 
		#print(self.u) 
		self.y = softmax(self.u) 
		return self.y 
		
	def backpropagate(self,x,t): 
		e = self.y - np.asarray(t).reshape(self.V,1) 
		# e.shape is V x 1 
		dLdW1 = np.dot(self.h,e.T) 
		X = np.array(x).reshape(self.V,1) 
		dLdW = np.dot(X, np.dot(self.W1,e).T) 
		self.W1 -=  self.alpha*dLdW1 
		self.W -= self.alpha*dLdW 
		
	def train(self,epochs): 
		for x in range(1,epochs+1):		 
			self.loss = 0
			for j in range(len(self.X_train)): 
				self.feed_forward(self.X_train[j]) 
				self.backpropagate(self.X_train[j],self.y_train[j]) 
				C = 0
				for m in range(self.V): 
					if(self.y_train[j][m]): 
						self.loss += -1*self.u[m][0] 
						C += 1
				self.loss += C*np.log(np.sum(np.exp(self.u))) 
			print("epoch ",x, " loss = ",self.loss) 
			self.alpha *= 1/( (1+self.alpha*x) ) 
			
	def predict(self,word,number_of_predictions): 
		if word in self.words: 
			index = self.word_index[word] 
			X = [0 for i in range(self.V)] 
			X[index] = 1
			prediction = self.feed_forward(X) 
			output = {} 
			for i in range(self.V): 
				output[prediction[i][0]] = i 
			
			top_context_words = [] 
			for k in sorted(output,reverse=True): 
				top_context_words.append(self.words[output[k]]) 
				if(len(top_context_words)>=number_of_predictions): 
					break
	
			return top_context_words 
		else: 
			print("Word not found in dicitonary") 

In [3]:
def softmax(x): 
	"""Compute softmax values for each sets of scores in x."""
	e_x = np.exp(x - np.max(x)) 
	return e_x / e_x.sum() 

In [4]:
def preprocessing(corpus): 
	stop_words = set(stopwords.words('english'))	 
	training_data = [] 
	sentences = corpus.split(".") 
	for i in range(len(sentences)): 
		sentences[i] = sentences[i].strip() 
		sentence = sentences[i].split() 
		x = [word.strip(string.punctuation) for word in sentence if word not in stop_words] 
		x = [word.lower() for word in x] 
		training_data.append(x) 
	return training_data 

In [5]:
def prepare_data_for_training(sentences,w2v): 
	data = {} 
	for sentence in sentences: 
		for word in sentence: 
			if word not in data: 
				data[word] = 1
			else: 
				data[word] += 1
	V = len(data) 
	data = sorted(list(data.keys())) 
	vocab = {} 
	for i in range(len(data)): 
		vocab[data[i]] = i 
	
	for sentence in sentences: 
		for i in range(len(sentence)): 
			center_word = [0 for x in range(V)] 
			center_word[vocab[sentence[i]]] = 1
			context = [0 for x in range(V)] 
			
			for j in range(i-w2v.window_size,i+w2v.window_size): 
				if i!=j and j>=0 and j<len(sentence): 
					context[vocab[sentence[j]]] += 1
			w2v.X_train.append(center_word) 
			w2v.y_train.append(context) 
	w2v.initialize(V,data) 

	return w2v.X_train,w2v.y_train 

In [6]:
corpus = "The earth revolves around the sun. The moon revolves around the earth"
epochs = 1000

In [7]:
training_data = preprocessing(corpus) 
training_data

[['the', 'earth', 'revolves', 'around', 'sun'],
 ['the', 'moon', 'revolves', 'around', 'earth']]

In [8]:
w2v = word2vec() 
prepare_data_for_training(training_data,w2v) 

([[0, 0, 0, 0, 0, 1],
  [0, 1, 0, 0, 0, 0],
  [0, 0, 0, 1, 0, 0],
  [1, 0, 0, 0, 0, 0],
  [0, 0, 0, 0, 1, 0],
  [0, 0, 0, 0, 0, 1],
  [0, 0, 1, 0, 0, 0],
  [0, 0, 0, 1, 0, 0],
  [1, 0, 0, 0, 0, 0],
  [0, 1, 0, 0, 0, 0]],
 [[0, 1, 0, 0, 0, 0],
  [0, 0, 0, 1, 0, 1],
  [1, 1, 0, 0, 0, 1],
  [0, 1, 0, 1, 1, 0],
  [1, 0, 0, 1, 0, 0],
  [0, 0, 1, 0, 0, 0],
  [0, 0, 0, 1, 0, 1],
  [1, 0, 1, 0, 0, 1],
  [0, 1, 1, 1, 0, 0],
  [1, 0, 0, 1, 0, 0]])

In [9]:
w2v.train(epochs) 

epoch  1  loss =  41.0211647242128
epoch  2  loss =  40.960421075990986
epoch  3  loss =  40.9000558289789
epoch  4  loss =  40.840124302842874
epoch  5  loss =  40.780679944717406
epoch  6  loss =  40.721774046689255
epoch  7  loss =  40.663455489636995
epoch  8  loss =  40.60577051645015
epoch  9  loss =  40.548762536966954
epoch  10  loss =  40.49247196624824
epoch  11  loss =  40.43693609707445
epoch  12  loss =  40.38218900683914
epoch  13  loss =  40.32826149833563
epoch  14  loss =  40.275181073316816
epoch  15  loss =  40.22297193716351
epoch  16  loss =  40.17165503253815
epoch  17  loss =  40.12124809953262
epoch  18  loss =  40.071765759546395
epoch  19  loss =  40.02321961994924
epoch  20  loss =  39.97561839649061
epoch  21  loss =  39.92896805040506
epoch  22  loss =  39.883271937220904
epoch  23  loss =  39.838530964398416
epoch  24  loss =  39.79474375509135
epoch  25  loss =  39.75190681553031
epoch  26  loss =  39.710014703759
epoch  27  loss =  39.669060197701945
epo

epoch  237  loss =  37.76041834427195
epoch  238  loss =  37.75883483047827
epoch  239  loss =  37.75726438678846
epoch  240  loss =  37.755706853861874
epoch  241  loss =  37.75416207490815
epoch  242  loss =  37.75262989563705
epoch  243  loss =  37.75111016420933
epoch  244  loss =  37.74960273118883
epoch  245  loss =  37.74810744949574
epoch  246  loss =  37.7466241743607
epoch  247  loss =  37.74515276328031
epoch  248  loss =  37.74369307597324
epoch  249  loss =  37.742244974337744
epoch  250  loss =  37.74080832240983
epoch  251  loss =  37.739382986322575
epoch  252  loss =  37.737968834266205
epoch  253  loss =  37.73656573644925
epoch  254  loss =  37.73517356506032
epoch  255  loss =  37.73379219423105
epoch  256  loss =  37.73242149999954
epoch  257  loss =  37.73106136027488
epoch  258  loss =  37.72971165480226
epoch  259  loss =  37.72837226512896
epoch  260  loss =  37.72704307457111
epoch  261  loss =  37.72572396818103
epoch  262  loss =  37.72441483271539
epoch  26

epoch  487  loss =  37.56615498228638
epoch  488  loss =  37.56577717765567
epoch  489  loss =  37.56540092108792
epoch  490  loss =  37.565026203113725
epoch  491  loss =  37.5646530143405
epoch  492  loss =  37.56428134545181
epoch  493  loss =  37.56391118720652
epoch  494  loss =  37.563542530438035
epoch  495  loss =  37.56317536605361
epoch  496  loss =  37.56280968503356
epoch  497  loss =  37.56244547843053
epoch  498  loss =  37.56208273736884
epoch  499  loss =  37.56172145304365
epoch  500  loss =  37.5613616167204
epoch  501  loss =  37.56100321973394
epoch  502  loss =  37.56064625348805
epoch  503  loss =  37.560290709454584
epoch  504  loss =  37.55993657917288
epoch  505  loss =  37.55958385424907
epoch  506  loss =  37.559232526355466
epoch  507  loss =  37.558882587229846
epoch  508  loss =  37.558534028674835
epoch  509  loss =  37.55818684255736
epoch  510  loss =  37.557841020807814
epoch  511  loss =  37.55749655541972
epoch  512  loss =  37.557153438448836
epoch 

epoch  704  loss =  37.509372353186436
epoch  705  loss =  37.50919179450815
epoch  706  loss =  37.509011748886316
epoch  707  loss =  37.50883221414015
epoch  708  loss =  37.508653188101206
epoch  709  loss =  37.50847466861323
epoch  710  loss =  37.508296653532184
epoch  711  loss =  37.50811914072604
epoch  712  loss =  37.50794212807478
epoch  713  loss =  37.507765613470305
epoch  714  loss =  37.507589594816295
epoch  715  loss =  37.50741407002816
epoch  716  loss =  37.507239037033
epoch  717  loss =  37.507064493769455
epoch  718  loss =  37.50689043818766
epoch  719  loss =  37.50671686824917
epoch  720  loss =  37.50654378192686
epoch  721  loss =  37.50637117720485
epoch  722  loss =  37.50619905207844
epoch  723  loss =  37.506027404554054
epoch  724  loss =  37.505856232649094
epoch  725  loss =  37.50568553439194
epoch  726  loss =  37.5055153078218
epoch  727  loss =  37.50534555098873
epoch  728  loss =  37.50517626195347
epoch  729  loss =  37.5050074387874
epoch  

epoch  932  loss =  37.478257371273
epoch  933  loss =  37.478154491324226
epoch  934  loss =  37.47805183234494
epoch  935  loss =  37.477949393624534
epoch  936  loss =  37.47784717445544
epoch  937  loss =  37.477745174133105
epoch  938  loss =  37.47764339195601
epoch  939  loss =  37.477541827225615
epoch  940  loss =  37.47744047924636
epoch  941  loss =  37.477339347325625
epoch  942  loss =  37.47723843077377
epoch  943  loss =  37.477137728904054
epoch  944  loss =  37.47703724103267
epoch  945  loss =  37.47693696647872
epoch  946  loss =  37.47683690456414
epoch  947  loss =  37.47673705461382
epoch  948  loss =  37.476637415955395
epoch  949  loss =  37.476537987919436
epoch  950  loss =  37.47643876983928
epoch  951  loss =  37.47633976105108
epoch  952  loss =  37.4762409608938
epoch  953  loss =  37.47614236870918
epoch  954  loss =  37.47604398384171
epoch  955  loss =  37.475945805638645
epoch  956  loss =  37.47584783344995
epoch  957  loss =  37.47575006662836
epoch 

In [10]:
print( w2v.predict("around",3), '\n')
print( w2v.predict("around",5), '\n')
print( w2v.predict("sun",3), '\n')
print( w2v.predict("earth",3), '\n')
print( w2v.predict("jupiter",3))

['earth', 'the', 'sun'] 

['earth', 'the', 'sun', 'revolves', 'around'] 

['revolves', 'sun', 'the'] 

['earth', 'around', 'the'] 

Word not found in dicitonary
None
