In [None]:
import os
import gradio as gr
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenetv2_preprocess
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import Model
import joblib
import warnings
import tensorflow as tf

# Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning, module='tqdm')

# Set up directories
BASE_DIR = os.path.abspath(os.getcwd())
MODELS_DIR = os.path.join(BASE_DIR, "Trained Models")
print(f"Base directory: {BASE_DIR}")
print(f"Looking for models in: {MODELS_DIR}\n")

# Define model filenames and label file
model_filenames = ["resnet50.keras", "custom_cnn.keras", "mobilenet.keras", "lstm_cnn.keras", "logistic_regression.pkl", "svm.pkl"]
label_file = "label_encoder_classes.npy"

# Verify file existence
for fname in model_filenames + [label_file]:
    full_path = os.path.join(MODELS_DIR, fname)
    if not os.path.exists(full_path):
        raise FileNotFoundError(f"❌ Missing file: {full_path}")
    size = os.path.getsize(full_path)
    print(f"✔ Found {fname} ({size/1e6:.2f} MB)")

print("\nAll files present. Loading label encoder…")

# Load label encoder
label_path = os.path.join(MODELS_DIR, label_file)
try:
    classes = np.load(label_path, allow_pickle=True)
    label_mapping = {i: cls for i, cls in enumerate(classes)}
except Exception as e:
    raise Exception(f"❌ Error loading label_encoder_classes.npy: {e}")

# Load models
print("Loading models…")
model_dict = {}
for fname in model_filenames:
    full_path = os.path.join(MODELS_DIR, fname)
    try:
        if fname.endswith('.keras'):
            model_dict[fname] = load_model(full_path, compile=False)
            print(f"✔ Loaded Keras model: {fname}")
        elif fname.endswith('.pkl'):
            model_dict[fname] = joblib.load(full_path)
            print(f"✔ Loaded scikit-learn model: {fname}")
    except Exception as e:
        raise RuntimeError(f"❌ Failed to load {fname}: {e}")

# Define preprocessing functions for each model
preprocess_dict = {
    "resnet50.keras": resnet50_preprocess,
    "mobilenet.keras": mobilenetv2_preprocess,
    "custom_cnn.keras": lambda x: x / 255.0,
    "lstm_cnn.keras": lambda x: x,
}

# Feature extractor for scikit-learn models (using ResNet50)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
feature_extractor = Model(inputs=base_model.input, outputs=base_model.output)
feature_extractor.compile()

def extract_features(image):
    """Extract features from an image using ResNet50 for scikit-learn models."""
    img = image.resize((224, 224))
    arr = np.array(img, dtype=np.float32)
    arr = preprocess_input(arr)
    arr = np.expand_dims(arr, axis=0)
    feats = feature_extractor.predict(arr, verbose=0)
    feats = tf.keras.layers.GlobalAveragePooling2D()(feats)
    return feats.numpy().flatten()

# Prediction function with model-specific preprocessing
def predict_image(selected_model, image):
    try:
        if image is None:
            return "❌ Error: No image uploaded!"
        
        model = model_dict[selected_model]
        if selected_model.endswith('.keras'):
            preprocess_fn = preprocess_dict[selected_model]
            img = image.resize((224, 224))
            arr = np.array(img, dtype=np.float32)
            arr = preprocess_fn(arr)
            arr = np.expand_dims(arr, axis=0)
            preds = model.predict(arr, verbose=0)[0]
        elif selected_model.endswith('.pkl'):
            features = extract_features(image)
            if selected_model == "svm.pkl":
                pred_class = model.predict([features])[0]
                preds = np.zeros(len(classes))
                preds[pred_class] = 1.0
            else:
                preds = model.predict_proba([features])[0]
        else:
            return "❌ Error: Unsupported model type!"
        
        idx = np.argmax(preds)
        name = label_mapping[idx]
        conf = preds[idx]
        
        # Use HTML for consistent formatting
        out = f"<b>Prediction:</b> {name}<br>"
        out += f"<b>Confidence:</b> {conf:.4f}<br><br>"
        out += "<h3>Class Probabilities:</h3>"
        out += "<ul>"
        for i, p in enumerate(preds):
            out += f"<li>{label_mapping[i]}: {p:.4f}</li>"
        out += "</ul>"
        return out
    except Exception as e:
        return f"❌ Error: {str(e)}"

# Launch Gradio interface
gr.Interface(
    fn=predict_image,
    inputs=[
        gr.Dropdown(choices=model_filenames, label="Select Model"),
        gr.Image(type="pil", label="Upload Image")
    ],
    outputs=gr.Markdown(),
    title="🏠 Room Classifier",
    description="Upload an image and choose a model to classify bath/bed/kitchen/living/dining."
).launch(server_name="127.0.0.1", server_port=8000)