In [3]:
import os
import cv2
import numpy as np
import onnxruntime as ort
from tqdm import tqdm

# ---------------------------
# Paths
# ---------------------------
MODEL_DIR = r"C:\Users\awais\OneDrive\Desktop\Thesis\U-Net_Iris_Segmentation_Model"
MODEL_PATH = os.path.join(MODEL_DIR, "iris_semseg_upp_scse_mobilenetv2.onnx")
INPUT_DIR = r"C:\Users\awais\OneDrive\Desktop\Thesis\Iran\IranIris"
OUTPUT_DIR = r"C:\Users\awais\OneDrive\Desktop\Iran_Iris_Segmented_Masks"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ---------------------------
# Model setup
# ---------------------------
ort_sess = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])

# ---------------------------
# Normalization parameters
# ---------------------------
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

# ---------------------------
# Processing loop
# ---------------------------
for img_name in tqdm(os.listdir(INPUT_DIR), desc="Processing images"):
    img_path = os.path.join(INPUT_DIR, img_name)
    img = cv2.imread(img_path)

    if img is None:
        print(f"❌ Error reading {img_name}")
        continue

    # Convert to RGB and resize to model input
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_resized = cv2.resize(img_rgb, (640, 480))

    # Normalize: scale + z-score
    img_norm = (img_resized / 255.0 - MEAN) / STD

    # Add batch dimension and transpose to NCHW
    img_input = np.transpose(img_norm, (2, 0, 1))[np.newaxis, :].astype(np.float32)

    # Run inference
    outputs = ort_sess.run(None, {"input": img_input})
    preds = outputs[0][0]  # shape: (4, 480, 640)

    # Threshold each class
    iris_mask = (preds[1] > 0.5).astype(np.uint8)
    pupil_mask = (preds[2] > 0.5).astype(np.uint8)

    # Iris = white (255), pupil = black (0)
    final_mask = ((iris_mask - pupil_mask) > 0).astype(np.uint8) * 255

    # Resize mask to original image
    final_mask_resized = cv2.resize(final_mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Save
    save_path = os.path.join(OUTPUT_DIR, os.path.splitext(img_name)[0] + ".png")
    cv2.imwrite(save_path, final_mask_resized)

print("✅ Done — pure binary masks created (iris = white, pupil = black, background = black).")


Processing images: 100%|██████████| 792/792 [02:53<00:00,  4.57it/s]

✅ Done — pure binary masks created (iris = white, pupil = black, background = black).



