### Sampling with saved models

You can modify the file at "./save/checkpoint". The first line:

    model_checkpoint_path: "model.ckpt-???"
    
indicates the saved model you will use for sampling. The default value will be the last model you've saved.

Run the following two code block for sampling.

In [1]:
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.contrib import legacy_seq2seq

import numpy as np


class Model():
    def __init__(self, args, training=True):
        self.args = args
        if not training:
            args.batch_size = 1
            args.seq_length = 1

        # choose different rnn cell 
        if args.model == 'rnn':
            cell_fn = rnn.RNNCell
        elif args.model == 'gru':
            cell_fn = rnn.GRUCell
        elif args.model == 'lstm':
            cell_fn = rnn.LSTMCell
        elif args.model == 'nas':
            cell_fn = rnn.NASCell
        else:
            raise Exception("model type not supported: {}".format(args.model))

        # warp multi layered rnn cell into one cell with dropout
        cells = []
        for _ in range(args.num_layers):
            cell = cell_fn(args.rnn_size)
            if training and (args.output_keep_prob < 1.0 or args.input_keep_prob < 1.0):
                cell = rnn.DropoutWrapper(cell,
                                          input_keep_prob=args.input_keep_prob,
                                          output_keep_prob=args.output_keep_prob)
            cells.append(cell)
        self.cell = cell = rnn.MultiRNNCell(cells, state_is_tuple=True)

        # input/target data (int32 since input is char-level)
        self.input_data = tf.placeholder(
            tf.int32, [args.batch_size, args.seq_length])
        self.targets = tf.placeholder(
            tf.int32, [args.batch_size, args.seq_length])
        self.initial_state = cell.zero_state(args.batch_size, tf.float32)

        # softmax output layer, use softmax to classify
        with tf.variable_scope('rnnlm'):
            softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
            softmax_b = tf.get_variable("softmax_b", [args.vocab_size])

        # transform input to embedding
        embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
        inputs = tf.nn.embedding_lookup(embedding, self.input_data)

        # dropout beta testing: double check which one should affect next line
        if training and args.output_keep_prob:
            inputs = tf.nn.dropout(inputs, args.output_keep_prob)

        # unstack the input to fits in rnn model
        inputs = tf.split(inputs, args.seq_length, 1)
        inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

        # loop function for rnn_decoder, which take the previous i-th cell's output and generate the (i+1)-th cell's input
        def loop(prev, _):
            prev = tf.matmul(prev, softmax_w) + softmax_b
            prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
            return tf.nn.embedding_lookup(embedding, prev_symbol)

        # rnn_decoder to generate the ouputs and final state. When we are not training the model, we use the loop function.
        outputs, last_state = legacy_seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if not training else None, scope='rnnlm')
        output = tf.reshape(tf.concat(outputs, 1), [-1, args.rnn_size])

        # output layer
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        self.probs = tf.nn.softmax(self.logits)

        # loss is calculate by the log loss and taking the average.
        loss = legacy_seq2seq.sequence_loss_by_example(
                [self.logits],
                [tf.reshape(self.targets, [-1])],
                [tf.ones([args.batch_size * args.seq_length])])
        with tf.name_scope('cost'):
            self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
        self.final_state = last_state
        self.lr = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()

        # calculate gradients
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
                args.grad_clip)
        with tf.name_scope('optimizer'):
            optimizer = tf.train.AdamOptimizer(self.lr)

        # apply gradient change to the all the trainable variable.
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

        # instrument tensorboard
        tf.summary.histogram('logits', self.logits)
        tf.summary.histogram('loss', loss)
        tf.summary.scalar('train_loss', self.cost)

    def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
        state = sess.run(self.cell.zero_state(1, tf.float32))
        for char in prime[:-1]:
            x = np.zeros((1, 1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [state] = sess.run([self.final_state], feed)

        def weighted_pick(weights):
            t = np.cumsum(weights)
            s = np.sum(weights)
            return(int(np.searchsorted(t, np.random.rand(1)*s)))

        ret = prime
        char = prime[-1]
        for _ in range(num):
            x = np.zeros((1, 1))
            x[0, 0] = vocab[char]
            feed = {self.input_data: x, self.initial_state: state}
            [probs, state] = sess.run([self.probs, self.final_state], feed)
            p = probs[0]

            if sampling_type == 0:
                sample = np.argmax(p)
            elif sampling_type == 2:
                if char == ' ':
                    sample = weighted_pick(p)
                else:
                    sample = np.argmax(p)
            else:  # sampling_type == 1 default:
                sample = weighted_pick(p)

            pred = chars[sample]
            ret += pred
            char = pred
        return ret

In [2]:
import os
from six.moves import cPickle
from six import text_type

tf.flags.DEFINE_string('data_dir', 'data/tinyshakespeare', 'data directory containing input.txt with training examples')
tf.flags.DEFINE_string('save_dir', 'save', 'directory to store checkpointed models')
tf.flags.DEFINE_string('log_dir', 'logs', 'directory to store tensorboard logs')
tf.flags.DEFINE_integer('save_every', 1000, 'Save frequency. Number of passes between checkpoints of the model.')
tf.flags.DEFINE_string('init_from', None, "continue training from saved model at this path.")

# Model params

tf.flags.DEFINE_string('model', 'lstm', 'lstm, rnn, gru, or nas')
tf.flags.DEFINE_integer('rnn_size', 128, 'size of RNN hidden state')
tf.flags.DEFINE_integer('num_layers', 2, 'number of layers in the RNN')

# Optimization

tf.flags.DEFINE_integer('seq_length', 50, 'RNN sequence length. Number of timesteps to unroll for.')
tf.flags.DEFINE_integer('batch_size', 50, 'minibatch size.')
tf.flags.DEFINE_integer('num_epochs', 1, 'number of epochs. Number of full passes through the training examples.')
tf.flags.DEFINE_float('grad_clip', 5., 'clip gradients at this value')
tf.flags.DEFINE_float('learning_rate', 0.002, 'learning rate')
tf.flags.DEFINE_float('decay_rate', 0.97, 'decay rate for rmsprop')
tf.flags.DEFINE_float('output_keep_prob', 1.0, 'probability of keeping weights in the hidden layer')
tf.flags.DEFINE_float('input_keep_prob', 1.0, 'probability of keeping weights in the input layer')
tf.flags.DEFINE_integer('vocab_size', None, 'probability of keeping weights in the input layer')

# Sampling and Testing

tf.flags.DEFINE_integer('n', 500, 'number of characters to sample')
tf.flags.DEFINE_string('prime', u' ', 'prime text')
tf.flags.DEFINE_integer('sample', 1, '0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces')
tf.app.flags.DEFINE_string('f', '', 'kernel')
args = tf.flags.FLAGS

with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
    chars, vocab = cPickle.load(f)

args.vocab_size = 65
#Use most frequent char if no prime is given
if args.prime == '':
    args.prime = chars[0]
model = Model(args, training = False)
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver = tf.train.Saver(tf.global_variables())
    ckpt = tf.train.get_checkpoint_state(args.save_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample).encode('utf-8'))

INFO:tensorflow:Restoring parameters from save/model.ckpt-50
b" casdebiu st itty oa nlorwloeAw:lr\ne tew?OMe:ystKdlh aeu\nrarriTugprwlgercs u aoIK: IS hEt!w  nnh h; a imyerddot:  niia tr\nielTav peeiyahahf-d'e,E hvgtt b  ct?rwoi iSyCwt uwm\nfiwdg c\nankke fhd oH myvh Tso e:v. oaReeberdeI  orwk a nmeah,  ydmtreeerohoooeys ohcUib\n:o nbdsOwaanld dbc mie Tk\n. btieiRehDnetHIws nbopigi,e MCta,V  IiC\naLfg g:Tnuhny e tuf shudu' yi a?ntoe hWnIO b, ttittU u,hgyem Gt hak rMIKh gtr aoneh\n nnrg\nU?ssneiveMg h\nvab oulobioPatiaoTrnn,ko o tUtbnAec hkadedor c ,h uolhy \n Iueraowint"
