<a href="https://colab.research.google.com/github/amberyliang/Cat-Dog-Rabbit-Image-Classifier-with-Grad-CAM-Gradio-/blob/main/cat_dog_rabbit_dectector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ✅ 貓狗兔三分類模型 - Colab 版（使用 MobileNetV2 + 自定義分類層）
# 資料夾格式：dataset/train/[cat|dog|rabbit] 以及 dataset/val/[cat|dog|rabbit]

# ⚠️ 請先掛載 Google Drive
from google.colab import drive
import os

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install gradio


Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6

In [None]:
# ✅ 導入套件
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import seaborn as sns
import cv2
import gradio as gr



TensorFlow version: 2.18.0
Num GPUs Available: 0


In [None]:
# 1️⃣ 指定掛載後的資料路徑
base_dir = '/content/drive/MyDrive/colab_data/image_process_dataset - simple'
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')

In [None]:
# ✅ 檢查是否有已儲存的模型
model_path = '/content/drive/MyDrive/colab_data/my_model.h5'
if os.path.exists(model_path):
    model = load_model(model_path)
    print("✅ 載入現有模型成功")
else:
    # 2️⃣ 影像預處理與資料增強
    train_datagen = ImageDataGenerator(rescale=1./255,
                                       rotation_range=20,
                                       zoom_range=0.2,
                                       horizontal_flip=True)
    val_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='categorical'
    )

    val_generator = val_datagen.flow_from_directory(
        val_dir,
        target_size=(224, 224),
        batch_size=32,
        class_mode='categorical',
        shuffle=False
    )

    # 3️⃣ 建立模型（使用預訓練 MobileNetV2）
    base_model = MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.3)(x)
    predictions = Dense(3, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    for layer in base_model.layers:
        layer.trainable = False

    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=10
    )

    for layer in base_model.layers[-30:]:
        layer.trainable = True

    model.compile(optimizer=Adam(learning_rate=1e-5),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    fine_tune_history = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=10
    )

    model.save(model_path)
    print("💾 模型已儲存至 my_model.h5")

    # 混淆矩陣與報告可自行啟用：
    Y_pred = model.predict(val_generator)
    y_pred = np.argmax(Y_pred, axis=1)
    cm = confusion_matrix(val_generator.classes, y_pred)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=val_generator.class_indices.keys(), yticklabels=val_generator.class_indices.keys())
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    print(classification_report(val_generator.classes, y_pred, target_names=list(val_generator.class_indices.keys())))




✅ 載入現有模型成功


In [None]:
# 你只要建一個正確順序的 list（照原模型訓練順序）
np.save('/content/drive/MyDrive/colab_data/labels.npy', ['cat', 'dog', 'rabbit'])  # 順序照你原本的訓練資料夾

In [10]:
# 🔍 Grad-CAM 函數
from tensorflow.keras.preprocessing import image
import matplotlib.cm as cm

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(predictions[0])
        class_channel = predictions[:, pred_index]
    grads = tape.gradient(class_channel, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

labels = list(np.load('/content/drive/MyDrive/colab_data/labels.npy', allow_pickle=True))

CONFIDENCE_THRESHOLD = 0.6

def classify_and_gradcam(uploaded_img):
    try:
        if uploaded_img is None:
            raise ValueError("⚠️ 上傳圖片失敗，請重新選擇圖片。")

        if not isinstance(uploaded_img, np.ndarray):
            uploaded_img = np.array(uploaded_img)

        print("✅ model object:", model)
        print("⚠️ Received type:", type(uploaded_img))
        print("⚠️ Shape:", getattr(uploaded_img, 'shape', 'No shape'))

        img = cv2.cvtColor(uploaded_img, cv2.COLOR_BGR2RGB)
        resized = cv2.resize(img, (224, 224))
        x = resized.astype(np.float32) / 255.0
        x = np.expand_dims(x, axis=0)

        print("Running prediction...")
        print("Input shape:", x.shape)
        pred = model.predict(x)[0]
        print("Prediction result:", pred)

        heatmap = make_gradcam_heatmap(x, model, 'Conv_1')
        heatmap = cv2.resize(heatmap, (224, 224))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        overlay = cv2.addWeighted(resized, 0.6, heatmap, 0.4, 0)

        result_dict = {labels[i]: float(pred[i]) for i in range(len(labels))}

        if max(result_dict.values()) < CONFIDENCE_THRESHOLD:
            return {"Unknown / 非貓狗兔": 1.0}, overlay

        return result_dict, overlay

    except Exception as e:
        import traceback
        traceback.print_exc()
        return {"Error": str(e)}, np.zeros((224, 224, 3), dtype=np.uint8)

def batch_predict(images):
    results = []
    for img in images:
        try:
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            resized = cv2.resize(img_rgb, (224, 224))
            x = resized.astype(np.float32) / 255.0
            x = np.expand_dims(x, axis=0)
            pred = model.predict(x)[0]
            label = labels[np.argmax(pred)]
            confidence = float(np.max(pred))
            if confidence < CONFIDENCE_THRESHOLD:
                results.append({"label": "Unknown / 非貓狗兔", "confidence": confidence})
            else:
                results.append({"label": label, "confidence": confidence})
        except Exception as e:
            results.append({"label": "Error", "confidence": str(e)})
    return results

# ✅ Gradio UI 修正相機上傳：不使用 source="camera"，改為 default Image + 手機支援
camera_interface = gr.Interface(
    fn=classify_and_gradcam,
    inputs=gr.Image(type="numpy", image_mode="RGB", label="用相機拍照（可直接拍照或選圖）"),
    outputs=[gr.Label(num_top_classes=3), gr.Image(label="Grad-CAM")],
    title="📷 拍照即時辨識",
    description="使用手機或筆電拍照 / 上傳，即時判斷是貓、狗還是兔子"
)

gr.TabbedInterface(
    [
        gr.Interface(
            fn=classify_and_gradcam,
            inputs=gr.Image(type="numpy", image_mode="RGB", label="Upload Image"),
            outputs=[gr.Label(num_top_classes=3), gr.Image(label="Grad-CAM")],
            title="🐱🐶🐰 單張圖片分類",
            description="上傳一張圖片，我會告訴你是貓、狗還是兔子，並顯示模型看哪裡判斷的！"
        ),
        gr.Interface(
            fn=batch_predict,
            inputs=gr.File(file_types=['.jpg', '.png'], label="上傳多張圖片", file_count="multiple"),
            outputs=gr.Dataframe(headers=["label", "confidence"], label="預測結果"),
            title="📦 批次圖片分類",
            description="一次上傳多張圖片，快速進行分類預測"
        ),
        camera_interface
    ],
    tab_names=["單張圖片預測", "多圖批次預測", "拍照即時辨識"]
).launch(share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://adbd9aa7d5f427e864.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


✅ model object: <Functional name=functional_4, built=True>
⚠️ Received type: <class 'numpy.ndarray'>
⚠️ Shape: (4568, 3045, 3)
Running prediction...
Input shape: (1, 224, 224, 3)
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 103ms/step
Prediction result: [0.32399756 0.5935661  0.08243628]


Expected: [['input_layer_1']]
Received: inputs=Tensor(shape=(1, 224, 224, 3))


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://adbd9aa7d5f427e864.gradio.live


