In [143]:
import numpy as np
import tensorflow as tf
import urllib

In [4]:
def reset_tf():
    global sess
    sess.close()
    tf.reset_default_graph()
    sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))

In [5]:
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))

In [187]:
reset_tf()

seq_length = 50
batch_size = 64
embedding_size = 64
hidden_size = 100
vocab_size = 256
num_layers = 2

input_data = tf.placeholder(tf.int32, [batch_size, seq_length])
input_lengths = tf.placeholder(tf.int32, [batch_size])
target_data = tf.placeholder(tf.int32, [batch_size, seq_length])

rnn_cell = tf.contrib.rnn.BasicRNNCell(hidden_size)
# rnn_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(hidden_size)] * num_layers)

initial_states = rnn_cell.zero_state(batch_size, tf.float32)

# embedded_inputs = tf.one_hot(input_data, vocab_size)

embedding = tf.get_variable('embedding', [vocab_size, embedding_size])
embedded_inputs = tf.nn.embedding_lookup(embedding, input_data)

softmax_w = tf.get_variable("softmax_w", [hidden_size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])

outputs, final_states = tf.nn.dynamic_rnn(rnn_cell,
                                          embedded_inputs, 
                                          initial_state=initial_states, 
                                          sequence_length=input_lengths)

flat_outputs = tf.reshape(outputs, [-1, hidden_size])
flat_targets = tf.reshape(target_data, [-1])

flat_output_logits = tf.matmul(flat_outputs, softmax_w) + softmax_b
flat_output_probs = tf.nn.softmax(flat_output_logits)

flat_loss_mask = tf.sign(tf.to_float(flat_targets))
flat_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=flat_output_logits, labels=flat_targets) * flat_loss_mask
# flat_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=flat_output_logits, labels=flat_targets)

mean_loss = tf.reduce_mean(flat_losses)
total_loss = tf.reduce_sum(flat_losses)

optimizer = tf.train.AdamOptimizer(1e-3)
gradients, variables = zip(*optimizer.compute_gradients(mean_loss))
gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
train_op = optimizer.apply_gradients(zip(gradients, variables))

# train_op = tf.train.AdamOptimizer(0.01).minimize(mean_loss)

sess.run(tf.global_variables_initializer())

In [166]:
def sample(initial_data, count):
    curr_initial_states  = np.zeros(initial_states.shape)
    curr_input_data = np.zeros(input_data.shape)
    curr_input_lengths = [1] + [0] * (batch_size - 1)
    
    result = [initial_data]
    
    for i in range(count):
        curr_input_data[0,0] = result[-1]
        ps, curr_initial_states = sess.run((flat_output_probs, final_states), feed_dict = {
            input_data: curr_input_data,
            input_lengths: curr_input_lengths,
            initial_states: curr_initial_states
        })
        result.append(np.random.choice(len(ps[0]), p = ps[0]))
    
    return result

In [26]:
def generate_batches(array, batch_size, seq_length):
    num_seqs = (len(array) + seq_length - 1) // seq_length
    num_seqs_per_batch = (num_seqs + batch_size - 1) // batch_size
    
    for i in range(num_seqs_per_batch):
        seqs = []
        seq_lens = []
        
        for j in range(batch_size):
            offset = (j*num_seqs_per_batch + i)*seq_length
            
            seq = array[offset:offset+seq_length]
            seq_len = len(seq)
            seq = np.pad(seq, (0,seq_length-len(seq)), 'constant', constant_values=0)
            
            seqs.append(seq)
            seq_lens.append(seq_len)
            
        yield np.stack(seqs), seq_lens
        
    return

In [147]:
train_text = None
with urllib.request.urlopen('http://textfiles.com/stories/13chil.txt') as response:
    train_text = response.read().decode("utf-8")
train_text = ' '.join(train_text.split())

In [178]:
train_array = np.array([ord(ch) for ch in train_text])

# train_array = np.array(list(range(30)))

for i in range(200):
    epoch_loss = 0.0
    curr_initial_states = np.zeros(initial_states.shape)
    train_input_batches = generate_batches(train_array[:-1], batch_size, seq_length)
    train_target_batches = generate_batches(train_array[1:], batch_size, seq_length)
    
    for (curr_input_data, curr_input_lens), (curr_target_data, _) in zip(train_input_batches, train_target_batches):
        feed_dict = {
            input_data: curr_input_data, 
            input_lengths: curr_input_lens,
            target_data: curr_target_data,
            initial_states: curr_initial_states }
        _, curr_loss, curr_initial_states = sess.run((train_op, total_loss, final_states), feed_dict = feed_dict)
        epoch_loss += curr_loss
        
    epoch_loss /= len(train_array) - 1
    
    if i % 10 == 0:
        print(f'epoch {i}: loss={epoch_loss}')


epoch 0: loss=0.24252834709323182
epoch 10: loss=0.12895266474509726
epoch 20: loss=0.0946833357519033
epoch 30: loss=0.08441579390545281
epoch 40: loss=0.07864886497964664
epoch 50: loss=0.07478695305026307
epoch 60: loss=0.07171535686570771
epoch 70: loss=0.06939767915375379
epoch 80: loss=0.06701291726560009
epoch 90: loss=0.06783491640674824
epoch 100: loss=0.07884371037385901
epoch 110: loss=0.4817079038036113
epoch 120: loss=0.18474281077482263
epoch 130: loss=0.1124166761125837
epoch 140: loss=0.09135466205830477
epoch 150: loss=0.08045643203112544
epoch 160: loss=0.07393209690950354
epoch 170: loss=0.0693089485168457
epoch 180: loss=0.065573415950853
epoch 190: loss=0.062494316879583865


In [184]:
''.join([chr(ch) for ch in sample(ord('R'), 100)])

'R¦abbe wals st and at wered over his splaws an then he looked over his spersend hor lon in eirey. But'