In [4]:
numgen = 800
starting_word = "Romeo"

In [6]:
import tensorflow as tf
import numpy as np
import os

In [7]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 
                                       'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
text = open(path_to_file, 'rb').read()
text = text.decode(encoding='utf-8')
vocab = sorted(set(text))
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])

In [8]:
char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

seq_length = 100
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)

In [9]:
def split_input_target(chunk):
  input_text = chunk[:-1]
  target_text = chunk[1:]
  return input_text, target_text

dataset = sequences.map(split_input_target)

In [10]:
BUFFER_SIZE = 10000

BATCH_SIZE = 64

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

# print(dataset)

In [11]:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

In [12]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

In [13]:
model = build_model(
    vocab_size = len(vocab),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units,
    batch_size=BATCH_SIZE)

# model.summary()

In [14]:
def loss(labels, logits):
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

model.compile(optimizer='adam', loss=loss)

In [15]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [None]:
# EPOCHS = 53
# history = model.fit(dataset, 
#                     epochs=EPOCHS, 
#                     callbacks=[checkpoint_callback])

In [16]:
# tf.train.latest_checkpoint(checkpoint_dir)

In [17]:
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
# model.summary()

In [18]:
def generate_text(model, num_generate, temperature, start_string):
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    text_generated = []
    model.reset_states()

    
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)
        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()

        input_eval = tf.expand_dims([predicted_id], 0)

        text_generated.append(idx2char[predicted_id])
    
    return (start_string + ''.join(text_generated))

In [19]:
generated_text = generate_text(model, num_generate=numgen, temperature=1, start_string=starting_word)

print(generated_text)

Romeo lamber-age;
Or I shall remain!

SEBASTIAN:
Ay, and for a king, and am to e, she will be corness as a man
Than fool it so, see that he does behove my husband's dower.

KING RICHARD III:
I hope I hear him cure it in the peace of worms,
When woo'd thee, of all degrees
A rancourt in the tortured safety of Lancaster.

VIRGILIA:
O, good my false fruit--

Ghy, alas! I hope thou wilt be crown'd Edw;
And, in the wars, shall I be tee this neighbour kindness seques,
With overy-horour'd Tranio, in the precious a treacherous censure.
If any servant to his Aufidius a horsed succes?
My noble masters, hear me quickly, your followers;
And mountion-one hath so plove, signior HArd butchers point at.

QUEEN MARGARET:
O, bid my sovereign liege, a kind of every days to thee,
That with a bloody days as false, a
