In [238]:
import torch
from torch import nn

import numpy as np

# raw_text = "A Town Mouse once visited a relative who lived in the country. \
# For lunch the Country Mouse served wheat stalks, roots, and \
# acorns, with a dash of cold water for drink. The Town Mouse ate \
# very sparingly, nibbling a little of this and a little of that, \
# and by her manner making it very plain that she ate the simple \
# food only to be polite."

with open("alice.txt", "r") as f:
    raw_text = f.readlines()
    
raw_text = [line.strip() for line in raw_text]
raw_text = " ".join(raw_text).lower()

# charsdot = set(""abcdefghijklmnopqrstuvwxyz. "")
# text = ""
# for char in raw_text.lower():
#     if char in charsdot:
#         text+= char

text = raw_text.split('.')

print(text)

chars = set(raw_text)

# create dictionary, map integer to character
int2char = dict(enumerate(chars))

# create dictionary, map character to integer
char2int = {char: ind for ind, char in int2char.items()}


["alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do:  once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, `and what is the use of a book,' thought alice `without pictures or conversation?'  so she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a white rabbit with pink eyes ran close by her", "  there was nothing so very remarkable in that; nor did alice think it so very much out of the way to hear the rabbit say to itself, `oh dear!  oh dear!  i shall be late!'  (when she thought it over afterwards, it occurred to her that she ought to have wondered at this, but at the time it all seemed quite natural); but when the rabbit actually took a watch out of its waistcoat- pocket, and looked at

In [239]:
# use largest sentence as maxlen
maxlen = 300#len(max(text, key=len))

# pad each sentence to maxlen with ' '
for i in range(len(text)):
    while len(text[i])<maxlen:
        text[i] += ' '

In [240]:
#lists to hold input and target sequences
input_seq = []
target_seq = []

for i in range(len(text)):
    # remove last char of each input
    input_seq.append(text[i][:maxlen-1])
    
    # remove first char of each target
    target_seq.append(text[i][1:maxlen])
    print("Input Sequence: {}\nTarget Sequence: {}".format(input_seq[i], target_seq[i]))

Input Sequence: alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do:  once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, `and what is the use of a book,' thought alice `without pictures or conversatio
Target Sequence: lice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do:  once or twice she had peeped into the book her sister was reading, but it had no pictures or conversations in it, `and what is the use of a book,' thought alice `without pictures or conversation
Input Sequence:   there was nothing so very remarkable in that; nor did alice think it so very much out of the way to hear the rabbit say to itself, `oh dear!  oh dear!  i shall be late!'  (when she thought it over afterwards, it occurred to her that she ought to have wondered at this, but at the time it all seeme
Target Sequence:  there was nothing so very remarka

In [241]:
for i in range(len(text)):
    input_seq[i] = [char2int[character] for character in input_seq[i]]
    target_seq[i] = [char2int[character] for character in target_seq[i]]

In [242]:
dict_size = len(char2int)
seq_len = maxlen - 1
batch_size = len(text)

def one_hot_encode(sequence, dict_size, seq_len, batch_size):
    # Creating a multi-dimensional array of zeros with the desired output shape
    features = np.zeros((batch_size, seq_len, dict_size), dtype=np.float32)
    
    # Replacing the 0 at the relevant character index with a 1 to represent that character
    for i in range(batch_size):
        for u in range(seq_len):
            features[i, u, sequence[i][u]] = 1
    return features

In [243]:
input_seq = one_hot_encode(input_seq, dict_size, seq_len, batch_size)

In [244]:
input_seq = torch.from_numpy(input_seq)
target_seq = torch.Tensor(target_seq)

In [245]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, CPU used")

GPU not available, CPU used


In [246]:
class Model(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super(Model, self).__init__()
        
        #params
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        #Degine layers
        #  RNN Layer
        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)
        #  Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_size)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        #init hidden state for first input
        hidden = self.init_hidden(batch_size)
        
        out, hidden = self.rnn(x,hidden)
        
        out = out.contiguous().view(-1, self.hidden_dim)
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        #generates first hidden state of zeros for forward pass
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        return hidden

In [247]:
model = Model(input_size=dict_size, output_size=dict_size, hidden_dim=12, n_layers=1)
model.to(device)

n_epochs = 500
lr=0.01

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [248]:
#training
for epoch in range(1, n_epochs + 1):
    optimizer.zero_grad()
    input_seq.to(device)
    output, hidden = model(input_seq)
    loss = criterion(output, target_seq.view(-1).long())
    loss.backward()
    optimizer.step()
    
    if epoch%100 == 0:
        print('Epoch: {}/{}.............'.format(epoch, n_epochs), end=' ')
        print("Loss: {:.4f}".format(loss.item()))

Epoch: 100/500............. Loss: 1.8671
Epoch: 200/500............. Loss: 1.3932
Epoch: 300/500............. Loss: 1.2511
Epoch: 400/500............. Loss: 1.1877
Epoch: 500/500............. Loss: 1.1501


In [249]:
# This function takes in the model and character as arguments and returns the next character prediction and hidden state
def predict(model, character):
    # One-hot encoding our input to fit into the model
    character = np.array([[char2int[c] for c in character]])
    character = one_hot_encode(character, dict_size, character.shape[1], 1)
    character = torch.from_numpy(character)
    character.to(device)
    
    out, hidden = model(character)

    prob = nn.functional.softmax(out[-1], dim=0).data
    # Taking the class with the highest probability score from the output
    char_ind = torch.max(prob, dim=0)[1].item()

    return int2char[char_ind], hidden


In [250]:
# This function takes the desired output length and input characters as arguments, returning the produced sentence
def sample(model, out_len, start='hey'):
    model.eval() # eval mode
    start = start.lower()
    # First off, run through the starting characters
    chars = [ch for ch in start]
    size = out_len - len(chars)
    # Now pass in the previous characters and get a new one
    for ii in range(size):
        char, h = predict(model, chars)
        chars.append(char)
    
    return ''.join(chars)

In [261]:
sample(model, 300, 'e')

'er all the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the t'