In [2]:
# ---------------------------------
# Import Required Libraries
# ---------------------------------
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model

# ---------------------------------
# Load the Saved UNet Model
# ---------------------------------
unet_model_path = "/content/drive/MyDrive/Segmentation Dataset/1/unet_model.h5"  # Update with the correct path
model = load_model(model_path)
 # Update with the correct path

print("✅ UNet model loaded successfully from:", unet_model_path)

# ---------------------------------
# Preprocessing Function for UNet
# ---------------------------------
def preprocess_for_unet(img, target_size=(256,256)):
    img_resized = cv2.resize(img, target_size)
    img_input = np.expand_dims(img_resized, axis=0) / 255.0  # Normalize and add batch dimension
    return img_input

# ---------------------------------
# UNet Inference Function
# ---------------------------------
def unet_predict(image):
    """
    Takes an RGB image (numpy array) and returns a binary mask predicted by the UNet.
    """
    input_img = preprocess_for_unet(image)
    pred = unet_model.predict(input_img)[0]  # shape (256,256,1)
    pred_binary = (pred > 0.5).astype(np.uint8)  # Threshold at 0.5
    return np.squeeze(pred_binary)  # shape (256,256)

# ---------------------------------
# Traditional Segmentation Methods
# ---------------------------------
def traditional_method_1(image):
    # Otsu's thresholding on grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    return (mask/255).astype(np.uint8)

def traditional_method_2(image):
    # Adaptive thresholding (Gaussian)
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    mask = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY, 11, 2)
    return (mask/255).astype(np.uint8)

def traditional_method_3(image):
    # Canny edge detection followed by dilation
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    kernel = np.ones((3,3), np.uint8)
    mask = cv2.dilate(edges, kernel, iterations=1)
    mask = cv2.threshold(mask, 0, 255, cv2.THRESH_BINARY)[1]
    return (mask/255).astype(np.uint8)

def traditional_method_4(image):
    # K-means clustering segmentation on color image
    Z = image.reshape((-1, 3))
    Z = np.float32(Z)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    K = 2
    _, labels, centers = cv2.kmeans(Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    labels = labels.flatten()
    mask = labels.reshape((image.shape[0], image.shape[1]))
    # Choose the cluster with lower mean intensity as mask (assuming background is brighter)
    if np.mean(centers[0]) < np.mean(centers[1]):
        mask = (mask == 0).astype(np.uint8)
    else:
        mask = (mask == 1).astype(np.uint8)
    return mask

def traditional_method_5(image):
    # Watershed segmentation (simple example)
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    ret, thresh = cv2.threshold(gray,0,255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
    kernel = np.ones((3,3), np.uint8)
    opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=2)
    sure_bg = cv2.dilate(opening, kernel, iterations=3)
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    ret, sure_fg = cv2.threshold(dist_transform,0.7*dist_transform.max(),255,0)
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg,sure_fg)
    ret, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown==255] = 0
    markers = cv2.watershed(image, markers)
    mask = (markers > 1).astype(np.uint8)
    return mask

def traditional_method_6(image):
    # For the 6th row, we want to show the UNet result instead of a traditional method.
    return unet_predict(image)

# List of traditional methods (first 5 are traditional)
traditional_methods = [
    traditional_method_1,
    traditional_method_2,
    traditional_method_3,
    traditional_method_4,
    traditional_method_5
]

# ---------------------------------
# Inference and Visualization for a Single Photo
# ---------------------------------
def run_inference_single_photo(image_path):
    # Load original image in RGB
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
    orig = cv2.imread(image_path)
    orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB)
    orig_resized = cv2.resize(orig, (256,256))

    # Get UNet segmentation result
    unet_mask = unet_predict(orig)

    # Get traditional segmentation results for the first 5 methods
    trad_masks = []
    for method in traditional_methods:
        mask = method(orig_resized)
        trad_masks.append(mask)

    # For the 6th row, we use UNet segmentation as the "traditional" result.
    trad_masks.append(unet_mask)

    # Create a grid: 6 rows x 3 columns.
    # Each row:
    #   Column 1: Original image
    #   Column 2: Traditional segmentation result (or UNet for 6th row)
    #   Column 3: UNet segmentation result (same in every row)
    fig, axes = plt.subplots(6, 3, figsize=(15, 30))

    for i in range(6):
        # Column 1: Original image
        axes[i, 0].imshow(orig_resized)
        axes[i, 0].set_title("Original")
        axes[i, 0].axis("off")

        # Column 2: Traditional segmentation result (for i=0..4, methods; for i=5, UNet)
        if i < len(trad_masks):
            axes[i, 1].imshow(trad_masks[i], cmap="gray")
            if i < 5:
                axes[i, 1].set_title(f"Traditional Method {i+1}")
            else:
                axes[i, 1].set_title("UNet (Traditional Column)")
        else:
            axes[i, 1].imshow(np.zeros_like(orig_resized))
            axes[i, 1].set_title("N/A")
        axes[i, 1].axis("off")

        # Column 3: UNet segmentation result (same for all rows)
        axes[i, 2].imshow(unet_mask, cmap="gray")
        axes[i, 2].set_title("UNet")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

# ---------------------------------
# Main: Run Inference on a Single Photo and Display Grid
# ---------------------------------
# Update with your test image path
test_image_path = "/content/drive/MyDrive/Segmentation Dataset/1/face_crop/your_test_image.jpg"
run_inference_single_photo(test_image_path)