In [None]:
import json
from pathlib import Path
import sys
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display

project_root = Path(".").resolve().parent
sys.path.insert(0, str(project_root))

from utils.clip import load_clip, classify_image
from utils.owl import load_owl, detect_with_owl
from utils.box_mask import enlarge_box, draw_annotated


In [None]:
owl_thresh  = 0.18
box_scale   = 1.1
font = ImageFont.truetype("/Library/fonts/Arial.ttf", size=50)

In [None]:
processor_clip, model_clip = load_clip()
processor_owl,  model_owl  = load_owl()
print("Models loaded")

In [None]:
image_path = Path("../pipeline_optimization_dataset/1SIka2FSC_tE6_94GW3GsRvb-Gi5CA8wa__Küche_Wohnung1.jpg")
image = Image.open(image_path).convert("RGB")

In [None]:
text_labels_outside = [[
    "a house number", "a license plate", "person", "a face",
    "a religious symbol", "a political symbol", "a cat", "a dog",
]]
text_labels_inside = [[
    "a calendar", "a license plate", "a paper", "person",
    "a framed picture", "a picture", "a poster board",
    "a name", "a face", "a religious symbol", "a political symbol",
    "a sex toy", "a nude image", "a cat", "a dog",
    "a mirror", "a window", "a television"
]]
per_label_thresh = {
    "a calendar": 0.20, "a paper": 0.20, "a house number": 0.21,
    "a license plate": 0.19, "person": 0.20, "a framed picture": 0.22,
    "a picture": 0.22, "a poster board": 0.30, "a name": 0.20,
    "a face": 0.20, "a religious symbol": 0.24, "a political symbol": 0.20,
    "a sex toy": 0.23, "a nude image": 0.30, "a cat": 0.28, "a dog": 0.28,
    "a mirror": 0.30, "a window": 0.30, "a television": 0.50
}
default_thresh = owl_thresh

In [None]:
w, h  = image.size
inout = classify_image(image, processor_clip, model_clip)
labs  = text_labels_inside if inout == "an indoor scene" else text_labels_outside
print("Scene classification:", inout)

In [None]:
boxes_p, scores, labels = detect_with_owl(image, labs, processor_owl, model_owl, threshold=owl_thresh)
raw_boxes = [b.tolist() for b in boxes_p]

annotated_pre = draw_annotated(image.copy(), raw_boxes, [float(s.item()) for s in scores], labels, font=font)
display(annotated_pre)

In [None]:
kept = [
    (b, s, l) for b, s, l in zip(raw_boxes, scores, labels)
    if s.item() >= per_label_thresh.get(l, default_thresh)
]
if kept:
    boxes_f, scores_f, labels_f = map(list, zip(*kept))
else:
    boxes_f, scores_f, labels_f = [], [], []
    print("No boxes after score filter.")

if boxes_f:
    person_or_nude = {"person", "a nude image"}
    always_drop    = {"a television", "a window", "a mirror"}
    remove_idx = set()

    for i, lab in enumerate(labels_f):
        if lab in always_drop:
            remove_idx.add(i)

    for i, (box_i, lab_i) in enumerate(zip(boxes_f, labels_f)):
        if lab_i not in ("a television", "a window"):
            continue
        x0, y0, x1, y1 = box_i
        Ai = max(0, x1 - x0) * max(0, y1 - y0)
        if Ai == 0:
            continue
        for j, (box_j, lab_j) in enumerate(zip(boxes_f, labels_f)):
            if j == i: 
                continue
            x0j, y0j, x1j, y1j = box_j
            iw = max(0, min(x1, x1j) - max(x0, x0j))
            ih = max(0, min(y1, y1j) - max(y0, y0j))
            if iw * ih == 0:
                continue
            overlap_ratio = (iw * ih) / Ai
            if lab_j in person_or_nude:
                continue
            if overlap_ratio >= 0.20:
                remove_idx.add(j)

    filtered = [
        (b, s, l)
        for k, (b, s, l) in enumerate(zip(boxes_f, scores_f, labels_f))
        if k not in remove_idx
    ]
    
    if filtered:
        boxes_f, scores_f, labels_f = map(list, zip(*filtered))
    else:
        boxes_f, scores_f, labels_f = [], [], []

post_enl = [enlarge_box(box=b, scale=box_scale, img_w=w, img_h=h) for b in boxes_f]

annotated_post = draw_annotated(image.copy(), post_enl, [s.item() for s in scores_f], labels_f, font=font)
display(annotated_post)

In [None]:
clip_labels = {
    "a house number", "a political symbol", "a family picture", "a picture frame",
    "a canvas", "a picture", "a paper",
}

In [None]:
def clip_multi(crop, prompts):
    """Run CLIP zero-shot classification on the given crop."""
    inputs = processor_clip(text=prompts, images=crop, return_tensors="pt", padding=True)
    with torch.no_grad():
        logits = model_clip(**inputs).logits_per_image.squeeze(0)
        probs  = torch.softmax(logits, dim=0)
    probs_list = probs.tolist()
    return int(torch.argmax(probs).item()), probs_list

In [None]:
def get_clip_prompts(label: str):
    if label == "a house number":
        return ["a digits on a wall", "a house number", "a number"], ["a blank wall without digits", "a door or window", "decoration","a street sign with text"]
    elif label == "a political symbol":
        return ["a flag", "a political symbol"], ["a traffic sign", "an advertisement poster"]
    elif label == "a framed picture":
        return ["a framed picture", "a photograph"], ["a pianting", "an artwork", "a wooden board", "a blank wall", "a door", "a window", "a television"]
    elif label == "a picture":
        return ["a picture", "a photograph", "a framed picture"], ["a pianting", "an artwork","a wooden board", "a blank wall", "a door", "a window", "a television"]
    elif label == "a paper":
        return ["a paper"], ["a book", "a folder"]
    return [], []

In [None]:
kept_dicts = [{"box": box, "label": label, "score": score} for box, score, label in kept]

In [None]:
kept_clip = []          
for det in kept_dicts:
    lbl = det["label"]
    det["clip_pass"] = True            

    if lbl in clip_labels:
        if lbl in {"a house number", "a political symbol"} and inout == "an outdoor scene":
            x0, y0, x1, y1 = map(int, det["box"])
            crop = image.crop((x0, y0, x1, y1))
            pos, neg = get_clip_prompts(lbl)
            prompts = pos + neg
            if prompts:
                win_idx, confs = clip_multi(crop, prompts)
                det["clip_pass"] = win_idx < len(pos)
        elif lbl not in {"a house number", "a political symbol"}:
            x0, y0, x1, y1 = map(int, det["box"])
            crop = image.crop((x0, y0, x1, y1))
            pos, neg = get_clip_prompts(lbl)
            prompts = pos + neg
            if prompts:
                win_idx, confs = clip_multi(crop, prompts)
                det["clip_pass"] = win_idx < len(pos)

    if det["clip_pass"]:
        if crop is None:  
            x0, y0, x1, y1 = map(int, det["box"])
            crop = image.crop((x0, y0, x1, y1))
        display(crop) 
        kept_clip.append(det)  

In [None]:
if kept_clip:  
     annotated_clip = draw_annotated(image, [d["box"]   for d in kept_clip], [d["score"] for d in kept_clip], [d["label"] for d in kept_clip], font=font)
     display(annotated_clip)