In [67]:
from io import open
import os
import torch
import torch.nn as nn
import torch.optim as optim
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

In [73]:
def dataset(filename,m_common=3050):
    lines = open(filename, encoding='utf-8').read()
    n_words = []
    sentences = sent_tokenize(lines)
    pro_sentences=[]
    input_data=[]
    output_data=[]
    for sentence in sentences:
        words = word_tokenize(sentence)
        n_words+=words
        words = ['<START>'] + words + ['<END>']
        pro_sentences.append(words) 
    fdist1 = nltk.FreqDist(n_words)
    print(len(fdist1))
    n_words = ['<START>']+[word for (word,freq) in fdist1.most_common(m_common)]+['<UNKNOWN>','<END>']
    n_tokens = len(n_words)
    for sentence in pro_sentences:
        length = len(sentence)
        indata = torch.zeros(length-1,n_tokens)
        outdata = []
        for i in range(length-1):
            indata[i][n_words.index(sentence[i]) if sentence[i] in n_words else n_words.index('<UNKNOWN>')]=1
        for i in range(length-1):
            outdata.append([n_words.index(sentence[i+1]) if sentence[i+1] in n_words else n_words.index('<UNKNOWN>')])
        input_data.append(indata)
        output_data.append(outdata)
    return input_data , output_data , n_words  ,n_tokens 
input_tensors , target_tensors , word_list ,n_tokens = dataset('./shakespeare.txt')

3063


In [74]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.2)
        self.softmax = nn.LogSoftmax()

    def forward(self,input, hidden):
        input_combined = torch.cat((input, hidden), 0)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 0)
        output = self.o2o(output_combined)
        output = self.softmax(output)
        return output.view(1,-1), hidden

    def initHidden(self):
        return torch.zeros(self.hidden_size)


In [75]:
criterion = nn.NLLLoss()
rnn = RNN(n_tokens,200,n_tokens)
rnn = rnn.cuda()
optimizer = optim.Adam(rnn.parameters(),lr =0.0001)
n_epochs=120
n_samples = len(input_tensors)

In [76]:
for epoch in range(n_epochs):
    epoch_loss =0
    for (i,sentence_tensor) in enumerate(input_tensors):
        optimizer.zero_grad()
        sen_length = len(sentence_tensor)
        hidden = rnn.initHidden()
        hidden = hidden.cuda()
        sen_loss =0
        for (j,word_tensor) in enumerate(sentence_tensor):
            #print(word_tensor)
            output,hidden = rnn(word_tensor.cuda(),hidden)
            #print(output)
            l=criterion(output,torch.LongTensor(target_tensors[i][j]).cuda())
            sen_loss+=l
        epoch_loss+= sen_loss/sen_length
        sen_loss.backward()
        optimizer.step()
    epoch_loss = epoch_loss/n_samples
    print("Loss at epoch {0} is ".format(epoch),float(epoch_loss))



Loss at epoch 0 is  6.688699245452881
Loss at epoch 1 is  6.0889058113098145
Loss at epoch 2 is  5.928800582885742
Loss at epoch 3 is  5.781315326690674
Loss at epoch 4 is  5.635859489440918
Loss at epoch 5 is  5.485524654388428
Loss at epoch 6 is  5.32078218460083
Loss at epoch 7 is  5.146284103393555
Loss at epoch 8 is  4.970675468444824
Loss at epoch 9 is  4.793716907501221
Loss at epoch 10 is  4.612541198730469
Loss at epoch 11 is  4.430445671081543
Loss at epoch 12 is  4.249372959136963
Loss at epoch 13 is  4.0687079429626465
Loss at epoch 14 is  3.8952548503875732
Loss at epoch 15 is  3.729579210281372
Loss at epoch 16 is  3.5515048503875732
Loss at epoch 17 is  3.373509645462036
Loss at epoch 18 is  3.2023913860321045
Loss at epoch 19 is  3.0350754261016846
Loss at epoch 20 is  2.8728411197662354
Loss at epoch 21 is  2.7147603034973145
Loss at epoch 22 is  2.5625600814819336
Loss at epoch 23 is  2.4194436073303223
Loss at epoch 24 is  2.287484884262085
Loss at epoch 25 is  2.173

In [84]:
start_token , end_token = torch.zeros(n_tokens),torch.zeros(n_tokens)
start_token[word_list.index('which')]=1
end_token[word_list.index('<END>')]=1
#start_token , end_token = start_token.cuda() , end_token.cuda()
output_sentence=[]
hidden = rnn.initHidden()
#hidden = hidden.cuda()
rnn = rnn.cpu()
while torch.equal(start_token,end_token)==False:
    output,hidden = rnn(start_token,hidden)
    _,index = torch.max(output[0],0)
    output_sentence.append(word_list[index])
    start_token = torch.zeros(n_tokens)
    start_token[index]=1
print(" ".join(output_sentence))

thou art too dear for To possessing , And like enough thou know'st thy estimate , The charter of thy worth gives thee releasing : My bonds in thee are all determinate . <END>


