In [5]:
"""
IMDB Sentiment — Enhanced BiLSTM Trainer & Inference
"""

from __future__ import annotations
import os
import re
import random
from pathlib import Path
from typing import List, Tuple

# Quieter TF logs
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# -----------------------
# Reproducibility
# -----------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.keras.utils.set_random_seed(SEED)

# -----------------------
# Config (can be overridden via CLI)
# -----------------------
NUM_WORDS   = 20_000
MAXLEN      = 250
EMBED_DIM   = 128
LSTM_UNITS1 = 128
LSTM_UNITS2 = 64
DROPOUT     = 0.30
BATCH_SIZE  = 64
EPOCHS      = 10

# Works in both scripts and notebooks
try:
    ROOT = Path(__file__).resolve().parent
except NameError:  # notebook
    ROOT = Path.cwd()

MODEL_KERAS = ROOT / "imdb_lstm.keras"  # preferred
MODEL_H5    = ROOT / "imdb_lstm.h5"     # legacy

# -----------------------
# Data loading
# -----------------------
def load_imdb(num_words: int = NUM_WORDS, maxlen: int = MAXLEN):
    (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=num_words)
    x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
    x_test  = keras.preprocessing.sequence.pad_sequences(x_test,  maxlen=maxlen)
    print(f"Train samples: {len(x_train)}  Test samples: {len(x_test)}")
    return (x_train, y_train), (x_test, y_test)

def get_word_index(num_words: int = NUM_WORDS):
    """IMDB word index, shifted for reserved tokens, limited to top-N."""
    word_index = keras.datasets.imdb.get_word_index()
    word_index = {k: (v + 3) for k, v in word_index.items()}  # shift
    word_index["<PAD>"] = 0
    word_index["<START>"] = 1
    word_index["<UNK>"] = 2
    word_index["<UNUSED>"] = 3
    return {k: v for k, v in word_index.items() if v < num_words}

# -----------------------
# Model
# -----------------------
def build_model(vocab_size: int = NUM_WORDS, maxlen: int = MAXLEN) -> keras.Model:
    inputs = layers.Input(shape=(maxlen,), dtype="int32")
    x = layers.Embedding(vocab_size, EMBED_DIM)(inputs)  # input_length deprecated in Keras 3
    x = layers.Bidirectional(layers.LSTM(LSTM_UNITS1, return_sequences=True))(x)
    x = layers.Dropout(DROPOUT)(x)
    x = layers.Bidirectional(layers.LSTM(LSTM_UNITS2))(x)
    x = layers.Dropout(DROPOUT)(x)
    x = layers.Dense(64, activation="relu")(x)
    x = layers.Dropout(DROPOUT)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)

    model = keras.Model(inputs, outputs, name="imdb_bilstm")
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],
    )
    return model

# -----------------------
# Training
# -----------------------
def train_and_save():
    (x_train, y_train), (x_test, y_test) = load_imdb()
    model = build_model(NUM_WORDS, MAXLEN)
    model.summary()

    callbacks = [
        keras.callbacks.ModelCheckpoint(
            filepath=str(MODEL_KERAS),
            monitor="val_auc",
            mode="max",
            save_best_only=True,
            verbose=1,
        ),
        keras.callbacks.EarlyStopping(
            monitor="val_auc",
            mode="max",
            patience=3,
            restore_best_weights=True,
            verbose=1,
        ),
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.5,
            patience=2,
            min_lr=1e-5,
            verbose=1,
        ),
    ]

    model.fit(
        x_train, y_train,
        validation_split=0.2,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        verbose=1,
        callbacks=callbacks,
        shuffle=True,
    )

    print("\nEvaluating on test set...")
    test_loss, test_acc, test_auc = model.evaluate(x_test, y_test, verbose=0)
    print(f"Test — loss: {test_loss:.4f}  acc: {test_acc:.4f}  auc: {test_auc:.4f}")

    model.save(MODEL_KERAS)
    model.save(MODEL_H5)
    print(f"✅ Saved: {MODEL_KERAS.name} and {MODEL_H5.name}")

    # Optional: extra metrics via scikit-learn
    try:
        from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
        y_prob = model.predict(x_test, batch_size=512, verbose=0).ravel()
        y_pred = (y_prob >= 0.5).astype("int32")
        print("\nClassification report:\n", classification_report(y_test, y_pred, digits=4))
        print("Confusion matrix:\n", confusion_matrix(y_test, y_pred))
        print("ROC-AUC:", roc_auc_score(y_test, y_prob))
    except Exception as e:
        print("(Extra metrics skipped; install scikit-learn to enable.)", e)

# -----------------------
# Inference on custom text
# -----------------------
_WORD_INDEX_CACHE = None
def _simple_clean(s: str) -> List[str]:
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s']", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s.split()

def encode_review(text: str, maxlen: int = MAXLEN) -> np.ndarray:
    """Encode raw text to the same integer format IMDB dataset uses."""
    global _WORD_INDEX_CACHE
    if _WORD_INDEX_CACHE is None:
        _WORD_INDEX_CACHE = get_word_index(NUM_WORDS)
    words = _simple_clean(text)
    seq = [1]  # <START>
    for w in words:
        idx = _WORD_INDEX_CACHE.get(w, 2)  # 2 = <UNK>
        seq.append(idx)
    arr = keras.preprocessing.sequence.pad_sequences([seq], maxlen=maxlen)
    return arr

def load_model_for_inference() -> keras.Model:
    """Load a saved model if available; otherwise raise a clear error."""
    if MODEL_KERAS.exists():
        print(f"📦 Loading {MODEL_KERAS.name}")
        return keras.models.load_model(MODEL_KERAS)
    if MODEL_H5.exists():
        print(f"📦 Loading {MODEL_H5.name} (legacy)")
        return keras.models.load_model(MODEL_H5)
    raise FileNotFoundError(
        f"No saved model found at:\n  {MODEL_KERAS}\n  {MODEL_H5}\n"
        "Run this script with --train at least once to create them."
    )

def predict_text(text: str) -> Tuple[float, str]:
    model = load_model_for_inference()
    x = encode_review(text)
    prob = float(model.predict(x, verbose=0).ravel()[0])
    label = "Positive" if prob >= 0.5 else "Negative"
    return prob, label

# -----------------------
# CLI (robust in Jupyter)
# -----------------------
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="IMDB Sentiment — BiLSTM")
    parser.add_argument("--train", action="store_true", help="Train the model")
    parser.add_argument("--predict", type=str, help="Score a custom review")
    parser.add_argument("--epochs", type=int, default=EPOCHS, help="Override epochs")
    parser.add_argument("--batch_size", type=int, default=BATCH_SIZE, help="Override batch size")

    # Ignore unknown args that Jupyter injects (e.g., -f <path>)
    args, _ = parser.parse_known_args()

    # Allow simple overrides
    EPOCHS = int(args.epochs)
    BATCH_SIZE = int(args.batch_size)

    if args.train:
        train_and_save()

    elif args.predict is not None:
        p, lab = predict_text(args.predict)
        print(f"Review: {args.predict}")
        print(f"Score: {p:.4f}  Prediction: {lab}")

    else:
        # ---- Quick demo that DOES NOT require saved files ----
        (x_train, y_train), _ = load_imdb()
        model = build_model()
        model.fit(
            x_train, y_train,
            validation_split=0.2,
            epochs=3,
            batch_size=BATCH_SIZE,
            verbose=1,
        )
        print("Quick demo score on a sample review (in-memory model):")
        sample = "The movie was fantastic! I absolutely loved it and would watch it again."
        x = encode_review(sample)
        prob = float(model.predict(x, verbose=0).ravel()[0])
        lab = "Positive" if prob >= 0.5 else "Negative"
        print(f"Sample: {sample}\nScore: {prob:.4f}  Prediction: {lab}")


Train samples: 25000  Test samples: 25000
Epoch 1/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m257s[0m 794ms/step - accuracy: 0.7782 - auc: 0.8643 - loss: 0.4604 - val_accuracy: 0.8494 - val_auc: 0.9212 - val_loss: 0.3609
Epoch 2/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m222s[0m 710ms/step - accuracy: 0.8961 - auc: 0.9554 - loss: 0.2671 - val_accuracy: 0.8576 - val_auc: 0.9278 - val_loss: 0.3485
Epoch 3/3
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m225s[0m 720ms/step - accuracy: 0.9361 - auc: 0.9791 - loss: 0.1784 - val_accuracy: 0.8616 - val_auc: 0.9264 - val_loss: 0.4113
Quick demo score on a sample review (in-memory model):
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json
[1m1641221/1641221[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Sample: The movie was fantastic! I absolutely loved it and would watch it again.
Score: 0.9083  Prediction: Positive
