In [3]:
import os
import torch
import clip
import numpy as np
from PIL import Image
import cv2
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor
import torchvision.transforms as T

device = "cpu"

itemA = "A/2008_000336.jpg"
maskA = "A/2008_000336.png"
itemC = "C"
os.makedirs("itemCPrime", exist_ok=True)

clipModel, preprocess = clip.load("ViT-B/32", device=device)

sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b.pth")
samPredictor = SamPredictor(sam.to(device))

ref_image = Image.open(itemA).convert("RGB")
ref_mask = Image.open(maskA).convert("L")

ref_image_np = np.array(ref_image)
ref_mask_np = np.array(ref_mask)
ref_mask_np = (ref_mask_np > 128).astype(np.uint8)

ref_image_np[ref_mask_np == 0] = 0
masked_ref_image = Image.fromarray(ref_image_np)

ref_tensor = preprocess(masked_ref_image).unsqueeze(0).to(device)
with torch.no_grad():
    ref_feat = clipModel.encode_image(ref_tensor)
    ref_feat /= ref_feat.norm(dim=-1, keepdim=True)

with open("itemB.txt", "r") as f:
    text_description = f.read().strip()

text_token = clip.tokenize([text_description]).to(device)

with torch.no_grad():
    text_feat = clipModel.encode_text(text_token)
    text_feat /= text_feat.norm(dim=-1, keepdim=True)


for img_name in tqdm(os.listdir("C")):
    if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
        continue

    img_path = os.path.join("C", img_name)
    output_img_path = os.path.join("itemCPrime", img_name)
    output_mask_path = os.path.join("itemCPrime", img_name.rsplit(".", 1)[0] + "_mask.png")

    try:
        # ---- Load and preprocess image ----
        image_pil = Image.open(img_path).convert("RGB")
        image_tensor = preprocess(image_pil).unsqueeze(0).to(device)

        with torch.no_grad():
            image_feat = clipModel.encode_image(image_tensor)
            image_feat /= image_feat.norm(dim=-1, keepdim=True)

        # ---- Calculate similarity with both visual and text features ----
        sim_img = (image_feat @ ref_feat.T).item()
        sim_txt = (image_feat @ text_feat.T).item()
        avg_sim = (sim_img + sim_txt) / 2

        # ---- Load image in OpenCV format for SAM ----
        image_cv = cv2.imread(img_path)
        image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
        samPredictor.set_image(image_rgb)

        h, w, _ = image_rgb.shape

        # Use a single point in the center of the image
        point = np.array([[w // 2, h // 2]])
        label = np.array([1])

        # ---- Generate masks using SAM ----
        masks, _, _ = samPredictor.predict(
            point_coords=point,
            point_labels=label,
            multimask_output=True
        )

        # Choose one of the masks or blank if not confident
        if avg_sim >= 0.35:
            final_mask = (masks[0] * 255).astype(np.uint8)  # Keep cat
        else:
            final_mask = np.zeros((h, w), dtype=np.uint8)   # No cat — blank mask

        # ---- Save image and corresponding mask ----
        image_pil.save(output_img_path)
        cv2.imwrite(output_mask_path, final_mask)

    except Exception as e:
        print(f"Error processing {img_name}: {e}")

100%|██████████| 786/786 [1:16:04<00:00,  5.81s/it]
