In [1]:
import torch
import torch.nn as nn
import string
import random
import time
import math
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker


In [2]:
lorem_ipsum = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Duis quis eros tincidunt, dapibus mi in, tincidunt sapien. Vivamus massa nunc, finibus condimentum mattis id, iaculis a nibh. Aenean quis convallis diam. In hac habitasse platea dictumst. Quisque lacinia convallis."

all_characters = string.printable
n_characters = len(all_characters)

# Functions to convert characters to and from tensors

def char_to_tensor(char):
    tensor = torch.zeros(1, n_characters)
    tensor[0][all_characters.index(char)] = 1
    return tensor

def line_to_tensor(line):
    tensor = torch.zeros(len(line), 1, n_characters)
    for idx, char in enumerate(line):
        tensor[idx][0][all_characters.index(char)] = 1
    return tensor


In [3]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.o2o = nn.Linear(hidden_size + output_size, output_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        input_combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(input_combined)
        output = self.i2o(input_combined)
        output_combined = torch.cat((hidden, output), 1)
        output = self.o2o(output_combined)
        output = self.dropout(output)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)


In [4]:
def random_training_set():    
    start_index = random.randint(0, len(lorem_ipsum) - 2)
    end_index = random.randint(start_index, len(lorem_ipsum) - 1)
    line = lorem_ipsum[start_index:end_index]
    input_line_tensor = line_to_tensor(line)
    target_line_tensor = line_to_tensor(lorem_ipsum[start_index+1:end_index+1])
    return input_line_tensor, target_line_tensor


# Function to perform one training step

def train(input_line_tensor, target_line_tensor):
    target_line_tensor.unsqueeze_(-1)
    hidden = rnn.initHidden()
    rnn.zero_grad()
    loss = 0
    for i in range(input_line_tensor.size(0)):
        output, hidden = rnn(input_line_tensor[i], hidden)
        l = criterion(output, target_line_tensor[i])
        loss += l
    loss.backward()
    optimizer.step()
    return output, loss.item() / input_line_tensor.size(0)


NameError: name 'rnn' is not defined

In [None]:
def generate(decoder, prime_str='A', predict_len=100, temperature=0.8):
    hidden = decoder.initHidden()
    prime_input = line_to_tensor(prime_str)
    predicted = prime_str

    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[p], hidden)
    inp = prime_input[-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)

        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Add predicted character to string and use as next input
        predicted_char = all_characters[top_i]
        predicted += predicted_char
        inp = char_to_tensor(predicted_char)

    return predicted

# Training process
n_epochs = 5000
print_every = 500
plot_every = 10
hidden_size = 100
n_layers = 1
lr = 0.005

rnn = RNN(n_characters, hidden_size, n_characters)
start = time.time()
all_losses = []
current_loss = 0

for epoch in range(1, n_epochs + 1):
    output, loss = train(*random_training_set())
    current_loss += loss

    if epoch % print_every == 0:
        print(f'time: {time.time() - start} | epoch: {epoch} ({epoch / n_epochs * 100}%) | loss: {loss}')
        print(generate(rnn, 'Li', 200), '\n')

    if epoch % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0


In [None]:
plt.figure()
plt.plot(all_losses)
