In [1]:
import os
import numpy as np
import re
import shutil
import tensorflow as tf

In [2]:
DATA_DIR = "./data"
CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")

In [6]:
def download_and_read(urls):
    texts = []
    for i, url in enumerate(urls):
        p = tf.keras.utils.get_file("ex1-{:d}.txt".format(i), url,
            cache_dir=".")
        text = open(p, mode="r", encoding="utf-8").read()
        # remove byte order mark
        text = text.replace("\ufeff", "")
        # remove newlines
        text = text.replace('\n', ' ')
        text = re.sub(r'\s+', " ", text)
        # add it to the list
        texts.extend(text)
    return texts

In [7]:
texts = download_and_read([
    "http://www.gutenberg.org/cache/epub/28885/pg28885.txt",
    "https://www.gutenberg.org/files/12/12-0.txt"
])

In [8]:
vocab = sorted(set(texts))
print("vocab size: {:d}".format(len(vocab)))

vocab size: 94


In [9]:
char2idx = {c:i for i, c in enumerate(vocab)}
idx2char = {i:c for c, i in char2idx.items()}

In [10]:
char2idx

{' ': 0,
 '!': 1,
 '"': 2,
 '#': 3,
 '$': 4,
 '%': 5,
 '&': 6,
 "'": 7,
 '(': 8,
 ')': 9,
 '*': 10,
 ',': 11,
 '-': 12,
 '.': 13,
 '/': 14,
 '0': 15,
 '1': 16,
 '2': 17,
 '3': 18,
 '4': 19,
 '5': 20,
 '6': 21,
 '7': 22,
 '8': 23,
 '9': 24,
 ':': 25,
 ';': 26,
 '?': 27,
 '@': 28,
 'A': 29,
 'B': 30,
 'C': 31,
 'D': 32,
 'E': 33,
 'F': 34,
 'G': 35,
 'H': 36,
 'I': 37,
 'J': 38,
 'K': 39,
 'L': 40,
 'M': 41,
 'N': 42,
 'O': 43,
 'P': 44,
 'Q': 45,
 'R': 46,
 'S': 47,
 'T': 48,
 'U': 49,
 'V': 50,
 'W': 51,
 'X': 52,
 'Y': 53,
 'Z': 54,
 '[': 55,
 ']': 56,
 '_': 57,
 'a': 58,
 'b': 59,
 'c': 60,
 'd': 61,
 'e': 62,
 'f': 63,
 'g': 64,
 'h': 65,
 'i': 66,
 'j': 67,
 'k': 68,
 'l': 69,
 'm': 70,
 'n': 71,
 'o': 72,
 'p': 73,
 'q': 74,
 'r': 75,
 's': 76,
 't': 77,
 'u': 78,
 'v': 79,
 'w': 80,
 'x': 81,
 'y': 82,
 'z': 83,
 '·': 84,
 'Æ': 85,
 'ù': 86,
 '—': 87,
 '‘': 88,
 '’': 89,
 '“': 90,
 '”': 91,
 '•': 92,
 '™': 93}

In [11]:
idx2char

{0: ' ',
 1: '!',
 2: '"',
 3: '#',
 4: '$',
 5: '%',
 6: '&',
 7: "'",
 8: '(',
 9: ')',
 10: '*',
 11: ',',
 12: '-',
 13: '.',
 14: '/',
 15: '0',
 16: '1',
 17: '2',
 18: '3',
 19: '4',
 20: '5',
 21: '6',
 22: '7',
 23: '8',
 24: '9',
 25: ':',
 26: ';',
 27: '?',
 28: '@',
 29: 'A',
 30: 'B',
 31: 'C',
 32: 'D',
 33: 'E',
 34: 'F',
 35: 'G',
 36: 'H',
 37: 'I',
 38: 'J',
 39: 'K',
 40: 'L',
 41: 'M',
 42: 'N',
 43: 'O',
 44: 'P',
 45: 'Q',
 46: 'R',
 47: 'S',
 48: 'T',
 49: 'U',
 50: 'V',
 51: 'W',
 52: 'X',
 53: 'Y',
 54: 'Z',
 55: '[',
 56: ']',
 57: '_',
 58: 'a',
 59: 'b',
 60: 'c',
 61: 'd',
 62: 'e',
 63: 'f',
 64: 'g',
 65: 'h',
 66: 'i',
 67: 'j',
 68: 'k',
 69: 'l',
 70: 'm',
 71: 'n',
 72: 'o',
 73: 'p',
 74: 'q',
 75: 'r',
 76: 's',
 77: 't',
 78: 'u',
 79: 'v',
 80: 'w',
 81: 'x',
 82: 'y',
 83: 'z',
 84: '·',
 85: 'Æ',
 86: 'ù',
 87: '—',
 88: '‘',
 89: '’',
 90: '“',
 91: '”',
 92: '•',
 93: '™'}

In [12]:
texts_as_ints = np.array([char2idx[c] for c in texts])
data = tf.data.Dataset.from_tensor_slices(texts_as_ints)

In [13]:
seq_length = 100
sequences = data.batch(seq_length + 1, drop_remainder=True)
def split_train_labels(sequence):
    input_seq = sequence[0:-1]
    output_seq = sequence[1:]
    return input_seq, output_seq

In [14]:
sequences = sequences.map(split_train_labels)

In [15]:
sequences

<MapDataset element_spec=(TensorSpec(shape=(100,), dtype=tf.int32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))>

In [16]:
batch_size = 64
steps_per_epoch = len(texts) // seq_length // batch_size
dataset = sequences.shuffle(10000).batch(
    batch_size, drop_remainder=True)

In [17]:
class CharGenModel(tf.keras.Model):
    def __init__(self, vocab_size, num_timesteps,
            embedding_dim, **kwargs):
        super(CharGenModel, self).__init__(**kwargs)
        self.embedding_layer = tf.keras.layers.Embedding(
            vocab_size,
            embedding_dim
        )
        self.rnn_layer = tf.keras.layers.GRU(
            num_timesteps,
            recurrent_initializer="glorot_uniform",
            recurrent_activation="sigmoid",
            stateful=True,
            return_sequences=True)
        self.dense_layer = tf.keras.layers.Dense(vocab_size)
    def call(self, x):
        x = self.embedding_layer(x)
        x = self.rnn_layer(x)
        x = self.dense_layer(x)
        return x

In [18]:
vocab_size = len(vocab)
embedding_dim = 256
model = CharGenModel(vocab_size, seq_length, embedding_dim)
model.build(input_shape=(batch_size, seq_length))

In [19]:
def loss(labels, predictions):
    return tf.losses.sparse_categorical_crossentropy(
        labels,
        predictions,
        from_logits=True
    )
model.compile(optimizer=tf.optimizers.Adam(), loss=loss)

In [20]:
def generate_text(model, prefix_string, char2idx, idx2char,
        num_chars_to_generate=1000, temperature=1.0):
    input = [char2idx[s] for s in prefix_string]
    input = tf.expand_dims(input, 0)
    text_generated = []
    model.reset_states()
    for i in range(num_chars_to_generate):
        preds = model(input)
        preds = tf.squeeze(preds, 0) / temperature
        # predict char returned by model
        pred_id = tf.random.categorical(preds, num_samples=1)[-1, 0].numpy()
        text_generated.append(idx2char[pred_id])
        # pass the prediction as the next input to the model
        input = tf.expand_dims([pred_id], 0)

    return prefix_string + "".join(text_generated)

In [21]:
num_epochs = 50
for i in range(num_epochs // 10):
    model.fit(
        dataset.repeat(),
        epochs=10,
        steps_per_epoch=steps_per_epoch
        # callbacks=[checkpoint_callback, tensorboard_callback]
    )
    checkpoint_file = os.path.join(
        CHECKPOINT_DIR, "model_epoch_{:d}".format(i+1))
    model.save_weights(checkpoint_file)
    # create generative model using the trained model so far
    gen_model = CharGenModel(vocab_size, seq_length, embedding_dim)
    gen_model.load_weights(checkpoint_file)
    gen_model.build(input_shape=(1, seq_length))
    print("after epoch: {:d}".format(i+1)*10)
    print(generate_text(gen_model, "Alice ", char2idx, idx2char))
    print("---")

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1after epoch: 1
Alice vaidly doy%s find timederaded, thou ding of the cours unanber. The said that inain. "You hat murs.” “She soor was hith evorse,'r dot toll, mose tich, and on an of thor Alice Hainl. Do?" "Ot'to the cut tike if _cmechmor orght brutroone™ tE qaG--woga sice coof-thing to mee tremants. EBmecr, thing Maing, abin. The ase or ham: you was, andibiguts wight Rizs, twaed of the Maice myY on: Nove in with Liglawee auns ghonh are reapl ard to coll," said than’ty suth Thech you wonty, in't gned ne every Whet her anbohe theted not’l? h, the recak retmon of Adise, and Thet as ormenter: quit do this noch, Tay tron's nece: “Do Umo hin, I jutse, re—turthed in welled upllent of cas to chas tas upiived bact, and as I bors on hie?” “Litelf] "E What (she 