<a href="https://colab.research.google.com/github/abduyea/AD325-DS/blob/main/ST_Demo_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Setup environment and mount Google Drive

from google.colab import drive
drive.mount("/content/drive")


Mounted at /content/drive


In [2]:
# Load pretrained models for inference only

import os
import pickle
import numpy as np
import tensorflow as tf
import tensorflow as tf
tf.get_logger().setLevel("ERROR")


ROOT_DIR = "/content/drive/MyDrive/Final_Storytelling_engine/storytelling_engine"
MODELS_DIR = os.path.join(ROOT_DIR, "models")

NLP_DIR = os.path.join(MODELS_DIR, "nlp")
GAN_DIR = os.path.join(MODELS_DIR, "gan")
AE_DIR = os.path.join(MODELS_DIR, "ae")

NLP_MODEL_PATH = os.path.join(NLP_DIR, "rocstories_seq2seq.keras")
NLP_TOKENIZER_PATH = os.path.join(NLP_DIR, "tokenizer.pkl")
NLP_META_PATH = os.path.join(NLP_DIR, "metadata.npz")

GAN_GEN_PATH = os.path.join(GAN_DIR, "generator_model.keras")
AE_MODEL_PATH = os.path.join(AE_DIR, "conv_autoencoder.keras")
AE_ENCODER_PATH = os.path.join(AE_DIR, "ae_encoder.keras")

required_files = [
    NLP_MODEL_PATH,
    NLP_TOKENIZER_PATH,
    NLP_META_PATH,
    GAN_GEN_PATH,
    AE_MODEL_PATH,
    AE_ENCODER_PATH,
]

for path in required_files:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing model file: {path}")

nlp_model = tf.keras.models.load_model(NLP_MODEL_PATH)

with open(NLP_TOKENIZER_PATH, "rb") as f:
    nlp_tokenizer = pickle.load(f)

meta = np.load(NLP_META_PATH, allow_pickle=True)
vocab_size = int(meta["vocab_size"])
max_encoder_len = int(meta["max_encoder_len"])
max_decoder_len = int(meta["max_decoder_len"])
start_token = str(meta["start_token"])
end_token = str(meta["end_token"])

gan_generator = tf.keras.models.load_model(GAN_GEN_PATH, compile=False)
autoencoder = tf.keras.models.load_model(AE_MODEL_PATH, compile=False)
encoder = tf.keras.models.load_model(AE_ENCODER_PATH, compile=False)

print(" All pretrained models loaded (inference mode)")


 All pretrained models loaded (inference mode)


In [3]:
# Text preprocessing and sentence-5 prediction (NLP inference)

word_index = getattr(nlp_tokenizer, "word_index", {})
index_word = getattr(nlp_tokenizer, "index_word", {})

start_id = word_index.get(start_token) or word_index.get("start")
end_id = word_index.get(end_token) or word_index.get("end")

if start_id is None or end_id is None:
    raise ValueError("Start or end token not found in tokenizer.")


def prep_encoder_text(s1, s2, s3, s4):
    text = " ".join([s1, s2, s3, s4]).strip().lower()
    seq = nlp_tokenizer.texts_to_sequences([text])[0][:max_encoder_len]

    if len(seq) < max_encoder_len:
        seq += [0] * (max_encoder_len - len(seq))

    return np.array(seq, dtype=np.int32)[None, :]


def predict_sentence5(s1, s2, s3, s4):
    enc_input = prep_encoder_text(s1, s2, s3, s4)
    dec_tokens = [start_id]

    for _ in range(max_decoder_len - 1):
        dec_input = dec_tokens[:max_decoder_len]
        dec_input += [0] * (max_decoder_len - len(dec_input))
        dec_input = np.array(dec_input, dtype=np.int32)[None, :]

        preds = nlp_model.predict([enc_input, dec_input], verbose=0)
        t = min(len(dec_tokens) - 1, preds.shape[1] - 1)
        next_id = int(np.argmax(preds[0, t]))

        if next_id in (0, end_id):
            break

        dec_tokens.append(next_id)

    words = [index_word.get(i, "") for i in dec_tokens[1:]]
    sentence = " ".join(w for w in words if w).strip()

    return sentence if sentence else "(empty prediction)"


In [4]:
#  Interactive demo UI (self-contained, inference-only)

import ipywidgets as w
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import numpy as np

LATENT_DIM = 100


def generate_image():
    z = np.random.normal(0.0, 1.0, (1, LATENT_DIM)).astype(np.float32)
    img = gan_generator.predict(z, verbose=0)[0]
    return np.clip(img, 0.0, 1.0)


def enhance_image(img):
    out = autoencoder.predict(img[None, ...], verbose=0)[0]
    return np.clip(out, 0.0, 1.0)


s1 = w.Textarea(description="S1:", layout=w.Layout(width="100%", height="60px"))
s2 = w.Textarea(description="S2:", layout=w.Layout(width="100%", height="60px"))
s3 = w.Textarea(description="S3:", layout=w.Layout(width="100%", height="60px"))
s4 = w.Textarea(description="S4:", layout=w.Layout(width="100%", height="60px"))

gen_img = w.Checkbox(value=True, description="Generate image (GAN)")
use_ae = w.Checkbox(value=True, description="Enhance with AE")

run_btn = w.Button(description="Run Storytelling Engine", button_style="primary")
out = w.Output()


def run_demo(_):
    with out:
        clear_output(wait=True)

        sent5 = predict_sentence5(s1.value, s2.value, s3.value, s4.value)
        print("Predicted sentence 5:")
        print(sent5)
        print()

        if gen_img.value:
            img = generate_image()
            img_out = enhance_image(img) if use_ae.value else img

            plt.figure(figsize=(6, 4))
            plt.imshow(img_out)
            plt.axis("off")
            plt.title("Generated Scene")
            plt.show()


run_btn.on_click(run_demo)

display(
    w.VBox(
        [
            w.HTML("<h3>Storytelling Engine — Inference Demo</h3>"),
            s1,
            s2,
            s3,
            s4,
            w.HBox([gen_img, use_ae, run_btn]),
            out,
        ]
    )
)


VBox(children=(HTML(value='<h3>Storytelling Engine — Inference Demo</h3>'), Textarea(value='', description='S1…