In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import tensorflow as tf
from tensorflow import keras

shakespeare_url = "https://homl.info/shakespeare"

file_path = keras.utils.get_file('shakespeare.txt', shakespeare_url)

with open(file_path) as f:
    shakespeare_text = f.read()

print(shakespeare_text[:80])

2025-04-26 11:22:54.368557: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745666574.700099      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745666574.783278      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Downloading data from https://homl.info/shakespeare
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.


In [4]:
text_vectorization = keras.layers.TextVectorization(split='character', standardize='lower')
text_vectorization.adapt([shakespeare_text])
encoded = text_vectorization([shakespeare_text])[0]
encoded -= 2

n_tokens = text_vectorization.vocabulary_size() - 2 
print('Number of distinct tokens: {}'.format(n_tokens))
dataset_size = len(encoded) 
print('Number of total tokens: {}'.format(dataset_size))

Number of distinct tokens: 39
Number of total tokens: 1115394


In [5]:
def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=32):
    ds = tf.data.Dataset.from_tensor_slices(sequence)
    ds = ds.window(length + 1, shift=1, drop_remainder=True)
    ds = ds.flat_map(lambda window_ds: window_ds.batch(length + 1))
    if shuffle:
        ds = ds.shuffle(buffer_size=100_000, seed=seed)
    ds = ds.batch(batch_size)
    return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)

In [6]:
tf.random.set_seed(42)

train_ds = to_dataset(encoded[:1000000], 128, True, 42)
valid_ds = to_dataset(encoded[1000000:1060000], 128, True, 42)
test_ds = to_dataset(encoded[1060000:], 128, True, 42)

In [10]:
inputs = keras.layers.Input(shape=(None, ), dtype='int64')
x = keras.layers.Embedding(input_dim = n_tokens, output_dim = 16)(inputs)
x = keras.layers.GRU(128, return_sequences=True)(x)
outputs = keras.layers.Dense(n_tokens, activation="softmax")(x)
model = keras.Model(inputs, outputs)

model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam", metrics=["accuracy"])
model_ckpt = keras.callbacks.ModelCheckpoint("my_shakespeare_model.keras", monitor="val_accuracy", save_best_only=True)

history = model.fit(train_ds, validation_data=valid_ds, epochs=4, callbacks=[model_ckpt])

Epoch 1/4
  31243/Unknown [1m267s[0m 8ms/step - accuracy: 0.5546 - loss: 1.4748

  self.gen.throw(typ, value, traceback)


[1m31246/31246[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m283s[0m 8ms/step - accuracy: 0.5546 - loss: 1.4748 - val_accuracy: 0.5361 - val_loss: 1.5985
Epoch 2/4
[1m31246/31246[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m281s[0m 8ms/step - accuracy: 0.6023 - loss: 1.2798 - val_accuracy: 0.5434 - val_loss: 1.5650
Epoch 3/4
[1m31246/31246[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m282s[0m 8ms/step - accuracy: 0.6066 - loss: 1.2602 - val_accuracy: 0.5454 - val_loss: 1.5602
Epoch 4/4
[1m31246/31246[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m281s[0m 8ms/step - accuracy: 0.6089 - loss: 1.2498 - val_accuracy: 0.5476 - val_loss: 1.5618


In [23]:
encoded_predictions = text_vectorization(['To be or not to b'])[0]
encoded_predictions -=2

encoded_predictions = encoded_predictions[tf.newaxis, :]
print(encoded_predictions.shape)

predictions = model.predict([encoded_predictions])
print(predictions.shape)

y_pred = tf.argmax(predictions[0,-1])  # choose the most probable character ID
print('Next character is: {}'.format(text_vectorization.get_vocabulary()[y_pred + 2]))

(1, 17)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step
(1, 17, 39)
Next character is: e


In [28]:
def next_char(text, temperature=1):
    text_predictions = text_vectorization([text])[0]
    text_predictions -=2
    text_predictions = text_predictions[tf.newaxis, :]
    
    y_proba = model.predict([text_predictions])[0, -1:]
    rescaled_logits = tf.math.log(y_proba) / temperature
    char_id = tf.random.categorical(rescaled_logits, num_samples=1)[0, 0]
    return text_vectorization.get_vocabulary()[char_id + 2]

In [29]:
def extend_text(text, n_chars=50, temperature=1):
    for _ in range(n_chars):
        text += next_char(text, temperature)
    return text

In [30]:
print(extend_text("To be or not to be", temperature=0.01))

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18