In [1]:
import torch
import torch.nn as nn

In [2]:
sentences = ['We ate apple', 'They went cafe','She cut pears  ','They went school']

In [3]:
#Giving each unique word their unique indexes

index = {} # stores word and their corresponding indices
inputs = []  # stores indices by sentences
for sentence in sentences:
  indexofsentence = []
  for word in (sentence.split()):
    if word not in index:
      index[word] = len(index)
    
    indexofsentence.append(index[word])

  inputs.append(indexofsentence)

print (index)
print (inputs)

{'We': 0, 'ate': 1, 'apple': 2, 'They': 3, 'went': 4, 'cafe': 5, 'She': 6, 'cut': 7, 'pears': 8, 'school': 9}
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [3, 4, 9]]


In [4]:
prefixes = torch.LongTensor([ind[:-1] for ind in inputs])
labels = torch.LongTensor([ind[-1] for ind in inputs])
prefixes,labels

(tensor([[0, 1],
         [3, 4],
         [6, 7],
         [3, 4]]), tensor([2, 5, 8, 9]))

In [13]:
class NLM(nn.Module):

  # init function -- initializes all the parameters of the network | # forward function
  def __init__ (self,d_embedding, d_hidden, window_size, len_vocab):
    super().__init__()
    self.d_emb = d_embedding
    self.embeddings = nn.Embedding(len_vocab, d_embedding)     
    self.W_hidden = nn.Linear(d_embedding*window_size, d_hidden)    #concatenate embeddings to hidden 
    self.W_out = nn.Linear(d_hidden, len_vocab)        #hidden -->> output probability distribution function

  def forward (self, input):   #each input will be a batch of prefixes at least for this example
    batch_size, window_size = input.size()
    embs = self.embeddings(input)  #4*2*5
    # print ('embedding size:', embs.size())
    #next, we want to concatenate the prefix embeddings together
    concat_embs = embs.view(batch_size, window_size*self.d_emb)
    # print (embs.size())
    # print (embs)
    # print (concat_embs)
    hiddens = self.W_hidden(concat_embs)
    # print ('hidden size:',hiddens.size())
    #finally project hiddent o vocab space 
    outs = self.W_out(hiddens)
    # print ('output size:',outs.size())
    return outs

model = NLM(d_embedding = 5, d_hidden=12, window_size=2, len_vocab=len(index))

num_epochs = 150 
lr = 0.1
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params = model.parameters(), lr =lr)

for i in range(num_epochs):
  logits = model(prefixes)
  loss = loss_fn(logits,labels)
  # print (loss)
  #updating the parameters 
  #step 1 : compute gradient (partial deriv of loss wrt parameters)
  
  loss.backward()
  #step 2: update params using gradient descent
  optimizer.step()
  #step 3 : zero out the gradients -- make sure we dont accumulate
  optimizer.zero_grad()

  print (f'epoch {i}, loss {loss}')

epoch 0, loss 2.2702059745788574
epoch 1, loss 2.044121503829956
epoch 2, loss 1.8511338233947754
epoch 3, loss 1.6780940294265747
epoch 4, loss 1.520195484161377
epoch 5, loss 1.376814603805542
epoch 6, loss 1.2485883235931396
epoch 7, loss 1.135582447052002
epoch 8, loss 1.0367567539215088
epoch 9, loss 0.9503534436225891
epoch 10, loss 0.8745297193527222
epoch 11, loss 0.8077791929244995
epoch 12, loss 0.7490330934524536
epoch 13, loss 0.6975516080856323
epoch 14, loss 0.6527555584907532
epoch 15, loss 0.6140956878662109
epoch 16, loss 0.5809864401817322
epoch 17, loss 0.552798867225647
epoch 18, loss 0.5288877487182617
epoch 19, loss 0.5086283683776855
epoch 20, loss 0.49144554138183594
epoch 21, loss 0.4768317937850952
epoch 22, loss 0.464353084564209
epoch 23, loss 0.4536454677581787
epoch 24, loss 0.4444082975387573
epoch 25, loss 0.43639498949050903
epoch 26, loss 0.4294043183326721
epoch 27, loss 0.4232720732688904
epoch 28, loss 0.4178638458251953
epoch 29, loss 0.41306972503

In [14]:
rev_vocab = dict((idx,word) for (word,idx) in index.items())
weate = prefixes[0].unsqueeze(0)
logits = model(weate)
probs = nn.functional.softmax(logits, dim=1).squeeze()
arg_max = torch.argmax(probs).item()
print ('Given "we ate", the model predicts "%s" with %0.4f probability' % (rev_vocab[arg_max],probs[arg_max]))

Given "we ate", the model predicts "apple" with 0.9931 probability
