In [None]:
from pathlib import Path
import sys
sys.path.insert(0, str(Path.cwd().parent))
from pathlib import Path
from typing import List
import torch
import numpy as np
from PIL import Image, ImageFont, ImageDraw
from IPython.display import display
import cv2  
from utils.clip import load_clip, classify_image
from utils.owl import load_owl, detect_with_owl
from utils.blip import load_blip, classify_boxes
from utils.box_mask import enlarge_box, draw_annotated

In [None]:
owl_thresh  = 0.18
box_scale   = 1.1
font = ImageFont.truetype("/Library/fonts/Arial.ttf", size=50)

In [None]:
processor_clip, model_clip = load_clip()
processor_owl,  model_owl  = load_owl()
processor_blip, model_blip = load_blip()
print("Models loaded")

In [None]:
image_path = Path("../pipeline_optimization_dataset/1CzPt_pZhdvbAhtM32ImlzZxtqsaFQ76j__Gewerberaum_III_UG.jpg")
image = Image.open(image_path).convert("RGB")

In [None]:
text_labels_outside = [[
    "a house number", "a license plate", "person", "a face",
    "a religious symbol", "a political symbol", "a cat", "a dog",
]]
text_labels_inside = [[
    "a calendar", "a license plate", "a paper", "person",
    "a framed picture", "a picture", "a poster board",
    "a name", "a face", "a religious symbol", "a political symbol",
    "a sex toy", "a nude image", "a cat", "a dog",
    "a mirror", "a window", "a television"
]]
per_label_thresh = {
    "a calendar": 0.20, "a paper": 0.20, "a house number": 0.21,
    "a license plate": 0.19, "person": 0.20, "a framed picture": 0.22,
    "a picture": 0.22, "a poster board": 0.30, "a name": 0.20,
    "a face": 0.20, "a religious symbol": 0.24, "a political symbol": 0.20,
    "a sex toy": 0.23, "a nude image": 0.30, "a cat": 0.28, "a dog": 0.28,
    "a mirror": 0.30, "a window": 0.30, "a television": 0.50
}
default_thresh = owl_thresh

In [None]:
w, h  = image.size
inout = classify_image(image, processor_clip, model_clip)
labs  = text_labels_inside if inout == "an indoor scene" else text_labels_outside
print("Scene classification:", inout)

In [None]:
boxes_p, scores, labels = detect_with_owl(image, labs, processor_owl, model_owl, threshold=owl_thresh)
raw_boxes = [b.tolist() for b in boxes_p]

annotated_pre = draw_annotated(image.copy(), raw_boxes, [float(s.item()) for s in scores], labels, font=font)
display(annotated_pre)

In [None]:
kept = [
    (b, s, l) for b, s, l in zip(raw_boxes, scores, labels)
    if s.item() >= per_label_thresh.get(l, default_thresh)
]
if kept:
    boxes_f, scores_f, labels_f = map(list, zip(*kept))
else:
    boxes_f, scores_f, labels_f = [], [], []
    print("No boxes after score filter.")

if boxes_f:
    person_or_nude = {"person", "a nude image"}
    always_drop    = {"a television", "a window", "a mirror"}
    remove_idx = set()

    for i, lab in enumerate(labels_f):
        if lab in always_drop:
            remove_idx.add(i)

    for i, (box_i, lab_i) in enumerate(zip(boxes_f, labels_f)):
        if lab_i not in ("a television", "a window"):
            continue
        x0, y0, x1, y1 = box_i
        Ai = max(0, x1 - x0) * max(0, y1 - y0)
        if Ai == 0:
            continue
        for j, (box_j, lab_j) in enumerate(zip(boxes_f, labels_f)):
            if j == i: 
                continue
            x0j, y0j, x1j, y1j = box_j
            iw = max(0, min(x1, x1j) - max(x0, x0j))
            ih = max(0, min(y1, y1j) - max(y0, y0j))
            if iw * ih == 0:
                continue
            overlap_ratio = (iw * ih) / Ai
            if lab_j in person_or_nude:
                continue
            if overlap_ratio >= 0.20:
                remove_idx.add(j)

    filtered = [
        (b, s, l)
        for k, (b, s, l) in enumerate(zip(boxes_f, scores_f, labels_f))
        if k not in remove_idx
    ]
    if filtered:
        boxes_f, scores_f, labels_f = map(list, zip(*filtered))
    else:
        boxes_f, scores_f, labels_f = [], [], []

post_enl = [enlarge_box(box=b, scale=box_scale, img_w=w, img_h=h) for b in boxes_f]

annotated_post = draw_annotated(image.copy(), post_enl, [s.item() for s in scores_f], labels_f, font=font)
display(annotated_post)

In [None]:
if not post_enl:
    print("No boxes to classify with BLIP.")
else:
    actions: List[str] = []

    def expand_box(box, scale, img_w, img_h, min_crop_size=None):
        x0, y0, x1, y1 = box
        w_box = x1 - x0
        h_box = y1 - y0
        cx = x0 + w_box / 2.0
        cy = y0 + h_box / 2.0

        new_w = w_box * scale
        new_h = h_box * scale

        if min_crop_size is not None:
            if new_w < min_crop_size:
                new_w = min_crop_size
            if new_h < min_crop_size:
                new_h = min_crop_size

        x0n = max(0, int(round(cx - new_w / 2.0)))
        y0n = max(0, int(round(cy - new_h / 2.0)))
        x1n = min(img_w, int(round(cx + new_w / 2.0)))
        y1n = min(img_h, int(round(cy + new_h / 2.0)))
        return [x0n, y0n, x1n, y1n]

    expand_factor = 1.4
    min_crop_size = 250  
    if inout != "an indoor scene":
        for lbl in labels_f:
            if lbl in ("a license plate", "license plate"):
                actions.append("blur")
            else:
                actions.append("inpaint")

        print("Actions (outdoor heuristic):")
        for b, l, a in zip(post_enl, labels_f, actions):
            print(f"{l:>18} -> {a}")

    else:
        blip_boxes = []
        blip_indices = []
        expanded_crops = []

        img_w, img_h = image.size

        for i, lbl in enumerate(labels_f):
            if lbl == "person":
                actions.append("inpaint")
            else:
                blip_indices.append(i)
                original_box = post_enl[i]
                blip_boxes.append(original_box)
                actions.append(None)  
                exp_box = expand_box(
                    original_box, expand_factor, img_w, img_h, min_crop_size=min_crop_size
                )
                expanded_crops.append(exp_box)

        if blip_boxes:
            blip_results = classify_boxes(
                image,
                blip_boxes,
                processor_blip,
                model_blip,
                expand_factor=expand_factor,
                min_crop_size=min_crop_size
            )
            for idx_rel, act in enumerate(blip_results):
                idx_abs = blip_indices[idx_rel]
                actions[idx_abs] = act

        for i, a in enumerate(actions):
            if a is None:
                actions[i] = "blur"

        if blip_boxes:
            print(f"Displaying {len(expanded_crops)} expanded (x{expand_factor}) BLIP crops:")
            for (orig_idx, exp_box, act) in zip(blip_indices, expanded_crops,
                                                [actions[i] for i in blip_indices]):
                x0e, y0e, x1e, y1e = exp_box
                crop_img = image.crop((x0e, y0e, x1e, y1e))
                print(f"Box #{orig_idx} label={labels_f[orig_idx]} action={act} expanded_box={exp_box}")
                display(crop_img)

        print("Final actions (indoor):")
        for b, l, a in zip(post_enl, labels_f, actions):
            print(f"{l:>18} -> {a}")


In [None]:
if post_enl:
    class_annot = image.copy()
    draw_cls = ImageDraw.Draw(class_annot)

    for (x0, y0, x1, y1), lbl, act in zip(post_enl, labels_f, actions):
        color = "blue" if act in ("inpaint", "free") else "red"
        draw_cls.rectangle([x0, y0, x1, y1], outline=color, width=4)
        tag = f"{lbl}, {act}"
        tw, th = draw_cls.textbbox((0, 0), tag, font=font)[2:]
        draw_cls.rectangle([x0, y0 - th - 4, x0 + tw + 4, y0], fill="white")
        draw_cls.text((x0 + 2, y0 - th - 2), tag, fill=color, font=font)

    display(class_annot)