In [None]:
import os
import sys; sys.path.append('../lib')
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

from data import Text
from gradients import compare_gradients_recurrent
from history import TrainHistoryRecurrent
from recurrent_network import RecurrentNetwork

# Constants

In [None]:
DATA_DIR = '../data'
PICKLE_DIR = '../pickle'
FIGURE_DIR = '../figures'

HYPERPARAMS = {
    'hidden_state_size': 100,
    'sequence_length': 25,
    'eta': 0.1,
    'updates': 100000
}

# Load data

In [None]:
text = Text.from_file(DATA_DIR, 'goblet_book.txt')

# Compare analytical and numerical gradient

In [None]:
ds = text.sequence(beg=0,
                   end=HYPERPARAMS['sequence_length'],
                   rep='indices_one_hot',
                   labeled=True)

network_constructor = partial(
    RecurrentNetwork,
    input_size=text.num_characters,
    hidden_state_size=5)

compare_gradients_recurrent(network_constructor,
                            ds,
                            h=1e-4,
                            random_seed=0)

# Train network

In [None]:
network = RecurrentNetwork(
    input_size=text.num_characters,
    hidden_state_size=HYPERPARAMS['hidden_state_size'],
    random_seed=0)

history = network.train(
    text,
    seq_length=HYPERPARAMS['sequence_length'],
    eta=HYPERPARAMS['eta'],
    n_updates=HYPERPARAMS['updates'],
    verbose=True,
    verbose_show_loss=False,
    verbose_show_samples=True)

history.save(PICKLE_DIR, postfix='rnn_goblet')

In [None]:
history = TrainHistoryRecurrent.load(PICKLE_DIR, postfix='rnn_goblet')

In [None]:
history.visualize()

plt.savefig(os.path.join(FIGURE_DIR, 'rnn_loss.svg'))

In [None]:
np.random.seed(0)

network = history.final_network

sequence = network.synthesize(length=1000)

sequence = text.get_characters(sequence, one_hot=True)

print(sequence)