In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [2]:
np.random.seed(42)
tf.random.set_seed(42)

n_steps = 5
dataset = tf.data.Dataset.from_tensor_slices(tf.range(15))
dataset = dataset.window(n_steps, shift=2, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(n_steps))
dataset = dataset.shuffle(10).map(lambda window: (window[:-1], window[1:]))
dataset = dataset.batch(3).prefetch(1)

for index, (x_batch, y_batch) in enumerate(dataset):
    print('_' * 20, 'Batch', index, '\nx_batch')
    print(x_batch.numpy())
    print('=' * 5, '\ny_batch')
    print(y_batch.numpy())

____________________ Batch 0 
x_batch
[[6 7 8 9]
 [2 3 4 5]
 [4 5 6 7]]
===== 
y_batch
[[ 7  8  9 10]
 [ 3  4  5  6]
 [ 5  6  7  8]]
____________________ Batch 1 
x_batch
[[ 0  1  2  3]
 [ 8  9 10 11]
 [10 11 12 13]]
===== 
y_batch
[[ 1  2  3  4]
 [ 9 10 11 12]
 [11 12 13 14]]


In [3]:
shakespeare_url = 'https://homl.info/shakespeare'
filepath = keras.utils.get_file('shakespeare.txt', shakespeare_url)

with open(filepath) as f:
    shakespeare_text = f.read()

In [4]:
print(shakespeare_text[:148])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?



In [5]:
''.join(sorted(set(shakespeare_text.lower())))

"\n !$&',-.3:;?abcdefghijklmnopqrstuvwxyz"

In [6]:
tokenizer = keras.preprocessing.text.Tokenizer(char_level=True)
tokenizer.fit_on_texts(shakespeare_text)

In [7]:
tokenizer.texts_to_sequences(['First'])

[[20, 6, 9, 8, 3]]

In [8]:
tokenizer.sequences_to_texts([[20, 6, 9, 8, 3]])

['f i r s t']

In [9]:
max_id = len(tokenizer.word_index)
dataset_size = tokenizer.document_count

In [10]:
[encoded] = np.array(tokenizer.texts_to_sequences([shakespeare_text])) - 1

In [11]:
train_size = dataset_size * 90 // 100
dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])

In [12]:
n_steps = 100
window_length = n_steps + 1
dataset = dataset.window(window_length, shift=1, drop_remainder=True)

In [13]:
dataset = dataset.flat_map(lambda window: window.batch(window_length))

In [14]:
np.random.seed(42)
tf.random.set_seed(42)

batch_size = 32
dataset = dataset.shuffle(10000).batch(batch_size)
dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))

In [15]:
dataset = dataset.map(
    lambda x_batch, y_batch: (tf.one_hot(x_batch, depth=max_id), y_batch))

In [16]:
dataset = dataset.prefetch(1)

In [17]:
for x_batch, y_batch in dataset.take(1):
    print(x_batch.shape, y_batch.shape)

(32, 100, 39) (32, 100)


In [18]:
model = keras.models.Sequential([
    keras.layers.GRU(128, return_sequences=True, 
                     input_shape=[None, max_id], 
                     dropout=0.2), 
    keras.layers.GRU(128, return_sequences=True, 
                     dropout=0.2), 
    keras.layers.TimeDistributed(
        keras.layers.Dense(max_id, activation='softmax'))
])

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')

history = model.fit(dataset, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [19]:
def preprocess(texts):
    x = np.array(tokenizer.texts_to_sequences(texts)) - 1
    return tf.one_hot(x, max_id)

In [20]:
x_new = preprocess(['How are yo'])
y_pred = np.argmax(model(x_new), axis=-1)
tokenizer.sequences_to_texts(y_pred + 1)[0][-1]

'u'

In [21]:
def next_char(text, temperature=1):
    x_new = preprocess([text])
    y_proba = model(x_new)[0, -1:, :]
    rescaled_logits = tf.math.log(y_proba) / temperature
    char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1
    return tokenizer.sequences_to_texts(char_id.numpy())[0]

In [22]:
next_char('How are yo', temperature=1)

'u'

In [23]:
def complete_text(text, n_chars=50, temperature=1):
    for _ in range(n_chars):
        text += next_char(text, temperature)
    return text

In [24]:
complete_text('t', temperature=0.2)

'ther for the sword of sea,\nthat she is not a scoldi'

In [25]:
complete_text('t', temperature=1)

'tod sea:\ni promised hor he will beside\nyour penile '

In [26]:
complete_text('t', temperature=2)

"tvely!\nwell maze: yel.' all heart ruch-jaaema'trds,"