# Anna KaRNNa

In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.

This network is based off of Andrej Karpathy's [post on RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) and [implementation in Torch](https://github.com/karpathy/char-rnn). Also, some information [here at r2rt](http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html) and from [Sherjil Ozair](https://github.com/sherjilozair/char-rnn-tensorflow) on GitHub. Below is the general architecture of the character-wise RNN.

<img src="assets/charseq.jpeg" width="500">

In [1]:
import time
from collections import namedtuple

import numpy as np
import tensorflow as tf

First we'll load the text file and convert it into integers for our network to use.

In [2]:
with open('anna.txt', 'r') as f:
    text=f.read()
vocab = set(text)
vocab_to_int = {c: i for i, c in enumerate(vocab)}
int_to_vocab = dict(enumerate(vocab))
chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)

In [3]:
text[:100]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverythin'

In [4]:
chars[:100]

array([81,  7, 47, 54, 43,  0,  1, 58, 60, 53, 53, 53, 26, 47, 54, 54, 42,
       58, 61, 47, 49, 52, 66, 52,  0, 67, 58, 47,  1,  0, 58, 47, 66, 66,
       58, 47, 66, 52, 57,  0, 40, 58,  0, 68,  0,  1, 42, 58, 27, 23,  7,
       47, 54, 54, 42, 58, 61, 47, 49, 52, 66, 42, 58, 52, 67, 58, 27, 23,
        7, 47, 54, 54, 42, 58, 52, 23, 58, 52, 43, 67, 58, 50, 80, 23, 53,
       80, 47, 42, 33, 53, 53, 11, 68,  0,  1, 42, 43,  7, 52, 23], dtype=int32)

Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.

Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.

The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the `split_frac` keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set.

In [5]:
def split_data(chars, batch_size, num_steps, split_frac=0.9):
    """ 
    Split character data into training and validation sets, inputs and targets for each set.
    
    Arguments
    ---------
    chars: character array
    batch_size: Size of examples in each of batch
    num_steps: Number of sequence steps to keep in the input and pass to the network
    split_frac: Fraction of batches to keep in the training set
    
    
    Returns train_x, train_y, val_x, val_y
    """
    
    slice_size = batch_size * num_steps
    n_batches = int(len(chars) / slice_size)
    
    # Drop the last few characters to make only full batches
    x = chars[: n_batches*slice_size]
    y = chars[1: n_batches*slice_size + 1]
    
    # Split the data into batch_size slices, then stack them into a 2D matrix 
    x = np.stack(np.split(x, batch_size))
    y = np.stack(np.split(y, batch_size))
    
    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps
    
    # Split into training and validation sets, keep the virst split_frac batches for training
    split_idx = int(n_batches*split_frac)
    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
    
    return train_x, train_y, val_x, val_y

In [6]:
train_x, train_y, val_x, val_y = split_data(chars, 10, 200)

In [7]:
train_x.shape

(10, 178400)

In [8]:
train_x[:,:10]

array([[81,  7, 47, 54, 43,  0,  1, 58, 60, 53],
       [ 3, 23, 71, 58,  7,  0, 58, 49, 50, 68],
       [58, 59, 47, 43, 59,  7, 52, 23, 29, 58],
       [50, 43,  7,  0,  1, 58, 80, 50, 27, 66],
       [58, 43,  7,  0, 58, 66, 47, 23, 71, 41],
       [58, 78,  7,  1, 50, 27, 29,  7, 58, 66],
       [43, 58, 43, 50, 53, 71, 50, 33, 53, 53],
       [50, 58,  7,  0,  1, 67,  0, 66, 61, 25],
       [ 7, 47, 43, 58, 52, 67, 58, 43,  7,  0],
       [ 0,  1, 67,  0, 66, 61, 58, 47, 23, 71]], dtype=int32)

I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size `batch_size X num_steps`. For example, if we want our network to train on a sequence of 100 characters, `num_steps = 100`. For the next batch, we'll shift this window the next sequence of `num_steps` characters. In this way we can feed batches to the network and the cell states will continue through on each batch.

In [9]:
def get_batch(arrs, num_steps):
    batch_size, slice_size = arrs[0].shape
    
    n_batches = int(slice_size/num_steps)
    for b in range(n_batches):
        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]

In [16]:
def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,
              learning_rate=0.001, grad_clip=5, sampling=False):
        
    if sampling == True:
        batch_size, num_steps = 1, 1

    tf.reset_default_graph()
    
    # Declare placeholders we'll feed into the graph
    with tf.name_scope('inputs'):
        inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')
        x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')
    
    with tf.name_scope('targets'):
        targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')
        y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')
        y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])
    
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')
    
    # Build the RNN layers
    with tf.name_scope("RNN_cells"):
        lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
        drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
        cell = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(lstm_size),
                        output_keep_prob=keep_prob) for _ in range(num_layers)])    
    with tf.name_scope("RNN_init_state"):
        initial_state = cell.zero_state(batch_size, tf.float32)

    # Run the data through the RNN layers
    with tf.name_scope("RNN_forward"):
        outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=initial_state)
    
    final_state = state
    
    # Reshape output so it's a bunch of rows, one row for each cell output
    with tf.name_scope('sequence_reshape'):
        seq_output = tf.concat(outputs, axis=1,name='seq_output')
        output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')
    
    # Now connect the RNN outputs to a softmax layer and calculate the cost
    with tf.name_scope('logits'):
        softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),
                               name='softmax_w')
        softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')
        logits = tf.matmul(output, softmax_w) + softmax_b
        tf.summary.histogram('softmax_w', softmax_w)
        tf.summary.histogram('softmax_b', softmax_b)

    with tf.name_scope('predictions'):
        preds = tf.nn.softmax(logits, name='predictions')
        tf.summary.histogram('predictions', preds)
    
    with tf.name_scope('cost'):
        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')
        cost = tf.reduce_mean(loss, name='cost')
        tf.summary.scalar('cost', cost)

    # Optimizer for training, using gradient clipping to control exploding gradients
    with tf.name_scope('train'):
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)
        train_op = tf.train.AdamOptimizer(learning_rate)
        optimizer = train_op.apply_gradients(zip(grads, tvars))
    
    merged = tf.summary.merge_all()
    
    # Export the nodes 
    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',
                    'keep_prob', 'cost', 'preds', 'optimizer', 'merged']
    Graph = namedtuple('Graph', export_nodes)
    local_dict = locals()
    graph = Graph(*[local_dict[each] for each in export_nodes])
    
    return graph

## Hyperparameters

Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are `lstm_size` and `num_layers`. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability.

In [17]:
batch_size = 100
num_steps = 100
lstm_size = 512
num_layers = 2
learning_rate = 0.001

## Training

Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by `save_every_n`) I calculate the validation loss and save a checkpoint.

In [18]:
!mkdir -p checkpoints/anna

In [19]:
epochs = 10
save_every_n = 100
train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)

model = build_rnn(len(vocab), 
                  batch_size=batch_size,
                  num_steps=num_steps,
                  learning_rate=learning_rate,
                  lstm_size=lstm_size,
                  num_layers=num_layers)

saver = tf.train.Saver(max_to_keep=100)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter('./logs/2/train', sess.graph)
    test_writer = tf.summary.FileWriter('./logs/2/test')
    
    # Use the line below to load a checkpoint and resume training
    #saver.restore(sess, 'checkpoints/anna20.ckpt')
    
    n_batches = int(train_x.shape[1]/num_steps)
    iterations = n_batches * epochs
    for e in range(epochs):
        
        # Train network
        new_state = sess.run(model.initial_state)
        loss = 0
        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):
            iteration = e*n_batches + b
            start = time.time()
            feed = {model.inputs: x,
                    model.targets: y,
                    model.keep_prob: 0.5,
                    model.initial_state: new_state}
            summary, batch_loss, new_state, _ = sess.run([model.merged, model.cost, 
                                                          model.final_state, model.optimizer], 
                                                          feed_dict=feed)
            loss += batch_loss
            end = time.time()
            print('Epoch {}/{} '.format(e+1, epochs),
                  'Iteration {}/{}'.format(iteration, iterations),
                  'Training loss: {:.4f}'.format(loss/b),
                  '{:.4f} sec/batch'.format((end-start)))
            
            train_writer.add_summary(summary, iteration)
        
            if (iteration%save_every_n == 0) or (iteration == iterations):
                # Check performance, notice dropout has been set to 1
                val_loss = []
                new_state = sess.run(model.initial_state)
                for x, y in get_batch([val_x, val_y], num_steps):
                    feed = {model.inputs: x,
                            model.targets: y,
                            model.keep_prob: 1.,
                            model.initial_state: new_state}
                    summary, batch_loss, new_state = sess.run([model.merged, model.cost, 
                                                               model.final_state], feed_dict=feed)
                    val_loss.append(batch_loss)
                    
                test_writer.add_summary(summary, iteration)

                print('Validation loss:', np.mean(val_loss),
                      'Saving checkpoint!')
                #saver.save(sess, "checkpoints/anna/i{}_l{}_{:.3f}.ckpt".format(iteration, lstm_size, np.mean(val_loss)))

Epoch 1/10  Iteration 1/1780 Training loss: 4.4201 13.0641 sec/batch
Epoch 1/10  Iteration 2/1780 Training loss: 4.3806 11.9964 sec/batch
Epoch 1/10  Iteration 3/1780 Training loss: 4.2447 10.1318 sec/batch
Epoch 1/10  Iteration 4/1780 Training loss: 4.4902 9.6179 sec/batch
Epoch 1/10  Iteration 5/1780 Training loss: 4.4368 9.4478 sec/batch
Epoch 1/10  Iteration 6/1780 Training loss: 4.3105 9.2586 sec/batch
Epoch 1/10  Iteration 7/1780 Training loss: 4.1999 9.5176 sec/batch
Epoch 1/10  Iteration 8/1780 Training loss: 4.1125 9.7781 sec/batch
Epoch 1/10  Iteration 9/1780 Training loss: 4.0384 10.1461 sec/batch
Epoch 1/10  Iteration 10/1780 Training loss: 3.9767 10.4247 sec/batch
Epoch 1/10  Iteration 11/1780 Training loss: 3.9197 10.7697 sec/batch
Epoch 1/10  Iteration 12/1780 Training loss: 3.8711 10.8080 sec/batch
Epoch 1/10  Iteration 13/1780 Training loss: 3.8286 11.0134 sec/batch
Epoch 1/10  Iteration 14/1780 Training loss: 3.7907 11.0898 sec/batch
Epoch 1/10  Iteration 15/1780 Trai

Epoch 1/10  Iteration 118/1780 Training loss: 3.2053 11.1554 sec/batch
Epoch 1/10  Iteration 119/1780 Training loss: 3.2035 11.2530 sec/batch
Epoch 1/10  Iteration 120/1780 Training loss: 3.2015 11.3239 sec/batch
Epoch 1/10  Iteration 121/1780 Training loss: 3.2008 11.2427 sec/batch
Epoch 1/10  Iteration 122/1780 Training loss: 3.2002 11.2076 sec/batch
Epoch 1/10  Iteration 123/1780 Training loss: 3.1983 11.1959 sec/batch
Epoch 1/10  Iteration 124/1780 Training loss: 3.1966 11.1456 sec/batch
Epoch 1/10  Iteration 125/1780 Training loss: 3.1945 11.2481 sec/batch
Epoch 1/10  Iteration 126/1780 Training loss: 3.1924 11.2236 sec/batch
Epoch 1/10  Iteration 127/1780 Training loss: 3.1906 11.2288 sec/batch
Epoch 1/10  Iteration 128/1780 Training loss: 3.1887 11.2023 sec/batch
Epoch 1/10  Iteration 129/1780 Training loss: 3.1865 11.1457 sec/batch
Epoch 1/10  Iteration 130/1780 Training loss: 3.1844 11.1483 sec/batch
Epoch 1/10  Iteration 131/1780 Training loss: 3.1825 11.1763 sec/batch
Epoch 

Epoch 2/10  Iteration 233/1780 Training loss: 2.4448 11.2180 sec/batch
Epoch 2/10  Iteration 234/1780 Training loss: 2.4437 11.2415 sec/batch
Epoch 2/10  Iteration 235/1780 Training loss: 2.4424 11.2729 sec/batch
Epoch 2/10  Iteration 236/1780 Training loss: 2.4408 11.2408 sec/batch
Epoch 2/10  Iteration 237/1780 Training loss: 2.4393 11.2036 sec/batch
Epoch 2/10  Iteration 238/1780 Training loss: 2.4382 11.1518 sec/batch
Epoch 2/10  Iteration 239/1780 Training loss: 2.4367 11.1604 sec/batch
Epoch 2/10  Iteration 240/1780 Training loss: 2.4356 11.2002 sec/batch
Epoch 2/10  Iteration 241/1780 Training loss: 2.4345 11.3227 sec/batch
Epoch 2/10  Iteration 242/1780 Training loss: 2.4331 11.1783 sec/batch
Epoch 2/10  Iteration 243/1780 Training loss: 2.4316 11.3378 sec/batch
Epoch 2/10  Iteration 244/1780 Training loss: 2.4306 11.1362 sec/batch
Epoch 2/10  Iteration 245/1780 Training loss: 2.4292 11.1976 sec/batch
Epoch 2/10  Iteration 246/1780 Training loss: 2.4274 11.2108 sec/batch
Epoch 

Epoch 2/10  Iteration 348/1780 Training loss: 2.3097 11.2025 sec/batch
Epoch 2/10  Iteration 349/1780 Training loss: 2.3086 11.1708 sec/batch
Epoch 2/10  Iteration 350/1780 Training loss: 2.3076 11.2172 sec/batch
Epoch 2/10  Iteration 351/1780 Training loss: 2.3068 11.1469 sec/batch
Epoch 2/10  Iteration 352/1780 Training loss: 2.3059 11.1627 sec/batch
Epoch 2/10  Iteration 353/1780 Training loss: 2.3050 11.1968 sec/batch
Epoch 2/10  Iteration 354/1780 Training loss: 2.3040 11.2171 sec/batch
Epoch 2/10  Iteration 355/1780 Training loss: 2.3030 11.1398 sec/batch
Epoch 2/10  Iteration 356/1780 Training loss: 2.3020 11.1335 sec/batch
Epoch 3/10  Iteration 357/1780 Training loss: 2.2043 11.2139 sec/batch
Epoch 3/10  Iteration 358/1780 Training loss: 2.1489 11.1358 sec/batch
Epoch 3/10  Iteration 359/1780 Training loss: 2.1344 11.1822 sec/batch
Epoch 3/10  Iteration 360/1780 Training loss: 2.1286 11.1412 sec/batch
Epoch 3/10  Iteration 361/1780 Training loss: 2.1268 11.1707 sec/batch
Epoch 

Epoch 3/10  Iteration 463/1780 Training loss: 2.0436 11.2055 sec/batch
Epoch 3/10  Iteration 464/1780 Training loss: 2.0430 11.2056 sec/batch
Epoch 3/10  Iteration 465/1780 Training loss: 2.0424 11.2392 sec/batch
Epoch 3/10  Iteration 466/1780 Training loss: 2.0418 11.0696 sec/batch
Epoch 3/10  Iteration 467/1780 Training loss: 2.0411 11.1347 sec/batch
Epoch 3/10  Iteration 468/1780 Training loss: 2.0404 11.1286 sec/batch
Epoch 3/10  Iteration 469/1780 Training loss: 2.0397 11.2194 sec/batch
Epoch 3/10  Iteration 470/1780 Training loss: 2.0390 11.1435 sec/batch
Epoch 3/10  Iteration 471/1780 Training loss: 2.0382 11.1461 sec/batch
Epoch 3/10  Iteration 472/1780 Training loss: 2.0372 11.1583 sec/batch
Epoch 3/10  Iteration 473/1780 Training loss: 2.0366 11.2216 sec/batch
Epoch 3/10  Iteration 474/1780 Training loss: 2.0359 11.3673 sec/batch
Epoch 3/10  Iteration 475/1780 Training loss: 2.0352 11.1979 sec/batch
Epoch 3/10  Iteration 476/1780 Training loss: 2.0346 11.3509 sec/batch
Epoch 

Epoch 4/10  Iteration 578/1780 Training loss: 1.8743 11.1993 sec/batch
Epoch 4/10  Iteration 579/1780 Training loss: 1.8740 11.2876 sec/batch
Epoch 4/10  Iteration 580/1780 Training loss: 1.8727 11.2934 sec/batch
Epoch 4/10  Iteration 581/1780 Training loss: 1.8721 11.1821 sec/batch
Epoch 4/10  Iteration 582/1780 Training loss: 1.8712 11.1947 sec/batch
Epoch 4/10  Iteration 583/1780 Training loss: 1.8706 11.1150 sec/batch
Epoch 4/10  Iteration 584/1780 Training loss: 1.8708 11.1868 sec/batch
Epoch 4/10  Iteration 585/1780 Training loss: 1.8696 11.2933 sec/batch
Epoch 4/10  Iteration 586/1780 Training loss: 1.8700 11.2534 sec/batch
Epoch 4/10  Iteration 587/1780 Training loss: 1.8693 11.4094 sec/batch
Epoch 4/10  Iteration 588/1780 Training loss: 1.8689 11.2936 sec/batch
Epoch 4/10  Iteration 589/1780 Training loss: 1.8684 11.3324 sec/batch
Epoch 4/10  Iteration 590/1780 Training loss: 1.8681 11.1549 sec/batch
Epoch 4/10  Iteration 591/1780 Training loss: 1.8682 11.1539 sec/batch
Epoch 

Epoch 4/10  Iteration 693/1780 Training loss: 1.8219 14.4908 sec/batch
Epoch 4/10  Iteration 694/1780 Training loss: 1.8217 14.2156 sec/batch
Epoch 4/10  Iteration 695/1780 Training loss: 1.8215 15.7206 sec/batch
Epoch 4/10  Iteration 696/1780 Training loss: 1.8211 17.3379 sec/batch
Epoch 4/10  Iteration 697/1780 Training loss: 1.8208 17.8525 sec/batch
Epoch 4/10  Iteration 698/1780 Training loss: 1.8205 19.6999 sec/batch
Epoch 4/10  Iteration 699/1780 Training loss: 1.8201 18.0227 sec/batch
Epoch 4/10  Iteration 700/1780 Training loss: 1.8196 18.3051 sec/batch
Validation loss: 1.64238 Saving checkpoint!
Epoch 4/10  Iteration 701/1780 Training loss: 1.8196 15.6898 sec/batch
Epoch 4/10  Iteration 702/1780 Training loss: 1.8197 17.0044 sec/batch
Epoch 4/10  Iteration 703/1780 Training loss: 1.8193 16.3866 sec/batch
Epoch 4/10  Iteration 704/1780 Training loss: 1.8190 16.1168 sec/batch
Epoch 4/10  Iteration 705/1780 Training loss: 1.8185 17.6150 sec/batch
Epoch 4/10  Iteration 706/1780 Tr

Epoch 5/10  Iteration 808/1780 Training loss: 1.7125 17.6544 sec/batch
Epoch 5/10  Iteration 809/1780 Training loss: 1.7122 15.2909 sec/batch
Epoch 5/10  Iteration 810/1780 Training loss: 1.7116 19.0527 sec/batch
Epoch 5/10  Iteration 811/1780 Training loss: 1.7111 15.2456 sec/batch
Epoch 5/10  Iteration 812/1780 Training loss: 1.7103 16.3975 sec/batch
Epoch 5/10  Iteration 813/1780 Training loss: 1.7100 17.8573 sec/batch
Epoch 5/10  Iteration 814/1780 Training loss: 1.7096 16.6621 sec/batch
Epoch 5/10  Iteration 815/1780 Training loss: 1.7092 18.9152 sec/batch
Epoch 5/10  Iteration 816/1780 Training loss: 1.7086 17.5868 sec/batch
Epoch 5/10  Iteration 817/1780 Training loss: 1.7081 18.6115 sec/batch
Epoch 5/10  Iteration 818/1780 Training loss: 1.7077 20.2528 sec/batch
Epoch 5/10  Iteration 819/1780 Training loss: 1.7074 20.2837 sec/batch
Epoch 5/10  Iteration 820/1780 Training loss: 1.7070 14.5191 sec/batch
Epoch 5/10  Iteration 821/1780 Training loss: 1.7068 14.4185 sec/batch
Epoch 

Epoch 6/10  Iteration 923/1780 Training loss: 1.6317 13.1841 sec/batch
Epoch 6/10  Iteration 924/1780 Training loss: 1.6318 18.2084 sec/batch
Epoch 6/10  Iteration 925/1780 Training loss: 1.6312 13.7539 sec/batch
Epoch 6/10  Iteration 926/1780 Training loss: 1.6309 16.6957 sec/batch
Epoch 6/10  Iteration 927/1780 Training loss: 1.6298 15.0020 sec/batch
Epoch 6/10  Iteration 928/1780 Training loss: 1.6284 13.7993 sec/batch
Epoch 6/10  Iteration 929/1780 Training loss: 1.6268 12.9770 sec/batch
Epoch 6/10  Iteration 930/1780 Training loss: 1.6261 15.3195 sec/batch
Epoch 6/10  Iteration 931/1780 Training loss: 1.6254 15.3128 sec/batch
Epoch 6/10  Iteration 932/1780 Training loss: 1.6256 14.5837 sec/batch
Epoch 6/10  Iteration 933/1780 Training loss: 1.6248 16.2066 sec/batch
Epoch 6/10  Iteration 934/1780 Training loss: 1.6235 13.2182 sec/batch
Epoch 6/10  Iteration 935/1780 Training loss: 1.6234 13.8582 sec/batch
Epoch 6/10  Iteration 936/1780 Training loss: 1.6222 17.9208 sec/batch
Epoch 

Epoch 6/10  Iteration 1038/1780 Training loss: 1.5965 11.7159 sec/batch
Epoch 6/10  Iteration 1039/1780 Training loss: 1.5964 11.7980 sec/batch
Epoch 6/10  Iteration 1040/1780 Training loss: 1.5961 11.7983 sec/batch
Epoch 6/10  Iteration 1041/1780 Training loss: 1.5956 11.8575 sec/batch
Epoch 6/10  Iteration 1042/1780 Training loss: 1.5955 11.9526 sec/batch
Epoch 6/10  Iteration 1043/1780 Training loss: 1.5953 11.8460 sec/batch
Epoch 6/10  Iteration 1044/1780 Training loss: 1.5951 11.9011 sec/batch
Epoch 6/10  Iteration 1045/1780 Training loss: 1.5949 12.1588 sec/batch
Epoch 6/10  Iteration 1046/1780 Training loss: 1.5947 12.1617 sec/batch
Epoch 6/10  Iteration 1047/1780 Training loss: 1.5946 12.2505 sec/batch
Epoch 6/10  Iteration 1048/1780 Training loss: 1.5944 11.7990 sec/batch
Epoch 6/10  Iteration 1049/1780 Training loss: 1.5939 11.8103 sec/batch
Epoch 6/10  Iteration 1050/1780 Training loss: 1.5938 12.1580 sec/batch
Epoch 6/10  Iteration 1051/1780 Training loss: 1.5938 11.8137 se

Epoch 7/10  Iteration 1152/1780 Training loss: 1.5389 16.4911 sec/batch
Epoch 7/10  Iteration 1153/1780 Training loss: 1.5385 19.6034 sec/batch
Epoch 7/10  Iteration 1154/1780 Training loss: 1.5381 17.6582 sec/batch
Epoch 7/10  Iteration 1155/1780 Training loss: 1.5378 18.2780 sec/batch
Epoch 7/10  Iteration 1156/1780 Training loss: 1.5374 17.6941 sec/batch
Epoch 7/10  Iteration 1157/1780 Training loss: 1.5368 15.3814 sec/batch
Epoch 7/10  Iteration 1158/1780 Training loss: 1.5368 17.7352 sec/batch
Epoch 7/10  Iteration 1159/1780 Training loss: 1.5365 16.3193 sec/batch
Epoch 7/10  Iteration 1160/1780 Training loss: 1.5362 17.5200 sec/batch
Epoch 7/10  Iteration 1161/1780 Training loss: 1.5357 18.4273 sec/batch
Epoch 7/10  Iteration 1162/1780 Training loss: 1.5353 19.8708 sec/batch
Epoch 7/10  Iteration 1163/1780 Training loss: 1.5349 20.1008 sec/batch
Epoch 7/10  Iteration 1164/1780 Training loss: 1.5349 18.7514 sec/batch
Epoch 7/10  Iteration 1165/1780 Training loss: 1.5348 18.0782 se

Epoch 8/10  Iteration 1266/1780 Training loss: 1.4943 11.4174 sec/batch
Epoch 8/10  Iteration 1267/1780 Training loss: 1.4938 11.6123 sec/batch
Epoch 8/10  Iteration 1268/1780 Training loss: 1.4944 11.5812 sec/batch
Epoch 8/10  Iteration 1269/1780 Training loss: 1.4930 12.7674 sec/batch
Epoch 8/10  Iteration 1270/1780 Training loss: 1.4927 18.0746 sec/batch
Epoch 8/10  Iteration 1271/1780 Training loss: 1.4925 13.4075 sec/batch
Epoch 8/10  Iteration 1272/1780 Training loss: 1.4905 12.6254 sec/batch
Epoch 8/10  Iteration 1273/1780 Training loss: 1.4889 12.1290 sec/batch
Epoch 8/10  Iteration 1274/1780 Training loss: 1.4890 13.3938 sec/batch
Epoch 8/10  Iteration 1275/1780 Training loss: 1.4891 14.5568 sec/batch
Epoch 8/10  Iteration 1276/1780 Training loss: 1.4891 16.6197 sec/batch
Epoch 8/10  Iteration 1277/1780 Training loss: 1.4887 16.1951 sec/batch
Epoch 8/10  Iteration 1278/1780 Training loss: 1.4874 12.8316 sec/batch
Epoch 8/10  Iteration 1279/1780 Training loss: 1.4876 12.2721 se

Epoch 8/10  Iteration 1380/1780 Training loss: 1.4665 11.9996 sec/batch
Epoch 8/10  Iteration 1381/1780 Training loss: 1.4664 12.7044 sec/batch
Epoch 8/10  Iteration 1382/1780 Training loss: 1.4664 11.9522 sec/batch
Epoch 8/10  Iteration 1383/1780 Training loss: 1.4664 12.0232 sec/batch
Epoch 8/10  Iteration 1384/1780 Training loss: 1.4664 12.1048 sec/batch
Epoch 8/10  Iteration 1385/1780 Training loss: 1.4664 11.9824 sec/batch
Epoch 8/10  Iteration 1386/1780 Training loss: 1.4662 12.1910 sec/batch
Epoch 8/10  Iteration 1387/1780 Training loss: 1.4665 11.9942 sec/batch
Epoch 8/10  Iteration 1388/1780 Training loss: 1.4663 11.9496 sec/batch
Epoch 8/10  Iteration 1389/1780 Training loss: 1.4662 12.1403 sec/batch
Epoch 8/10  Iteration 1390/1780 Training loss: 1.4663 11.9541 sec/batch
Epoch 8/10  Iteration 1391/1780 Training loss: 1.4660 12.7419 sec/batch
Epoch 8/10  Iteration 1392/1780 Training loss: 1.4660 12.1174 sec/batch
Epoch 8/10  Iteration 1393/1780 Training loss: 1.4660 11.8978 se

Epoch 9/10  Iteration 1494/1780 Training loss: 1.4271 11.5230 sec/batch
Epoch 9/10  Iteration 1495/1780 Training loss: 1.4277 11.4836 sec/batch
Epoch 9/10  Iteration 1496/1780 Training loss: 1.4279 11.7889 sec/batch
Epoch 9/10  Iteration 1497/1780 Training loss: 1.4284 11.4643 sec/batch
Epoch 9/10  Iteration 1498/1780 Training loss: 1.4279 11.5346 sec/batch
Epoch 9/10  Iteration 1499/1780 Training loss: 1.4279 11.5558 sec/batch
Epoch 9/10  Iteration 1500/1780 Training loss: 1.4279 11.4174 sec/batch
Validation loss: 1.31137 Saving checkpoint!
Epoch 9/10  Iteration 1501/1780 Training loss: 1.4290 10.9291 sec/batch
Epoch 9/10  Iteration 1502/1780 Training loss: 1.4289 10.8391 sec/batch
Epoch 9/10  Iteration 1503/1780 Training loss: 1.4283 11.7488 sec/batch
Epoch 9/10  Iteration 1504/1780 Training loss: 1.4281 12.1697 sec/batch
Epoch 9/10  Iteration 1505/1780 Training loss: 1.4276 11.9085 sec/batch
Epoch 9/10  Iteration 1506/1780 Training loss: 1.4275 11.6099 sec/batch
Epoch 9/10  Iteratio

Epoch 10/10  Iteration 1607/1780 Training loss: 1.4154 10.9156 sec/batch
Epoch 10/10  Iteration 1608/1780 Training loss: 1.4038 10.8536 sec/batch
Epoch 10/10  Iteration 1609/1780 Training loss: 1.4040 10.8155 sec/batch
Epoch 10/10  Iteration 1610/1780 Training loss: 1.4007 10.8245 sec/batch
Epoch 10/10  Iteration 1611/1780 Training loss: 1.3998 10.8677 sec/batch
Epoch 10/10  Iteration 1612/1780 Training loss: 1.3986 10.9325 sec/batch
Epoch 10/10  Iteration 1613/1780 Training loss: 1.3943 10.9843 sec/batch
Epoch 10/10  Iteration 1614/1780 Training loss: 1.3933 10.9180 sec/batch
Epoch 10/10  Iteration 1615/1780 Training loss: 1.3927 10.8571 sec/batch
Epoch 10/10  Iteration 1616/1780 Training loss: 1.3939 10.8295 sec/batch
Epoch 10/10  Iteration 1617/1780 Training loss: 1.3922 10.9348 sec/batch
Epoch 10/10  Iteration 1618/1780 Training loss: 1.3900 10.7568 sec/batch
Epoch 10/10  Iteration 1619/1780 Training loss: 1.3903 10.8105 sec/batch
Epoch 10/10  Iteration 1620/1780 Training loss: 1.3

Epoch 10/10  Iteration 1719/1780 Training loss: 1.3733 10.7932 sec/batch
Epoch 10/10  Iteration 1720/1780 Training loss: 1.3732 10.8362 sec/batch
Epoch 10/10  Iteration 1721/1780 Training loss: 1.3729 10.9534 sec/batch
Epoch 10/10  Iteration 1722/1780 Training loss: 1.3729 10.6819 sec/batch
Epoch 10/10  Iteration 1723/1780 Training loss: 1.3727 10.6743 sec/batch
Epoch 10/10  Iteration 1724/1780 Training loss: 1.3722 10.9389 sec/batch
Epoch 10/10  Iteration 1725/1780 Training loss: 1.3718 11.0053 sec/batch
Epoch 10/10  Iteration 1726/1780 Training loss: 1.3718 10.8767 sec/batch
Epoch 10/10  Iteration 1727/1780 Training loss: 1.3716 10.7155 sec/batch
Epoch 10/10  Iteration 1728/1780 Training loss: 1.3712 11.5818 sec/batch
Epoch 10/10  Iteration 1729/1780 Training loss: 1.3712 10.8739 sec/batch
Epoch 10/10  Iteration 1730/1780 Training loss: 1.3711 10.8728 sec/batch
Epoch 10/10  Iteration 1731/1780 Training loss: 1.3709 10.7315 sec/batch
Epoch 10/10  Iteration 1732/1780 Training loss: 1.3

In [20]:
tf.train.get_checkpoint_state('checkpoints/anna')

model_checkpoint_path: "checkpoints/anna/i178_l512_2.580.ckpt"
all_model_checkpoint_paths: "checkpoints/anna/i178_l512_2.580.ckpt"

## Sampling

Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.

The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.



In [17]:
def pick_top_n(preds, vocab_size, top_n=5):
    p = np.squeeze(preds)
    p[np.argsort(p)[:-top_n]] = 0
    p = p / np.sum(p)
    c = np.random.choice(vocab_size, 1, p=p)[0]
    return c

In [21]:
def sample(checkpoint, n_samples, lstm_size, vocab_size, prime="The "):
    prime = "Far"
    samples = [c for c in prime]
    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, checkpoint)
        new_state = sess.run(model.initial_state)
        for c in prime:
            x = np.zeros((1, 1))
            x[0,0] = vocab_to_int[c]
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

        c = pick_top_n(preds, len(vocab))
        samples.append(int_to_vocab[c])

        for i in range(n_samples):
            x[0,0] = c
            feed = {model.inputs: x,
                    model.keep_prob: 1.,
                    model.initial_state: new_state}
            preds, new_state = sess.run([model.preds, model.final_state], 
                                         feed_dict=feed)

            c = pick_top_n(preds, len(vocab))
            samples.append(int_to_vocab[c])
        
    return ''.join(samples)

In [44]:
checkpoint = "checkpoints/anna/i3560_l512_1.122.ckpt"
samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime="Far")
print(samp)

Farlathit that if had so
like it that it were. He could not trouble to his wife, and there was
anything in them of the side of his weaky in the creature at his forteren
to him.

"What is it? I can't bread to those," said Stepan Arkadyevitch. "It's not
my children, and there is an almost this arm, true it mays already,
and tell you what I have say to you, and was not looking at the peasant,
why is, I don't know him out, and she doesn't speak to me immediately, as
you would say the countess and the more frest an angelembre, and time and
things's silent, but I was not in my stand that is in my head. But if he
say, and was so feeling with his soul. A child--in his soul of his
soul of his soul. He should not see that any of that sense of. Here he
had not been so composed and to speak for as in a whole picture, but
all the setting and her excellent and society, who had been delighted
and see to anywing had been being troed to thousand words on them,
we liked him.

That set in her money at th

In [43]:
checkpoint = "checkpoints/anna/i200_l512_2.432.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farnt him oste wha sorind thans tout thint asd an sesand an hires on thime sind thit aled, ban thand and out hore as the ter hos ton ho te that, was tis tart al the hand sostint him sore an tit an son thes, win he se ther san ther hher tas tarereng,.

Anl at an ades in ond hesiln, ad hhe torers teans, wast tar arering tho this sos alten sorer has hhas an siton ther him he had sin he ard ate te anling the sosin her ans and
arins asd and ther ale te tot an tand tanginge wath and ho ald, so sot th asend sat hare sother horesinnd, he hesense wing ante her so tith tir sherinn, anded and to the toul anderin he sorit he torsith she se atere an ting ot hand and thit hhe so the te wile har
ens ont in the sersise, and we he seres tar aterer, to ato tat or has he he wan ton here won and sen heren he sosering, to to theer oo adent har herere the wosh oute, was serild ward tous hed astend..

I's sint on alt in har tor tit her asd hade shithans ored he talereng an soredendere tim tot hees. Tise sor 

In [46]:
checkpoint = "checkpoints/anna/i600_l512_1.750.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Fard as astice her said he celatice of to seress in the raice, and to be the some and sere allats to that said to that the sark and a cast a the wither ald the pacinesse of her had astition, he said to the sount as she west at hissele. Af the cond it he was a fact onthis astisarianing.


"Or a ton to to be that's a more at aspestale as the sont of anstiring as
thours and trey.

The same wo dangring the
raterst, who sore and somethy had ast out an of his book. "We had's beane were that, and a morted a thay he had to tere. Then to
her homent andertersed his his ancouted to the pirsted, the soution for of the pirsice inthirgest and stenciol, with the hard and and
a colrice of to be oneres,
the song to this anderssad.
The could ounterss the said to serom of
soment a carsed of sheres of she
torded
har and want in their of hould, but
her told in that in he tad a the same to her. Serghing an her has and with the seed, and the camt ont his about of the
sail, the her then all houg ant or to hus

In [47]:
checkpoint = "checkpoints/anna/i1000_l512_1.484.ckpt"
samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime="Far")
print(samp)

Farrat, his felt has at it.

"When the pose ther hor exceed
to his sheant was," weat a sime of his sounsed. The coment and the facily that which had began terede a marilicaly whice whether the pose of his hand, at she was alligated herself the same on she had to
taiking to his forthing and streath how to hand
began in a lang at some at it, this he cholded not set all her. "Wo love that is setthing. Him anstering as seen that."

"Yes in the man that say the mare a crances is it?" said Sergazy Ivancatching. "You doon think were somether is ifficult of a mone of
though the most at the countes that the
mean on the come to say the most, to
his feesing of
a man she, whilo he
sained and well, that he would still at to said. He wind at his for the sore in the most
of hoss and almoved to see him. They have betine the sumper into at he his stire, and what he was that at the so steate of the
sound, and shin should have a geest of shall feet on the conderation to she had been at that imporsing the