In [1]:
import cv2
import numpy as np
from PIL import Image
import torch
import clip
import os
from tqdm import tqdm
import logging
import csv

In [3]:
def extract_obj_from_mask(img_path, mask_path, output_size=(244, 244)):
    image = cv2.imread(img_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    if image is None or mask is None:
        raise ValueError(f"Mask: {mask_path} or image: {img_path} not found")

    _, thresh = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if not contours:
        raise ValueError(f"No contour found in {mask_path}")

    x, y, w, h = cv2.boundingRect(np.concatenate(contours))
    cropped = image[y:y+h, x:x+w]
    cropped_resized = cv2.resize(cropped, output_size)
    return Image.fromarray(cv2.cvtColor(cropped_resized, cv2.COLOR_BGR2RGB))

In [5]:
def extract_center_crop(frame, size_ratio=0.5):
    h, w = frame.shape[:2]
    ch, cw = int(h * size_ratio), int(w * size_ratio)
    if ch == 0 or cw == 0:
        raise ValueError("Frame too small for cropping")
    x1, y1 = (w - cw) // 2, (h - ch) // 2
    return frame[y1:y1 + ch, x1:x1 + cw]

In [7]:
def classify_with_clip(pil_img, model, preprocess, text_inputs, labels, device):
    img_input = preprocess(pil_img).unsqueeze(0).to(device)

    with torch.no_grad():
        img_features = model.encode_image(img_input)
        text_features = model.encode_text(text_inputs)

        img_features /= img_features.norm(dim=1, keepdim=True)
        text_features /= text_features.norm(dim=1, keepdim=True)

        similarity = (100.0 * img_features @ text_features.T).softmax(dim=-1)
        top_prob, top_label_idx = similarity[0].max(0)

    return labels[top_label_idx], top_prob.item()

In [9]:
def setup_clip_model(labels):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, preprocess = clip.load("ViT-B/32", device=device)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {label}") for label in labels]).to(device)
    return model, preprocess, text_inputs, device

In [11]:
def run_batch_detection(image_root, mask_root, model, preprocess, text_inputs, labels, device):
    all_detections = []

    for class_folder in os.listdir(image_root):
        image_folder = os.path.join(image_root, class_folder)
        mask_folder = os.path.join(mask_root, class_folder)

        if not os.path.isdir(image_folder):
            continue

        logging.info(f"üîç Scanning folder: {class_folder}")
        for img_file in tqdm(os.listdir(image_folder)):
            if not img_file.lower().endswith((".jpg", ".png", ".jpeg")):
                continue

            image_path = os.path.join(image_folder, img_file)
            mask_path = os.path.join(mask_folder, os.path.splitext(img_file)[0] + "_mask.png")

            try:
                if os.path.exists(mask_path):
                    cropped = extract_obj_from_mask(image_path, mask_path)
                else:
                    frame = cv2.imread(image_path)
                    if frame is None:
                        logging.warning(f"‚ö†Ô∏è Skipping {img_file} ‚Äî image could not be loaded.")
                        continue
                    cropped = extract_center_crop(frame)
                    cropped = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))

                label, prob = classify_with_clip(cropped, model, preprocess, text_inputs, labels, device)
                logging.info(f"üì£ ALERT: {label.upper()} detected with {prob:.1%} confidence ‚Äî {img_file}")
                all_detections.append((img_file, label, prob))

            except Exception as e:
                logging.error(f"‚ùå Error processing {img_file}: {e}")

    return all_detections

In [13]:
def run_video_detection(model, preprocess, text_inputs, labels, device, video_source=0):
    cap = cv2.VideoCapture(video_source)
    logging.info("üé• Starting video stream...")

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        try:
            crop = extract_center_crop(frame)
            crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
            label, prob = classify_with_clip(crop_pil, model, preprocess, text_inputs, labels, device)
            alert_text = f"{label.upper()} ({prob*100:.1f}%)"
            cv2.putText(frame, alert_text, (30, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 2)
        except Exception as e:
            logging.warning(f"Frame processing error: {e}")

        cv2.imshow("Live Detection", frame)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            logging.info("üëã Exiting...")
            break

    cap.release()
    cv2.destroyAllWindows()

In [15]:
def save_detections(detections, output_path="detections.csv"):
    with open(output_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["Image", "Label", "Confidence"])
        writer.writerows(detections)

In [39]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

if __name__ == "__main__":
    labels = ["human", "butterfly", "dog", "cat", "horse", "elephant", "squirrel"]
    model, preprocess, text_inputs, device = setup_clip_model(labels)

    # Optional test
    # sample_img = "animals10/raw-img/gatto/1014.jpeg"
    # sample_mask = "animals10/renamed_masks/horse/horse_0001_mask.png"
    # try:
    #     cropped = extract_obj_from_mask(sample_img, sample_mask)
    #     cropped.show()
    #     label, prob = classify_with_clip(cropped, model, preprocess, text_inputs, labels, device)
    #     print(f"üîî Detected: {label} (confidence: {prob:.2%})")
    # except Exception as e:
    #     logging.error(f"Sample test failed: {e}")

    # Batch mode
    # detections = run_batch_detection("Test_img/", "test_mask/All", model, preprocess, text_inputs, labels, device)
    # save_detections(detections)

    # Uncomment to run video detection
    run_video_detection(model, preprocess, text_inputs, labels, device, "Untitled design.mp4")


INFO: üé• Starting video stream...
INFO: üëã Exiting...
