In [3]:
import os
import cv2
import csv
import numpy as np
import torch
import onnxruntime
from glob import glob
from tqdm import tqdm
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# ---------- Helper Functions ----------

def preprocess_for_u2net(image, target_size=(320, 320)):
    """
    Preprocess the image for U2-Net:
      - Resize
      - Convert BGR to RGB
      - Normalize to [0, 1]
      - Rearrange to CHW and add batch dimension
    """
    orig_h, orig_w = image.shape[:2]
    image_resized = cv2.resize(image, target_size)
    image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_input = np.transpose(image_norm, (2, 0, 1))  # CHW
    image_input = np.expand_dims(image_input, axis=0)    # add batch dim
    return image_input, (orig_w, orig_h)

def postprocess_u2net(prediction, orig_size, threshold=0.5):
    """
    Process U2-Net output:
      - Remove batch and channel dimensions
      - Resize to original image size
      - Binarize using threshold
    """
    # Assuming the model output shape is [1,1,H,W]
    pred_mask = prediction[0, 0, :, :]
    pred_mask = cv2.resize(pred_mask, orig_size)
    pred_mask_bin = (pred_mask > threshold).astype(np.uint8) * 255
    return pred_mask_bin

def select_largest_mask(masks):
    """
    From SAM’s list of masks, select the one with the largest area.
    Each mask in the list is expected to be a dict with a 'segmentation' key.
    """
    if not masks:
        return None
    areas = [mask['segmentation'].sum() for mask in masks]
    largest_mask = masks[np.argmax(areas)]['segmentation']
    # Convert boolean mask to 0-255 uint8 image.
    return (largest_mask.astype(np.uint8)) * 255

def compute_iou(mask_pred, mask_gt):
    """
    Compute the Intersection over Union (IoU) for two binary masks.
    Assumes both masks are binary with values 0 or 255.
    """
    mask_pred_bool = mask_pred.astype(bool)
    mask_gt_bool = mask_gt.astype(bool)
    intersection = np.logical_and(mask_pred_bool, mask_gt_bool).sum()
    union = np.logical_or(mask_pred_bool, mask_gt_bool).sum()
    return intersection / union if union != 0 else 0

def ensure_dir(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

# ---------- Device Setup ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---------- Load Models ----------

# 1. Load U2-Net (ONNX) with GPU support if available
u2net_path = os.path.join("models", "u2net.onnx")
providers = onnxruntime.get_available_providers()
if "CUDAExecutionProvider" in providers:
    u2net_session = onnxruntime.InferenceSession(u2net_path, providers=["CUDAExecutionProvider"])
    print("U2-Net using CUDAExecutionProvider")
else:
    u2net_session = onnxruntime.InferenceSession(u2net_path)
    print("U2-Net using CPU ExecutionProvider")

# 2. Load SAM model using its registry and move to the proper device
sam_checkpoint = os.path.join("models", "sam_vit_h_4b8939.pth")
sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)

# ---------- Setup Directories for Saving Results ----------

# Directories for input images and ground truth masks.
image_dir = "DUT-OMRON-image"
mask_dir = "DUT-OMRON-mask"

# Directories to save results (filenames remain the same as original).
u2net_result_dir = os.path.join("models", "u2net_result")
sam_result_dir = os.path.join("models", "sam_result")
ensure_dir(u2net_result_dir)
ensure_dir(sam_result_dir)

# CSV file to log comparison results.
csv_filename = "compare_scores.csv"

# ---------- Processing and Evaluation ----------

# Get list of images (assuming common image extensions)
image_paths = sorted(glob(os.path.join(image_dir, "*.*")))
iou_u2net_list = []
iou_sam_list = []
results = []

# Create a tqdm progress bar.
pbar = tqdm(image_paths, total=len(image_paths), desc="Processing images")

for i, image_path in enumerate(pbar):
    # Update the progress bar postfix to show the count
    pbar.set_postfix_str(f"{i+1}/{len(image_paths)} complete")

    # Use the original filename for saving
    filename = os.path.basename(image_path)
    basename = os.path.splitext(filename)[0]
    gt_mask_path = os.path.join(mask_dir, basename + ".png")  # Adjust extension if needed

    # Read image and ground truth mask
    image = cv2.imread(image_path)
    gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
    if image is None or gt_mask is None:
        print(f"Skipping {basename}: could not read image or mask.")
        continue

    # --- U2-Net Inference ---
    input_tensor, orig_size = preprocess_for_u2net(image)
    input_name = u2net_session.get_inputs()[0].name
    pred = u2net_session.run(None, {input_name: input_tensor})[0]
    u2net_mask = postprocess_u2net(pred, orig_size)

    # Save U2-Net result using the original filename
    u2net_save_path = os.path.join(u2net_result_dir, filename)
    cv2.imwrite(u2net_save_path, u2net_mask)

    # --- SAM Inference ---
    masks = mask_generator.generate(image)
    sam_mask = select_largest_mask(masks)
    if sam_mask is None:
        print(f"No mask generated by SAM for {basename}.")
        continue

    # Save SAM result using the original filename
    sam_save_path = os.path.join(sam_result_dir, filename)
    cv2.imwrite(sam_save_path, sam_mask)

    # --- Evaluation (IoU) ---
    gt_mask_bin = (gt_mask > 0).astype(np.uint8) * 255
    iou_u2net = compute_iou(u2net_mask, gt_mask_bin)
    iou_sam = compute_iou(sam_mask, gt_mask_bin)
    iou_u2net_list.append(iou_u2net)
    iou_sam_list.append(iou_sam)

    results.append([filename, iou_u2net, iou_sam])
    tqdm.write(f"Image: {basename} | U2-Net IoU: {iou_u2net:.4f} | SAM IoU: {iou_sam:.4f}")

# Write CSV file with header: Image, U2-Net IoU, SAM IoU.
with open(csv_filename, mode='w', newline='') as csv_file:
    writer = csv.writer(csv_file)
    writer.writerow(["Image", "U2-Net IoU", "SAM IoU"])
    writer.writerows(results)

# ---------- Summary of Results ----------
if iou_u2net_list and iou_sam_list:
    avg_iou_u2net = np.mean(iou_u2net_list)
    avg_iou_sam = np.mean(iou_sam_list)
    print("\n--- Overall Performance ---")
    print(f"Average U2-Net IoU: {avg_iou_u2net:.4f}")
    print(f"Average SAM IoU: {avg_iou_sam:.4f}")
else:
    print("No valid results to summarize.")


No numbers missing in the sequence.
