<a href="https://colab.research.google.com/github/TrevorIkky/CharRNN/blob/main/CharRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
class StopCallback(keras.callbacks.Callback):
  """
  Metrics should be similar to what was defined in the model.compile method
  therefore the condition on on_epoch_end can change depending on the metrics
  """
  def __init__(self, metrics=0.97):
    super(StopCallback, self).__init__()
    self.metrics = metrics
  def on_epoch_end(self, epoch, logs=None):
    if logs.get('acc') >= self.metrics:
      print('Stopping model with {} accuracy'.format(logs.get('acc')))
      self.model.stop_training = True

In [3]:
class ResetStateCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    self.model.reset_states = True

In [4]:
class CharRNN:
  def __init__(self, url, accuracy_threshold=0.97):
    self.batch_size = 32
    self.window_size = 100 + 1
    self.accuracy_threshold = accuracy_threshold
    path = keras.utils.get_file("shakespear.txt", url)
    with open(path) as f:
      shakespear_text = f.read()
    tokenizer = Tokenizer(char_level=True, oov_token="<OOV>")
    tokenizer.fit_on_texts([shakespear_text])
    char_indices = tokenizer.word_index
    char_count = tokenizer.document_count
    [encoded] = np.array(tokenizer.texts_to_sequences([shakespear_text])) - 1
    train_size = len(encoded) * 80 // 100
    self.depth=len(char_indices)
    self.train = self.seq2seq_windows(encoded[:train_size], self.depth, self.window_size, self.batch_size)
   
  def seq2seq_windows(self, encoded_chars, max_depth, window_size=32, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices(encoded_chars)
    ds = ds.window(window_size, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda w: w.batch(window_size))
    ds = ds.map(lambda w: (w[:-1], w[1:]))
    ds = ds.map(lambda X, Y: (tf.one_hot(X, max_depth), Y))
    return ds.shuffle(1000).batch(batch_size).prefetch(1)

  def seq2seq_windows_fstateful(self, encoded_chars, max_depth, window_size=32, batch_size=32):
    #TODO: 
    pass

  def compile_stateful_model(self):
    #TODO make a stateful RNN that preserves the RNN's state between each iteration
    pass

  def compile_model(self):
    self.model = keras.models.Sequential([
      keras.layers.GRU(100, return_sequences=True, 
                       dropout=0.2, input_shape=[None, self.depth], 
                       recurrent_dropout=0.3),
      keras.layers.GRU(100, return_sequences=True, 
                       dropout=0.2, recurrent_dropout=0.3),
      keras.layers.TimeDistributed(keras.layers.Dense(self.depth, activation="softmax"))
      ])
    self.model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), 
              optimizer=keras.optimizers.Adam(), metrics=["acc"])
  def fit(self, accuracy):
    early_stopping=keras.callbacks.EarlyStopping(patience=4)
    lr_scheduler = keras.callbacks.LearningRateScheduler(lambda e: 1e-6 * 10**(e / 30))
    acc_stopping = StopCallback(accuracy)
    self.history = self.model.fit(self.train, epochs=1000, callbacks=[acc_stopping, early_stopping])

  def preprocess_text(self, text):
    [encoded_text] = np.array(tokenizer.text_to_sequences([text])) - 1
    return tf.one_hot(encoded_text, depth)
  
  def predict(self, text):
    x = preprocess_text(text)
    x_classes = self.history.predict_classes(x)[0, -1:, :]
    rescaled_logits = tf.math.log(x_classes)
    char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1
    return tokenizer.sequences_to_text(char_id.numpy())[0]
        
  def build_text(self, text, n_chars=40):
    for _ in range(n_chars):
      text += predict(text)
    return text

In [None]:
shakespear_url = "http://homl.info/shakespeare"
ch_rnn = CharRNN(shakespear_url, 0.98)
ch_rnn.compile_model()
#ch_rnn.fit(0.98)