# **Text generation with an RNN**

following [tensorflow tutorial](https://www.tensorflow.org/text/tutorials/text_generation).

This tutorial looks at training an RNN model to predict the next character in a sequence. The model is trained on text written by shakespeare.

Something to keep in mind. Text generation process demonstrated in Udacity.
- Tokenization, followed by converting the texts to sequences.

- Each sequence would have a set length. From each sequences, we use the last token as the label and the remaining sequences as the feature vector.

- Convert the labels into one-hot vectors, with it's length being the vocabulary size

- Train a classification model on the sequences and one-hot labels.

- We would then generate text by providing a seed word followed by multiple inference.


##**Import dependencies**

In [1]:
import tensorflow as tf
import numpy as np
import os
import time

print(tf.__version__)

2.8.2


## **Get the dataset**

In [2]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt


In [None]:
# decode the text
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
print(text)

In [4]:
print(f"Number of characters in the text: {len(text)}")

Number of characters in the text: 1115394


In [5]:
# find the number of unique characters in the file
vocab = sorted(set(text))
print(vocab)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [6]:
print(f"Number of unique charachter in the text is {len(vocab)}")

Number of unique charachter in the text is 65


The approach used here was to download the data as a file, read and then decode it into a list. This is miles better than downloading it  using tensorflow dataset and then loading it.

An intresting note there, the vocab size was determined by the number of unique characthers in the dataset. This aligns with the task the model would be trained to do which is to predict the next probable characther given an initial characther.


## **Process the text**

In [7]:
# encode sample string into utf-8 format
example_texts = ["megasxlr", "theweekend"]

example_text_chars = tf.strings.unicode_split(example_texts, input_encoding="UTF-8")
print(example_text_chars)

<tf.RaggedTensor [[b'm', b'e', b'g', b'a', b's', b'x', b'l', b'r'],
 [b't', b'h', b'e', b'w', b'e', b'e', b'k', b'e', b'n', b'd']]>


In [None]:
for index, char_ in enumerate(list(vocab)):
  print(f"index:{index}, character:{char_}")

**Define a text encoder**

In [9]:
# create an encoder to convert the string into token
ids_from_chars = tf.keras.layers.StringLookup(vocabulary=list(vocab),
                                              mask_token=None)


In [10]:
ids = ids_from_chars(example_text_chars)
print(ids)

<tf.RaggedTensor [[52, 44, 46, 40, 58, 63, 51, 57],
 [59, 47, 44, 62, 44, 44, 50, 44, 53, 43]]>


it seems like it's encoded the text using the index from the initial vocab (with a slight offset).

It looks like [**`tf.keras.layers.StringLookUp`**](https://www.tensorflow.org/api_docs/python/tf/keras/layers/StringLookup) is an alternative to the Tokenizer function which is deprecated in v2.9. The StringLookup is also able to create one-hot vectors to use as tokens for each characthers.

Another alternative for tokenization is the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization)

**Define a text decoder**

In [11]:
# create a text decoder
chars_from_ids = tf.keras.layers.StringLookup(vocabulary=ids_from_chars.get_vocabulary(),
                                              invert=True,
                                              mask_token=None)


In [12]:
decoded_chars = chars_from_ids(ids)
print(decoded_chars)

<tf.RaggedTensor [[b'm', b'e', b'g', b'a', b's', b'x', b'l', b'r'],
 [b't', b'h', b'e', b'w', b'e', b'e', b'k', b'e', b'n', b'd']]>


In [13]:
# the decoded chars are returns as list of char. We can join the individual
# chars back into a string
tf.strings.reduce_join(decoded_chars, axis=-1).numpy()

array([b'megasxlr', b'theweekend'], dtype=object)

In [14]:
print(decoded_chars)

<tf.RaggedTensor [[b'm', b'e', b'g', b'a', b's', b'x', b'l', b'r'],
 [b't', b'h', b'e', b'w', b'e', b'e', b'k', b'e', b'n', b'd']]>


In [15]:
def text_from_ids(ids):
  """Function defined to convert string tokens back into an array of strings."""
  return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

The trained model would be able to predict the next probable character given an initial characther or sequence of characther.

Towards this
- our training dataset needs to contain text where the input contain parts of the text and label contains the remaining parts of it.
- Example: Input: Megas, Label: Megasx


In [16]:
# convert the individual text in the dialog into chars
text_split_into_individual_chars = tf.strings.unicode_split(text, 'UTF-8')
print(text_split_into_individual_chars)


tf.Tensor([b'F' b'i' b'r' ... b'g' b'.' b'\n'], shape=(1115394,), dtype=string)


In [17]:
# convert each characthers into tokens
text_ids = ids_from_chars(text_split_into_individual_chars)
print(text_ids)


tf.Tensor([19 48 57 ... 46  9  1], shape=(1115394,), dtype=int64)


In [18]:
# generate a tensorsliceDataset from the tensorslice
ids_dataset = tf.data.Dataset.from_tensor_slices(text_ids)
print(ids_dataset)

<TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>


In [19]:
for ids in ids_dataset.take(10):
  print(chars_from_ids(ids).numpy().decode('utf-8'))

F
i
r
s
t
 
C
i
t
i


In [20]:
# define the max_length of the sequences
seq_length = 100


In [21]:
# from the dataset generate a batch with a length of 101
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)

# display a single batch
for seq in sequences.take(5):
  print(text_from_ids(seq).numpy())


b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'


In [22]:
for seq in sequences.take(5):
  print(f"{seq}")

[19 48 57 58 59  2 16 48 59 48 65 44 53 11  1 15 44 45 54 57 44  2 62 44
  2 55 57 54 42 44 44 43  2 40 53 64  2 45 60 57 59 47 44 57  7  2 47 44
 40 57  2 52 44  2 58 55 44 40 50  9  1  1 14 51 51 11  1 32 55 44 40 50
  7  2 58 55 44 40 50  9  1  1 19 48 57 58 59  2 16 48 59 48 65 44 53 11
  1 38 54 60  2]
[40 57 44  2 40 51 51  2 57 44 58 54 51 61 44 43  2 57 40 59 47 44 57  2
 59 54  2 43 48 44  2 59 47 40 53  2 59 54  2 45 40 52 48 58 47 13  1  1
 14 51 51 11  1 31 44 58 54 51 61 44 43  9  2 57 44 58 54 51 61 44 43  9
  1  1 19 48 57 58 59  2 16 48 59 48 65 44 53 11  1 19 48 57 58 59  7  2
 64 54 60  2 50]
[53 54 62  2 16 40 48 60 58  2 26 40 57 42 48 60 58  2 48 58  2 42 47 48
 44 45  2 44 53 44 52 64  2 59 54  2 59 47 44  2 55 44 54 55 51 44  9  1
  1 14 51 51 11  1 36 44  2 50 53 54 62  6 59  7  2 62 44  2 50 53 54 62
  6 59  9  1  1 19 48 57 58 59  2 16 48 59 48 65 44 53 11  1 25 44 59  2
 60 58  2 50 48]
[51 51  2 47 48 52  7  2 40 53 43  2 62 44  6 51 51  2 47 40 61 44  2 42


In [23]:
# define a function to split a sequence into a feature vector and a label
def split_input_target(sequence):
  input_text = sequence[:-1]
  target_text = sequence[1:]
  return input_text, target_text

# example
split_input_target(list("Megasxlr"))

(['M', 'e', 'g', 'a', 's', 'x', 'l'], ['e', 'g', 'a', 's', 'x', 'l', 'r'])

In [24]:
# Apply the function into the batched sequence
dataset = sequences.map(split_input_target)
print(dataset)

<MapDataset element_spec=(TensorSpec(shape=(100,), dtype=tf.int64, name=None), TensorSpec(shape=(100,), dtype=tf.int64, name=None))>


In [25]:
for input_example, target_example in dataset.take(1):
  print(f"input: {text_from_ids(input_example).numpy()}")
  print(f"label: {text_from_ids(target_example).numpy()}")

input: b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'
label: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '


wow, the feature and label are not so different.

**Create training batches from the mapped dataset**

In [26]:
BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = (dataset
           .shuffle(BUFFER_SIZE)
           .batch(BATCH_SIZE, drop_remainder=True)
           .prefetch(tf.data.experimental.AUTOTUNE))

dataset

<PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>

Summary of text processing pipeline and generating feature and labels
1. Split the text into individual chars
2. Convert the individual chars into tokens
3. From the list of tokens generate a batch containing sequences of 101 tokens.
4. Split each batch into a feature vector and label
  - feature is the batch sequence excluding the last char (first 100 tokens)
  - label is the batch sequence excluding the first char (last 100 tokens)
5. Generate a new dataset containing batchs of 64 samples (feature vectors and labels)

## **Define the text generation model**

In [27]:
# define the model parameters
vocab_size = len(ids_from_chars.get_vocabulary())
print(vocab_size)

embedding_dim = 256
rnn_units = 1024

66


In [28]:
# define the model
class MyModel(tf.keras.Model):

  # define class constructor to initialise the layers with the parameters
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True,
                                   return_state=True)
    self.dense = tf.keras.layers.Dense(vocab_size)
  

  # What are the requirement for defining sub class models derived from the Model class
  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs

    # pass input to the embedding layer
    x = self.embedding(x, training=training)

    # set initial state 
    if states is None:
      states = self.gru.get_initial_state(x)
    
    # call the GRU layer with the embedding, initial state and training
    # what is training?
    x, states = self.gru(x, initial_state=states, training=training)

    # Call the dense layer with the output of the GRU layer
    x = self.dense(x, training=training)

    # dense output and state if specified.
    if return_state:
      return x, states
    else:
      return x
     

In [29]:
# define an instance of the model
model = MyModel(vocab_size, embedding_dim, rnn_units)

try the model on a single sample from the dataset

In [None]:
for batch_feature_vector, batch_label in dataset.take(1):
  example_batch_predictions = model(batch_feature_vector)
  print(batch_feature_vector.shape)
  print(example_batch_predictions.shape, "(batch size, sequence_length, vocab_size)\n\n")
  #print(example_batch_predictions)

There are 64 samples in a batch.

For each sample the model would predict a sequence with a length of 100. Recap that the label is a sequence of tokens and not an individual char. 
Also recap that return sequence is set to True for the model so it would produce an output for each token in the sequence.

Finally the dense layer has 66 neurons, so it would predict 66 diffrent values for each input.


Another explanation on the shape of the model output.   
As the it iterates through each token in the sequence of length 100, the model generates a prediction containing 66 values. This prediction is the probability distribution of the next probable char in the sequence.

So we end up with 100 x 66 probability distribution for each token in the sequenece.

In [None]:
for i in range(0, 100):
  #print(f"probability distribution for next char: {example_batch_predictions[0][i]}")
  print(f"predicted class for the next char: {np.argmax(example_batch_predictions[0][i])}\
   predicted char: {chars_from_ids(np.argmax(example_batch_predictions[0][i]))}")

In [43]:
# I'm assuming that the model outputs a probabilty distribution for each input seeing as the number of units in the dense layer is equal to the vocab size
predicted_token = [np.argmax(distribution) for distribution in example_batch_predictions[0].numpy()]
predicted_chars = [chars_from_ids(token).numpy().decode('utf-8') for token in predicted_token]

predicted_string = "".join(predicted_chars)
print(predicted_string)

 lulSSolhzhivD..;loloolj
hSl3;;X&.i,..Ussih  hSSo:l..llj;lj;Sol.cDo.ojSSo.rrPS;:XllMarDvD[UNK]Dmh-wrwlfN


The next probable word would be determined from the output distribution produced by the dense layer.

In [49]:
# Is the sum of the assumed distribution equal to 1
sum_of_assumed_distribution = [np.sum(assumed_distribution) for assumed_distribution in example_batch_predictions[0].numpy()]
print(sum_of_assumed_distribution)

[-0.036192082, -0.100264624, 0.019551916, 0.10647095, 0.13143021, 0.036574095, 0.116993666, 0.15830699, 0.030974502, 0.025985744, -0.054978117, -0.11925046, 0.015699547, 0.10477143, -0.018283477, -0.042313967, -0.08930658, 0.0482501, 0.13610755, 0.04593085, 0.122342065, 0.1647057, 0.023054179, -0.009273745, -0.09761684, -0.10142118, 0.01766396, -0.066134945, -0.00088376366, -0.14412998, -0.15900391, -0.0009611242, -0.06818372, -0.12073738, -0.1522885, -0.15655327, -0.14487398, -0.2590501, -0.21714349, -0.17405833, 0.006345008, 0.020538159, -0.0562195, -0.06491099, -0.22549288, -0.1968332, -0.046858937, -0.063140176, -0.03473142, -0.07979047, 0.05666147, -0.039546542, -0.0500682, 0.06933912, -0.032980267, -0.04307981, -0.09222884, 0.04706017, 0.016584598, -0.053318217, 0.0377983, 0.121271715, 0.1634708, 0.016889958, -0.03086875, 0.06405425, -0.06029231, -0.060481455, -0.03330439, -0.033128005, 0.04792918, -0.061681267, 0.08374953, 0.019326497, 0.05886274, 0.0750042, -0.05857191, -0.0263

it looks as if the model is not predicting a probability distribution for each input, so how is the next char predicted from the model's output??

In [32]:
model.summary()

Model: "my_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  16896     
                                                                 
 gru (GRU)                   multiple                  3938304   
                                                                 
 dense (Dense)               multiple                  67650     
                                                                 
Total params: 4,022,850
Trainable params: 4,022,850
Non-trainable params: 0
_________________________________________________________________


In [33]:
# Take the first sample in the batch of 64
print(example_batch_predictions[0])

tf.Tensor(
[[-1.76655850e-03 -2.36624526e-03  1.48365814e-02 ... -8.10714625e-03
   2.89362529e-03  1.59120071e-03]
 [-2.67120870e-03  7.79632665e-03  1.29185822e-02 ...  1.02399159e-02
  -8.53871275e-03 -9.98578034e-03]
 [ 4.10384359e-03 -1.14062941e-03  7.91937299e-03 ...  8.41901638e-05
   3.16182151e-04 -1.42250005e-02]
 ...
 [ 1.13328472e-02  8.69241729e-03  2.03810260e-03 ... -2.08790274e-03
  -5.54079190e-03 -3.73866968e-03]
 [ 9.16255545e-03  7.48849940e-03 -8.99519678e-03 ... -5.08212904e-03
  -1.10344915e-02  4.23197169e-04]
 [ 1.13810971e-02 -4.56483290e-03  8.98127630e-03 ...  4.93140658e-03
  -4.98686358e-03 -6.07391447e-03]], shape=(100, 66), dtype=float32)


In [None]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
print(sampled_indices)

Looking at the docs for [tf.random.categorical](https://www.tensorflow.org/api_docs/python/tf/random/categorical) it draws samples from a categorical distribution, so it seems to me that it just takes a sample from the *categorical distribution*.

In [35]:
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
print(sampled_indices)

[16 56 44 64 37 50  4  5  2 15  7 62 15 61 26 43 41 62 20 38  5 51  3 49
 43 39 14 55 59  3 26 13 42 57 51  0  9  2 47 43 41 27 59 47 11 59 14 40
 32 15 26 18 30 41 38 23 34 24 43  5 56 35 42 23 21  1 33 32 24 37 32 40
 22 58 23 51  7  4 33 11 56 29 49 10 23 19 13 18 38 49 12 55 65 42 43  2
 36 36 63 37]


In [36]:
# display the input and model prediction
print("Input:\n", text_from_ids(batch_feature_vector[0]).numpy())
print("Expected label: \n", text_from_ids(batch_label[0]).numpy())
print("Predicted label: \n", text_from_ids(sampled_indices).numpy())

Input:
 b"nt; and every one did threat\nTo-morrow's vengeance on the head of Richard.\n\nRATCLIFF:\nMy lord!\n\nKING"
Expected label: 
 b"t; and every one did threat\nTo-morrow's vengeance on the head of Richard.\n\nRATCLIFF:\nMy lord!\n\nKING "
Predicted label: 
 b'CqeyXk$& B,wBvMdbwGY&l!jdZApt!M?crl[UNK]. hdbNth:tAaSBMEQbYJUKd&qVcJH\nTSKXSaIsJl,$T:qPj3JF?EYj;pzcd WWxX'


I'm not sure why the tf.random.categorical is used here

In [67]:
# try out the random categorical
input = tf.math.log([[0.75, 0.75, 0.75]])
samples = tf.random.categorical(input, 100)

print(input)
print(samples)


tf.Tensor([[-0.28768206 -0.28768206 -0.28768206]], shape=(1, 3), dtype=float32)
tf.Tensor(
[[1 2 0 1 0 0 2 2 2 2 0 1 1 0 1 1 0 0 1 0 0 1 1 0 0 1 2 2 1 2 0 1 0 1 2 0
  2 0 1 1 1 1 2 2 0 0 0 2 0 0 0 1 1 1 1 1 1 2 1 1 2 2 0 2 2 2 0 2 0 2 0 2
  0 2 1 0 2 1 1 2 2 1 0 2 0 2 2 1 0 1 0 1 2 2 2 2 2 2 2 2]], shape=(1, 100), dtype=int64)


looking the output, the random.categorical function takes sampes from the categorical distribution and returns the class index. 

So above from the input with shape `[batch_size, num_class]` it would randomly take a sample from the input and return it's index.

In [69]:
from collections import Counter

samples_counts = Counter(samples[0].numpy())
print(f"Number of occurance of 0 in samples : {samples_counts[0]}")
print(f"Number of occurance of 1 in samples: {samples_counts[1]}")
print(f"Number of occurance of 2 in samples: {samples_counts[2]}")

Number of occurance of 0 in samples : 31
Number of occurance of 1 in samples: 32
Number of occurance of 2 in samples: 37


it also looks like the number of samples taken for each class is determined by the probability of that class in the input logits.

Recap that logits: *"unnormalized log-probabilities for all classes"*   

Just to summarise.
- Using random categorical, it would select n samples from a set of unnormalized log-probabilities. It doesn't select samples with the highest probabilities but rather the probability indicates the likelihood that the sample is selected.

In [44]:
print(example_batch_predictions[0])

tf.Tensor(
[[-1.76655850e-03 -2.36624526e-03  1.48365814e-02 ... -8.10714625e-03
   2.89362529e-03  1.59120071e-03]
 [-2.67120870e-03  7.79632665e-03  1.29185822e-02 ...  1.02399159e-02
  -8.53871275e-03 -9.98578034e-03]
 [ 4.10384359e-03 -1.14062941e-03  7.91937299e-03 ...  8.41901638e-05
   3.16182151e-04 -1.42250005e-02]
 ...
 [ 1.13328472e-02  8.69241729e-03  2.03810260e-03 ... -2.08790274e-03
  -5.54079190e-03 -3.73866968e-03]
 [ 9.16255545e-03  7.48849940e-03 -8.99519678e-03 ... -5.08212904e-03
  -1.10344915e-02  4.23197169e-04]
 [ 1.13810971e-02 -4.56483290e-03  8.98127630e-03 ...  4.93140658e-03
  -4.98686358e-03 -6.07391447e-03]], shape=(100, 66), dtype=float32)


In [45]:
print(batch_label[0])

tf.Tensor(
[59 12  2 40 53 43  2 44 61 44 57 64  2 54 53 44  2 43 48 43  2 59 47 57
 44 40 59  1 33 54  8 52 54 57 57 54 62  6 58  2 61 44 53 46 44 40 53 42
 44  2 54 53  2 59 47 44  2 47 44 40 43  2 54 45  2 31 48 42 47 40 57 43
  9  1  1 31 14 33 16 25 22 19 19 11  1 26 64  2 51 54 57 43  3  1  1 24
 22 27 20  2], shape=(100,), dtype=int64)


In [46]:
# define a loss function for the model
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

# get the loss for the first batch in the dataset
example_batch_mean_loss = loss(batch_label, example_batch_predictions)
print(f"Prediction shape: {example_batch_predictions.shape}")
print(f"Mean loss: {example_batch_mean_loss}")

Prediction shape: (64, 100, 66)
Mean loss: 4.187921524047852


In [47]:
tf.exp(example_batch_mean_loss).numpy()

65.88571

I'm not too sure about this, but 
- *The exponential of the mean loss should be approximately equal to the vocabulary size*. As the output logits from the dense layer should have similar magnitudes.

## **Train the model**

In [48]:
# define the loss and optimizer for the model
model.compile(optimizer='adam', loss=loss)


**Define model callbacks**

ModelCheckpoint
- Saves model/ weight at a defined frequency. (So it saves the model or it's weight at different point during training)

In [None]:
# define a directory to store checkoints of the model during training
checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

# define the model checkpoint
# Saves the model weight at the end of each epoch to the training_checkpoints dir
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                                         save_weights_only=True)


In [None]:
# Define the training epochs 
EPOCHS = 20

In [None]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

## **Generating Text**

In generating text, we would provide a seed character and then run inference to predict the next probable characther, running this multiple times would allow us to generate larger pieces of text.

In [None]:
class OneStep(tf.keras.Model):
  def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
    super().__init__()
    self.temperature = temperature
    self.model = model
    self.chars_from_ids = chars_from_ids
    self.ids_from_chars = ids_from_chars
  
    # Create a mask to prevent "[UNK]" from being generated.
    skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
    sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index
        values=[-float('inf')]*len(skip_ids),
        indices=skip_ids,
        # Match the shape of the vocabulary
        dense_shape=[len(ids_from_chars.get_vocabulary())])
    self.prediction_mask = tf.sparse.to_dense(sparse_mask)
  

  @tf.function
  def generate_one_step(self, inputs, states=None):
    # Convert strings to token IDs
    input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
    input_ids = self.ids_from_chars(input_chars).to_tensor()

    # Run the model.
    # predicted_logits.shape is [batch, char, next_char_logits]
    predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)

    # Only use the last prediction.
    predicted_logits = predicted_logits[:, -1, :]
    predicted_logits = predicted_logits/self.temperature
    # Apply the prediction mask: prevent "[UNK]" from being generated.
    predicted_logits = predicted_logits + self.prediction_mask

    # Sample the output logits to generate token IDs.
    predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
    predicted_ids = tf.squeeze(predicted_ids, axis=-1)

    # Convert from token ids to characters
    predicted_chars = self.chars_from_ids(predicted_ids)
    
    # Return the characters and model state.
    return predicted_chars, states

In [None]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)

In [None]:
start = time.time()
states = None
next_char = tf.constant(['Post MALONE:'])
result = [next_char]

for n in range(1000):
  next_char, states = one_step_model.generate_one_step(next_char, states=states)
  result.append(next_char)


result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n', '_'*80)
print('\nRun time:', end - start)

## **Save the text generator**

In [None]:
# Save the model
tf.saved_model.save(one_step_model, 'one_step')


In [None]:
# load the model
one_step_model_reloaded = tf.saved_model.load('one_step')

## **Advanced: Customized Training**

In [None]:
class CustomTraining(MyModel):
  @tf.function
  def train_step(self, inputs):
    inputs, labels = inputs
    with tf.GradientTape() as tape:
      predictions = self(inputs, training=True)
      loss = self.loss(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    self.optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return {'loss': loss}
  


In [None]:
model = CustomTraining(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units)

In [None]:
model.compile(optimizer = tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

In [None]:
model.fit(dataset, epochs=1)

In [None]:
EPOCHS = 10

mean = tf.metrics.Mean()

for epoch in range(EPOCHS):
  start = time.time()

  mean.reset_states()
  for (batch_n, (inp, target)) in enumerate(dataset):
    logs = model.train_staep([inp, target])
    mean.update_state(logs['loss'])

    if batch_n % 50 == 0:
      template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
      print(template)
    
    # saving (checkpoint) the model every 5 epochs
    if (epoch + 1) % 5 == 0:
      model.save-weights(checkpoint_prefix.format(epoch=epoch))
    

    print()
    print(f"Epoch {epoch+1}, loss: {mean.result().numpy():.4f}")
    print(f"Time taken for 1 epoch {time.time() - start:.2f} sec")
    print("_"*80)

## **Define and train a model**