In [1]:
import numpy as np
import json
import re
import string

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, losses

## 0. Parameters <a name="parameters"></a>

In [2]:
VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 20

## 1. Load the data <a name="load"></a>

In [3]:
import requests

# List of URLs for additional texts
urls = [
    "https://www.gutenberg.org/files/1041/1041-0.txt",        # Hamlet
    "https://www.gutenberg.org/cache/epub/1517/pg1517.txt",   # Merry Wives
    "https://www.gutenberg.org/cache/epub/1533/pg1533.txt"    # Macbeth
]

# Initialize an empty string to hold all text
all_text = ""

# Download each text file and append to all_text
for url in urls:
    response = requests.get(url)
    if response.status_code == 200:
        text = response.text
        all_text += text + "\n\n"  # Separate texts by newlines
    else:
        print(f"Failed to retrieve {url} with status code {response.status_code}")

# Save combined text to a single file
with open("combined_shakespeare.txt", "w", encoding="utf-8") as file:
    file.write(all_text)


# Count the occurrences of "ACT" and "SCENE" as rough indicators.
n_acts = all_text.lower().count("act")
n_scenes = all_text.lower().count("scene")

print(f"Number of Acts: {n_acts}")
print(f"Number of Scenes: {n_scenes}")

Number of Acts: 82
Number of Scenes: 105


In [4]:
# Count the piece from Shakespear it loads
n_response = len(all_text)
print (f"{n_response} Stories loaded")

382373 Stories loaded


## 2. Tokenise the data

In [5]:
# Pad the punctuation, to treat them as separate 'words'
def pad_punctuation(s):
    s = re.sub(f"([{string.punctuation}])", r" \1 ", s)
    s = re.sub(" +", " ", s)
    return s

text_data = [pad_punctuation(x) for x in all_text]

In [6]:
# Display an example by Shakespearain
example_data = text_data[10]
example_data

'O'

In [7]:
# Convert to a Tensorflow Dataset
text_ds = (
    tf.data.Dataset.from_tensor_slices(text_data)
    .batch(BATCH_SIZE)
    .shuffle(1000)
)

In [8]:
# Create a vectorisation layer
vectorize_layer = layers.TextVectorization(
    standardize="lower",
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=MAX_LEN + 1,
)

In [9]:
# Adapt the layer to the training set
vectorize_layer.adapt(text_ds)
vocab = vectorize_layer.get_vocabulary()

In [10]:
# Display some token:word mappings
for i, word in enumerate(vocab[:10]):
    print(f"{i}: {word}")

0: 
1: [UNK]
2: e
3: t
4: o
5: a
6: s
7: i
8: n
9: h


In [11]:
# Display the same example converted to ints
example_data_tensor = tf.constant([example_data])
example_tokenised = vectorize_layer(example_data)
print(example_tokenised.numpy())

[4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


## 3. Create the Training Set

In [12]:
# Create the training set of recipes and the same text shifted by one word
def prepare_inputs(text):
    text = tf.expand_dims(text, -1)
    tokenized_sentences = vectorize_layer(text)
    x = tokenized_sentences[:, :-1]
    y = tokenized_sentences[:, 1:]
    return x, y


train_ds = text_ds.map(prepare_inputs)

## 4. Build the LSTM <a name="build"></a>

In [13]:
inputs = layers.Input(shape=(None,), dtype="int32")
x = layers.Embedding(VOCAB_SIZE, EMBEDDING_DIM)(inputs)
x = layers.LSTM(N_UNITS, return_sequences=True)(x)
outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(x)
lstm = models.Model(inputs, outputs)
lstm.summary()

## 5. Train the LSTM <a name="train"></a>

In [14]:
loss_fn = losses.SparseCategoricalCrossentropy()
lstm.compile("adam", loss_fn)

In [35]:
# Create a TextGenerator checkpoint
class TextGenerator(callbacks.Callback):
    def __init__(self, index_to_word, top_k=10):
        self.index_to_word = index_to_word
        self.word_to_index = {
            word: index for index, word in enumerate(index_to_word)
        }  # <1>

    def sample_from(self, probs, temperature):  # <2>
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs), probs

    def generate(self, start_prompt, max_tokens, temperature):
        start_tokens = [
            self.word_to_index.get(x, 1) for x in start_prompt.split()
        ]  # <3>
        sample_token = None
        info = []
        while len(start_tokens) < max_tokens and sample_token != 0:  # <4>
            x = np.array([start_tokens])
            y = self.model.predict(x, verbose=0)  # <5>
            sample_token, probs = self.sample_from(y[0][-1], temperature)  # <6>
            info.append({"prompt": start_prompt, "word_probs": probs})
            start_tokens.append(sample_token)  # <7>
            start_prompt = start_prompt + " " + self.index_to_word[sample_token]
        print(f"\ngenerated text:\n{start_prompt}\n")
        return info

    def on_epoch_end(self, epoch, logs=None):
        self.generate("To be or not to be", max_tokens=100, temperature=1.0)

In [39]:
# Tokenize starting prompt
text_generator = TextGenerator(vocab)


In [17]:
lstm.fit(
    train_ds,
    epochs=EPOCHS,
    callbacks=[text_generator],
)

Epoch 1/20
[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - loss: 0.1105
generated text:
To be or not to be 

[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m505s[0m 42ms/step - loss: 0.1105
Epoch 2/20
[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - loss: 8.6366e-07
generated text:
To be or not to be 

[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m501s[0m 42ms/step - loss: 8.6364e-07
Epoch 3/20
[1m11949/11950[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 42ms/step - loss: 9.0060e-09
generated text:
To be or not to be 

[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m498s[0m 42ms/step - loss: 9.0056e-09
Epoch 4/20
[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - loss: 5.4751e-11
generated text:
To be or not to be 

[1m11950/11950[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m499s[0m 42ms/step - loss: 5.4752e-11
Epoch 5/20
[1m1

<keras.src.callbacks.history.History at 0x7a665a72fca0>

## 6. Generate text using the LSTM

In [23]:
def print_probs(info, vocab, top_k=5):
    for i in info:
        print(f"\nPROMPT: {i['prompt']}")
        word_probs = i["word_probs"]

        # Get sorted probabilities and indices within the vocabulary size
        p_sorted = np.sort(word_probs)[::-1]
        i_sorted = np.argsort(word_probs)[::-1]

        # Filter indices to be within the vocabulary size
        valid_indices = [idx for idx in i_sorted if 0 <= idx < len(vocab)]

        # Select top_k valid indices and probabilities
        p_sorted = [p_sorted[i_sorted.tolist().index(idx)] for idx in valid_indices[:top_k]]
        i_sorted = valid_indices[:top_k]

        for p, i in zip(p_sorted, i_sorted):
            print(f"{vocab[i]}:   \t{np.round(100*p,2)}%")
        print("--------\n")

In [36]:
info = text_generator.generate(
    "To be, or not to | ", max_tokens=10, temperature = 0.1
)
print_probs(info, vocab)


generated text:
To be, or not to |  


PROMPT: To be, or not to | 
:   	100.0%
à:   	0.0%
ê:   	0.0%
î:   	0.0%
#:   	0.0%
--------



In [37]:
info = text_generator.generate(
    "O powerful | ", max_tokens=10, temperature=0.5
)
print_probs(info, vocab)


generated text:
O powerful |  


PROMPT: O powerful | 
:   	100.0%
]:   	0.0%
l:   	0.0%
b:   	0.0%
î:   	0.0%
--------



In [38]:
info = text_generator.generate(
    "Fair is foul, and foul is | ", max_tokens=10, temperature=1.0
)
print_probs(info, vocab)


generated text:
Fair is foul, and foul is |  


PROMPT: Fair is foul, and foul is | 
:   	100.0%
]:   	0.0%
_:   	0.0%
l:   	0.0%
ê:   	0.0%
--------

