In [1]:
import numpy as np

In [11]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import os

### Model

In [75]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        
        self.encoder = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input, hidden):
        input = self.encoder(input.view(1, -1))
        output, hidden = self.gru(input.view(1, 1, -1), hidden)
        output = self.decoder(output.view(1, -1))
        return output, hidden

    def init_hidden(self):
        return Variable(torch.zeros(self.n_layers, 1, self.hidden_size))

### Utils

In [117]:
import unidecode
import string
import random
import time
import math
import torch
from torch.autograd import Variable

all_characters = string.printable
n_characters = len(all_characters)

def read_file(filename):
    """Read file"""
    f = unidecode.unidecode(open(filename).read())
    return f, len(f)

def char_tensor(string):
    """String to tensor"""
    tensor = torch.zeros(len(string)).long()
    for c in range(len(string)):
        tensor[c] = all_characters.index(string[c])
    return Variable(tensor)

def elapsed(start):
    """Get elapsed time"""
    secs = time.time() - start
    mins = math.floor(secs / 60)
    secs -= mins * 60
    return '{}m {}s'.format(mins, secs)

def random_chunk(size):
    start_index = random.randint(0, file_len - size)
    end_index = start_index + size + 1
    return file[start_index:end_index]

def random_training_set(size=200, verbose=False):    
    chunk = random_chunk(size)
    if verbose:
        print(chunk)
    inp = char_tensor(chunk[:-1])
    target = char_tensor(chunk[1:])
    return inp, target

In [13]:
data_path = os.path.abspath('_data/tiny-shakespeare.txt')

In [88]:
file, file_len = read_file(data_path)
print('file_len =', file_len)
print(file[:100])

file_len = 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [89]:
random_chunk(200)

"d learn this lesson, draw thy sword in right.\n\nPRINCE:\nMy gracious father, by your kingly leave,\nI'll draw it as apparent to the crown,\nAnd in that quarrel use it to the death.\n\nCLIFFORD:\nWhy, that is "

In [91]:
random_training_set(50)

(tensor([ 18,  35,  14,  94,  29,  32,  24,  94,  22,  14,  10,  23,
          18,  23,  16,  28,  94,  18,  23,  94,  24,  23,  14,  94,
          32,  24,  27,  13,  75,  96,  96,  51,  53,  44,  49,  38,
          40,  94,  40,  39,  58,  36,  53,  39,  77,  96,  55,  17,
          10,  29]),
 tensor([ 35,  14,  94,  29,  32,  24,  94,  22,  14,  10,  23,  18,
          23,  16,  28,  94,  18,  23,  94,  24,  23,  14,  94,  32,
          24,  27,  13,  75,  96,  96,  51,  53,  44,  49,  38,  40,
          94,  40,  39,  58,  36,  53,  39,  77,  96,  55,  17,  10,
          29,  94]))

In [105]:
def evaluate(prime_str='A', predict_len=100, temperature=0.8):
    hidden = decoder.init_hidden()
    prime_input = char_tensor(prime_str)
    predicted = prime_str

    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = decoder(prime_input[p], hidden)
    inp = prime_input[-1]
    
    for p in range(predict_len):
        output, hidden = decoder(inp, hidden)
        
        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]
        
        # Add predicted character to string and use as next input
        predicted_char = all_characters[top_i]
        predicted += predicted_char
        inp = char_tensor(predicted_char)

    return predicted

In [112]:
def train(inp, target):
    hidden = decoder.init_hidden()
    decoder_optimizer.zero_grad()
    loss = 0

    for c in range(chunk_len):
        output, hidden = decoder(inp[c], hidden)
#         print(output, torch.from_numpy(np.array([target[c].data.numpy()])))
        loss += criterion(output, torch.from_numpy(np.array([target[c].data.numpy()])))

    loss.backward()
    decoder_optimizer.step()

    return loss.data[0] / chunk_len

In [118]:
args = {
    "hidden_size": 50,
    "n_layers": 2,
    "lr": 0.005,
    "n_epochs": 2000,
    "print_every": 100,
    "plot_every": 10,
    "hidden_size": 100,
    "chunk_len": 200,
}

n_epochs = args["n_epochs"]
chunk_len = args["chunk_len"]

decoder = RNN(n_characters, args["hidden_size"], n_characters, args["n_layers"])
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args["lr"])
criterion = nn.CrossEntropyLoss()

start = time.time()
all_losses = []
loss_avg = 0

for epoch in range(1, n_epochs + 1):
    loss = train(*random_training_set(chunk_len))
    loss_avg += loss

    if epoch % args["print_every"] == 0:
        print('[%s (%d %d%%) %.4f]' % (elapsed(start), epoch, epoch / n_epochs * 100, loss))
        print(evaluate('Wh', 100), '\n')

    if epoch % args["plot_every"] == 0:
        all_losses.append(loss_avg / args["plot_every"])
        loss_avg = 0

  


[0m 12.960622549057007s (100 5%) 2.2937]
Wh, gin aut gerfert Lor,.

 lis ligh reto the soneen.

IS
DTESEELNENIESSE::
Hort shhe enout to thard b 

[0m 26.145856618881226s (200 10%) 2.1413]
Wh's and of hous in my krille ante for rone, the ofrave and and sidist nord shorlene; him cret me beng 

[0m 39.15329909324646s (300 15%) 2.2238]
Wh nut gathy for a the as busirding and mis ar you sests sand!

LULUTENTARD IINTIO Dento the an nony b 

[0m 52.09091377258301s (400 20%) 2.0840]
Whous theere'd
We thime the hear,
Whor by his ther hares
Wither thereing she fiir, im-Buth gitt
Thaf m 

[1m 5.035867691040039s (500 25%) 1.8629]
What the liettlerd,
And thou shis caignomy Il beards! 'tet yous buders.

MENCENTIO:
And thy you loft o 

[1m 18.007661819458008s (600 30%) 2.0541]
Wh encerd'd as the rujdor myre gruntle:
Wherjod hin now Thy my spollal.

KENANIUSS:
Wourth faul god; h 

[1m 31.101768732070923s (700 35%) 1.9145]
Where and mans;
Brate thave in besif not lade the his nace
Be for be erp is and 