In [1]:
import os
import glob
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
DATA_URL = "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt"
SEQ_LENGTH = 10          
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EMBEDDING_DIM = 128
RNN_UNITS = 256
EPOCHS = 5               
MODEL_TYPE = "bidir_lstm"  
CHECKPOINT_DIR = "./nw_checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [3]:
path_to_file = tf.keras.utils.get_file("dataset.txt", DATA_URL)
with open(path_to_file, "r", encoding="utf-8") as f:
    text = f.read()

print(f"Loaded text length: {len(text)} chars")

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt
[1m1115394/1115394[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Loaded text length: 1115394 chars


In [4]:
tokenizer = Tokenizer(oov_token="<OOV>")
tokenizer.fit_on_texts([text])         

token_sequence = tokenizer.texts_to_sequences([text])[0]
vocab_size = len(tokenizer.word_index) + 1  
print(f"Vocab size (including OOV): {vocab_size}")

Vocab size (including OOV): 12634


In [5]:
tokens_ds = tf.data.Dataset.from_tensor_slices(token_sequence)
windows = tokens_ds.batch(SEQ_LENGTH + 1, drop_remainder=True)

def split_input_target(chunk):
    input_seq = chunk[:-1]  
    target = chunk[-1]       
    return input_seq, target

dataset = windows.map(split_input_target, num_parallel_calls=tf.data.AUTOTUNE)

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

print("Dataset prepared. Sample batch shapes:")
for x, y in dataset.take(1):
    print("x:", x.shape, "y:", y.shape)

Dataset prepared. Sample batch shapes:
x: (64, 10) y: (64,)


In [6]:
def build_training_model(vocab_size, embedding_dim, rnn_units, seq_length, model_type="bidir_lstm"):
   
    inputs = tf.keras.Input(shape=(seq_length,), dtype="int32")
    x = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=seq_length)(inputs)

    if model_type == "lstm":
        x = tf.keras.layers.LSTM(rnn_units)(x)
    elif model_type == "gru":
        x = tf.keras.layers.GRU(rnn_units)(x)
    elif model_type == "bidir_lstm":
        x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(rnn_units))(x)
    elif model_type == "bidir_gru":
        x = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(rnn_units))(x)
    else:
        raise ValueError("Unsupported model_type")

    logits = tf.keras.layers.Dense(vocab_size)(x)  
    model = tf.keras.Model(inputs=inputs, outputs=logits)
    return model

model = build_training_model(vocab_size, EMBEDDING_DIM, RNN_UNITS, SEQ_LENGTH, model_type=MODEL_TYPE)
model.summary()



In [8]:
def loss_fn(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

model.compile(optimizer="adam", loss=loss_fn)

checkpoint_pattern = os.path.join(CHECKPOINT_DIR, "nw_ckpt_epoch_{epoch:02d}.weights.h5")
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_pattern,
    save_weights_only=True,
    save_freq="epoch"
)

In [9]:

history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_cb])

Epoch 1/5
[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 251ms/step - loss: 7.2551
Epoch 2/5
[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 253ms/step - loss: 6.5378
Epoch 3/5
[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 261ms/step - loss: 6.3044
Epoch 4/5
[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 256ms/step - loss: 6.0134
Epoch 5/5
[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 214ms/step - loss: 5.6354


In [10]:
index_word = {index: word for word, index in tokenizer.word_index.items()}
index_word[0] = ""  

import math
import numpy as np

def predict_next_words(model, tokenizer, seed_text, seq_length=SEQ_LENGTH, top_k=5, temperature=1.0):
    
    seq = tokenizer.texts_to_sequences([seed_text])[0]
    seq = seq[-seq_length:]
    input_seq = pad_sequences([seq], maxlen=seq_length, padding="pre")  
    logits = model.predict(input_seq, verbose=0) 
    logits = logits[0]  

    if temperature != 1.0:
        logits = logits / (temperature + 1e-8)

    probs = tf.nn.softmax(logits).numpy()

    top_k = min(top_k, len(probs))
    top_indices = np.argpartition(probs, -top_k)[-top_k:]
    top_indices = top_indices[np.argsort(-probs[top_indices])]  
    results = [(index_word.get(int(i), "<UNK>"), float(probs[i])) for i in top_indices]
    return results

In [11]:
def generate_continuation(model, tokenizer, seed_text, num_words=20, seq_length=SEQ_LENGTH,
                          temperature=1.0, sample=True):
   
    current = seed_text.strip()
    for _ in range(num_words):
        seq = tokenizer.texts_to_sequences([current])[0]
        seq = seq[-seq_length:]
        input_seq = pad_sequences([seq], maxlen=seq_length, padding="pre")
        logits = model.predict(input_seq, verbose=0)[0]
        logits = logits / (temperature + 1e-8)
        probs = tf.nn.softmax(logits).numpy()

        if sample:
            next_id = np.random.choice(len(probs), p=probs)
        else:
            next_id = int(np.argmax(probs))

        next_word = index_word.get(next_id, "<UNK>")
        if next_word == "":
            break
        current = current + " " + next_word
    return current

In [12]:
weights_files = glob.glob(os.path.join(CHECKPOINT_DIR, "*.weights.h5"))
weights_files.sort()
if len(weights_files) > 0:
    latest = weights_files[-1]
    print("Loading weights from:", latest)
    model.load_weights(latest)
else:
    print("No checkpoint weights found (you can skip loading if you just trained).")

Loading weights from: ./nw_checkpoints\nw_ckpt_epoch_05.weights.h5


In [13]:
seed = "to be or not"
print("\nTop 5 next word predictions (seed: '%s'):" % seed)
print(predict_next_words(model, tokenizer, seed, top_k=5, temperature=0.8))

print("\nGenerated continuation (sample=true):")
print(generate_continuation(model, tokenizer, seed, num_words=30, temperature=0.8, sample=True))

print("\nGenerated continuation (greedy):")
print(generate_continuation(model, tokenizer, seed, num_words=30, temperature=0.8, sample=False))


Top 5 next word predictions (seed: 'to be or not'):
[('which', 0.01709016226232052), ('nothing', 0.014554977416992188), ('day', 0.011925100348889828), ('my', 0.010046116076409817), ('her', 0.009882175363600254)]

Generated continuation (sample=true):
to be or not her day which you by escalus put be himself it you may be were my great friend king vincentio not show'd leicestershire autolycus the admired as hastings with ourselves and

Generated continuation (greedy):
to be or not which the day of the day of his head of his head of his head of his head of his head of his head of his head of his head
