In [1]:
import tensorflow as tf


# Stateful RNN

In [3]:
# Getting the Shakespeare text

shakespeare_url = "https://homl.info/shakespeare"
filepath = tf.keras.utils.get_file("shakespeare.txt", shakespeare_url)
with open(filepath) as f:
  shakespeare_text = f.read()

# Processing the text 

# Vectorizing by character - each character is now mapped to an integer
text_vec_layer = tf.keras.layers.TextVectorization(split="character", standardize="lower")
text_vec_layer.adapt([shakespeare_text])
encoded_text = text_vec_layer([shakespeare_text])[0]

# The TextVectorization layer uses 0 for padding and 1 for unknown chars. We don't need them
# in this case, so we can deduct 2 from all character keys so that they start at 0
encoded_text -= 2
n_tokens = text_vec_layer.vocabulary_size() - 2
dataset_size = len(encoded_text)

print ("Total number of characters: ", dataset_size)
print ("Number of unique characters: ", n_tokens)

Total number of characters:  1115394
Number of unique characters:  39


In [4]:
# Processing the dataset for a stateful RNN

# No shuffling, batches of size 1 and no overlapping sequences. For example,
# Text: "First citizen before sunrise"
# Sequence 1: "First citize" label: "n"
# Sequence 2: "n before sunris" label: "e"

def to_dataset_for_stateful_rnn(sequence, length):
  ds = tf.data.Dataset.from_tensor_slices(sequence)

  # When windowing the dataset, we don't have overlapping windows  
  ds = ds.window(length + 1, shift=length, drop_remainder=True)

  # Inner lambda unpacks the windows to tensors of given length.
  # Then we batch them at size 1 to help maintain state between batches
  ds = ds.flat_map(lambda window: window.batch(length + 1)).batch(1)

  # Map batches to train and label. No shuffling is done to maintain ordering
  return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

length = 100
tf.random.set_seed(42)

stateful_train_set = to_dataset_for_stateful_rnn(encoded_text[:1_000_000], length)
stateful_valid_set = to_dataset_for_stateful_rnn(encoded_text[1_000_000:1_060_000], length)
stateful_test_set = to_dataset_for_stateful_rnn(encoded_text[1_060_000:], length)


In [None]:
# Stateful RNN model

model = tf.keras.Sequential([
  # A stateful RNN needs to know the batch size. Input length can be None since inputs can be of any length. 
  tf.keras.layers.Embedding(input_dim=n_tokens, output_dim=16, batch_input_shape=[1,None]),

  # stateful=True will make the hidden state to be shared between sequences
  tf.keras.layers.GRU(128, return_sequences=True, stateful=True),

  tf.keras.layers.Dense(n_tokens, activation="softmax")
])

model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])

# Reset the state between epochs (each epoch is a pass over the entire training dataset)
class ResetStateCallback(tf.keras.callbacks.Callback):
  def on_epoch_begin(self, epoch, logs=None):
    self.model.reset_states()

model_ckpt = tf.keras.callbacks.ModelCheckpoint("my_shakespeare_model", monitor="val_accuracy", save_best_only=True)
history = model.fit(stateful_train_set, validation_data=stateful_valid_set, epochs=10, 
                    callbacks=[ ResetStateCallback(), model_ckpt])



Training on kaggle...

In [8]:
# Load the trained model
loaded_model = tf.keras.models.load_model("models/my_shakespeare_model_stateful")

# Wrap it with the preprocessing step
shakespeare_model = tf.keras.Sequential([
  text_vec_layer,
  tf.keras.layers.Lambda(lambda X: X - 2), # no PAD or UNKNOWN tokens
  loaded_model
])

# Generating text

# Since this model outputs one character at a time, we can add the predicted character to the 
# seed text and resend to the model for prediction in a loop. This approach is called "greedy decoding" and
# in practice it just repeats the same word over and over.

# Instead we'll output all the probabilities of the next character and choose the next one according to a 
# parameter called "temperature". This parameter is between 0-1. Values closer to 0 will choose the higher
# probability character whereas values closer to 1 will choose the lower probability ones, adding to the
# randomness.

def next_char(text_model, text, temperature=1):
  y_proba = text_model.predict([text])[0, -1:]
  rescaled_logits = tf.math.log(y_proba) / temperature
  char_id = tf.random.categorical(rescaled_logits, num_samples=1)[0, 0]

  return text_vec_layer.get_vocabulary()[char_id + 2]

def extend_text(text_model, text, n_chars=50, temperature=1):
  for _ in range(n_chars):
    text += next_char(text_model, text, temperature)
  
  return text

print (extend_text(shakespeare_model, "What about boar spears?", n_chars=200, temperature=0.2))




What about boar spears?

gremio:
have i will be so must in a state and here.

gremio:
and i will be the gremio, and i will the great and life
and her seems and so heard and man and still and
the great and the great and will
