<a href="https://colab.research.google.com/github/WSLINMSAI/MSAI-531-B01/blob/main/Assignment_11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# Colab Ready Alice Text Generator - Character Level
# - Non stateful GRU for training with shuffled sequences
# - Separate stateful sampler for generation
# - Safe ASCII charset cleaning to avoid odd symbols
# - Last timestep logits with temperature sampling
# - Keras 3 compatible checkpoints ending in .weights.h5
# - NEW: sample text every N epochs with a fixed seed and length

import os, io, string, numpy as np, tensorflow as tf

print("TensorFlow version:", tf.__version__)
try:
    devs = tf.config.list_physical_devices("GPU")
    print("GPU available:", bool(devs), devs)
except Exception as e:
    print("GPU check error:", e)

# -----------------------------
# Config
# -----------------------------
EPOCHS = 50
SEQ_LENGTH = 100
BATCH_SIZE = 64
BUFFER_SIZE = 10000
EMBED_DIM = 256
RNN_UNITS = 512
LEARNING_RATE = 1e-3
CLIPNORM = 1.0

TEMPERATURE = 0.8
GEN_LEN = 1000              # final sample length after training
SEED_TEXT = "Alice "

SAMPLE_EVERY = 10           # generate during training every N epochs
SAMPLE_LEN = 400            # sample length for the periodic samples
SAMPLE_SEED = "Alice "      # standard prefix used each time

CHECKPOINT_DIR = "./checkpoints_alice_clean"
LOCAL_TEXT_PATH = "/content/alice_in_wonderland.txt"  # upload a file here to override

# -----------------------------
# Text loading and cleaning
# -----------------------------
def load_text():
    # Use local file if present, else download from Gutenberg
    if os.path.exists(LOCAL_TEXT_PATH):
        with io.open(LOCAL_TEXT_PATH, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()
    url = "https://www.gutenberg.org/files/11/11-0.txt"
    path = tf.keras.utils.get_file("alice_in_wonderland.txt", origin=url)
    with io.open(path, "r", encoding="utf-8", errors="ignore") as f:
        return f.read()

def clean_to_ascii(text):
    # Keep ASCII letters, digits, punctuation, space, newline, tab
    allowed = set(string.ascii_letters + string.digits + string.punctuation + " \n\t")
    cleaned = "".join(c if c in allowed else " " for c in text)
    # Collapse whitespace to avoid long runs
    cleaned = " ".join(cleaned.split())
    # Add simple newlines after periods for readability
    cleaned = cleaned.replace(". ", ".\n")
    return cleaned

# -----------------------------
# Vectorization
# -----------------------------
def build_vocab(text):
    vocab = sorted(list(set(text)))
    char2idx = {u: i for i, u in enumerate(vocab)}
    idx2char = np.array(vocab)
    return vocab, char2idx, idx2char

def text_to_ids(text, char2idx):
    return np.array([char2idx[c] for c in text], dtype=np.int32)

def make_dataset(ids, seq_length, batch_size, buffer_size):
    char_ds = tf.data.Dataset.from_tensor_slices(ids)
    sequences = char_ds.batch(seq_length + 1, drop_remainder=True)

    def split_input_target(chunk):
        return chunk[:-1], chunk[1:]

    ds = sequences.map(split_input_target, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(buffer_size).batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE)
    return ds

# -----------------------------
# Model
# -----------------------------
def build_model(vocab_size, embedding_dim, rnn_units, batch_size, stateful):
    if stateful:
        inputs = tf.keras.Input(batch_shape=[batch_size, None], dtype=tf.int32)
    else:
        inputs = tf.keras.Input(shape=(None,), dtype=tf.int32)
    x = tf.keras.layers.Embedding(vocab_size, embedding_dim)(inputs)
    x = tf.keras.layers.GRU(
        rnn_units,
        return_sequences=True,
        stateful=stateful,
        recurrent_initializer="glorot_uniform",
        name="rnn"  # name the GRU so we can reset its state later
    )(x)
    outputs = tf.keras.layers.Dense(vocab_size)(x)
    return tf.keras.Model(inputs, outputs, name="CharRNN")

# -----------------------------
# Training
# -----------------------------
def loss_fn(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def train_model(train_ds, vocab_size, extra_callbacks=None):
    model = build_model(
        vocab_size=vocab_size,
        embedding_dim=EMBED_DIM,
        rnn_units=RNN_UNITS,
        batch_size=None,   # flexible batch for non stateful training
        stateful=False
    )
    opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, clipnorm=CLIPNORM)
    model.compile(optimizer=opt, loss=loss_fn)

    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, "epoch{epoch:02d}.weights.h5"),
        save_weights_only=True,
        save_best_only=False,
        verbose=1,
    )
    es_cb = tf.keras.callbacks.EarlyStopping(monitor="loss", patience=5, restore_best_weights=True)

    callbacks = [ckpt_cb, es_cb]
    if extra_callbacks:
        callbacks += list(extra_callbacks)

    print("Training...")
    history = model.fit(train_ds, epochs=EPOCHS, callbacks=callbacks)
    return model, history

# -----------------------------
# Sampling helpers
# -----------------------------
@tf.function
def sample_step(model, input_id, temperature=1.0):
    logits = model(input_id)            # shape [1, 1, vocab]
    logits = logits[:, -1, :] / temperature
    next_id = tf.random.categorical(logits, num_samples=1)[0, 0]
    return next_id

def build_stateful_sampler(vocab_size, trained_model):
    sampler = build_model(
        vocab_size=vocab_size,
        embedding_dim=EMBED_DIM,
        rnn_units=RNN_UNITS,
        batch_size=1,
        stateful=True
    )
    sampler.set_weights(trained_model.get_weights())
    return sampler

def reset_rnn_states(model):
    # Reset states of any stateful RNN layers
    for layer in model.layers:
        if hasattr(layer, "reset_states"):
            layer.reset_states()

def generate_text(sampler, start_string, char2idx, idx2char, gen_len, temperature):
    # Convert seed to ids and prime the state with the whole seed
    input_ids = tf.expand_dims([char2idx[c] for c in start_string], 0)  # [1, T]
    reset_rnn_states(sampler)
    _ = sampler(input_ids)                    # prime with seed
    next_input = input_ids[:, -1:]            # start from last seed char

    out_chars = []
    for _ in range(gen_len):
        nid = sample_step(sampler, next_input, temperature=temperature)
        nid_val = nid.numpy()
        out_chars.append(idx2char[nid_val])
        next_input = tf.expand_dims([nid_val], 0)
    return start_string + "".join(out_chars)

# -----------------------------
# Callback to sample during training every N epochs
# -----------------------------
class TextSamplerCallback(tf.keras.callbacks.Callback):
    def __init__(self, char2idx, idx2char, seed_text, gen_len, temperature, vocab_size, every=10):
        super().__init__()
        self.char2idx = char2idx
        self.idx2char = idx2char
        self.seed_text = seed_text
        self.gen_len = gen_len
        self.temperature = temperature
        self.vocab_size = vocab_size
        self.every = every

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.every == 0:
            sampler = build_stateful_sampler(self.vocab_size, self.model)
            sample = generate_text(
                sampler,
                start_string=self.seed_text,
                char2idx=self.char2idx,
                idx2char=self.idx2char,
                gen_len=self.gen_len,
                temperature=self.temperature
            )
            print("\n" + "=" * 80)
            print(f"Sample after epoch {epoch + 1}:")
            print(sample)
            print("=" * 80)

# -----------------------------
# Main
# -----------------------------
raw_text = load_text()
print("Raw length:", len(raw_text))

text = clean_to_ascii(raw_text)
print("Clean length:", len(text))
print("\nSample of cleaned text:\n", text[:400])

vocab, char2idx, idx2char = build_vocab(text)
print("Vocab size:", len(vocab))

ids = text_to_ids(text, char2idx)
ds = make_dataset(ids, SEQ_LENGTH, BATCH_SIZE, BUFFER_SIZE)

sampler_cb = TextSamplerCallback(
    char2idx=char2idx,
    idx2char=idx2char,
    seed_text=SAMPLE_SEED,
    gen_len=SAMPLE_LEN,
    temperature=TEMPERATURE,
    vocab_size=len(vocab),
    every=SAMPLE_EVERY
)

model, history = train_model(ds, vocab_size=len(vocab), extra_callbacks=[sampler_cb])

sampler = build_stateful_sampler(len(vocab), model)
final_sample = generate_text(
    sampler,
    start_string=SEED_TEXT,
    char2idx=char2idx,
    idx2char=idx2char,
    gen_len=GEN_LEN,
    temperature=TEMPERATURE
)

print("\n" + "=" * 80)
print("Final sample after training:")
print(final_sample)
print("=" * 80)


TensorFlow version: 2.19.0
GPU available: True [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Raw length: 144696
Clean length: 140805

Sample of cleaned text:
 *** START OF THE PROJECT GUTENBERG EBOOK 11 *** [Illustration] Alice s Adventures in Wonderland by Lewis Carroll THE MILLENNIUM FULCRUM EDITION 3.0 Contents CHAPTER I.
Down the Rabbit-Hole CHAPTER II.
The Pool of Tears CHAPTER III.
A Caucus-Race and a Long Tale CHAPTER IV.
The Rabbit Sends in a Little Bill CHAPTER V.
Advice from a Caterpillar CHAPTER VI.
Pig and Pepper CHAPTER VII.
A Mad Tea-Party
Vocab size: 70
Training...
Epoch 1/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 3.7022
Epoch 1: saving model to ./checkpoints_alice_clean/epoch01.weights.h5
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 31ms/step - loss: 3.6829
Epoch 2/50
[1m19/21[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m0s[0m 23ms/step - loss: 2.5884
Epoch 2: saving model to ./checkp