In [3]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import os
import time


path_to_file = keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

text = open(path_to_file, 'rb').read().decode(encoding='utf-8')

# print('Length of text: {} characters'.format(len(text)))

# print(text[:250])

vocab = sorted(set(text))
# print('{} unique characters'.format(len(vocab)))


char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

# print(text)
# print(char2idx[text[1]])

text_as_int = np.array([char2idx[c] for c in text])

# print('{')
# for char,_ in zip(char2idx, range(20)):
#   print(' {:4s}: {:3d},'.format(repr(char), char2idx[char]))
# print(' ...\n}')

# print('{} ---- characters mapped to int ----> {}'.format(repr(text[:13]), text_as_int[:13]))

seq_length = 100
examples_per_epoch = len(text)

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

# for i in char_dataset.take(10):
#   print(idx2char[i.numpy()])

sequences = char_dataset.batch(seq_length+1, drop_remainder=True)

# for item in sequences.take(5):
#   print(repr(''.join(idx2char[item.numpy()])))


def split_input_target(chunk):
  input_text = chunk[:-1]
  target_text = chunk[1:]
  return input_text, target_text

dataset = sequences.map(split_input_target)

# for input_example, target_example in dataset.take(1):
#   print('Input data:', repr(''.join(idx2char[input_example.numpy()])))
#   print('Target data:', repr(''.join(idx2char[target_example.numpy()])))

# print(input_example)

# for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):
#   print('Step {:4d}'.format(i))
#   print(' input: {} ({:s})'.format(input_idx, repr(idx2char[input_idx])))
#   print(' expected output: {} ({:s})'.format(target_idx, repr(idx2char[target_idx])))

BATCH_SIZE = 64
BUFFER_SIZE = 10000

# print(dataset)

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

# dataset

vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = keras.Sequential([
    keras.layers.Embedding(vocab_size, embedding_dim,
                batch_input_shape=[batch_size, None]),
    keras.layers.GRU(rnn_units,
            return_sequences=True,
            stateful=True,
            recurrent_initializer='glorot_uniform'),
    keras.layers.Dense(vocab_size)
  ])
  return model

model = build_model(vocab_size=len(vocab),
            embedding_dim=embedding_dim,
            rnn_units=rnn_units,
            batch_size=BATCH_SIZE)

# for input_example_batch, target_example_batch in dataset.take(1):
#   example_batch_predictions = model(input_example_batch)
#   print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

# model.summary()

# sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
# sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

# sampled_indices

# print('Input: \n', repr(''.join(idx2char[input_example_batch[0]])))
# print()
# print('Next Char Predictions: \n', repr(''.join(idx2char[sampled_indices])))


# def loss(labels, logits):
#   return keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

# example_batch_loss = loss(target_example_batch, example_batch_predictions)
# print('Prediction shape:', example_batch_predictions.shape, '# sampled_indices')
# print('scalar_loss:     ', example_batch_loss.numpy().mean())


# model.compile(optimizer='adam', loss=loss)

checkpoint_dir = './sample_data/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')

# checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
#   filepath=checkpoint_prefix,
#   save_weights_only=True
# )

# model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
# model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
# model.build(tf.TensorShape([1, None]))

# model.summary()

# EPOCHS=10
# history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

# def generate_text(model, start_string):
#   num_generate = 1000

#   input_eval = [char2idx[s] for s in start_string]
#   input_eval = tf.expand_dims(input_eval, 0)

#   text_generated = []
#   temperature = 1

#   model.reset_states()
#   for i in range(num_generate):
#       predictions = model(input_eval)
#       predictions = tf.squeeze(predictions, 0)

#       predictions = predictions / temperature
#       predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

#       input_eval = tf.expand_dims([predicted_id], 0)

#       text_generated.append(idx2char[predicted_id])

#   return (start_string + ''.join(text_generated))

# print(generate_text(model, start_string=u'ROMEO: '))

optimizer = keras.optimizers.Adam()

@tf.function
def train_step(inp, target):
  with tf.GradientTape() as tape:
    predictions = model(inp)
    loss = tf.reduce_mean(
      keras.losses.sparse_categorical_crossentropy(
        target, predictions, from_logits=True
      )
    )
  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

  return loss


EPOCHS = 10
for epoch in range(EPOCHS):
  start = time.time()

  hidden = model.reset_states()

  for (batch_n, (inp, target)) in enumerate(dataset):
    loss = train_step(inp, target)

    if batch_n % 100 == 0:
      template = 'Epoch {} Batch {} Loss {}'
      print(template.format(epoch+1, batch_n, loss))
    
  if(epoch+1) % 5 == 0:
    model.save_weights(checkpoint_prefix.format(epoch=epoch))
  print('Epoch {} Loss {:4f}'.format(epoch+1, loss))
  print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

model.save_weights(checkpoint_prefix.format(epoch=epoch))


Epoch 1 Batch 0 Loss 4.174200534820557
Epoch 1 Batch 100 Loss 2.332514762878418
Epoch 1 Loss 2.157851
Time taken for 1 epoch 25.76604175567627 sec

Epoch 2 Batch 0 Loss 2.1815009117126465
Epoch 2 Batch 100 Loss 1.9283359050750732
Epoch 2 Loss 1.776544
Time taken for 1 epoch 24.20400094985962 sec

Epoch 3 Batch 0 Loss 1.796547532081604
Epoch 3 Batch 100 Loss 1.7160378694534302
Epoch 3 Loss 1.608171
Time taken for 1 epoch 24.239969730377197 sec

Epoch 4 Batch 0 Loss 1.607374668121338
Epoch 4 Batch 100 Loss 1.5479381084442139
Epoch 4 Loss 1.527776
Time taken for 1 epoch 24.238996744155884 sec

Epoch 5 Batch 0 Loss 1.4365278482437134
Epoch 5 Batch 100 Loss 1.4929004907608032
Epoch 5 Loss 1.457773
Time taken for 1 epoch 24.296348094940186 sec

Epoch 6 Batch 0 Loss 1.382655382156372
Epoch 6 Batch 100 Loss 1.405850887298584
Epoch 6 Loss 1.392492
Time taken for 1 epoch 24.217530965805054 sec

Epoch 7 Batch 0 Loss 1.3137742280960083
Epoch 7 Batch 100 Loss 1.3746695518493652
Epoch 7 Loss 1.37293