In [None]:
# predict_emotion.py

import tensorflow as tf
import numpy as np
import os
from transformers import BertTokenizer, TFBertModel # Only if using BERT model

# Import configurations and helper functions
from config import (
    TEXT_COLUMN, LABEL_COLUMNS, MODEL_TYPE,
    MAX_SEQUENCE_LENGTH, BERT_MAX_LEN,
    SAVE_DIR, MODEL_SAVE_PATH, TOKENIZER_SAVE_PATH
)
from data_handler import clean_text, tokenize_and_pad_sequences, load_tokenizer
from model_architectures import BahdanauAttention # Required for loading BiGRU_Attention model

def load_emotion_model():
    """
    Loads the saved Keras emotion classification model.

    Returns:
        tf.keras.Model: The loaded model.
    Raises:
        FileNotFoundError: If the model file does not exist.
        ValueError: If model type is unknown or custom objects are missing.
    """
    if not os.path.exists(MODEL_SAVE_PATH):
        raise FileNotFoundError(f"Error: Model not found at '{MODEL_SAVE_PATH}'. Please train the model first.")

    print(f"Loading {MODEL_TYPE} model from {MODEL_SAVE_PATH}...")
    try:
        custom_objects = {}
        if MODEL_TYPE == 'BiGRU_Attention':
            custom_objects = {'BahdanauAttention': BahdanauAttention}

        model = tf.keras.models.load_model(MODEL_SAVE_PATH, custom_objects=custom_objects)
        print("Model loaded successfully.")
        return model
    except Exception as e:
        raise ValueError(f"Error loading model: {e}. Make sure MODEL_TYPE in config.py matches the saved model and custom objects are correctly defined.")

def predict_emotion(lyrics, model, tokenizer=None, bert_tokenizer=None, bert_model=None):
    """
    Predicts emotions for a given lyric text.

    Args:
        lyrics (str): The song lyric text to predict.
        model (tf.keras.Model): The loaded Keras model.
        tokenizer (tf.keras.preprocessing.text.Tokenizer, optional): Pre-fitted Keras tokenizer (for BiGRU).
        bert_tokenizer (transformers.BertTokenizer, optional): BERT tokenizer (for BERT_XLSTM).
        bert_model (transformers.TFBertModel, optional): BERT model (for BERT_XLSTM).

    Returns:
        dict: A dictionary with emotion labels and their predicted probabilities.
    """
    cleaned_lyrics = clean_text(lyrics)
    processed_input = None

    if MODEL_TYPE == 'BiGRU_Attention':
        if tokenizer is None:
            raise ValueError("Tokenizer must be provided for BiGRU_Attention model prediction.")
        sequences = tokenizer.texts_to_sequences([cleaned_lyrics])
        processed_input = tf.keras.preprocessing.sequence.pad_sequences(
            sequences, maxlen=MAX_SEQUENCE_LENGTH, padding='post', truncating='post'
        )
    elif MODEL_TYPE == 'BERT_XLSTM':
        if bert_tokenizer is None or bert_model is None:
            print("Initializing BERT Tokenizer and Model for prediction...")
            bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            bert_model = TFBertModel.from_pretrained('bert-base-uncased')
            print("BERT components initialized.")
        tokens = bert_tokenizer([cleaned_lyrics], padding='max_length', truncation=True,
                                max_length=BERT_MAX_LEN, return_tensors='tf')
        # Use only the CLS token embedding
        outputs = bert_model(tokens)['last_hidden_state'][:, 0, :]
        processed_input = outputs.numpy()
    else:
        raise ValueError(f"Unknown MODEL_TYPE '{MODEL_TYPE}'. Cannot preprocess for prediction.")

    if processed_input is None:
        return {"error": "Failed to process input for prediction."}

    # Make prediction
    prediction_proba = model.predict(processed_input)[0] # Get the first (and only) sample's prediction

    # Map probabilities to labels
    results = {}
    for i, label in enumerate(LABEL_COLUMNS):
        results[label] = float(prediction_proba[i]) # Convert numpy.float32 to standard float

    return results

if __name__ == "__main__":
    model = None
    tokenizer = None
    bert_tokenizer_pred = None
    bert_model_pred = None

    try:
        model = load_emotion_model()

        if MODEL_TYPE == 'BiGRU_Attention':
            tokenizer = load_tokenizer(TOKENIZER_SAVE_PATH)
        elif MODEL_TYPE == 'BERT_XLSTM':
            # Initialize BERT tokenizer and model once for prediction
            print("Initializing BERT Tokenizer and Model for prediction...")
            bert_tokenizer_pred = BertTokenizer.from_pretrained('bert-base-uncased')
            bert_model_pred = TFBertModel.from_pretrained('bert-base-uncased')
            print("BERT components initialized.")

        sample_lyrics_1 = "I feel so happy today, the sun is shining, and everything is going my way. Full of joy!"
        sample_lyrics_2 = "A shadow falls upon my heart, a heavy weight, tearing me apart. Full of sorrow and despair."
        sample_lyrics_3 = "My blood boils, a fiery rage, I will not stand for this injustice anymore. Unleash the fury!"
        sample_lyrics_4 = "I dream of you endlessly, your presence consumes my thoughts. A longing that never fades, an undeniable pull."


        print("\n--- Prediction for Sample 1 ---")
        predictions = predict_emotion(sample_lyrics_1, model, tokenizer, bert_tokenizer_pred, bert_model_pred)
        for label, prob in predictions.items():
            print(f"- {label}: {prob:.4f}")

        print("\n--- Prediction for Sample 2 ---")
        predictions = predict_emotion(sample_lyrics_2, model, tokenizer, bert_tokenizer_pred, bert_model_pred)
        for label, prob in predictions.items():
            print(f"- {label}: {prob:.4f}")

        print("\n--- Prediction for Sample 3 ---")
        predictions = predict_emotion(sample_lyrics_3, model, tokenizer, bert_tokenizer_pred, bert_model_pred)
        for label, prob in predictions.items():
            print(f"- {label}: {prob:.4f}")

        print("\n--- Prediction for Sample 4 ---")
        predictions = predict_emotion(sample_lyrics_4, model, tokenizer, bert_tokenizer_pred, bert_model_pred)
        for label, prob in predictions.items():
            print(f"- {label}: {prob:.4f}")

    except (FileNotFoundError, ValueError) as e:
        print(f"Error during prediction setup: {e}")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")