In [None]:
import json
from pathlib import Path
import sys
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from IPython.display import display

project_root = Path(".").resolve().parent
sys.path.insert(0, str(project_root))

from utils.clip import load_clip, classify_image
from utils.owl import load_owl, detect_with_owl
from utils.sam import load_sam, segment_with_sam, post_process_sam_masks
from utils.mask import create_black_white_mask
from utils.box_mask import enlarge_box, draw_annotated
from utils.opencv_canny import analyze_segmentation_edges

In [None]:
owl_thresh  = 0.18
box_scale   = 1.1
font = ImageFont.truetype("/Library/fonts/Arial.ttf", size=50)

In [None]:
processor_clip, model_clip = load_clip()
processor_owl,  model_owl  = load_owl()
processor_sam,  model_sam  = load_sam()
print("Models loaded")

In [None]:
image_path = Path("../pipeline_optimization_dataset/1SIka2FSC_tE6_94GW3GsRvb-Gi5CA8wa__Küche_Wohnung1.jpg")
image = Image.open(image_path).convert("RGB")

In [None]:
text_labels_outside = [[
    "a house number", "a license plate", "person", "a face",
    "a religious symbol", "a political symbol", "a cat", "a dog",
]]
text_labels_inside = [[
    "a calendar", "a license plate", "a paper", "person",
    "a framed picture", "a picture", "a poster board",
    "a name", "a face", "a religious symbol", "a political symbol",
    "a sex toy", "a nude image", "a cat", "a dog",
    "a mirror", "a window", "a television"
]]
per_label_thresh = {
    "a calendar": 0.20, "a paper": 0.20, "a house number": 0.21,
    "a license plate": 0.19, "person": 0.20, "a framed picture": 0.22,
    "a picture": 0.22, "a poster board": 0.30, "a name": 0.20,
    "a face": 0.20, "a religious symbol": 0.24, "a political symbol": 0.20,
    "a sex toy": 0.23, "a nude image": 0.30, "a cat": 0.28, "a dog": 0.28,
    "a mirror": 0.30, "a window": 0.30, "a television": 0.50
}
default_thresh = owl_thresh

In [None]:
w, h  = image.size
inout = classify_image(image, processor_clip, model_clip)
labs  = text_labels_inside if inout == "an indoor scene" else text_labels_outside
print("Scene classification:", inout)

In [None]:
boxes_p, scores, labels = detect_with_owl(image, labs, processor_owl, model_owl, threshold=owl_thresh)
raw_boxes = [b.tolist() for b in boxes_p]

annotated_pre = draw_annotated(image.copy(), raw_boxes, [float(s.item()) for s in scores], labels, font=font)
display(annotated_pre)

In [None]:
kept = [
    (b, s, l) for b, s, l in zip(raw_boxes, scores, labels)
    if s.item() >= per_label_thresh.get(l, default_thresh)
]
if kept:
    boxes_f, scores_f, labels_f = map(list, zip(*kept))
else:
    boxes_f, scores_f, labels_f = [], [], []
    print("No boxes after score filter.")

# Overlap logic
if boxes_f:
    person_or_nude = {"person", "a nude image"}
    always_drop    = {"a television", "a window", "a mirror"}
    remove_idx = set()

    for i, lab in enumerate(labels_f):
        if lab in always_drop:
            remove_idx.add(i)

    for i, (box_i, lab_i) in enumerate(zip(boxes_f, labels_f)):
        if lab_i not in ("a television", "a window"):
            continue
        x0, y0, x1, y1 = box_i
        Ai = max(0, x1 - x0) * max(0, y1 - y0)
        if Ai == 0:
            continue
        for j, (box_j, lab_j) in enumerate(zip(boxes_f, labels_f)):
            if j == i: 
                continue
            x0j, y0j, x1j, y1j = box_j
            iw = max(0, min(x1, x1j) - max(x0, x0j))
            ih = max(0, min(y1, y1j) - max(y0, y0j))
            if iw * ih == 0:
                continue
            overlap_ratio = (iw * ih) / Ai
            if lab_j in person_or_nude:
                continue
            if overlap_ratio >= 0.20:
                remove_idx.add(j)

    filtered = [
        (b, s, l)
        for k, (b, s, l) in enumerate(zip(boxes_f, scores_f, labels_f))
        if k not in remove_idx
    ]
    if filtered:
        boxes_f, scores_f, labels_f = map(list, zip(*filtered))
    else:
        boxes_f, scores_f, labels_f = [], [], []

post_enl = [enlarge_box(box=b, scale=box_scale, img_w=w, img_h=h) for b in boxes_f]

annotated_post = draw_annotated(image.copy(), post_enl, [s.item() for s in scores_f], labels_f, font=font)
display(annotated_post)

In [None]:
if not post_enl:
    print("No boxes to run SAM on. Stop here.")
else:
    tb = torch.tensor(post_enl, dtype=torch.float32).unsqueeze(0)
    outs, ins = segment_with_sam(image, tb, processor_sam, model_sam)
    masks_from_sam = post_process_sam_masks(outs, processor_sam, ins)[0]

In [None]:
if post_enl:
    mask10 = create_black_white_mask(masks_from_sam, threshold=0.5, combine=True,dilation_px=10)
    display(mask10)


In [None]:
gray       = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
edges      = cv2.Canny(gray, 100, 200)
edge_color = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)

mask_np = (np.array(mask10.convert("L")) > 127).astype(np.uint8)
contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

margin = 60
kernel = cv2.getStructuringElement(
    cv2.MORPH_ELLIPSE,
    (2*margin + 1, 2*margin + 1)
)

for cnt in contours:
    single_mask = np.zeros_like(mask_np, dtype=np.uint8)
    cv2.drawContours(single_mask, [cnt], -1, color=1, thickness=-1)

    dilated_mask = cv2.dilate(single_mask, kernel, iterations=1)

    dilated_cnts, _ = cv2.findContours(
        dilated_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )

    cv2.drawContours(edge_color, [cnt], -1, (255, 0, 0), 2)

    cv2.drawContours(edge_color, dilated_cnts, -1, (0, 255, 0), 2)

display(Image.fromarray(edge_color))

In [None]:
if post_enl:
    is_outdoor = (inout != "an indoor scene")
    always_inpaint_indoor = {"person", "a religious symbol", "a political symbol"}

    sam_bool_masks = [
        (np.array(m) > 0.5).astype(np.uint8) for m in masks_from_sam
    ]

    bw10_cv = np.array(mask10.convert("L"))
    mask10_binary = (bw10_cv > 127).astype(np.uint8)
    contours, _ = cv2.findContours(
        mask10_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    print(f"Objects in 10px mask: {len(contours)}")

    classified_contours = []
    image_for_canny = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)

    for i, cnt in enumerate(contours):
        single_mask = np.zeros_like(mask10_binary)
        cv2.drawContours(single_mask, [cnt], -1, color=1, thickness=-1)

        overlaps = [
            np.logical_and(single_mask, sam_bool_masks[j]).sum()
            for j in range(len(sam_bool_masks))
        ]
        obj_idx = int(np.argmax(overlaps))
        lbl     = labels_f[obj_idx]

        if is_outdoor:
            action     = "blur" if lbl == "a license plate" else "inpaint"
            edge_ratio = float("nan")
        else:
            if lbl in always_inpaint_indoor:
                action, edge_ratio = "inpaint", float("nan")
            else:
                result = analyze_segmentation_edges(
                    image_for_canny,
                    single_mask.astype(bool),
                    margin=60,
                    edge_threshold=0.015
                )[0]
                action     = result["action"]
                edge_ratio = result["edge_ratio"]

        er_str = f"{edge_ratio:.3f}" if edge_ratio == edge_ratio else "—"
        classified_contours.append({
            "contour": cnt,
            "action": action,
            "edge_ratio": edge_ratio
        })

In [None]:
if post_enl:
    annotated_image = image.copy()
    draw = ImageDraw.Draw(annotated_image)

    for item in classified_contours:
        cnt = item['contour']
        action = item['action']
        edge_ratio = item['edge_ratio']

        color = "red" if action == "blur" else "blue"
        pts = [tuple(pt[0]) for pt in cnt]
        draw.line(pts + [pts[0]], fill=color, width=3)

        x, y, w0, h0 = cv2.boundingRect(cnt)
        label_txt = f"{action}, {edge_ratio:.3f}" if edge_ratio == edge_ratio else action
        text_w = font.getlength(label_txt)
        text_h = font.size

        draw.rectangle([x, y - text_h - 6, x + text_w + 10, y], fill="black")
        draw.text((x + 5, y - text_h - 3), label_txt, font=font, fill="white")

    display(annotated_image)
