# Text generation using a RNN ✍️ 

### Imports

In [1]:
import numpy as np
import time
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.utils import get_file
from tensorflow.strings import unicode_split, reduce_join, join
from tensorflow.keras.layers import StringLookup, Dense, GRU, Embedding
from tensorflow.data import Dataset, experimental
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow import constant, squeeze, SparseTensor, sparse, function
from tensorflow.random import categorical
from tensorflow.keras.models import Model

ModuleNotFoundError: No module named 'tensorflow'

### Get the data 📕

In [2]:
path_to_data = get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

### Have a look at the data 🔎

In [3]:
text = open(path_to_data, 'rb').read().decode(encoding='utf-8')

In [4]:
# Take a look at the first 250 characters in text
print(text[:250])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.



## 2️⃣ Preprocessing

### Vectorize the text

In [5]:
example_texts = ['abcdefg', 'xyz']

chars = unicode_split(example_texts, input_encoding='UTF-8')
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

### Generate the vocab 📖

❓ Generate a list of **unique characters** in our text and save it in the variable **`vocab`**.

In [6]:
vocab = sorted(set(text))

❓ Now create the StringLookup layer and save it as ids_from_chars:

In [7]:
ids_from_chars = StringLookup(vocabulary=vocab, mask_token=None)




❓ Use the layer below 👇 and edit chars variable above and see what happens when you add characters outside the vocab. 

In [8]:
ids = ids_from_chars(chars)
ids




<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>

❗️ Here instead of passing the original vocabulary generated with sorted(set(text)), use the get_vocabulary() method of the StringLookup to get the vocabulary assigned to the previous ids_from_chars layer. 

In [9]:
chars_from_ids = StringLookup(vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

This layer recovers the characters from the vectors of IDs, and returns them as a RaggedTensor of characters:

In [10]:
chars = chars_from_ids(ids)
chars

<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>

Use reduce_join to join the characters back into strings.

In [11]:
reduce_join(chars, axis=-1).numpy()

array([b'abcdefg', b'xyz'], dtype=object)

❓ Define a function text_from_ids that takes a tensor of ids and returns the corresponding text.

In [12]:
def text_from_ids(ids):
    return reduce_join(chars_from_ids(ids), axis=-1)

### The dataset 🚚

❓ First split our whole text using unicode_split and convert them all with ids_from_chars, to get all of our text as a single continuous array saved as all_ids.

In [13]:
all_ids = ids_from_chars(unicode_split(text, 'UTF-8'))
all_ids

<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46,  9,  1], dtype=int64)>

In [14]:
ids_dataset = Dataset.from_tensor_slices(all_ids)

The batch method allows us to set how many characters we should take at a time!

In [15]:
sequences = ids_dataset.batch(101, drop_remainder=True)

for seq in sequences.take(1):
    print(chars_from_ids(seq))

tf.Tensor(
[b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':'
 b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o'
 b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h'
 b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p'
 b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e'
 b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i'
 b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y'
 b'o' b'u' b' '], shape=(101,), dtype=string)


It's easier to see if we join the tokens back into strings 👇:

In [16]:
for seq in sequences.take(5):
    print(text_from_ids(seq).numpy())

b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'


❓ Write a function split_input_target which converts a sequence to a (input, label) pair.

In [17]:
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    print(input_text)
    print(target_text)
    return input_text, target_text

We map the function to our dataset.

In [18]:
dataset = sequences.map(split_input_target)
dataset

Tensor("strided_slice:0", shape=(100,), dtype=int64)
Tensor("strided_slice_1:0", shape=(100,), dtype=int64)


<_MapDataset element_spec=(TensorSpec(shape=(100,), dtype=tf.int64, name=None), TensorSpec(shape=(100,), dtype=tf.int64, name=None))>

Checkout what our dataset looks like now 👇

In [19]:
for input_example, target_example in dataset.take(1):
    print('Input :', text_from_ids(input_example).numpy())
    print('Target:', text_from_ids(target_example).numpy())

Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '


### Optimizing the dataset 🛠️

In [20]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (
    dataset
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(experimental.AUTOTUNE))
dataset

<_PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>

## 3️⃣ Building the Model

In [21]:
class MyModel(Model):
    def __init__(self, vocab_size):
        super().__init__(self)
        self.embedding = Embedding(vocab_size, 256)
        self.gru = GRU(1024, return_sequences=True, return_state=True)
        self.dense = Dense(vocab_size)
        
    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x, training=training)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)

        if return_state:
            return x, states
        else:
            return x

In [22]:
# Length of the vocabulary in StringLookup Layer
vocab_size = len(ids_from_chars.get_vocabulary())
model = MyModel(vocab_size=vocab_size)

❗️ For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character.

### Check the model 🔬

In [23]:
for input_example_batch, target_example_batch in dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, '# (batch_size, sequence_length, vocab_size)')

(64, 100, 66) # (batch_size, sequence_length, vocab_size)


Try it for the first example in the batch:

In [24]:
sampled_indices = categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = squeeze(sampled_indices, axis=-1).numpy()

In [25]:
sampled_indices

array([19,  6, 64, 55, 41,  2,  2, 31, 64, 23, 55,  7, 30, 58, 27, 38,  4,
        4,  4, 65,  6, 21, 39, 53, 60, 33, 24, 49,  0, 57, 28, 49, 46, 18,
       61, 59, 62, 31,  8, 23, 56, 25, 15, 24, 44, 37, 21, 26, 63, 56, 28,
       36, 63,  8, 42,  1, 22, 42, 53,  6, 58,  2, 46,  0, 36, 18, 22, 42,
        9, 39, 35, 30, 57,  5, 12, 60, 40, 50, 10, 42, 34, 38, 55,  2, 47,
       31, 37, 63,  5, 53, 29, 19, 41, 44, 46, 59, 36,  4, 55, 40],
      dtype=int64)

### Train the model 🏋️‍♂️

In [26]:
model.compile(optimizer='adam', loss=SparseCategoricalCrossentropy(from_logits=True))




In [27]:
%%time
history = model.fit(dataset, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
CPU times: total: 5h 9min 12s
Wall time: 19min 21s


## 4️⃣ Generate text 🧠

In [28]:
class OneStep(Model):
  def __init__(self, model, chars_from_ids, ids_from_chars):
    super().__init__()
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars

    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = SparseTensor(
        # Put a -inf at each bad index.
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape to the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = sparse.to_dense(sparse_mask)

  @function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs.
    input_chars = unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states,
                                          return_state=True)
    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = categorical(predicted_logits, num_samples=1)
    predicted_ids = squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)

    # Return the characters and model state.
    return predicted_chars, states

In [29]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

### Using the model 📝

In [30]:
%%time
states = None
next_char = constant(['Juliet: Where art thou, Romeo?'])
result = [next_char]

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(next_char, states=states)
    result.append(next_char)

result = join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)

Juliet: Where art thou, Romeo?

First Citizen:
Ay, my lord, poor Lucentio,
Can you will we trop you this islay the miserarish
perficed, you might the greaver dreeminations
When the swretch that with muinted that therein our general, alivert
Shoulderly gentle; then in die full with his prinusen
young anjer: heaven shall.

SICINIUS:
Ay, fooliaded manter lord,
The Dukely pleased of him foul attary.

First CKifnd:
Dir coveryot business! speak like strong: come hit
dut-birsches boods, by so friend themselvess showfulmness of my bidd
A rankly lipty for joys
Of the werved of aly as countil'd forth
And blows country you out to knot go.

Sailo, I
besaie our eyes, and the sweet clargest corn.

KING EDWARD IV:
Your bidh full by those thing thought--words in the house of that mouth, not
A man parts afterirgh, to my councry,
And spent to the kingdral sound feelingly:--

MIRGULE
S:
Before occiding down,
So hurthring: Through not restoul fly,
And I have sweet scole to go;
But when he crodds rid strai

In [31]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.