# Generating text inspired from Blizzard's Warcraft Franchise

In [28]:
import numpy as np
import pandas as pd
import tensorflow as tf
import os
import re

text = open('WC3.txt', 'rb').read().decode(encoding='utf-8')
print(f"Length of text: {len(text)} characters")

#replace all non alphabetic characters with space
non_alphabetic = [',', '.', '!', '?', '-', ':', ';', '(', ')', '[', ']', '{', '}', '\n', '\r', '\t', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '\x19']
for i in non_alphabetic:
    text = text.replace(i, ' ')
    
def remove_nonprintable_characters(text):
    # Define the regular expression pattern to match non-printable characters
    pattern = r'[^\x00-\x7F]|\x1B|\x00-\x08|\x0B\x0C|\x0E-\x1F|\x7F|\x80-\xFF'

    # Use re.sub to replace the matched pattern with an empty string
    cleaned_text = re.sub(pattern, '', text)

    return cleaned_text

text = remove_nonprintable_characters(text)

text = text.lower()

Length of text: 747168 characters


In [29]:
# Inspect how many unique characters appear throughout the supplied script/text data
unique_chars = sorted(set(text))
print(f"{len(unique_chars)} unique characters")
print(unique_chars)

31 unique characters
[' ', '"', '&', "'", '/', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [30]:
# Test Characters to demonstrate ids_from_chars
test_chars = ['arthas', 'illidan']
test_chars = tf.strings.unicode_split(test_chars, input_encoding='UTF-8')

# Create a mapping from unique characters to indices
ids_from_chars = tf.keras.layers.StringLookup(vocabulary=list(unique_chars), mask_token=None)

# Check that the ids were assigned correctly
ids = ids_from_chars(test_chars)
ids


<tf.RaggedTensor [[6, 23, 25, 13, 6, 24], [14, 17, 17, 14, 9, 6, 19]]>

In [31]:
# Create a mapping from indices to characters
chars_from_ids = tf.keras.layers.StringLookup(vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)

# Test on encoded examples
test_chars = chars_from_ids(ids)
test_chars

<tf.RaggedTensor [[b'a', b'r', b't', b'h', b'a', b's'],
 [b'i', b'l', b'l', b'i', b'd', b'a', b'n']]>

In [32]:
# Function to convert ids back to human readable text
def text_from_ids(ids):
    return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)

In [33]:
#  Create a dataset of the encoded text
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids

<tf.Tensor: shape=(740775,), dtype=int64, numpy=array([14, 19, 25, ..., 12,  1,  1], dtype=int64)>

In [34]:
# Covert the text vector into a stream of character indices
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
ids_dataset

# Denote the sequence length for each input sequence
seq_length = 50
examples_per_epoch = len(text) // (seq_length + 1)

# Use batch method to convert the individual characters into sequences of the desired size
sequences = ids_dataset.batch(seq_length + 1, drop_remainder=True)

# Test the batch method
for seq in sequences.take(2):
    print(chars_from_ids(seq))
    
# Covert back to human readable text
for seqq in sequences.take(2):
    print(text_from_ids(seqq).numpy())


tf.Tensor(
[b'i' b'n' b't' b'r' b'o' b'd' b'u' b'c' b't' b'i' b'o' b'n' b' ' b'm'
 b'o' b'v' b'i' b'e' b' ' b' ' b' ' b' ' b'n' b'a' b'r' b'r' b'a' b't'
 b'o' b'r' b' ' b' ' b't' b'h' b'e' b' ' b's' b'a' b'n' b'd' b's' b' '
 b'o' b'f' b' ' b't' b'i' b'm' b'e' b' ' b'h'], shape=(51,), dtype=string)
tf.Tensor(
[b'a' b'v' b'e' b' ' b'r' b'u' b'n' b' ' b'o' b'u' b't' b' ' b' ' b's'
 b'o' b'n' b' ' b'o' b'f' b' ' b'd' b'u' b'r' b'o' b't' b'a' b'n' b' '
 b' ' b' ' b'c' b'r' b'i' b'e' b's' b' ' b'o' b'f' b' ' b'w' b'a' b'r'
 b' ' b' ' b'e' b'c' b'h' b'o' b' ' b' ' b' '], shape=(51,), dtype=string)
b'introduction movie    narrator  the sands of time h'
b'ave run out  son of durotan   cries of war  echo   '


In [35]:
# Split the sequences into input and target offsetting by one character
def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

# Test the split_input_target function
split_input_target(list("Arthas my boy"))

(['A', 'r', 't', 'h', 'a', 's', ' ', 'm', 'y', ' ', 'b', 'o'],
 ['r', 't', 'h', 'a', 's', ' ', 'm', 'y', ' ', 'b', 'o', 'y'])

In [36]:
# Preprocess the text data to generate input and target text strings
dataset = sequences.map(split_input_target)

# Text the dataset
for x,y in dataset.take(1):
    print("Input: ", text_from_ids(x).numpy())
    print("Target: ", text_from_ids(y).numpy())

Input:  b'introduction movie    narrator  the sands of time '
Target:  b'ntroduction movie    narrator  the sands of time h'


In [37]:
# Split the data into managable sequences, assigning batch size and shuffling the data.
BATCH_SIZE = 64
BUFFER_SIZE = 100000

dataset = (
    dataset.shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

dataset

<PrefetchDataset element_spec=(TensorSpec(shape=(64, 50), dtype=tf.int64, name=None), TensorSpec(shape=(64, 50), dtype=tf.int64, name=None))>

In [38]:
# Assign initial parameters
# Length of the vocabulary in chars
vocab_size = len(unique_chars)
print(vocab_size)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 32

31


In [39]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values, mask=None):
        # query hidden state shape == (batch_size, hidden_size)
        # query_with_time_axis shape == (batch_size, 1, hidden_size)
        # values shape == (batch_size, max_len, hidden_size)
        # we are doing this to broadcast addition along the time axis to calculate the score
        print("Query shape:", query.shape)
        print("Values shape:", values.shape)
        query_with_time_axis = tf.expand_dims(query, axis=1)

        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))

        # Attention mask to handle variable-length sequences
        if mask is not None:
            score = score + (1.0 - mask) * -1e9  # Add a large negative value to masked positions

        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = tf.reduce_sum(attention_weights * values, axis=1)
        print(context_vector)
        return context_vector, attention_weights

In [40]:
# Build the model using a tf.keras.Model class
class MyModel(tf.keras.Model):
    def __init__(self, vocab_size,embedding_dim, rnn_units):
        super(MyModel,self).__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRUCell(rnn_units)
        # Add attention layer
        self.attention = BahdanauAttention(rnn_units)
        self.dense = tf.keras.layers.Dense(vocab_size)
        
    def call(self,inputs,states=None, return_state=False, training=False, mask=None):
        x = self.embedding(inputs, training=training)
        # If no previous state, initialise the state
        gru_layer = tf.keras.layers.RNN(self.gru, return_sequences=True, return_state=True)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = gru_layer(x, initial_state=states, training=training)
        print (x)
        query = states[0]
        query = tf.tile(tf.expand_dims(states[0], axis=0), [tf.shape(x)[0], 1])
        
        # Update to include the output of the attention layer
        context_vector, attention_weights = self.attention(query, x, mask=mask)
        context_vector_expanded = tf.expand_dims(context_vector, axis=1)
        context_vector_tiled = tf.tile(context_vector_expanded, [1, tf.shape(x)[1], 1])
        x = tf.concat([context_vector_tiled, x], axis=-1)
        
        x = self.dense(x,training=training)
        
        
        if return_state:
            return x, states
        else:
            return x

In [41]:
model = MyModel(
    # Assure the vocabulary size matches the StringLookup layers
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
)

# model.build(input_shape=(BATCH_SIZE, seq_length))
# model.summary()

In [42]:
for input_example_batch, target_example_batch in dataset.take(1):
    print("Input shape:", input_example_batch.shape)
    print(input_example_batch)
    print("Target shape:", target_example_batch.shape)
    print(target_example_batch)
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

Input shape: (64, 50)
tf.Tensor(
[[25 13 10 ... 20 20  9]
 [18 10  1 ... 10  6  9]
 [29 21 10 ... 20 26  1]
 ...
 [ 9  1 24 ... 10  1 20]
 [17 14 11 ... 10  6 23]
 [ 1 23 14 ...  9  1 26]], shape=(64, 50), dtype=int64)
Target shape: (64, 50)
tf.Tensor(
[[13 10 30 ... 20  9 30]
 [10  1 24 ...  6  9  1]
 [21 10  8 ... 26  1  8]
 ...
 [ 1 24 25 ...  1 20 11]
 [14 11 25 ...  6 23  1]
 [23 14 27 ...  1 26 19]], shape=(64, 50), dtype=int64)
tf.Tensor(
[[[-0.00689923  0.01031401  0.00427329 ...  0.00740257 -0.00573162
    0.02067883]
  [ 0.00332355 -0.006019    0.02316161 ...  0.00618907  0.01129002
    0.03404337]
  [-0.0136784  -0.02238584  0.00464262 ...  0.02635776 -0.02813916
    0.01868805]
  ...
  [-0.01756865  0.00298588 -0.00360701 ...  0.00843727  0.00567691
    0.00899155]
  [-0.02133823  0.00315036  0.00450381 ...  0.01464416 -0.00212759
    0.02327525]
  [-0.00779084  0.00125413  0.02001133 ... -0.01762514  0.0154162
    0.00989593]]

 [[-0.03677533  0.00213846  0.00037428 ... -0

In [43]:
sampled_indices = tf.random.categorical(
    example_batch_predictions[0], num_samples=1
)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

sampled_indices

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

Input:
 b'they would be wrong  i may take part in your blood'

Next Char Predictions:
 b'qomagjubzsxe&fwxmvys&s/nuu&ujajlmy[UNK]yigvjgnbth&wiyc'


In [44]:
# Model Summary to check the model architecture
model.build(input_shape=(BATCH_SIZE,seq_length))
model.summary()

Tensor("rnn/transpose_1:0", shape=(64, 50, 32), dtype=float32)
Query shape: (64, 32)
Values shape: (64, 50, 32)
Tensor("bahdanau_attention_1/Sum:0", shape=(64, 32), dtype=float32)
Model: "my_model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_1 (Embedding)     multiple                  8192      
                                                                 
 gru_cell_1 (GRUCell)        multiple                  27840     
                                                                 
 bahdanau_attention_1 (Bahda  multiple                 2145      
 nauAttention)                                                   
                                                                 
 dense_7 (Dense)             multiple                  2080      
                                                                 
Total params: 40,257
Trainable params: 40,257
Non-trainable params: 0
____

In [45]:
# Assign a loss function to the model
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [46]:
# # Compile the model with the loss function
# optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# model.compile(optimizer=optimizer, loss=loss)

In [47]:
# Create a directory to save the model checkpoints
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

# Only save every 10th epoch
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True,
    period=10
)

# Define the early stopping callback
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='loss',
    patience=3,
    restore_best_weights=True,
    min_delta = 0.001
)




In [48]:
# Set the epochs
EPOCHS = 15

# Define the loss function
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Define the optimizer
optimizer = tf.keras.optimizers.Adam()

# Define metrics to track during training
train_loss = tf.keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        # Forward pass
        predictions= model(inputs, training=True)
        # Compute loss
        loss = loss_object(targets, predictions)
    
    # Compute gradients
    gradients = tape.gradient(loss, model.trainable_variables)
    # Update weights
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    # Update metrics
    train_loss(loss)

# Training loop
for epoch in range(EPOCHS):
    # Reset metrics at the beginning of each epoch
    train_loss.reset_states()
    
    # Iterate over the dataset
    for inputs, targets in dataset:
        print(inputs)
        # Perform a training step
        train_step(inputs, targets)
    
    # Print the training loss for each epoch
    print(f'Epoch {epoch+1}, Loss: {train_loss.result()}')
    
    # Save the model checkpoints
    if (epoch + 1) % 10 == 0:
        model.save_weights(checkpoint_prefix.format(epoch=epoch+1))
    

tf.Tensor(
[[ 9  1 25 ...  9  1 24]
 [12 24  1 ...  8 10  1]
 [11  1 25 ... 11 23 20]
 ...
 [13 10  1 ...  1 21 17]
 [14 19  1 ...  1 13 26]
 [ 6 18  7 ...  1  1  1]], shape=(64, 50), dtype=int64)
Tensor("my_model_1/rnn/transpose_1:0", shape=(64, 50, 32), dtype=float32)
Query shape: (64, 32)
Values shape: (64, 50, 32)
Tensor("my_model_1/bahdanau_attention_1/Sum:0", shape=(64, 32), dtype=float32)
Tensor("my_model_1/rnn/transpose_1:0", shape=(64, 50, 32), dtype=float32)
Query shape: (64, 32)
Values shape: (64, 50, 32)
Tensor("my_model_1/bahdanau_attention_1/Sum:0", shape=(64, 32), dtype=float32)
tf.Tensor(
[[ 7 10 13 ... 14 18  1]
 [ 1 20 11 ... 23  8 13]
 [ 1  1 25 ... 26 18 18]
 ...
 [10  1 23 ... 20 26 23]
 [ 1 17 20 ... 19  1 24]
 [ 1 30 20 ...  1 20 11]], shape=(64, 50), dtype=int64)
tf.Tensor(
[[ 1 26 25 ... 19 24 26]
 [26  1 24 ... 30 20 26]
 [24 25 20 ... 10 23 24]
 ...
 [ 6 23 10 ... 23 30  1]
 [ 1  1  1 ... 10 30  4]
 [14 25 13 ... 25 13 10]], shape=(64, 50), dtype=int64)
tf.Te

In [49]:
for input_example_batch, target_example_batch in dataset.take(1):
    print("Input shape:", input_example_batch.shape)
    print(input_example_batch)
    print("Target shape:", target_example_batch.shape)
    print(target_example_batch)
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
    
sampled_indices = tf.random.categorical(
    example_batch_predictions[0], num_samples=1
)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()

sampled_indices

print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())

Input shape: (64, 50)
tf.Tensor(
[[26  8 13 ... 14 18  1]
 [13  6 24 ...  6 25  1]
 [23 14 19 ... 10  6 23]
 ...
 [14 25 10 ...  1 17 10]
 [23 24  1 ...  6  7 24]
 [13  6 24 ... 10 20 21]], shape=(64, 50), dtype=int64)
Target shape: (64, 50)
tf.Tensor(
[[ 8 13 10 ... 18  1 20]
 [ 6 24 24 ... 25  1 25]
 [14 19 12 ...  6 23  1]
 ...
 [25 10 18 ... 17 10 11]
 [24  1 25 ...  7 24 20]
 [ 6 24  1 ... 20 21 17]], shape=(64, 50), dtype=int64)
tf.Tensor(
[[[ 0.26245147  0.9491979  -0.9966795  ... -0.3863609  -0.94914407
   -0.48145986]
  [ 0.36254     0.9682671  -0.26519695 ... -0.21509308  0.8287088
    0.6968209 ]
  [ 0.9082583   0.97162676  0.1909474  ... -0.8828169   0.9174859
    0.76811147]
  ...
  [-0.66190404  0.5043178  -0.9980331  ... -0.47394598 -0.9766598
    0.4960285 ]
  [-0.5701056   0.85145617 -0.84447265 ... -0.29164857 -0.9792345
    0.75737584]
  [-0.9801542  -0.43469983 -0.99961305 ... -0.9881368   0.9361885
   -0.99962085]]

 [[ 0.9796486   0.05897833  0.01573379 ... -0.990

In [50]:
class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0, attention=None):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.chars_from_ids = chars_from_ids
        self.ids_from_chars = ids_from_chars
        self.attention = attention

        # Create a mask to prevent "[UNK]" from being generated.
        skip_ids = self.ids_from_chars(["[UNK]"])[:, None]
        sparse_mask = tf.SparseTensor(
            values=[-float("inf")] * len(skip_ids),
            indices=skip_ids,
            dense_shape=[len(ids_from_chars.get_vocabulary())]
        )
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)

    def generate_one_step(self, inputs, states=None):
        input_chars = tf.strings.unicode_split(inputs, "UTF-8")
        input_ids = self.ids_from_chars(input_chars).to_tensor()

        if states is None:
            batch_size = tf.shape(input_ids)[0]
            initial_state = self.model.gru.get_initial_state(batch_size=batch_size, dtype=tf.float32)
            states = [tf.expand_dims(state, axis=0) for state in initial_state]

        predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)
        predicted_logits = predicted_logits + self.prediction_mask

        # Reshape the logits to a 2D matrix for sampling
        logits_shape = tf.shape(predicted_logits)
        reshaped_logits = tf.reshape(predicted_logits, [-1, logits_shape[-1]])

        # Sample the output logits to generate token IDs.
        predicted_ids = tf.random.categorical(reshaped_logits / self.temperature, num_samples=1)
        predicted_ids = tf.squeeze(predicted_ids, axis=-1)

        # Convert from token ids to characters
        predicted_chars = self.chars_from_ids(predicted_ids)

        # Convert the predicted_chars tensor to a string
        generated_text = tf.strings.reduce_join(predicted_chars, axis=-1).numpy().decode("utf-8")
        return generated_text, states

In [51]:
# Initialize the simplified OneStep model
one_step_model = OneStep(model, chars_from_ids, ids_from_chars, attention=model.attention)

# Test function to generate 100 characters on top of an input sequence
def generate_text(input_sequence, num_extra_characters=2):
    input_sequence = tf.constant([input_sequence])
    states = None

    # Generate one character at a time and append it to the input sequence
    for _ in range(num_extra_characters):
        next_char, states = one_step_model.generate_one_step(input_sequence, states=states)
        input_sequence = tf.strings.join([input_sequence, next_char])

    generated_text = input_sequence.numpy()[0].decode("utf-8")
    return generated_text

# Test with different input sequences
input_sequence1 = "the quick brown"


generated_text1 = generate_text(input_sequence1)

print("Input Sequence 1:", input_sequence1)
print("Generated Text 1:", generated_text1)


tf.Tensor(
[[[ 1.99231848e-01  9.17238593e-01 -8.83358300e-01 -3.84582616e-02
    3.19592297e-01  9.95657265e-01  7.43118823e-01 -9.85413492e-01
   -9.81087506e-01  1.21921003e-01  1.44248575e-01 -1.38114557e-01
   -9.90934968e-01  9.45132375e-01 -6.13461912e-01  9.87355351e-01
    9.27164108e-02 -1.43218011e-01 -1.72380842e-02 -7.16544747e-01
    9.97774422e-01 -7.16303229e-01 -9.64756191e-01 -7.75756082e-03
   -4.85464297e-02 -9.83053327e-01 -4.66056019e-01 -1.62890762e-01
    8.38764906e-01  6.46624267e-01  9.46289003e-01  9.27484035e-01]
  [ 9.69275296e-01  9.24614072e-01  3.61047864e-01 -9.18547630e-01
   -4.57058065e-02  9.80622411e-01  9.82977092e-01  4.28869814e-01
    9.16078031e-01  3.74205500e-01 -9.57497180e-01  9.99643505e-01
   -9.88617897e-01  9.95509863e-01 -7.42879152e-01  6.48110271e-01
    9.33019042e-01 -9.67133284e-01 -1.28135800e-01 -1.62508041e-01
   -4.32070106e-01 -7.89727986e-01 -9.91225064e-01  8.44996691e-01
    1.28087392e-02 -8.54940057e-01  5.77731848e-01

In [52]:
class OneStep(tf.keras.Model):
    def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0, attention=None):
        super().__init__()
        self.temperature = temperature
        self.model = model
        self.chars_from_ids = chars_from_ids
        self.ids_from_chars = ids_from_chars
        self.attention = attention

        # Create a mask to prevent "[UNK]" from being generated.
        skip_ids = self.ids_from_chars(["[UNK]"])[:, None]
        sparse_mask = tf.SparseTensor(
        # Put a -inf at each bad index.
            values=[-float("inf")] * len(skip_ids),
            indices=skip_ids,
            # Match the shape to the vocabulary
            dense_shape = [len(ids_from_chars.get_vocabulary())]
        )
        self.prediction_mask = tf.sparse.to_dense(sparse_mask)
    
    def generate_one_step(self, inputs, states=None, dtype=tf.float32, batch_size=None):
            # Convert strings to token IDs.
            input_chars = tf.strings.unicode_split(inputs, "UTF-8")
            input_ids = self.ids_from_chars(input_chars).to_tensor()
            print(input_ids)
            # Use the states and predicted_logits returned from the model call
            if states is None:
            # Create a custom function to get the initial state with the desired batch size
                if batch_size is None:
                    batch_size = tf.shape(inputs)[0]  # Use the batch size of the input if not provided
            initial_state = self.model.gru.get_initial_state(batch_size=batch_size, dtype=dtype)
            states = [tf.expand_dims(state, axis=0) for state in initial_state]
        

            # Run the model and get the predicted logits and new states.
            predicted_logits, states = self.model(inputs=input_ids, states=states, return_state=True)

            # Apply the prediction mask: prevent "[UNK]" from being generated.
            predicted_logits = predicted_logits + self.prediction_mask

            # Compute context_vector using the attention mechanism
            query = states[-1]
            query = tf.tile(tf.expand_dims(query, axis=0), [tf.shape(predicted_logits)[0], 1])
            context_vector, _ = self.attention(query, predicted_logits)  # Pass 'predicted_logits' instead of 'x'

            # Expand dims to have shape (batch_size, 1, rnn_units) for concatenation
            context_vector_expanded = tf.expand_dims(context_vector, axis=1)
            
            # Tile the context_vector_expanded to match the sequence length of predicted_logits
            context_vector_tiled = tf.tile(context_vector_expanded, [1, tf.shape(predicted_logits)[1], 1])
            combined_logits = tf.concat([context_vector_tiled, predicted_logits], axis=-1)

            # Reduce the last axis using the specified temperature
            combined_logits = combined_logits / self.temperature

            # Sample the output logits to generate token IDs.
            batch_size, sequence_length, vocab_size = tf.shape(combined_logits)[0], tf.shape(combined_logits)[1], tf.shape(combined_logits)[2]
            reshaped_logits = tf.reshape(combined_logits, [batch_size * sequence_length, vocab_size])
            predicted_ids = tf.random.categorical(reshaped_logits, num_samples=1)
            predicted_ids = tf.squeeze(predicted_ids, axis=-1)


            # Convert from token ids to characters
            predicted_chars = self.chars_from_ids(predicted_ids)
            
            # Return the characters, model state, and attention weights.
            return predicted_chars, states



In [53]:
one_step_model = OneStep(model, chars_from_ids, ids_from_chars, temperature=1.0, attention=model.attention)

In [54]:
import time
# Generate text using a constant prompt
start = time.time()
states = None
next_char = tf.constant(["a"])
result = [next_char]
attention_weights_list = []

for n in range(1000):
    next_char, states = one_step_model.generate_one_step(
        next_char, states=states, dtype=tf.float32, batch_size=tf.shape(next_char)[0]
    )
    result.append(next_char)
    
    
print(result)
result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode("utf-8"), "\n\n" + "_" * 80)
print("\nRun time:", end - start)


tf.Tensor([[6]], shape=(1, 1), dtype=int64)
tf.Tensor(
[[[-0.00118333  0.06714608  0.9966999  -0.03799908 -0.19284694
   -0.8888666  -0.9368937  -0.9226308  -0.9898257   0.26582134
   -0.24350671  0.0316296   0.20497872 -0.9760124   0.28674403
   -0.83089626 -0.988633    0.9355321   0.96661365 -0.7568737
   -0.01206933  0.3679451   0.09348057 -0.99872285  0.8425257
    0.9971923  -0.91225225  0.5447845   0.9793912  -0.41879404
   -0.44880858 -0.99861175]]], shape=(1, 1, 32), dtype=float32)
Query shape: (1, 32)
Values shape: (1, 1, 32)
tf.Tensor(
[[-0.00118333  0.06714608  0.9966999  -0.03799908 -0.19284694 -0.8888666
  -0.9368937  -0.9226308  -0.9898257   0.26582134 -0.24350671  0.0316296
   0.20497872 -0.9760124   0.28674403 -0.83089626 -0.988633    0.9355321
   0.96661365 -0.7568737  -0.01206933  0.3679451   0.09348057 -0.99872285
   0.8425257   0.9971923  -0.91225225  0.5447845   0.9793912  -0.41879404
  -0.44880858 -0.99861175]], shape=(1, 32), dtype=float32)
Query shape: (1, 32)
V

ValueError: Exception encountered when calling layer "my_model_1" "                 f"(type MyModel).

An `initial_state` was passed that is not compatible with `cell.state_size`. Received `state_spec`=ListWrapper([InputSpec(shape=(1, 32), ndim=2), InputSpec(shape=(1, 32), ndim=2), InputSpec(shape=(1, 32), ndim=2), InputSpec(shape=(1, 32), ndim=2), InputSpec(shape=(1, 32), ndim=2)]); however `cell.state_size` is [32]

Call arguments received by layer "my_model_1" "                 f"(type MyModel):
  • inputs=tf.Tensor(shape=(5, 5), dtype=int64)
  • states=['tf.Tensor(shape=(1, 32), dtype=float32)', 'tf.Tensor(shape=(1, 32), dtype=float32)', 'tf.Tensor(shape=(1, 32), dtype=float32)', 'tf.Tensor(shape=(1, 32), dtype=float32)', 'tf.Tensor(shape=(1, 32), dtype=float32)']
  • return_state=True
  • training=False
  • mask=None