# 4. Inference / Deployment

This notebook demonstrates how to use the trained models to predict the type of a new PokÃ©mon image.

In [None]:
import sys
import os
import pickle
import joblib
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.models import load_model

# Add src to path
sys.path.append("../src")
from features import extract_kmeans_features, extract_histogram_features, download_image_from_url

# Constants
MODEL_DIR = "../models"
TEMP_IMG_DIR = "../data/temp_inference"
os.makedirs(TEMP_IMG_DIR, exist_ok=True)

## Load Models

In [None]:
# Load XGBoost
xgb_model = joblib.load(os.path.join(MODEL_DIR, "xgboost_model.pkl"))

# Load MLP
mlp_model = load_model(os.path.join(MODEL_DIR, "mlp_model.h5"))

# Load Label Binarizer
mlb = joblib.load(os.path.join(MODEL_DIR, "mlb.pkl"))

print("Models loaded successfully.")

## Prediction Function

In [None]:
def predict_pokemon_type(image_path_or_url):
    # Handle URL vs Local Path
    if image_path_or_url.startswith("http"):
        filename = image_path_or_url.split("/")[-1]
        local_path = os.path.join(TEMP_IMG_DIR, filename)
        success = download_image_from_url(image_path_or_url, local_path)
        if not success:
            print("Failed to download image.")
            return
        img_path = local_path
    else:
        img_path = image_path_or_url

    # Display Image
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(img_rgb)
    plt.axis('off')
    plt.show()

    # 1. XGBoost Prediction (K-Means)
    feat_kmeans = extract_kmeans_features(img_path)
    if feat_kmeans is None:
        print("Could not extract K-Means features.")
        return
    
    # XGBoost expects 2D array
    pred_xgb = xgb_model.predict([feat_kmeans])
    labels_xgb = mlb.inverse_transform(pred_xgb)
    
    # 2. MLP Prediction (Histogram)
    feat_hist = extract_histogram_features(img_path)
    if feat_hist is None:
        print("Could not extract Histogram features.")
        return
    
    # MLP expects 2D array
    pred_probs_mlp = mlp_model.predict(np.array([feat_hist]))
    pred_mlp = (pred_probs_mlp > 0.5).astype(int)
    labels_mlp = mlb.inverse_transform(pred_mlp)

    print("--- Predictions ---")
    print(f"XGBoost predicts: {labels_xgb[0]}")
    print(f"MLP predicts:     {labels_mlp[0]}")

## Test it!

In [None]:
# Example: Charizard (Fire/Flying)
charizard_url = "https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites/pokemon/other/official-artwork/6.png"
predict_pokemon_type(charizard_url)

In [None]:
# Example: Squirtle (Water)
squirtle_url = "https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites/pokemon/other/official-artwork/7.png"
predict_pokemon_type(squirtle_url)