In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfds

import numpy as np
import os
import time
import re

In [None]:
path_to_file = tf.keras.utils.get_file('alice_in_wonderland.txt', 'https://gist.githubusercontent.com/phillipj/4944029/raw/75ba2243dd5ec2875f629bf5d79f6c1e4b5a8b46/alice_in_wonderland.txt')

Downloading data from https://gist.githubusercontent.com/phillipj/4944029/raw/75ba2243dd5ec2875f629bf5d79f6c1e4b5a8b46/alice_in_wonderland.txt


In [None]:
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')

In [None]:
text = text[180:]
print(text[:500])

                           CHAPTER I

                      Down the Rabbit-Hole


  Alice was beginning to get very tired of sitting by her sister
on the bank, and of having nothing to do:  once or twice she had
peeped into the book her sister was reading, but it had no
pictures or conversations in it, `and what is the use of a book,'
thought Alice `without pictures or conversation?'

  So she was considering in her own mind (as well as she could,
for the hot day made her feel very sleepy and s


In [None]:
def preprocess(text):
    res = text.lower()
    res = re.sub('chapter\s+\w+\s+[\w\-\' ?]+\n+', '', res)
    res = re.sub('^\s+', '' ,res)
    res = re.sub('\n', ' ', res)
    res = re.sub(' +', ' ', res)
    res = re.sub('(\*\s*)+', '', res)
    res = re.sub('[\'\"\[\]`_]', '', res)
    res = re.sub('[?!;:]', '.', res)
    res = re.sub('[\(\)]', ',', res)
    res = re.sub('--', ',', res)
    return res

In [None]:
processed_text = preprocess(text)
print(processed_text[:100])
print(f'\nLength of preprocessed text: {len(processed_text)} characters')

alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to

Length of preprocessed text: 138632 characters


In [None]:
maxlen = 132

In [None]:
init_lines = processed_text.split('.')

lines = []
for line in init_lines:
    if len(line) <= maxlen:
        lines.append(line + '.')
        continue
    parts = line.split(',')
    buf = ""
    for part in parts:
        if len(buf) + len(part) + 1 <= maxlen:
            buf += ("," + part)
        else:
            lines.append(buf + '.')
            buf = part

In [None]:
vocab = list(sorted(set(processed_text)))
vocab = ['*'] + vocab
print(f'{len(vocab)} unique characters')

31 unique characters


In [None]:
vocab_size = len(vocab)
input_len = maxlen - 1
batch_size = 128
end_char_num = vocab.index('.')

In [None]:
encode = {vocab[i]:i for i in range(vocab_size)}

encoded_lines = np.array([[encode[line[i]] if i < len(line) else 0  for i in range(maxlen)] 
                          for line in lines])

x = [line[:-1] for line in encoded_lines]
y = [line[1:] for line in encoded_lines]

ds = tf.data.Dataset.from_tensor_slices((x, y))
ds = ds.batch(batch_size)

In [None]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

In [None]:
def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)

    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

In [None]:
class OneHotAndPositionEmbedding(layers.Layer):
    def __init__(self, max_len, vocab_size):
        super(OneHotAndPositionEmbedding, self).__init__()
        self.pos_enc = positional_encoding(max_len, vocab_size)
        self.vocab_size = vocab_size

    def call(self, x):
        seq_len = tf.shape(x)[-1]
        x = tf.one_hot(x, self.vocab_size)
        positions = self.pos_enc[:, :seq_len, :]
        return x + positions

In [None]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)


class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads, embed_dim)
        self.ffn = tf.keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
        attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
        attention_output = self.dropout1(attention_output)
        out1 = self.layernorm1(inputs + attention_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [None]:
num_heads = 8
feed_forward_dim = 256

def create_model():
    inputs = layers.Input(shape=(input_len,), dtype=tf.int32)
    embedding_layer = OneHotAndPositionEmbedding(input_len, vocab_size)
    x = embedding_layer(inputs)
    transformer_block1 = TransformerBlock(vocab_size, num_heads, feed_forward_dim)
    x = transformer_block1(x)
    transformer_block2 = TransformerBlock(vocab_size, num_heads, feed_forward_dim)
    x = transformer_block2(x)
    transformer_block3 = TransformerBlock(vocab_size, num_heads, feed_forward_dim)
    x = transformer_block3(x)
    transformer_block4 = TransformerBlock(vocab_size, num_heads, feed_forward_dim)
    x = transformer_block4(x)
    outputs = layers.Dense(vocab_size)(x)
    model = tf.keras.Model(inputs=inputs, outputs=[outputs, x])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(
        "adam", loss=[loss_fn, None],
    )  # No loss and optimization based on word embeddings from transformer block
    return model

In [None]:
class TextGenerator(tf.keras.callbacks.Callback):
    def __init__(
        self, max_tokens, start_tokens, vocab, top_k=5, print_every=1
    ):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.vocab = vocab
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = tf.keras.activations.softmax(tf.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.vocab[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = [_ for _ in self.start_tokens]
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while (num_tokens_generated <= self.max_tokens and 
               (len(tokens_generated) == 0 or 
                (tokens_generated[-1] != end_char_num and
                tokens_generated[-1] != 0))):
            pad_len = input_len - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:input_len]
                sample_index = input_len - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = "".join(
            [self.detokenize(_) for _ in self.start_tokens + tokens_generated]
        )
        print(f"generated text:\n{txt}\n")

In [None]:
line_start = "after a time "
start_tokens = [encode[char] for char in line_start]
text_gen_callback = TextGenerator(maxlen, start_tokens, vocab, print_every=10)

In [None]:
model = create_model()

# one decoder, 3 attention heads
model.fit(ds, verbose=1, epochs=500, callbacks=[text_gen_callback])

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
generated text:
after a time stharend e e int  a het i hedo o t   ot se  and  t asngert o inet ond be eeee*

Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
generated text:
after a time in a in s shed thedo tor atese t so an the shan thedr that thes the athe arone s wailer went owaon aoule t s t nd.

Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
generated text:
after a time he tout and the wher are hesthing tha whe t o we t ithe thisth aryere s tinge she wou onert o storyour thisha ti t  .

Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
generated text:
after a time wated sat the shin thare she ar sho toous, whe.

Epoch 41/

<tensorflow.python.keras.callbacks.History at 0x7f1c3965f650>

In [None]:
# 3 decoders, 3 attention heads
model2 = create_model()
text_gen_callback2 = TextGenerator(maxlen, start_tokens, vocab, print_every=50)
model2.fit(ds, verbose=0, epochs=1000, callbacks=[text_gen_callback2])

generated text:
after a time the shitt he treast saichte said the wathing her tis, a theres a whad.

generated text:
after a time tont the was as way, saw in what throw a would thre shingss tulle.

generated text:
after a time to the hattence all had hight a little be to befure.

generated text:
after a time of any, some sare though, which i march.

generated text:
after a time anythe was spil, the mouse.

generated text:
after a time of the griess, when a book in the word of him her her show whis was on one sigh to to simpeas.

generated text:
after a time in a little three way more, i though and seem or the work.

generated text:
after a time of make, i this mouse or way, ant i was ould to the duches went of beathing that with a litted whice.

generated text:
after a time of minute one shart, and, thos was going to be, as the cand off the timing it was the moutter, or with off.

generated text:
after a time it mean as the cormouse shouse.

generated text:
after a time you dance to, t

<tensorflow.python.keras.callbacks.History at 0x7f1c37884a50>

In [None]:
# 4 decoders, 8 attention heads
line_start3 = "alice "
start_tokens = [encode[char] for char in line_start3]
text_gen_callback3 = TextGenerator(maxlen, start_tokens, vocab, top_k=3, print_every=10)

model3 = create_model()
model3.fit(ds, verbose=0, epochs=1000, callbacks=[text_gen_callback3])

generated text:
alice the  theth   antou as s o t thee a the the   the   a t   s an t s   t s t  athe   the e  t te hait he thee he  heatou rar  .

generated text:
alice as t soure ano and in the ase the thenger the she ang thed an thou s se t ishar teris the thither terer the t   ere.

generated text:
alice she thon the shat to the te the an the wat the wery to in ther the oure wo wond s are o to the the thes thas t reasere t ar .

generated text:
alice sha thase sa wan son the that sher to sour wast this to bis *

generated text:
alice the to so aid there tait she wat thit wat was sand, the he to the sousthe whouters the the seat and to wis wen the then t.

generated text:
alice alice.

generated text:
alice the weren the do greantly.

generated text:
alice the mane had a got in the go the was of the said the mistenerse hered whas the drook shat the say sain.

generated text:
alice andid the matter toice the kis what wit so tremough the dind they was they do to the wat with way the t

<tensorflow.python.keras.callbacks.History at 0x7f56802dc890>