In [None]:
import sys, tempfile
from pathlib import Path
import json
import numpy as np
import torch
import cv2
from PIL import Image, ImageFont, ImageDraw
from IPython.display import display

project_root = Path(".").resolve().parent  
sys.path.insert(0, str(project_root))

from utils.clip import load_clip, classify_image
from utils.owl  import load_owl, detect_with_owl
from utils.sam  import load_sam, segment_with_sam, post_process_sam_masks
from utils.mask import create_black_white_mask
from utils.box_mask import enlarge_box, draw_annotated
from utils.opencv_canny import analyze_segmentation_edges
from utils.bluring import adaptive_blur
from utils.resize   import resize_long_side
from lama.runner   import run_lama

In [None]:
lama_model_dir = project_root / "lama" / "big-lama" 

In [None]:
image_path = Path("../pipeline_optimization_dataset/1-_Lxb7M5NPRSYioQ0f8KlvA_Gqr7HW1i___Wohnzimmer1.jpg") 
image = Image.open(image_path).convert("RGB")
w, h = image.size

In [None]:
owl_thresh = 0.18
target_long = 2048
box_scale = 1.1
dilations = (0, 10, 30, 55)

text_labels_outside = [[
    "a house number", "a license plate", "person", "a face",
    "a religious symbol", "a political symbol", "a cat", "a dog",
]]
text_labels_inside = [[
    "a calendar", "a license plate", "a paper", "person",
    "a framed picture", "a picture", "a poster board",
    "a name", "a face", "a religious symbol", "a political symbol",
    "a sex toy", "a nude image", "a cat", "a dog",
    "a mirror", "a window", "a television"
]]
per_label_thresh = {
    "a calendar": 0.20, "a paper": 0.20, "a house number": 0.21,
    "a license plate": 0.19, "person": 0.20, "a framed picture": 0.22,
    "a picture": 0.22, "a poster board": 0.30, "a name": 0.20,
    "a face": 0.20, "a religious symbol": 0.24, "a political symbol": 0.20,
    "a sex toy": 0.23, "a nude image": 0.30, "a cat": 0.28, "a dog": 0.28,
    "a mirror": 0.30, "a window": 0.30, "a television": 0.50
}
default_thresh = owl_thresh
always_inpaint_indoor = {"person", "a religious symbol", "a political symbol"}
font = ImageFont.truetype("/Library/Fonts/Arial.ttf", size=50)

In [None]:
processor_clip, model_clip = load_clip()
processor_owl,  model_owl  = load_owl()
processor_sam,  model_sam  = load_sam()
print("Models loaded.\n")

In [None]:
inout = classify_image(image, processor_clip, model_clip)
labs  = text_labels_inside if inout == "an indoor scene" else text_labels_outside
print(f"Scene classification: {inout}")

In [None]:
boxes_p, scores, labels = detect_with_owl(image, labs, processor_owl, model_owl, threshold=owl_thresh)
raw_boxes = [b.tolist() for b in boxes_p]

In [None]:
annot_pre = draw_annotated(image.copy(), raw_boxes, [float(s.item()) for s in scores], labels, font=font)
display(annot_pre)

In [None]:
kept = [
    (b, s, l) for b, s, l in zip(raw_boxes, scores, labels)
    if s.item() >= per_label_thresh.get(l, default_thresh)
]
if kept:
    boxes_f, scores_f, labels_f = map(list, zip(*kept))
else:
    boxes_f, scores_f, labels_f = [], [], []
    print("no boxes survive score filter; skipping.\n")

if boxes_f:
    person_or_nude = {"person", "a nude image", "nude image"}
    always_drop    = {"a television", "a window", "a mirror"}
    remove_idx = set(i for i, lab in enumerate(labels_f) if lab in always_drop)

    for i, (box_i, lab_i) in enumerate(zip(boxes_f, labels_f)):
        if lab_i not in ("a television", "a window"):
            continue
        x0, y0, x1, y1 = box_i
        A_i = max(0, x1 - x0) * max(0, y1 - y0)
        if A_i == 0:
            continue
        for j, (box_j, lab_j) in enumerate(zip(boxes_f, labels_f)):
            if j == i:
                continue
            x0j, y0j, x1j, y1j = box_j
            iw = max(0, min(x1, x1j) - max(x0, x0j))
            ih = max(0, min(y1, y1j) - max(y0, y0j))
            if iw * ih == 0:
                continue
            overlap_ratio = (iw * ih) / A_i
            if lab_j in person_or_nude:
                continue
            if overlap_ratio >= 0.20:
                remove_idx.add(j)

    filtered = [
        (b, s, l)
        for k, (b, s, l) in enumerate(zip(boxes_f, scores_f, labels_f))
        if k not in remove_idx
    ]
    if filtered:
        boxes_f, scores_f, labels_f = map(list, zip(*filtered))
    else:
        boxes_f, scores_f, labels_f = [], [], []

post_enl = [enlarge_box(box=b, scale=box_scale, img_w=w, img_h=h) for b in boxes_f]

In [None]:
annot_post = draw_annotated(image.copy(), post_enl, [float(s.item()) for s in scores_f], labels_f, font=font)
display(annot_post)

In [None]:
tb = torch.tensor(post_enl, dtype=torch.float32).unsqueeze(0)
outs, ins = segment_with_sam(image, tb, processor_sam, model_sam)
masks_from_sam = post_process_sam_masks(outs, processor_sam, ins)[0]
sam_bool_masks = [(np.array(m) > 0.5).astype(np.uint8) for m in masks_from_sam]

In [None]:
bw_per_d = {}
for dpx in dilations:
    bw = create_black_white_mask(masks_from_sam, threshold=0.5, combine=True, dilation_px=dpx)
    bw_per_d[dpx] = bw
    print(f"Dilation {dpx}px")
    display(bw)

In [None]:
bw10_cv = np.array(bw_per_d[10].convert("L"))   
mask10_binary = (bw10_cv > 127).astype(np.uint8)
contours, _ = cv2.findContours(mask10_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

In [None]:
classified_contours = []
image_for_canny = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
is_outdoor = (inout != "an indoor scene")

for i, cnt in enumerate(contours):
    single_mask = np.zeros_like(mask10_binary)
    cv2.drawContours(single_mask, [cnt], -1, color=1, thickness=-1)

    overlaps = [
        np.logical_and(single_mask, sam_bool_masks[j]).sum()
        for j in range(len(sam_bool_masks))
    ]
    obj_idx = int(np.argmax(overlaps)) if overlaps else 0
    lbl = labels_f[obj_idx] if labels_f else "unknown"

    if is_outdoor:
        action = "blur" if lbl == "a license plate" else "inpaint"
        edge_ratio = float("nan")
    else:
        if lbl in always_inpaint_indoor:
            action, edge_ratio = "inpaint", float("nan")
        else:
            result = analyze_segmentation_edges(image_for_canny,single_mask.astype(bool),margin=60,edge_threshold=0.015)[0]
            action = result["action"]
            edge_ratio = result["edge_ratio"]

    classified_contours.append({"contour": cnt, "action": action, "edge_ratio": edge_ratio})
    er_str = f"{edge_ratio:.3f}" if edge_ratio == edge_ratio else "—"
    print(f"Object {i+1}/{len(contours)}: {action} (edge_ratio={er_str})")

In [None]:
if post_enl:
    annotated_image = image.copy()
    draw = ImageDraw.Draw(annotated_image)

    for item in classified_contours:
        cnt        = item['contour']
        action     = item['action']
        edge_ratio = item['edge_ratio']

        color = "red" if action == "inpaint" else "blue"

        pts = [tuple(pt[0]) for pt in cnt]
        if len(pts) >= 2:
            draw.line(pts + [pts[0]], fill=color, width=3)

        x, y, w0, h0 = cv2.boundingRect(cnt)
        label_txt = f"{action}, {edge_ratio:.3f}" if edge_ratio == edge_ratio else action
        text_w = font.getlength(label_txt)
        text_h = font.size
        pad_x, pad_y = 5, 3
        draw.rectangle([x, y - text_h - 2*pad_y, x + text_w + 2*pad_x, y], fill="black")
        draw.text((x + pad_x, y - text_h - pad_y), label_txt, font=font, fill="white")

    print("Action overlay (red=inpaint, blue=blur):")
    display(annotated_image)
else:
    print("No post_enl boxes; skipping overlay.")


In [None]:
final_blur_masks  = {d: np.zeros_like(mask10_binary, dtype=np.uint8) for d in dilations}
final_inpnt_masks = {d: np.zeros_like(mask10_binary, dtype=np.uint8) for d in dilations}
erosion_kernel = np.ones((2 * 10 + 1, 2 * 10 + 1), np.uint8)

In [None]:
needed_dilations = {0, (30 if is_outdoor else 55)}
final_blur_masks  = {d: np.zeros_like(mask10_binary, dtype=np.uint8) for d in needed_dilations}
final_inpnt_masks = {d: np.zeros_like(mask10_binary, dtype=np.uint8) for d in needed_dilations}

for item in classified_contours:
    cnt    = item['contour']
    action = item['action']

    component_10px = np.zeros_like(mask10_binary)
    cv2.drawContours(component_10px, [cnt], -1, color=1, thickness=-1)
    component_0px  = cv2.erode(component_10px, erosion_kernel, iterations=1)

    for dpx in needed_dilations:
        if dpx == 0:
            final_shape = component_0px
        else:
            kernel = np.ones((2 * dpx + 1, 2 * dpx + 1), np.uint8)
            final_shape = cv2.dilate(component_0px, kernel, iterations=1)

        if action == "blur" and dpx == 0:
            final_blur_masks[dpx] = np.maximum(final_blur_masks[dpx], final_shape)
        if action != "blur" and dpx in needed_dilations:
            final_inpnt_masks[dpx] = np.maximum(final_inpnt_masks[dpx], final_shape)


In [None]:
paint_dpx = 55 if inout == "an indoor scene" else 30
blur_dpx  = 0

In [None]:
img_rs = resize_long_side(image, target_long)  
final_bgr = cv2.cvtColor(np.array(img_rs), cv2.COLOR_RGB2BGR)

In [None]:
with tempfile.TemporaryDirectory() as td:
    td = Path(td)
    tmp_img  = td / "img.png"
    tmp_mask = td / "mask.png"
    cv2.imwrite(str(tmp_img), final_bgr)

    pm = final_inpnt_masks.get(paint_dpx, None)
    if pm is not None and pm.any():
        pm_rs = resize_long_side((pm*255).astype(np.uint8), target_long, nearest=True)
        cv2.imwrite(str(tmp_mask), pm_rs)
        try:
            result_png = run_lama(image_path=tmp_img, mask_path=tmp_mask, model_dir=lama_model_dir, out_dir=td)
            result_bgr = cv2.imread(str(result_png), cv2.IMREAD_COLOR)
            if result_bgr is not None:
                final_bgr = result_bgr
            else:
                print("LaMa output missing; using original resized image.")
        except Exception as e:
            print(f"LaMa failed: {e}; using original resized image.")
    else:
        print("No inpaint mask to apply.")

In [None]:
bm = final_blur_masks.get(blur_dpx, None)
if bm is not None and bm.any():
    bm_rs = resize_long_side((bm*255).astype(np.uint8), target_long, nearest=True)
    final_bgr = adaptive_blur(final_bgr, bm_rs, strength=0.6)
else:
    print("No blur mask to apply.")

In [None]:
print("Final image:")
display(Image.fromarray(cv2.cvtColor(final_bgr, cv2.COLOR_BGR2RGB)))