# Synthetic Dataset Generation (Floorplans)

## Folder Structure (Output)
We generate modified floorplan images from:
- `synthetic_dataset/images_before/`

And save results to:
- `synthetic_dataset/images_after/substantial/<change_name>/`
- `synthetic_dataset/images_after/non-substantial/<change_name>/`

Each generated file name is appended with:
- `_<index>_synthetic`
Example: `plan_001_12_synthetic.png`

## Labels
We assume the detection model supports the following labels:
- `stove`, `sink`, `toilet`, `1stdoor`, `2sdoor`, `closet`

## Changes to Generate

### Substantial (מהותי)
- Remove: `stove`
- Remove: `sink`
- Remove: `toilet`
- Replace: `toilet -> stove`
- Replace: `sink -> stove`

### Non-Substantial (לא מהותי)
- Remove: `1stdoor`
- Remove: `2sdoor`
- Remove: `closet`
- Add: random `closet`

## Run Policy
- Run only on the first **50** images initially.
- If a change fails on an image (no detection / error) → skip and continue to the next change or next image.
- Outputs are written into the matching change folder.

## Requirements
- DETR fine-tuned model path: `..\Models\detr-finetuned-floorplans`
- (Optional) SAM checkpoint for better masks: `..\Models\Sam_Checkpoint\sam_vit_h_4b8939.pth`
- Stove patch library for replacement: `..\Data\Label_Pics_for_Synthetic_Data_Generation\stove` (images)
- Closet template library for adding closets: generated automatically from detected closets in the first N images.

---


In [None]:
import os, gc, random, re
from pathlib import Path

import numpy as np
import cv2
import torch
from PIL import Image

from transformers import DetrImageProcessor, DetrForObjectDetection

# Optional SAM (better masks). If not installed / checkpoint missing -> fallback to bbox mask
USE_SAM = True
try:
    from segment_anything import sam_model_registry, SamPredictor
except Exception:
    USE_SAM = False


# =========================================================
# CONFIG
# =========================================================
# Define paths relative to the 'Code' folder
BASE_DIR = Path(r"..\Data")

IMAGES_BEFORE_DIR = BASE_DIR / "images_before"
IMAGES_AFTER_DIR  = BASE_DIR / "images_after"

DETR_MODEL_PATH = Path(r"..\Models\detr-finetuned-floorplans")
SAM_CHECKPOINT  = Path(r"..\Models\Sam_Checkpoint\sam_vit_h_4b8939.pth")

STOVE_PATCH_DIR = Path(r"..\Data\Label_Pics_for_Synthetic_Data_Generation\stove")

RUN_FIRST_N_IMAGES = 50

DETR_THRESH = 0.25
BOX_PAD = 2

# mask feather for remove
FEATHER_SIGMA_REMOVE = 0.8

# replacement patch mask
INK_THRESH = 180
PATCH_DILATE = 0
FEATHER_SIGMA_PATCH = 0.6
MIN_PATCH_SIZE = 64

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

# =========================================================
# CHANGES
# =========================================================
CHANGES = [
    # substantial
    {"name": "remove_stove",         "group": "substantial",     "type": "remove",  "label": "stove"},
    {"name": "remove_sink",          "group": "substantial",     "type": "remove",  "label": "sink"},
    {"name": "remove_toilet",        "group": "substantial",     "type": "remove",  "label": "toilet"},
    {"name": "replace_toilet2stove", "group": "substantial",     "type": "replace", "src": "toilet", "dst": "stove"},
    {"name": "replace_sink2stove",   "group": "substantial",     "type": "replace", "src": "sink",   "dst": "stove"},

    # non-substantial
    {"name": "remove_1stdoor",       "group": "non-substantial", "type": "remove",  "label": "1stdoor"},
    {"name": "remove_2sdoor",        "group": "non-substantial", "type": "remove",  "label": "2sdoor"},
    {"name": "remove_closet",        "group": "non-substantial", "type": "remove",  "label": "closet"},
    {"name": "add_random_closet",    "group": "non-substantial", "type": "add_closet"},
]


# =========================================================
# HELPERS
# =========================================================
def list_images(folder: Path):
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
    if not folder.exists():
        print(f"Warning: Folder not found: {folder}")
        return []
    files = [p for p in folder.iterdir() if p.suffix.lower() in exts]
    files.sort()
    return files

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def safe_stem(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_\-\.]+", "_", s)

def expand_box_xyxy(box, W, H, pad):
    x0, y0, x1, y1 = box
    x0 = max(0, int(x0) - pad)
    y0 = max(0, int(y0) - pad)
    x1 = min(W - 1, int(x1) + pad)
    y1 = min(H - 1, int(y1) + pad)
    return [x0, y0, x1, y1]

def box_mask(H, W, box):
    x0, y0, x1, y1 = box
    m = np.zeros((H, W), dtype=np.uint8)
    m[y0:y1+1, x0:x1+1] = 255
    return m

def feather_mask(mask_uint8, sigma=1.0):
    if sigma <= 0:
        return mask_uint8
    k = int(max(3, round(sigma * 6)))
    if k % 2 == 0:
        k += 1
    return cv2.GaussianBlur(mask_uint8, (k, k), sigmaX=sigma, sigmaY=sigma)

def run_detr_get_best_box(image_pil, target_label: str, processor, model, thresh=0.25):
    """
    Return best box [x0,y0,x1,y1] for given label, or None.
    """
    inputs = processor(images=image_pil, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)

    W, H = image_pil.size
    target_sizes = torch.tensor([[H, W]], device=DEVICE)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=thresh)[0]

    # label map
    id2label = model.config.id2label

    best = None
    best_score = -1.0

    for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
        label_name = id2label[int(label_id)].lower().strip()
        if label_name == target_label.lower().strip():
            sc = float(score)
            if sc > best_score:
                best_score = sc
                b = box.detach().cpu().numpy().tolist()
                best = [int(round(b[0])), int(round(b[1])), int(round(b[2])), int(round(b[3]))]

    return best, best_score

def build_sam_predictor_if_possible(image_np_rgb):
    """
    Returns predictor or None
    """
    if not USE_SAM:
        return None
    if not SAM_CHECKPOINT.exists():
        return None
    try:
        sam = sam_model_registry["vit_h"](checkpoint=str(SAM_CHECKPOINT))
        sam.to(device=DEVICE)
        predictor = SamPredictor(sam)
        predictor.set_image(image_np_rgb, image_format="RGB")
        return predictor
    except Exception:
        return None

def sam_mask_from_box(predictor, box_xyxy):
    """
    SAM mask (uint8 0/255) or None
    """
    try:
        x0,y0,x1,y1 = box_xyxy
        box = np.array([x0,y0,x1,y1], dtype=np.float32)
        masks, _, _ = predictor.predict(box=box, multimask_output=False)
        m = (masks[0].astype(np.uint8) * 255)
        return m
    except Exception:
        return None

def apply_remove_white(image_pil, target_label, processor, model):
    """
    Detect target label and fill it with white (using SAM if possible, else bbox).
    Returns (out_pil, used_box) or raises RuntimeError if not found.
    """
    box, sc = run_detr_get_best_box(image_pil, target_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={target_label}")

    W, H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    img_np = np.array(image_pil.convert("RGB"))
    predictor = build_sam_predictor_if_possible(img_np)

    if predictor is not None:
        m = sam_mask_from_box(predictor, box)
        if m is None:
            m = box_mask(H, W, box)
    else:
        m = box_mask(H, W, box)

    m = feather_mask(m, sigma=FEATHER_SIGMA_REMOVE)

    out = img_np.copy()
    # blend to white
    alpha = (m.astype(np.float32) / 255.0)[..., None]
    white = np.full_like(out, 255)
    out = (out.astype(np.float32) * (1.0 - alpha) + white.astype(np.float32) * alpha).astype(np.uint8)

    # cleanup
    if predictor is not None:
        try:
            del predictor
        except:
            pass
        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    return Image.fromarray(out), box


# ---------------------------
# Stove replacement helpers
# ---------------------------
def make_ink_mask(rgb_uint8, ink_thresh=180, dilate=0):
    """
    Binary mask where "ink"/lines are 1.
    For floorplans: detect dark pixels.
    """
    gray = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2GRAY)
    m = (gray < ink_thresh).astype(np.uint8) * 255
    if dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate+1, 2*dilate+1))
        m = cv2.dilate(m, k, iterations=1)
    return m

def tight_crop_to_mask(rgb, mask, margin=2):
    ys, xs = np.where(mask > 0)
    if len(xs) == 0 or len(ys) == 0:
        return rgb, mask
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    x0 = max(0, x0 - margin); y0 = max(0, y0 - margin)
    x1 = min(rgb.shape[1]-1, x1 + margin); y1 = min(rgb.shape[0]-1, y1 + margin)
    return rgb[y0:y1+1, x0:x1+1].copy(), mask[y0:y1+1, x0:x1+1].copy()

def resize_keep_aspect(rgb, mask, target_w, target_h):
    h, w = rgb.shape[:2]
    if w <= 0 or h <= 0:
        return rgb, mask
    scale = min(target_w / w, target_h / h)
    nw = max(1, int(round(w * scale)))
    nh = max(1, int(round(h * scale)))
    rgb_r = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (nw, nh), interpolation=cv2.INTER_NEAREST)
    return rgb_r, mask_r

def pick_random_patch(folder: Path, min_size=64):
    if not folder.exists():
        raise RuntimeError(f"Patch folder not found: {folder}")
    files = list_images(folder)
    random.shuffle(files)
    for f in files:
        im = Image.open(f).convert("RGB")
        if min(im.size) >= min_size:
            return im
    # fallback
    return Image.open(files[0]).convert("RGB")

def paste_centered(base_rgb, patch_rgb, patch_mask, box_xyxy, feather_sigma=0.6):
    """
    Paste patch into box region, keeping white bg and placing only ink mask area.
    """
    x0,y0,x1,y1 = box_xyxy
    bw = max(1, x1-x0+1)
    bh = max(1, y1-y0+1)

    # resize patch to fit
    pr, mr = resize_keep_aspect(patch_rgb, patch_mask, bw, bh)

    # create soft mask
    m_soft = feather_mask(mr, sigma=feather_sigma)
    alpha = (m_soft.astype(np.float32) / 255.0)[..., None]

    out = base_rgb.copy()
    # position centered in the box
    ph, pw = pr.shape[:2]
    cx = x0 + bw//2
    cy = y0 + bh//2
    px0 = int(cx - pw//2)
    py0 = int(cy - ph//2)

    # clip paste ROI
    H,W = out.shape[:2]
    rx0 = max(0, px0); ry0 = max(0, py0)
    rx1 = min(W, px0+pw); ry1 = min(H, py0+ph)
    if rx1 <= rx0 or ry1 <= ry0:
        return out

    sx0 = rx0 - px0; sy0 = ry0 - py0
    sx1 = sx0 + (rx1-rx0); sy1 = sy0 + (ry1-ry0)

    roi = out[ry0:ry1, rx0:rx1].astype(np.float32)
    pr_roi = pr[sy0:sy1, sx0:sx1].astype(np.float32)
    a_roi = alpha[sy0:sy1, sx0:sx1].astype(np.float32)

    # blend only where mask is
    out[ry0:ry1, rx0:rx1] = (roi*(1-a_roi) + pr_roi*a_roi).astype(np.uint8)
    return out

def apply_replace_with_stove(image_pil, source_label, processor, model):
    """
    Detect source_label, remove it, paste a random stove patch into its box.
    Returns out_pil
    """
    # detect source
    box, sc = run_detr_get_best_box(image_pil, source_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={source_label}")

    W,H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    base_rgb = np.array(image_pil.convert("RGB"))
    # remove -> white in bbox (simple, stable)
    base_rgb_removed = base_rgb.copy()
    x0,y0,x1,y1 = box
    base_rgb_removed[y0:y1+1, x0:x1+1] = 255

    # load stove patch + mask
    stove_pil = pick_random_patch(STOVE_PATCH_DIR, min_size=MIN_PATCH_SIZE)
    stove_rgb = np.array(stove_pil)
    m = make_ink_mask(stove_rgb, ink_thresh=INK_THRESH, dilate=PATCH_DILATE)
    stove_rgb_c, m_c = tight_crop_to_mask(stove_rgb, m, margin=2)

    out_rgb = paste_centered(
        base_rgb_removed,
        stove_rgb_c,
        m_c,
        box,
        feather_sigma=FEATHER_SIGMA_PATCH
    )

    return Image.fromarray(out_rgb), box


# ---------------------------
# Add random closet (template pool from detected closets)
# ---------------------------
def extract_closet_template(image_pil, processor, model):
    """
    Detect closet and return cropped template (RGB) + mask (ink-based) cropped.
    """
    box, sc = run_detr_get_best_box(image_pil, "closet", processor, model, thresh=DETR_THRESH)
    if box is None:
        return None
    W,H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)
    rgb = np.array(image_pil.convert("RGB"))
    x0,y0,x1,y1 = box
    crop = rgb[y0:y1+1, x0:x1+1].copy()
    m = make_ink_mask(crop, ink_thresh=INK_THRESH, dilate=0)
    crop_c, m_c = tight_crop_to_mask(crop, m, margin=1)
    if crop_c.shape[0] < 10 or crop_c.shape[1] < 10:
        return None
    return {"rgb": crop_c, "mask": m_c}

def build_interior_mask_simple(rgb):
    """
    Very simple interior mask:
    treat pixels that are "almost white" as interior.
    """
    gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
    interior = (gray > 245).astype(np.uint8) * 255
    # close holes a bit
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25,25))
    interior = cv2.morphologyEx(interior, cv2.MORPH_CLOSE, k, iterations=1)
    return interior

def apply_add_random_closet(image_pil, template_pool):
    """
    Paste a random closet template in a random interior spot (tries multiple times).
    """
    if not template_pool:
        raise RuntimeError("No closet templates available")

    base_rgb = np.array(image_pil.convert("RGB"))
    H,W = base_rgb.shape[:2]
    interior = build_interior_mask_simple(base_rgb)

    tpl = random.choice(template_pool)
    patch = tpl["rgb"]
    pmask = tpl["mask"]

    ph,pw = patch.shape[:2]
    # scale patch randomly a bit
    scale = random.choice([0.7, 0.85, 1.0, 1.15])
    nw = max(8, int(pw*scale)); nh = max(8, int(ph*scale))
    patch_r = cv2.resize(patch, (nw, nh), interpolation=cv2.INTER_AREA)
    pmask_r = cv2.resize(pmask, (nw, nh), interpolation=cv2.INTER_NEAREST)

    # try random placements
    for _ in range(80):
        x = random.randint(0, max(0, W-nw))
        y = random.randint(0, max(0, H-nh))

        # must be mostly interior
        roi = interior[y:y+nh, x:x+nw]
        if roi.size == 0:
            continue
        if (roi > 0).mean() < 0.95:
            continue

        # avoid pasting on busy (too many black pixels already)
        base_roi = base_rgb[y:y+nh, x:x+nw]
        base_black = (cv2.cvtColor(base_roi, cv2.COLOR_RGB2GRAY) < 200).mean()
        if base_black > 0.06:
            continue

        out = paste_centered(
            base_rgb,
            patch_r,
            pmask_r,
            [x, y, x+nw-1, y+nh-1],
            feather_sigma=0.7
        )
        return Image.fromarray(out), [x,y,x+nw-1,y+nh-1]

    raise RuntimeError("No valid placement found for closet")


# =========================================================
# MAIN
# =========================================================
def main():
    # Create output folders for each change
    for ch in CHANGES:
        out_dir = IMAGES_AFTER_DIR / ch["group"] / ch["name"]
        ensure_dir(out_dir)

    # Load DETR
    print(f"Loading DETR from: {DETR_MODEL_PATH}")
    processor = DetrImageProcessor.from_pretrained(str(DETR_MODEL_PATH))
    model = DetrForObjectDetection.from_pretrained(str(DETR_MODEL_PATH)).to(DEVICE)
    model.eval()

    # Load first N images
    imgs = list_images(IMAGES_BEFORE_DIR)[:RUN_FIRST_N_IMAGES]
    print(f"Found {len(imgs)} images to process (first {RUN_FIRST_N_IMAGES}).")

    if len(imgs) == 0:
        print(f"Warning: No images found in {IMAGES_BEFORE_DIR}")

    # Build closet template pool from first N images (best-effort)
    template_pool = []
    print("Building closet template pool...")
    for p in imgs:
        try:
            pil = Image.open(p).convert("RGB")
            tpl = extract_closet_template(pil, processor, model)
            if tpl is not None:
                template_pool.append(tpl)
        except:
            pass
    print(f"Closet template pool size: {len(template_pool)}")

    # Process each image x each change
    for idx, img_path in enumerate(imgs):
        try:
            image_pil = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"[SKIP] cannot read {img_path.name}: {e}")
            continue

        for ch in CHANGES:
            out_dir = IMAGES_AFTER_DIR / ch["group"] / ch["name"]
            stem = safe_stem(img_path.stem)
            out_name = f"{stem}_{idx}_synthetic{img_path.suffix.lower()}"
            out_path = out_dir / out_name

            try:
                if ch["type"] == "remove":
                    out_pil, _ = apply_remove_white(image_pil, ch["label"], processor, model)
                elif ch["type"] == "replace":
                    # currently only dst=stove supported by this pipeline
                    out_pil, _ = apply_replace_with_stove(image_pil, ch["src"], processor, model)
                elif ch["type"] == "add_closet":
                    out_pil, _ = apply_add_random_closet(image_pil, template_pool)
                else:
                    raise RuntimeError(f"Unknown change type: {ch['type']}")

                out_pil.save(out_path)
                print(f"[OK] {img_path.name} -> {ch['name']} -> {out_path.as_posix()}")
            except Exception as e:
                # required behavior: if failed -> continue
                print(f"[FAIL] {img_path.name} -> {ch['name']}: {e}")
                continue

        # free memory a bit
        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    print("\nDONE")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


DEVICE: cuda




Found 50 images to process (first 50).
Closet template pool size: 43
[OK] image_003.jpg -> remove_stove -> synthetic_dataset/images_after/substantial/remove_stove/image_003_0_synthetic.jpg
[OK] image_003.jpg -> remove_sink -> synthetic_dataset/images_after/substantial/remove_sink/image_003_0_synthetic.jpg
[FAIL] image_003.jpg -> remove_toilet: No detection for label=toilet
[FAIL] image_003.jpg -> replace_toilet2stove: No detection for label=toilet
[OK] image_003.jpg -> replace_sink2stove -> synthetic_dataset/images_after/substantial/replace_sink2stove/image_003_0_synthetic.jpg
[OK] image_003.jpg -> remove_1stdoor -> synthetic_dataset/images_after/non-substantial/remove_1stdoor/image_003_0_synthetic.jpg
[OK] image_003.jpg -> remove_2sdoor -> synthetic_dataset/images_after/non-substantial/remove_2sdoor/image_003_0_synthetic.jpg
[OK] image_003.jpg -> remove_closet -> synthetic_dataset/images_after/non-substantial/remove_closet/image_003_0_synthetic.jpg
[OK] image_003.jpg -> add_random_clo

In [None]:
import os, gc, random, re, json, time
from pathlib import Path
from datetime import datetime

import numpy as np
import cv2
import torch
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection

# Optional SAM (better masks). If not installed / checkpoint missing -> fallback to bbox mask
USE_SAM = True
try:
    from segment_anything import sam_model_registry, SamPredictor
except Exception:
    USE_SAM = False


# =========================================================
# CONFIGURATION
# =========================================================
# Define root directories relative to the 'Code' folder
BASE_DATA_DIR = Path(r"..\Data")
MODELS_DIR    = Path(r"..\Models")

IMAGES_BEFORE_DIR = BASE_DATA_DIR / "images_before"
IMAGES_AFTER_DIR  = BASE_DATA_DIR / "images_after"
META_PATH         = BASE_DATA_DIR / "metadata.jsonl"

DETR_MODEL_PATH = MODELS_DIR / "detr-finetuned-floorplans"
SAM_CHECKPOINT  = MODELS_DIR / "Sam_Checkpoint" / "sam_vit_h_4b8939.pth"

# Path to stove patches for replacement
STOVE_PATCH_DIR = BASE_DATA_DIR / "Label_Pics_for_Synthetic_Data_Generation" / "stove"

RUN_FIRST_N_IMAGES = 50

DETR_THRESH = 0.25
BOX_PAD = 2

# Mask feathering for removal
FEATHER_SIGMA_REMOVE = 0.8

# Replacement patch mask settings
INK_THRESH = 180
PATCH_DILATE = 0
FEATHER_SIGMA_PATCH = 0.6
MIN_PATCH_SIZE = 64

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")


# =========================================================
# CHANGES CONFIGURATION
# =========================================================
CHANGES = [
    # Substantial changes
    {"name": "remove_stove",         "group": "substantial",     "type": "remove",  "label": "stove"},
    {"name": "remove_sink",          "group": "substantial",     "type": "remove",  "label": "sink"},
    {"name": "remove_toilet",        "group": "substantial",     "type": "remove",  "label": "toilet"},
    {"name": "replace_toilet2stove", "group": "substantial",     "type": "replace", "src": "toilet", "dst": "stove"},
    {"name": "replace_sink2stove",   "group": "substantial",     "type": "replace", "src": "sink",   "dst": "stove"},

    # Non-substantial changes
    {"name": "remove_1stdoor",       "group": "non-substantial", "type": "remove",  "label": "1stdoor"},
    {"name": "remove_2sdoor",        "group": "non-substantial", "type": "remove",  "label": "2sdoor"},
    {"name": "remove_closet",        "group": "non-substantial", "type": "remove",  "label": "closet"},

    # Add closet via wall-only logic
    {"name": "add_random_closet",    "group": "non-substantial", "type": "add_closet_wall_only_v2"},
]


# =========================================================
# UTILITIES
# =========================================================
def list_images(folder: Path):
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
    if not folder.exists():
        print(f"Warning: Input folder not found: {folder}")
        return []
    files = [p for p in folder.iterdir() if p.suffix.lower() in exts]
    files.sort()
    return files

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def safe_stem(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_\-\.]+", "_", s)

def expand_box_xyxy(box, W, H, pad):
    x0, y0, x1, y1 = box
    x0 = max(0, int(x0) - pad)
    y0 = max(0, int(y0) - pad)
    x1 = min(W - 1, int(x1) + pad)
    y1 = min(H - 1, int(y1) + pad)
    return [x0, y0, x1, y1]

def format_hhmmss(seconds: float) -> str:
    seconds = max(0, int(seconds))
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return f"{h:02d}:{m:02d}:{s:02d}"

def append_metadata(record: dict):
    ensure_dir(META_PATH.parent)
    with open(META_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")


# =========================================================
# DETR HELPERS
# =========================================================
def run_detr_get_best_box(image_pil, target_label: str, processor, model, thresh=0.25):
    """
    Return best box [x0,y0,x1,y1] for given label, or None.
    """
    inputs = processor(images=image_pil, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)

    W, H = image_pil.size
    target_sizes = torch.tensor([[H, W]], device=DEVICE)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=thresh)[0]

    id2label = model.config.id2label
    best = None
    best_score = -1.0

    for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
        label_name = id2label[int(label_id)].lower().strip()
        if label_name == target_label.lower().strip():
            sc = float(score)
            if sc > best_score:
                best_score = sc
                b = box.detach().cpu().numpy().tolist()
                best = [int(round(b[0])), int(round(b[1])), int(round(b[2])), int(round(b[3]))]

    return best, best_score


# =========================================================
# SAM HELPERS (optional)
# =========================================================
def box_mask(H, W, box):
    x0, y0, x1, y1 = box
    m = np.zeros((H, W), dtype=np.uint8)
    m[y0:y1+1, x0:x1+1] = 255
    return m

def feather_mask(mask_uint8, sigma=1.0):
    if sigma <= 0:
        return mask_uint8
    k = int(max(3, round(sigma * 6)))
    if k % 2 == 0:
        k += 1
    return cv2.GaussianBlur(mask_uint8, (k, k), sigmaX=sigma, sigmaY=sigma)

def build_sam_predictor_if_possible(image_np_rgb):
    if not USE_SAM:
        return None
    if not SAM_CHECKPOINT.exists():
        return None
    try:
        sam = sam_model_registry["vit_h"](checkpoint=str(SAM_CHECKPOINT))
        sam.to(device=DEVICE)
        predictor = SamPredictor(sam)
        predictor.set_image(image_np_rgb, image_format="RGB")
        return predictor
    except Exception:
        return None

def sam_mask_from_box(predictor, box_xyxy):
    try:
        x0,y0,x1,y1 = box_xyxy
        box = np.array([x0,y0,x1,y1], dtype=np.float32)
        masks, _, _ = predictor.predict(box=box, multimask_output=False)
        return (masks[0].astype(np.uint8) * 255)
    except Exception:
        return None


# =========================================================
# LOGIC: REMOVE (fill white)
# =========================================================
def apply_remove_white(image_pil, target_label, processor, model):
    box, sc = run_detr_get_best_box(image_pil, target_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={target_label}")

    W, H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    img_np = np.array(image_pil.convert("RGB"))
    predictor = build_sam_predictor_if_possible(img_np)

    if predictor is not None:
        m = sam_mask_from_box(predictor, box)
        if m is None:
            m = box_mask(H, W, box)
    else:
        m = box_mask(H, W, box)

    m = feather_mask(m, sigma=FEATHER_SIGMA_REMOVE)

    out = img_np.copy()
    alpha = (m.astype(np.float32) / 255.0)[..., None]
    white = np.full_like(out, 255)
    out = (out.astype(np.float32) * (1.0 - alpha) + white.astype(np.float32) * alpha).astype(np.uint8)

    if predictor is not None:
        try:
            del predictor
        except:
            pass
        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    return Image.fromarray(out), box, {"score": sc}


# =========================================================
# LOGIC: REPLACE (sink/toilet -> stove)
# =========================================================
def make_ink_mask(rgb_uint8, ink_thresh=180, dilate=0):
    gray = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2GRAY)
    m = (gray < ink_thresh).astype(np.uint8) * 255
    if dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate+1, 2*dilate+1))
        m = cv2.dilate(m, k, iterations=1)
    return m

def tight_crop_to_mask(rgb, mask, margin=2):
    ys, xs = np.where(mask > 0)
    if len(xs) == 0 or len(ys) == 0:
        return rgb, mask
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    x0 = max(0, x0 - margin); y0 = max(0, y0 - margin)
    x1 = min(rgb.shape[1]-1, x1 + margin); y1 = min(rgb.shape[0]-1, y1 + margin)
    return rgb[y0:y1+1, x0:x1+1].copy(), mask[y0:y1+1, x0:x1+1].copy()

def resize_keep_aspect(rgb, mask, target_w, target_h):
    h, w = rgb.shape[:2]
    if w <= 0 or h <= 0:
        return rgb, mask
    scale = min(target_w / w, target_h / h)
    nw = max(1, int(round(w * scale)))
    nh = max(1, int(round(h * scale)))
    rgb_r = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (nw, nh), interpolation=cv2.INTER_NEAREST)
    return rgb_r, mask_r

def pick_random_patch(folder: Path, min_size=64):
    if not folder.exists():
        raise RuntimeError(f"Patch folder not found: {folder}")
    files = list_images(folder)
    if not files:
        raise RuntimeError(f"No patches in: {folder}")
    random.shuffle(files)
    for f in files:
        im = Image.open(f).convert("RGB")
        if min(im.size) >= min_size:
            return im
    return Image.open(files[0]).convert("RGB")

def paste_centered(base_rgb, patch_rgb, patch_mask, box_xyxy, feather_sigma=0.6):
    x0,y0,x1,y1 = box_xyxy
    bw = max(1, x1-x0+1)
    bh = max(1, y1-y0+1)

    pr, mr = resize_keep_aspect(patch_rgb, patch_mask, bw, bh)
    m_soft = feather_mask(mr, sigma=feather_sigma)
    alpha = (m_soft.astype(np.float32) / 255.0)[..., None]

    out = base_rgb.copy()
    ph, pw = pr.shape[:2]
    cx = x0 + bw//2
    cy = y0 + bh//2
    px0 = int(cx - pw//2)
    py0 = int(cy - ph//2)

    H,W = out.shape[:2]
    rx0 = max(0, px0); ry0 = max(0, py0)
    rx1 = min(W, px0+pw); ry1 = min(H, py0+ph)
    if rx1 <= rx0 or ry1 <= ry0:
        return out

    sx0 = rx0 - px0; sy0 = ry0 - py0
    sx1 = sx0 + (rx1-rx0); sy1 = sy0 + (ry1-ry0)

    roi = out[ry0:ry1, rx0:rx1].astype(np.float32)
    pr_roi = pr[sy0:sy1, sx0:sx1].astype(np.float32)
    a_roi = alpha[sy0:sy1, sx0:sx1].astype(np.float32)

    out[ry0:ry1, rx0:rx1] = (roi*(1-a_roi) + pr_roi*a_roi).astype(np.uint8)
    return out

def apply_replace_with_stove(image_pil, source_label, processor, model):
    box, sc = run_detr_get_best_box(image_pil, source_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={source_label}")

    W,H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    base_rgb = np.array(image_pil.convert("RGB"))
    base_rgb_removed = base_rgb.copy()
    x0,y0,x1,y1 = box
    base_rgb_removed[y0:y1+1, x0:x1+1] = 255

    stove_pil = pick_random_patch(STOVE_PATCH_DIR, min_size=MIN_PATCH_SIZE)
    stove_rgb = np.array(stove_pil)
    m = make_ink_mask(stove_rgb, ink_thresh=INK_THRESH, dilate=PATCH_DILATE)
    stove_rgb_c, m_c = tight_crop_to_mask(stove_rgb, m, margin=2)

    out_rgb = paste_centered(
        base_rgb_removed,
        stove_rgb_c,
        m_c,
        box,
        feather_sigma=FEATHER_SIGMA_PATCH
    )

    return Image.fromarray(out_rgb), box, {"score": sc, "patch_src": str(stove_pil.size)}


# =========================================================
# LOGIC: ADD CLOSET (WALL-ONLY v2)
# =========================================================
TARGET_LABEL = "closet"
DETR_THRESH_CLOSET = 0.20

# Walls
WALL_KERNEL = 11

# Scan
MARGIN = 10
STRIDE = 4

# Interior constraint
MIN_INK_INSIDE_FRAC = 0.99

# Coarse "busy" filter
MAX_NONWALL_BLACK_FRAC = 0.06

# Overlap avoidance
EXISTING_DILATE = 3
MAX_OVERLAP_PIXELS_BASE = 10
MAX_OVERLAP_FRAC_BASE   = 0.03

# Keep-away ring
RING_DILATE = 10
MAX_RING_OVERLAP_PIXELS = 25
MAX_RING_OVERLAP_FRAC   = 0.03

# Mask blending
FEATHER_SIGMA_CLOSET = 0.6

# Template filtering
MIN_TEMPLATE_SIZE = 18
MAX_TEMPLATE_REL = 0.45
BOX_PAD_CLOSET = 2

# Near-wall bands
BANDS = [10, 14, 18, 24, 30, 40, 55]

# Adaptive overlap relaxation
OV_PIX_LIST  = [MAX_OVERLAP_PIXELS_BASE, 16, 24, 36]
OV_FRAC_LIST = [MAX_OVERLAP_FRAC_BASE,   0.04, 0.06, 0.08]

# Must be near wall
WALL_TOUCH_MAX_DIST = 3.0
MIN_WALL_TOUCH_FRAC = 0.03

# Thick-wall cleanup
THICK_MIN_AREA       = 2500
THICK_MIN_LONG_SIDE  = 160
THICK_MIN_SHORT_SIDE = 6

# Wall/Non-wall separation
WALL_INK_DILATE = 3


def filter_thick_map_components(thick_map, min_area=1500, min_long_side=120, min_short_side=6):
    bw = (thick_map > 0).astype(np.uint8)
    num, labels, stats, _ = cv2.connectedComponentsWithStats(bw, connectivity=8)
    out = np.zeros_like(bw)
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        long_side = max(w, h)
        short_side = min(w, h)
        if area < min_area:
            continue
        if long_side < min_long_side:
            continue
        if short_side < min_short_side:
            continue
        out[labels == i] = 1
    return (out * 255).astype(np.uint8)

def build_line_maps(img_bgr, wall_kernel):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    _, bin_inv = cv2.threshold(gray, 235, 255, cv2.THRESH_BINARY_INV) 

    k = cv2.getStructuringElement(cv2.MORPH_RECT, (wall_kernel, wall_kernel))
    thick = cv2.morphologyEx(bin_inv, cv2.MORPH_CLOSE, k, iterations=1)
    thick = cv2.erode(thick, np.ones((3,3), np.uint8), iterations=1)

    thick_clean = filter_thick_map_components(
        thick,
        min_area=THICK_MIN_AREA,
        min_long_side=THICK_MIN_LONG_SIDE,
        min_short_side=THICK_MIN_SHORT_SIDE
    )
    return bin_inv, thick_clean

def build_interior_from_walls(thick_map, close_gaps=13, wall_dilate=2):
    H, W = thick_map.shape[:2]
    walls = (thick_map > 0).astype(np.uint8) * 255

    if wall_dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_RECT, (2*wall_dilate+1, 2*wall_dilate+1))
        walls = cv2.dilate(walls, k, iterations=1)

    k2 = cv2.getStructuringElement(cv2.MORPH_RECT, (close_gaps, close_gaps))
    walls_closed = cv2.morphologyEx(walls, cv2.MORPH_CLOSE, k2, iterations=1)

    free = cv2.bitwise_not(walls_closed)
    ff = free.copy()
    mask = np.zeros((H+2, W+2), dtype=np.uint8)

    for seed in [(0,0), (W-1,0), (0,H-1), (W-1,H-1)]:
        if ff[seed[1], seed[0]] == 255:
            cv2.floodFill(ff, mask, seedPoint=seed, newVal=0)

    interior = ff
    return interior, walls_closed

def build_placement_mask(interior_mask, thick_map, wall_band):
    thick_bin = (thick_map > 0).astype(np.uint8)
    dist = cv2.distanceTransform(1 - thick_bin, cv2.DIST_L2, 3)
    near_wall = (dist <= wall_band).astype(np.uint8) * 255
    placement = cv2.bitwise_and(interior_mask, near_wall)
    return placement

def mask_from_crop_nonwhite(crop_bgr, feather_sigma):
    gray = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2GRAY)
    hard = cv2.inRange(gray, 0, 245)
    hard = cv2.morphologyEx(hard, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
    soft = hard.copy()
    if feather_sigma and feather_sigma > 0:
        soft = cv2.GaussianBlur(soft, (0,0), sigmaX=float(feather_sigma))
    return hard, soft

def make_ring_mask(mh_uint8, ring_dilate):
    mh = (mh_uint8 > 0).astype(np.uint8)
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*ring_dilate+1, 2*ring_dilate+1))
    dil = cv2.dilate(mh, k, iterations=1)
    ring = (dil > 0).astype(np.uint8) - (mh > 0).astype(np.uint8)
    ring[ring < 0] = 0
    return ring

def find_best_spot_wall_only(
    interior_mask, placement_mask,
    thick_map,
    all_ink, nonwall_ink,
    template_mask_hard, roi_w, roi_h,
    margin, stride,
    max_nonwall_black_frac,
    existing_dilate,
    max_overlap_pixels, max_overlap_frac,
    wall_band,
    min_ink_inside_frac,
    wall_touch_max_dist, min_wall_touch_frac,
    ring_dilate, max_ring_overlap_pixels, max_ring_overlap_frac
):
    H, W = interior_mask.shape[:2]

    if existing_dilate and existing_dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*existing_dilate+1, 2*existing_dilate+1))
        nonwall_ink_dil = cv2.dilate(nonwall_ink, k, iterations=1)
    else:
        nonwall_ink_dil = nonwall_ink

    thick_bin = (thick_map > 0).astype(np.uint8)
    dist_to_wall = cv2.distanceTransform(1 - thick_bin, cv2.DIST_L2, 3)

    mh = (template_mask_hard > 0).astype(np.uint8)
    mh_sum = int(mh.sum())
    if mh_sum == 0:
        return None

    ring = make_ring_mask(template_mask_hard, ring_dilate)
    ring_sum = int(ring.sum()) if ring is not None else 0

    best = None
    best_score = -1e18

    for y in range(margin, H - roi_h - margin, stride):
        for x in range(margin, W - roi_w - margin, stride):

            cx = x + roi_w // 2
            cy = y + roi_h // 2
            if placement_mask[cy, cx] == 0:
                continue

            roi_nonwall = nonwall_ink[y:y+roi_h, x:x+roi_w]
            nonwall_black_frac = float((roi_nonwall > 0).mean())
            if nonwall_black_frac > max_nonwall_black_frac:
                continue

            roi_int = (interior_mask[y:y+roi_h, x:x+roi_w] > 0).astype(np.uint8)
            inside_on_ink = roi_int[mh > 0]
            if inside_on_ink.size == 0:
                continue
            ink_inside_frac = float(inside_on_ink.mean())
            if ink_inside_frac < min_ink_inside_frac:
                continue

            roi_dist = dist_to_wall[y:y+roi_h, x:x+roi_w]
            ink_dist = roi_dist[mh > 0]
            if ink_dist.size == 0:
                continue
            if float(ink_dist.mean()) > float(wall_band):
                continue
            touch_frac = float((ink_dist <= wall_touch_max_dist).mean())
            if touch_frac < min_wall_touch_frac:
                continue

            roi_exist = (nonwall_ink_dil[y:y+roi_h, x:x+roi_w] > 0).astype(np.uint8)
            overlap_pixels = int((roi_exist * mh).sum())
            if overlap_pixels > max_overlap_pixels:
                continue
            overlap_frac = overlap_pixels / float(mh_sum)
            if overlap_frac > max_overlap_frac:
                continue

            ring_overlap = 0
            ring_overlap_frac = 0.0
            if ring_sum > 0:
                ring_overlap = int((roi_exist * ring).sum())
                if ring_overlap > max_ring_overlap_pixels:
                    continue
                ring_overlap_frac = ring_overlap / float(ring_sum)
                if ring_overlap_frac > max_ring_overlap_frac:
                    continue

            score = (
                (1.0 - nonwall_black_frac) * 1.5 +
                (1.0 - overlap_frac) * 4.0 +
                (1.0 - ring_overlap_frac) * 2.5 +
                (1.0 - min(1.0, float(ink_dist.mean()) / float(wall_band))) * 1.5 +
                touch_frac * 1.0 +
                ink_inside_frac * 0.5
            )

            if score > best_score:
                best_score = score
                best = (x, y, nonwall_black_frac, ink_inside_frac, overlap_frac, overlap_pixels,
                        float(ink_dist.mean()), touch_frac, ring_overlap)

    return best

def apply_add_closet_wall_only_v2(image_pil, processor, model):
    """
    Returns:
      out_pil,
      placed_box [x0,y0,x1,y1],
      template_box [x0,y0,x1,y1],
      extra_params dict,
      viz_rgb (for result_box)
    """
    img_bgr = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
    H, W = img_bgr.shape[:2]

    # DETR detections (all closets) from current image
    inputs = processor(images=image_pil, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    target_sizes = torch.tensor([image_pil.size[::-1]], device=DEVICE)
    res = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=DETR_THRESH_CLOSET)[0]
    id2label = model.config.id2label

    templates = []
    for sc, lab, bx in zip(res["scores"], res["labels"], res["boxes"]):
        if id2label[int(lab)] != TARGET_LABEL:
            continue
        b = expand_box_xyxy(bx.detach().cpu().numpy(), W, H, BOX_PAD_CLOSET)
        x0,y0,x1,y1 = map(int, b)
        crop = img_bgr[y0:y1, x0:x1].copy()
        if crop.size == 0:
            continue
        th, tw = crop.shape[:2]
        if th < MIN_TEMPLATE_SIZE or tw < MIN_TEMPLATE_SIZE:
            continue
        if tw > W * MAX_TEMPLATE_REL or th > H * MAX_TEMPLATE_REL:
            continue
        mh, ms = mask_from_crop_nonwhite(crop, FEATHER_SIGMA_CLOSET)
        templates.append({"crop": crop, "mask_hard": mh, "mask_soft": ms, "box": (x0,y0,x1,y1), "score": float(sc)})

    if not templates:
        raise RuntimeError("No closets detected for template (try lower DETR_THRESH_CLOSET).")

    canon = choose_template_highest_score(templates)
    template = canon["crop"]
    template_mask_hard = canon["mask_hard"]
    template_mask_soft = canon["mask_soft"]
    sx0,sy0,sx1,sy1 = canon["box"]

    # proportional scaling from median closets in this plan
    ws = [t["box"][2] - t["box"][0] for t in templates]
    hs = [t["box"][3] - t["box"][1] for t in templates]
    med_w, med_h = float(np.median(ws)), float(np.median(hs))
    base_h, base_w = template.shape[:2]
    target_scale = min(med_w / max(1.0, base_w), med_h / max(1.0, base_h))
    target_scale = max(0.5, min(1.2, float(target_scale)))
    SCALES = [target_scale * s for s in [0.9, 1.0, 1.1]]

    # build maps
    bin_inv_lines, thick_map = build_line_maps(img_bgr, WALL_KERNEL)
    interior_mask, _walls_closed = build_interior_from_walls(thick_map, close_gaps=13, wall_dilate=2)

    all_ink = (bin_inv_lines > 0).astype(np.uint8) * 255
    wall_ink = (thick_map > 0).astype(np.uint8) * 255
    if WALL_INK_DILATE > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*WALL_INK_DILATE+1, 2*WALL_INK_DILATE+1))
        wall_ink = cv2.dilate(wall_ink, k, iterations=1)
    nonwall_ink = cv2.bitwise_and(all_ink, cv2.bitwise_not(wall_ink))

    # search
    best = None
    chosen = {"scale": None, "band": None, "ov_pix": None, "ov_frac": None}

    for band in BANDS:
        placement_mask = build_placement_mask(interior_mask, thick_map, wall_band=band)
        if int((placement_mask > 0).sum()) < 500:
            continue

        for ov_pix in OV_PIX_LIST:
            for ov_frac in OV_FRAC_LIST:
                for sc in SCALES:
                    t_try, mh_try, ms_try = resize_template(template, template_mask_hard, template_mask_soft, sc)
                    roi_h, roi_w = t_try.shape[:2]

                    best_try = find_best_spot_wall_only(
                        interior_mask=interior_mask,
                        placement_mask=placement_mask,
                        thick_map=thick_map,
                        all_ink=all_ink,
                        nonwall_ink=nonwall_ink,
                        template_mask_hard=mh_try,
                        roi_w=roi_w, roi_h=roi_h,
                        margin=MARGIN, stride=STRIDE,
                        max_nonwall_black_frac=MAX_NONWALL_BLACK_FRAC,
                        existing_dilate=EXISTING_DILATE,
                        max_overlap_pixels=ov_pix,
                        max_overlap_frac=ov_frac,
                        wall_band=band,
                        min_ink_inside_frac=MIN_INK_INSIDE_FRAC,
                        wall_touch_max_dist=WALL_TOUCH_MAX_DIST,
                        min_wall_touch_frac=MIN_WALL_TOUCH_FRAC,
                        ring_dilate=RING_DILATE,
                        max_ring_overlap_pixels=MAX_RING_OVERLAP_PIXELS,
                        max_ring_overlap_frac=MAX_RING_OVERLAP_FRAC
                    )

                    if best_try is not None:
                        best = best_try
                        template, template_mask_hard, template_mask_soft = t_try, mh_try, ms_try
                        chosen = {"scale": float(sc), "band": int(band), "ov_pix": int(ov_pix), "ov_frac": float(ov_frac)}
                        break
                if best is not None:
                    break
            if best is not None:
                break
        if best is not None:
            break

    if best is None:
        raise RuntimeError("No placement found for add_closet_wall_only_v2.")

    px, py, nonwall_black_frac, ink_inside_frac, ov_f, ov_p, mean_dist, touch_frac, ring_overlap = best
    roi_h, roi_w = template.shape[:2]

    out = img_bgr.copy()
    out = alpha_paste(out, template, template_mask_soft, px, py)

    viz = out.copy()
    cv2.rectangle(viz, (sx0,sy0), (sx1,sy1), (255,0,0), 2)                 # template source
    cv2.rectangle(viz, (px,py), (px+roi_w, py+roi_h), (0,0,255), 2)        # placement

    placed_box = [int(px), int(py), int(px+roi_w), int(py+roi_h)]
    template_box = [int(sx0), int(sy0), int(sx1), int(sy1)]

    extra = {
        "template_score": float(canon["score"]),
        "placement": {
            "nonwall_black_frac": float(nonwall_black_frac),
            "ink_inside_frac": float(ink_inside_frac),
            "overlap_frac": float(ov_f),
            "overlap_pixels": int(ov_p),
            "mean_wall_dist": float(mean_dist),
            "touch_frac": float(touch_frac),
            "ring_overlap": int(ring_overlap),
        },
        "chosen": chosen,
    }

    out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
    viz_rgb = cv2.cvtColor(viz, cv2.COLOR_BGR2RGB)
    return Image.fromarray(out_rgb), placed_box, template_box, extra, viz_rgb


# =========================================================
# MAIN
# =========================================================
def main():
    # Create output folders for each change
    for ch in CHANGES:
        ensure_dir(IMAGES_AFTER_DIR / ch["group"] / ch["name"])

    # Load DETR
    print(f"Loading DETR from: {DETR_MODEL_PATH}")
    processor = DetrImageProcessor.from_pretrained(str(DETR_MODEL_PATH))
    model = DetrForObjectDetection.from_pretrained(str(DETR_MODEL_PATH)).to(DEVICE)
    model.eval()

    imgs = list_images(IMAGES_BEFORE_DIR)[:RUN_FIRST_N_IMAGES]
    print(f"Found {len(imgs)} images to process (first {RUN_FIRST_N_IMAGES}).")
    print(f"Metadata -> {META_PATH.as_posix()} (append) | run_id={RUN_ID}")

    total_tasks = len(imgs) * len(CHANGES)
    done_tasks = 0
    start_time = time.time()

    for idx, img_path in enumerate(imgs):
        try:
            image_pil = Image.open(img_path).convert("RGB")
        except Exception as e:
            # still advance tasks for this image (as failures per change)
            for ch in CHANGES:
                done_tasks += 1
                append_metadata({
                    "run_id": RUN_ID,
                    "status": "fail_read",
                    "src_image": str(img_path),
                    "change": ch["name"],
                    "group": ch["group"],
                    "error": str(e),
                    "ts": datetime.now().isoformat()
                })
            print(f"[SKIP] cannot read {img_path.name}: {e}")
            continue

        for ch in CHANGES:
            t0 = time.time()
            out_dir = IMAGES_AFTER_DIR / ch["group"] / ch["name"]
            stem = safe_stem(img_path.stem)
            out_name = f"{stem}_{idx}_synthetic{img_path.suffix.lower()}"
            out_path = out_dir / out_name

            rec = {
                "run_id": RUN_ID,
                "ts": datetime.now().isoformat(),
                "status": "fail",
                "src_image": str(img_path),
                "src_stem": stem,
                "src_index": idx,
                "change": ch["name"],
                "group": ch["group"],
                "out_image": str(out_path),
                "bbox": None,
                "extra": {},
            }

            try:
                if ch["type"] == "remove":
                    out_pil, bbox, extra = apply_remove_white(image_pil, ch["label"], processor, model)
                    rec["bbox"] = {"target": ch["label"], "box_xyxy": bbox}
                    rec["extra"] = extra

                    out_pil.save(out_path)
                    rec["status"] = "ok"

                elif ch["type"] == "replace":
                    out_pil, bbox, extra = apply_replace_with_stove(image_pil, ch["src"], processor, model)
                    rec["bbox"] = {"target": ch["src"], "box_xyxy": bbox}
                    rec["extra"] = extra

                    out_pil.save(out_path)
                    rec["status"] = "ok"

                elif ch["type"] == "add_closet_wall_only_v2":
                    out_pil, placed_box, template_box, extra, viz_rgb = apply_add_closet_wall_only_v2(image_pil, processor, model)

                    # save result image
                    out_pil.save(out_path)

                    # save result_box next to it
                    box_path = out_dir / f"{stem}_{idx}_synthetic_box.png"
                    Image.fromarray(viz_rgb).save(box_path)

                    rec["status"] = "ok"
                    rec["bbox"] = {
                        "template_box_xyxy": template_box,
                        "placed_box_xyxy": placed_box
                    }
                    rec["extra"] = extra
                    rec["out_result_box"] = str(box_path)

                else:
                    raise RuntimeError(f"Unknown change type: {ch['type']}")

            except Exception as e:
                rec["error"] = str(e)

            # metadata + progress
            rec["elapsed_sec"] = float(time.time() - t0)
            append_metadata(rec)

            done_tasks += 1
            elapsed = time.time() - start_time
            avg = elapsed / max(1, done_tasks)
            eta = avg * (total_tasks - done_tasks)

            pct = 100.0 * done_tasks / max(1, total_tasks)
            print(
                f"[{done_tasks}/{total_tasks} | {pct:5.1f}%] "
                f"elapsed={format_hhmmss(elapsed)} ETA={format_hhmmss(eta)} | "
                f"{img_path.name} -> {ch['name']} -> {rec['status']}"
            )

        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    print("\nDONE")


if __name__ == "__main__":
    main()

DEVICE: cuda
Found 50 images to process (first 50).
Metadata -> synthetic_dataset/metadata.jsonl (append) | run_id=20260113_125307
[1/450 |   0.2%] elapsed=00:00:07 ETA=00:54:16 | image_003.jpg -> remove_stove -> ok
[2/450 |   0.4%] elapsed=00:00:12 ETA=00:46:35 | image_003.jpg -> remove_sink -> ok
[3/450 |   0.7%] elapsed=00:00:12 ETA=00:31:02 | image_003.jpg -> remove_toilet -> fail
[4/450 |   0.9%] elapsed=00:00:12 ETA=00:23:16 | image_003.jpg -> replace_toilet2stove -> fail
[5/450 |   1.1%] elapsed=00:00:12 ETA=00:18:37 | image_003.jpg -> replace_sink2stove -> ok
[6/450 |   1.3%] elapsed=00:00:16 ETA=00:20:57 | image_003.jpg -> remove_1stdoor -> ok
[7/450 |   1.6%] elapsed=00:00:21 ETA=00:22:37 | image_003.jpg -> remove_2sdoor -> ok
[8/450 |   1.8%] elapsed=00:00:25 ETA=00:23:49 | image_003.jpg -> remove_closet -> ok
[9/450 |   2.0%] elapsed=00:00:29 ETA=00:23:41 | image_003.jpg -> add_random_closet -> ok
[10/450 |   2.2%] elapsed=00:00:29 ETA=00:21:28 | image_004.jpg -> remove_sto

# Synthetic Data Generation — Floorplans (Before → After)

## What this script does
Generates **synthetic “after” floorplan images** from `images_before/` by applying predefined changes (substantial / non-substantial).  
Each operation is logged into `metadata.jsonl`, and the run prints **progress + ETA + estimated total runtime**.

---

## Inputs
- Source images: `..\Data\images_before\`
- DETR model: `..\Models\detr-finetuned-floorplans\`
- (Optional) SAM checkpoint: `..\Models\Sam_Checkpoint\sam_vit_h_4b8939.pth`
- Stove patches for replacement: `...\Data\Label_Pics_for_Synthetic_Data_Generation\stove\`

---
Example:
`planA_12_synthetic.png`

### Closet addition debug images
For `add_random_closet`, the script also saves:
`*_synthetic_box.png`  
(Shows **template source box** + **placement box**)

---

## Supported labels
`stove, sink, toilet, 1stdoor, 2sdoor, closet`

---

## Changes generated

### Substantial
- remove: `stove`
- remove: `sink`
- remove: `toilet`
- replace: `toilet → stove`
- replace: `sink → stove`

### Non-substantial
- remove: `1stdoor`
- remove: `2sdoor`
- remove: `closet`
- add: `closet` (wall-only v2 placement)

---

## Run control (ALL images / subset)
At the top of the script:
- `RUN_FIRST_N_IMAGES = None` → run on **all images**
- `RUN_FIRST_N_IMAGES = 50` → run on **first 50** (quick test)

---

## Clean re-run behavior
At the start of the run, the script:
- deletes `..\Data\images_after`
- deletes `..\Data\metadata.jsonl`

So you always get a **fresh run** without duplicates.

---

## Metadata (`metadata.jsonl`)
One JSON line per attempted operation (success or failure), including:
- `run_id`, `timestamp`
- `change`, `group`
- `src_image`, `out_image`
- `status` = `ok` / `fail`
- `bbox` details (and for closet insertion: template + placement boxes)
- `elapsed_sec`

---

## End-of-run summary
After finishing, the script prints (computed from `metadata.jsonl`):
- successes/failures per change
- success rate per change
- average runtime per change


In [None]:
import os, gc, random, re, json, time, shutil
from pathlib import Path
from datetime import datetime

import numpy as np
import cv2
import torch
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection

# Optional SAM (better masks). If not installed / checkpoint missing -> fallback to bbox mask
USE_SAM = True
try:
    from segment_anything import sam_model_registry, SamPredictor
except Exception:
    USE_SAM = False


# =========================================================
# CONFIGURATION
# =========================================================
# Paths relative to the 'Code' folder
BASE_DATA_DIR = Path(r"..\Data")
MODELS_DIR    = Path(r"..\Models")

IMAGES_BEFORE_DIR = BASE_DATA_DIR / "images_before"
IMAGES_AFTER_DIR  = BASE_DATA_DIR / "images_after"
META_PATH         = BASE_DATA_DIR / "metadata.jsonl"

DETR_MODEL_PATH = MODELS_DIR / "detr-finetuned-floorplans"
SAM_CHECKPOINT  = MODELS_DIR / "Sam_Checkpoint" / "sam_vit_h_4b8939.pth"

# Adjusted to match your folder structure for patches
STOVE_PATCH_DIR = BASE_DATA_DIR / "Label_Pics_for_Synthetic_Data_Generation" / "stove"

# Run control: None = run all, Integer = run first N images
RUN_FIRST_N_IMAGES = None 

DETR_THRESH = 0.25
BOX_PAD = 2

# Mask feathering for remove
FEATHER_SIGMA_REMOVE = 0.8

# Replacement patch mask
INK_THRESH = 180
PATCH_DILATE = 0
FEATHER_SIGMA_PATCH = 0.6
MIN_PATCH_SIZE = 64

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using Device: {DEVICE}")

RUN_ID = datetime.now().strftime("%Y%m%d_%H%M%S")


# =========================================================
# CHANGES CONFIGURATION
# =========================================================
CHANGES = [
    # Substantial changes
    {"name": "remove_stove",         "group": "substantial",     "type": "remove",  "label": "stove"},
    {"name": "remove_sink",          "group": "substantial",     "type": "remove",  "label": "sink"},
    {"name": "remove_toilet",        "group": "substantial",     "type": "remove",  "label": "toilet"},
    {"name": "replace_toilet2stove", "group": "substantial",     "type": "replace", "src": "toilet", "dst": "stove"},
    {"name": "replace_sink2stove",   "group": "substantial",     "type": "replace", "src": "sink",   "dst": "stove"},

    # Non-substantial changes
    {"name": "remove_1stdoor",       "group": "non-substantial", "type": "remove",  "label": "1stdoor"},
    {"name": "remove_2sdoor",        "group": "non-substantial", "type": "remove",  "label": "2sdoor"},
    {"name": "remove_closet",        "group": "non-substantial", "type": "remove",  "label": "closet"},

    # Add closet via wall-only logic
    {"name": "add_random_closet",    "group": "non-substantial", "type": "add_closet_wall_only_v2"},
]


# =========================================================
# UTILITIES
# =========================================================
def list_images(folder: Path):
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
    if not folder.exists():
        print(f"Warning: Folder not found: {folder}")
        return []
    files = [p for p in folder.iterdir() if p.suffix.lower() in exts]
    files.sort()
    return files

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def safe_stem(s: str) -> str:
    return re.sub(r"[^a-zA-Z0-9_\-\.]+", "_", s)

def expand_box_xyxy(box, W, H, pad):
    x0, y0, x1, y1 = box
    x0 = max(0, int(x0) - pad)
    y0 = max(0, int(y0) - pad)
    x1 = min(W - 1, int(x1) + pad)
    y1 = min(H - 1, int(y1) + pad)
    return [x0, y0, x1, y1]

def format_hhmmss(seconds: float) -> str:
    seconds = max(0, int(seconds))
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return f"{h:02d}:{m:02d}:{s:02d}"

def append_metadata(record: dict):
    ensure_dir(META_PATH.parent)
    with open(META_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(record, ensure_ascii=False) + "\n")

def wipe_outputs_and_metadata():
    """
    Clears images_after/* and metadata.jsonl to avoid duplicates on re-run.
    """
    print("Cleaning up previous outputs...")
    if IMAGES_AFTER_DIR.exists():
        shutil.rmtree(IMAGES_AFTER_DIR, ignore_errors=True)
    ensure_dir(IMAGES_AFTER_DIR)

    if META_PATH.exists():
        META_PATH.unlink()

def iter_metadata(path: Path):
    if not path.exists():
        return
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except:
                continue


# =========================================================
# DETR HELPERS
# =========================================================
def run_detr_get_best_box(image_pil, target_label: str, processor, model, thresh=0.25):
    """
    Return best box [x0,y0,x1,y1] for given label, or None.
    """
    inputs = processor(images=image_pil, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)

    W, H = image_pil.size
    target_sizes = torch.tensor([[H, W]], device=DEVICE)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=thresh)[0]

    id2label = model.config.id2label
    best = None
    best_score = -1.0

    for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
        label_name = id2label[int(label_id)].lower().strip()
        if label_name == target_label.lower().strip():
            sc = float(score)
            if sc > best_score:
                best_score = sc
                b = box.detach().cpu().numpy().tolist()
                best = [int(round(b[0])), int(round(b[1])), int(round(b[2])), int(round(b[3]))]

    return best, best_score


# =========================================================
# SAM HELPERS
# =========================================================
def box_mask(H, W, box):
    x0, y0, x1, y1 = box
    m = np.zeros((H, W), dtype=np.uint8)
    m[y0:y1+1, x0:x1+1] = 255
    return m

def feather_mask(mask_uint8, sigma=1.0):
    if sigma <= 0:
        return mask_uint8
    k = int(max(3, round(sigma * 6)))
    if k % 2 == 0:
        k += 1
    return cv2.GaussianBlur(mask_uint8, (k, k), sigmaX=sigma, sigmaY=sigma)

def build_sam_predictor_if_possible(image_np_rgb):
    if not USE_SAM:
        return None
    if not SAM_CHECKPOINT.exists():
        return None
    try:
        sam = sam_model_registry["vit_h"](checkpoint=str(SAM_CHECKPOINT))
        sam.to(device=DEVICE)
        predictor = SamPredictor(sam)
        predictor.set_image(image_np_rgb, image_format="RGB")
        return predictor
    except Exception:
        return None

def sam_mask_from_box(predictor, box_xyxy):
    try:
        x0,y0,x1,y1 = box_xyxy
        box = np.array([x0,y0,x1,y1], dtype=np.float32)
        masks, _, _ = predictor.predict(box=box, multimask_output=False)
        return (masks[0].astype(np.uint8) * 255)
    except Exception:
        return None


# =========================================================
# LOGIC: REMOVE (fill white)
# =========================================================
def apply_remove_white(image_pil, target_label, processor, model):
    box, sc = run_detr_get_best_box(image_pil, target_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={target_label}")

    W, H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    img_np = np.array(image_pil.convert("RGB"))
    predictor = build_sam_predictor_if_possible(img_np)

    if predictor is not None:
        m = sam_mask_from_box(predictor, box)
        if m is None:
            m = box_mask(H, W, box)
    else:
        m = box_mask(H, W, box)

    m = feather_mask(m, sigma=FEATHER_SIGMA_REMOVE)

    out = img_np.copy()
    alpha = (m.astype(np.float32) / 255.0)[..., None]
    white = np.full_like(out, 255)
    out = (out.astype(np.float32) * (1.0 - alpha) + white.astype(np.float32) * alpha).astype(np.uint8)

    if predictor is not None:
        try:
            del predictor
        except:
            pass
        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    return Image.fromarray(out), box, {"score": sc}


# =========================================================
# LOGIC: REPLACE (sink/toilet -> stove)
# =========================================================
def make_ink_mask(rgb_uint8, ink_thresh=180, dilate=0):
    gray = cv2.cvtColor(rgb_uint8, cv2.COLOR_RGB2GRAY)
    m = (gray < ink_thresh).astype(np.uint8) * 255
    if dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*dilate+1, 2*dilate+1))
        m = cv2.dilate(m, k, iterations=1)
    return m

def tight_crop_to_mask(rgb, mask, margin=2):
    ys, xs = np.where(mask > 0)
    if len(xs) == 0 or len(ys) == 0:
        return rgb, mask
    x0, x1 = xs.min(), xs.max()
    y0, y1 = ys.min(), ys.max()
    x0 = max(0, x0 - margin); y0 = max(0, y0 - margin)
    x1 = min(rgb.shape[1]-1, x1 + margin); y1 = min(rgb.shape[0]-1, y1 + margin)
    return rgb[y0:y1+1, x0:x1+1].copy(), mask[y0:y1+1, x0:x1+1].copy()

def resize_keep_aspect(rgb, mask, target_w, target_h):
    h, w = rgb.shape[:2]
    if w <= 0 or h <= 0:
        return rgb, mask
    scale = min(target_w / w, target_h / h)
    nw = max(1, int(round(w * scale)))
    nh = max(1, int(round(h * scale)))
    rgb_r = cv2.resize(rgb, (nw, nh), interpolation=cv2.INTER_AREA)
    mask_r = cv2.resize(mask, (nw, nh), interpolation=cv2.INTER_NEAREST)
    return rgb_r, mask_r

def pick_random_patch(folder: Path, min_size=64):
    if not folder.exists():
        raise RuntimeError(f"Patch folder not found: {folder}")
    files = list_images(folder)
    if not files:
        raise RuntimeError(f"No patches in: {folder}")
    random.shuffle(files)
    for f in files:
        im = Image.open(f).convert("RGB")
        if min(im.size) >= min_size:
            return im
    return Image.open(files[0]).convert("RGB")

def paste_centered(base_rgb, patch_rgb, patch_mask, box_xyxy, feather_sigma=0.6):
    x0,y0,x1,y1 = box_xyxy
    bw = max(1, x1-x0+1)
    bh = max(1, y1-y0+1)

    pr, mr = resize_keep_aspect(patch_rgb, patch_mask, bw, bh)
    m_soft = feather_mask(mr, sigma=feather_sigma)
    alpha = (m_soft.astype(np.float32) / 255.0)[..., None]

    out = base_rgb.copy()
    ph, pw = pr.shape[:2]
    cx = x0 + bw//2
    cy = y0 + bh//2
    px0 = int(cx - pw//2)
    py0 = int(cy - ph//2)

    H,W = out.shape[:2]
    rx0 = max(0, px0); ry0 = max(0, py0)
    rx1 = min(W, px0+pw); ry1 = min(H, py0+ph)
    if rx1 <= rx0 or ry1 <= ry0:
        return out

    sx0 = rx0 - px0; sy0 = ry0 - py0
    sx1 = sx0 + (rx1-rx0); sy1 = sy0 + (ry1-ry0)

    roi = out[ry0:ry1, rx0:rx1].astype(np.float32)
    pr_roi = pr[sy0:sy1, sx0:sx1].astype(np.float32)
    a_roi = alpha[sy0:sy1, sx0:sx1].astype(np.float32)

    out[ry0:ry1, rx0:rx1] = (roi*(1-a_roi) + pr_roi*a_roi).astype(np.uint8)
    return out

def apply_replace_with_stove(image_pil, source_label, processor, model):
    box, sc = run_detr_get_best_box(image_pil, source_label, processor, model, thresh=DETR_THRESH)
    if box is None:
        raise RuntimeError(f"No detection for label={source_label}")

    W,H = image_pil.size
    box = expand_box_xyxy(box, W, H, BOX_PAD)

    base_rgb = np.array(image_pil.convert("RGB"))
    base_rgb_removed = base_rgb.copy()
    x0,y0,x1,y1 = box
    base_rgb_removed[y0:y1+1, x0:x1+1] = 255

    stove_pil = pick_random_patch(STOVE_PATCH_DIR, min_size=MIN_PATCH_SIZE)
    stove_rgb = np.array(stove_pil)
    m = make_ink_mask(stove_rgb, ink_thresh=INK_THRESH, dilate=PATCH_DILATE)
    stove_rgb_c, m_c = tight_crop_to_mask(stove_rgb, m, margin=2)

    out_rgb = paste_centered(
        base_rgb_removed,
        stove_rgb_c,
        m_c,
        box,
        feather_sigma=FEATHER_SIGMA_PATCH
    )

    return Image.fromarray(out_rgb), box, {"score": sc, "patch_src": str(stove_pil.size)}


# =========================================================
# LOGIC: ADD CLOSET (WALL-ONLY v2)
# =========================================================
TARGET_LABEL = "closet"
DETR_THRESH_CLOSET = 0.20

WALL_KERNEL = 11
MARGIN = 10
STRIDE = 4
MIN_INK_INSIDE_FRAC = 0.99
MAX_NONWALL_BLACK_FRAC = 0.06
EXISTING_DILATE = 3
MAX_OVERLAP_PIXELS_BASE = 10
MAX_OVERLAP_FRAC_BASE   = 0.03
RING_DILATE = 10
MAX_RING_OVERLAP_PIXELS = 25
MAX_RING_OVERLAP_FRAC   = 0.03
FEATHER_SIGMA_CLOSET = 0.6
MIN_TEMPLATE_SIZE = 18
MAX_TEMPLATE_REL = 0.45
BOX_PAD_CLOSET = 2
BANDS = [10, 14, 18, 24, 30, 40, 55]
OV_PIX_LIST  = [MAX_OVERLAP_PIXELS_BASE, 16, 24, 36]
OV_FRAC_LIST = [MAX_OVERLAP_FRAC_BASE,   0.04, 0.06, 0.08]
WALL_TOUCH_MAX_DIST = 3.0
MIN_WALL_TOUCH_FRAC = 0.03
THICK_MIN_AREA       = 2500
THICK_MIN_LONG_SIDE  = 160
THICK_MIN_SHORT_SIDE = 6
WALL_INK_DILATE = 3


def filter_thick_map_components(thick_map, min_area=1500, min_long_side=120, min_short_side=6):
    bw = (thick_map > 0).astype(np.uint8)
    num, labels, stats, _ = cv2.connectedComponentsWithStats(bw, connectivity=8)
    out = np.zeros_like(bw)
    for i in range(1, num):
        x, y, w, h, area = stats[i]
        long_side = max(w, h)
        short_side = min(w, h)
        if area < min_area:
            continue
        if long_side < min_long_side:
            continue
        if short_side < min_short_side:
            continue
        out[labels == i] = 1
    return (out * 255).astype(np.uint8)

def build_line_maps(img_bgr, wall_kernel):
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    _, bin_inv = cv2.threshold(gray, 235, 255, cv2.THRESH_BINARY_INV) 

    k = cv2.getStructuringElement(cv2.MORPH_RECT, (wall_kernel, wall_kernel))
    thick = cv2.morphologyEx(bin_inv, cv2.MORPH_CLOSE, k, iterations=1)
    thick = cv2.erode(thick, np.ones((3,3), np.uint8), iterations=1)

    thick_clean = filter_thick_map_components(
        thick,
        min_area=THICK_MIN_AREA,
        min_long_side=THICK_MIN_LONG_SIDE,
        min_short_side=THICK_MIN_SHORT_SIDE
    )
    return bin_inv, thick_clean

def build_interior_from_walls(thick_map, close_gaps=13, wall_dilate=2):
    H, W = thick_map.shape[:2]
    walls = (thick_map > 0).astype(np.uint8) * 255

    if wall_dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_RECT, (2*wall_dilate+1, 2*wall_dilate+1))
        walls = cv2.dilate(walls, k, iterations=1)

    k2 = cv2.getStructuringElement(cv2.MORPH_RECT, (close_gaps, close_gaps))
    walls_closed = cv2.morphologyEx(walls, cv2.MORPH_CLOSE, k2, iterations=1)

    free = cv2.bitwise_not(walls_closed)
    ff = free.copy()
    mask = np.zeros((H+2, W+2), dtype=np.uint8)

    for seed in [(0,0), (W-1,0), (0,H-1), (W-1,H-1)]:
        if ff[seed[1], seed[0]] == 255:
            cv2.floodFill(ff, mask, seedPoint=seed, newVal=0)

    interior = ff
    return interior, walls_closed

def build_placement_mask(interior_mask, thick_map, wall_band):
    thick_bin = (thick_map > 0).astype(np.uint8)
    dist = cv2.distanceTransform(1 - thick_bin, cv2.DIST_L2, 3)
    near_wall = (dist <= wall_band).astype(np.uint8) * 255
    placement = cv2.bitwise_and(interior_mask, near_wall)
    return placement

def mask_from_crop_nonwhite(crop_bgr, feather_sigma):
    gray = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2GRAY)
    hard = cv2.inRange(gray, 0, 245)
    hard = cv2.morphologyEx(hard, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1)
    soft = hard.copy()
    if feather_sigma and feather_sigma > 0:
        soft = cv2.GaussianBlur(soft, (0,0), sigmaX=float(feather_sigma))
    return hard, soft

def make_ring_mask(mh_uint8, ring_dilate):
    mh = (mh_uint8 > 0).astype(np.uint8)
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*ring_dilate+1, 2*ring_dilate+1))
    dil = cv2.dilate(mh, k, iterations=1)
    ring = (dil > 0).astype(np.uint8) - (mh > 0).astype(np.uint8)
    ring[ring < 0] = 0
    return ring

def find_best_spot_wall_only(
    interior_mask, placement_mask,
    thick_map,
    all_ink, nonwall_ink,
    template_mask_hard, roi_w, roi_h,
    margin, stride,
    max_nonwall_black_frac,
    existing_dilate,
    max_overlap_pixels, max_overlap_frac,
    wall_band,
    min_ink_inside_frac,
    wall_touch_max_dist, min_wall_touch_frac,
    ring_dilate, max_ring_overlap_pixels, max_ring_overlap_frac
):
    H, W = interior_mask.shape[:2]

    if existing_dilate and existing_dilate > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*existing_dilate+1, 2*existing_dilate+1))
        nonwall_ink_dil = cv2.dilate(nonwall_ink, k, iterations=1)
    else:
        nonwall_ink_dil = nonwall_ink

    thick_bin = (thick_map > 0).astype(np.uint8)
    dist_to_wall = cv2.distanceTransform(1 - thick_bin, cv2.DIST_L2, 3)

    mh = (template_mask_hard > 0).astype(np.uint8)
    mh_sum = int(mh.sum())
    if mh_sum == 0:
        return None

    ring = make_ring_mask(template_mask_hard, ring_dilate)
    ring_sum = int(ring.sum()) if ring is not None else 0

    best = None
    best_score = -1e18

    for y in range(margin, H - roi_h - margin, stride):
        for x in range(margin, W - roi_w - margin, stride):

            cx = x + roi_w // 2
            cy = y + roi_h // 2
            if placement_mask[cy, cx] == 0:
                continue

            roi_nonwall = nonwall_ink[y:y+roi_h, x:x+roi_w]
            nonwall_black_frac = float((roi_nonwall > 0).mean())
            if nonwall_black_frac > max_nonwall_black_frac:
                continue

            roi_int = (interior_mask[y:y+roi_h, x:x+roi_w] > 0).astype(np.uint8)
            inside_on_ink = roi_int[mh > 0]
            if inside_on_ink.size == 0:
                continue
            ink_inside_frac = float(inside_on_ink.mean())
            if ink_inside_frac < min_ink_inside_frac:
                continue

            roi_dist = dist_to_wall[y:y+roi_h, x:x+roi_w]
            ink_dist = roi_dist[mh > 0]
            if ink_dist.size == 0:
                continue
            if float(ink_dist.mean()) > float(wall_band):
                continue
            touch_frac = float((ink_dist <= wall_touch_max_dist).mean())
            if touch_frac < min_wall_touch_frac:
                continue

            roi_exist = (nonwall_ink_dil[y:y+roi_h, x:x+roi_w] > 0).astype(np.uint8)
            overlap_pixels = int((roi_exist * mh).sum())
            if overlap_pixels > max_overlap_pixels:
                continue
            overlap_frac = overlap_pixels / float(mh_sum)
            if overlap_frac > max_overlap_frac:
                continue

            ring_overlap = 0
            ring_overlap_frac = 0.0
            if ring_sum > 0:
                ring_overlap = int((roi_exist * ring).sum())
                if ring_overlap > max_ring_overlap_pixels:
                    continue
                ring_overlap_frac = ring_overlap / float(ring_sum)
                if ring_overlap_frac > max_ring_overlap_frac:
                    continue

            score = (
                (1.0 - nonwall_black_frac) * 1.5 +
                (1.0 - overlap_frac) * 4.0 +
                (1.0 - ring_overlap_frac) * 2.5 +
                (1.0 - min(1.0, float(ink_dist.mean()) / float(wall_band))) * 1.5 +
                touch_frac * 1.0 +
                ink_inside_frac * 0.5
            )

            if score > best_score:
                best_score = score
                best = (x, y, nonwall_black_frac, ink_inside_frac, overlap_frac, overlap_pixels,
                        float(ink_dist.mean()), touch_frac, ring_overlap)

    return best

def apply_add_closet_wall_only_v2(image_pil, processor, model):
    """
    Returns:
      out_pil,
      placed_box [x0,y0,x1,y1],
      template_box [x0,y0,x1,y1],
      extra_params dict,
      viz_rgb (for result_box)
    """
    img_bgr = cv2.cvtColor(np.array(image_pil.convert("RGB")), cv2.COLOR_RGB2BGR)
    H, W = img_bgr.shape[:2]

    # DETR detections (all closets) from current image
    inputs = processor(images=image_pil, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    target_sizes = torch.tensor([image_pil.size[::-1]], device=DEVICE)
    res = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=DETR_THRESH_CLOSET)[0]
    id2label = model.config.id2label

    templates = []
    for sc, lab, bx in zip(res["scores"], res["labels"], res["boxes"]):
        if id2label[int(lab)] != TARGET_LABEL:
            continue
        b = expand_box_xyxy(bx.detach().cpu().numpy(), W, H, BOX_PAD_CLOSET)
        x0,y0,x1,y1 = map(int, b)
        crop = img_bgr[y0:y1, x0:x1].copy()
        if crop.size == 0:
            continue
        th, tw = crop.shape[:2]
        if th < MIN_TEMPLATE_SIZE or tw < MIN_TEMPLATE_SIZE:
            continue
        if tw > W * MAX_TEMPLATE_REL or th > H * MAX_TEMPLATE_REL:
            continue
        mh, ms = mask_from_crop_nonwhite(crop, FEATHER_SIGMA_CLOSET)
        templates.append({"crop": crop, "mask_hard": mh, "mask_soft": ms, "box": (x0,y0,x1,y1), "score": float(sc)})

    if not templates:
        raise RuntimeError("No closets detected for template (try lower DETR_THRESH_CLOSET).")

    canon = choose_template_highest_score(templates)
    template = canon["crop"]
    template_mask_hard = canon["mask_hard"]
    template_mask_soft = canon["mask_soft"]
    sx0,sy0,sx1,sy1 = canon["box"]

    # proportional scaling from median closets in this plan
    ws = [t["box"][2] - t["box"][0] for t in templates]
    hs = [t["box"][3] - t["box"][1] for t in templates]
    med_w, med_h = float(np.median(ws)), float(np.median(hs))
    base_h, base_w = template.shape[:2]
    target_scale = min(med_w / max(1.0, base_w), med_h / max(1.0, base_h))
    target_scale = max(0.5, min(1.2, float(target_scale)))
    SCALES = [target_scale * s for s in [0.9, 1.0, 1.1]]

    # build maps
    bin_inv_lines, thick_map = build_line_maps(img_bgr, WALL_KERNEL)
    interior_mask, _walls_closed = build_interior_from_walls(thick_map, close_gaps=13, wall_dilate=2)

    all_ink = (bin_inv_lines > 0).astype(np.uint8) * 255
    wall_ink = (thick_map > 0).astype(np.uint8) * 255
    if WALL_INK_DILATE > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*WALL_INK_DILATE+1, 2*WALL_INK_DILATE+1))
        wall_ink = cv2.dilate(wall_ink, k, iterations=1)
    nonwall_ink = cv2.bitwise_and(all_ink, cv2.bitwise_not(wall_ink))

    # search
    best = None
    chosen = {"scale": None, "band": None, "ov_pix": None, "ov_frac": None}

    for band in BANDS:
        placement_mask = build_placement_mask(interior_mask, thick_map, wall_band=band)
        if int((placement_mask > 0).sum()) < 500:
            continue

        for ov_pix in OV_PIX_LIST:
            for ov_frac in OV_FRAC_LIST:
                for sc in SCALES:
                    t_try, mh_try, ms_try = resize_template(template, template_mask_hard, template_mask_soft, sc)
                    roi_h, roi_w = t_try.shape[:2]

                    best_try = find_best_spot_wall_only(
                        interior_mask=interior_mask,
                        placement_mask=placement_mask,
                        thick_map=thick_map,
                        all_ink=all_ink,
                        nonwall_ink=nonwall_ink,
                        template_mask_hard=mh_try,
                        roi_w=roi_w, roi_h=roi_h,
                        margin=MARGIN, stride=STRIDE,
                        max_nonwall_black_frac=MAX_NONWALL_BLACK_FRAC,
                        existing_dilate=EXISTING_DILATE,
                        max_overlap_pixels=ov_pix,
                        max_overlap_frac=ov_frac,
                        wall_band=band,
                        min_ink_inside_frac=MIN_INK_INSIDE_FRAC,
                        wall_touch_max_dist=WALL_TOUCH_MAX_DIST,
                        min_wall_touch_frac=MIN_WALL_TOUCH_FRAC,
                        ring_dilate=RING_DILATE,
                        max_ring_overlap_pixels=MAX_RING_OVERLAP_PIXELS,
                        max_ring_overlap_frac=MAX_RING_OVERLAP_FRAC
                    )

                    if best_try is not None:
                        best = best_try
                        template, template_mask_hard, template_mask_soft = t_try, mh_try, ms_try
                        chosen = {"scale": float(sc), "band": int(band), "ov_pix": int(ov_pix), "ov_frac": float(ov_frac)}
                        break
                if best is not None:
                    break
            if best is not None:
                break
        if best is not None:
            break

    if best is None:
        raise RuntimeError("No placement found for add_closet_wall_only_v2.")

    px, py, nonwall_black_frac, ink_inside_frac, ov_f, ov_p, mean_dist, touch_frac, ring_overlap = best
    roi_h, roi_w = template.shape[:2]

    out = img_bgr.copy()
    out = alpha_paste(out, template, template_mask_soft, px, py)

    viz = out.copy()
    cv2.rectangle(viz, (sx0,sy0), (sx1,sy1), (255,0,0), 2)
    cv2.rectangle(viz, (px,py), (px+roi_w, py+roi_h), (0,0,255), 2)

    placed_box = [int(px), int(py), int(px+roi_w), int(py+roi_h)]
    template_box = [int(sx0), int(sy0), int(sx1), int(sy1)]

    extra = {
        "template_score": float(canon["score"]),
        "placement": {
            "nonwall_black_frac": float(nonwall_black_frac),
            "ink_inside_frac": float(ink_inside_frac),
            "overlap_frac": float(ov_f),
            "overlap_pixels": int(ov_p),
            "mean_wall_dist": float(mean_dist),
            "touch_frac": float(touch_frac),
            "ring_overlap": int(ring_overlap),
        },
        "chosen": chosen,
    }

    out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
    viz_rgb = cv2.cvtColor(viz, cv2.COLOR_BGR2RGB)
    return Image.fromarray(out_rgb), placed_box, template_box, extra, viz_rgb


# =========================================================
# SUMMARY (from metadata.jsonl)
# =========================================================
def print_summary_from_metadata(meta_path: Path, changes: list):
    change_names = [c["name"] for c in changes]

    stats = {
        name: {"ok": 0, "fail": 0, "time_sum": 0.0, "time_count": 0}
        for name in change_names
    }

    total_ok = 0
    total_fail = 0

    for rec in iter_metadata(meta_path):
        ch = rec.get("change")
        if ch not in stats:
            continue

        status = rec.get("status", "fail")
        if status == "ok":
            stats[ch]["ok"] += 1
            total_ok += 1
        else:
            stats[ch]["fail"] += 1
            total_fail += 1

        t = rec.get("elapsed_sec", None)
        if isinstance(t, (int, float)):
            stats[ch]["time_sum"] += float(t)
            stats[ch]["time_count"] += 1

    print("\n" + "="*70)
    print("SUMMARY (from metadata.jsonl)")
    print("="*70)

    grand_total = total_ok + total_fail
    print(f"Total records: {grand_total} | OK: {total_ok} | FAIL: {total_fail}")
    print("-"*70)

    for name in change_names:
        ok = stats[name]["ok"]
        fail = stats[name]["fail"]
        tot = ok + fail
        ok_pct = (100.0 * ok / tot) if tot > 0 else 0.0
        avg_t = (stats[name]["time_sum"] / stats[name]["time_count"]) if stats[name]["time_count"] > 0 else 0.0
        print(f"{name:24s} | ok={ok:5d} | fail={fail:5d} | ok%={ok_pct:6.2f}% | avg_time={avg_t:6.3f}s")

    print("="*70 + "\n")


# =========================================================
# MAIN
# =========================================================
def main():
    # Clear previous outputs + metadata to avoid duplicates
    wipe_outputs_and_metadata()

    # Re-create output folders
    for ch in CHANGES:
        ensure_dir(IMAGES_AFTER_DIR / ch["group"] / ch["name"])

    # Load DETR
    print(f"Loading DETR from: {DETR_MODEL_PATH}")
    processor = DetrImageProcessor.from_pretrained(str(DETR_MODEL_PATH))
    model = DetrForObjectDetection.from_pretrained(str(DETR_MODEL_PATH)).to(DEVICE)
    model.eval()

    imgs_all = list_images(IMAGES_BEFORE_DIR)
    if RUN_FIRST_N_IMAGES is None:
        imgs = imgs_all
    else:
        imgs = imgs_all[:int(RUN_FIRST_N_IMAGES)]

    print(f"Found {len(imgs_all)} total images.")
    print(f"Will process {len(imgs)} images. (RUN_FIRST_N_IMAGES={RUN_FIRST_N_IMAGES})")
    print(f"Metadata -> {META_PATH.as_posix()} (fresh) | run_id={RUN_ID}")

    total_tasks = len(imgs) * len(CHANGES)
    done_tasks = 0
    start_time = time.time()

    # Rough upfront estimate
    print("\nRough ETA will stabilize after some tasks...\n")

    for idx, img_path in enumerate(imgs):
        try:
            image_pil = Image.open(img_path).convert("RGB")
        except Exception as e:
            for ch in CHANGES:
                done_tasks += 1
                append_metadata({
                    "run_id": RUN_ID,
                    "status": "fail_read",
                    "src_image": str(img_path),
                    "change": ch["name"],
                    "group": ch["group"],
                    "error": str(e),
                    "ts": datetime.now().isoformat(),
                    "elapsed_sec": 0.0
                })
            print(f"[SKIP] cannot read {img_path.name}: {e}")
            continue

        for ch in CHANGES:
            t0 = time.time()

            out_dir = IMAGES_AFTER_DIR / ch["group"] / ch["name"]
            stem = safe_stem(img_path.stem)
            out_name = f"{stem}_{idx}_synthetic{img_path.suffix.lower()}"
            out_path = out_dir / out_name

            rec = {
                "run_id": RUN_ID,
                "ts": datetime.now().isoformat(),
                "status": "fail",
                "src_image": str(img_path),
                "src_stem": stem,
                "src_index": idx,
                "change": ch["name"],
                "group": ch["group"],
                "out_image": str(out_path),
                "bbox": None,
                "extra": {},
            }

            try:
                if ch["type"] == "remove":
                    out_pil, bbox, extra = apply_remove_white(image_pil, ch["label"], processor, model)
                    rec["bbox"] = {"target": ch["label"], "box_xyxy": bbox}
                    rec["extra"] = extra
                    out_pil.save(out_path)
                    rec["status"] = "ok"

                elif ch["type"] == "replace":
                    out_pil, bbox, extra = apply_replace_with_stove(image_pil, ch["src"], processor, model)
                    rec["bbox"] = {"target": ch["src"], "box_xyxy": bbox}
                    rec["extra"] = extra
                    out_pil.save(out_path)
                    rec["status"] = "ok"

                elif ch["type"] == "add_closet_wall_only_v2":
                    out_pil, placed_box, template_box, extra, viz_rgb = apply_add_closet_wall_only_v2(image_pil, processor, model)

                    out_pil.save(out_path)

                    box_path = out_dir / f"{stem}_{idx}_synthetic_box.png"
                    Image.fromarray(viz_rgb).save(box_path)

                    rec["status"] = "ok"
                    rec["bbox"] = {
                        "template_box_xyxy": template_box,
                        "placed_box_xyxy": placed_box
                    }
                    rec["extra"] = extra
                    rec["out_result_box"] = str(box_path)

                else:
                    raise RuntimeError(f"Unknown change type: {ch['type']}")

            except Exception as e:
                rec["error"] = str(e)

            rec["elapsed_sec"] = float(time.time() - t0)
            append_metadata(rec)

            done_tasks += 1

            # Progress + rolling ETA
            elapsed = time.time() - start_time
            avg = elapsed / max(1, done_tasks)
            eta = avg * (total_tasks - done_tasks)

            pct = 100.0 * done_tasks / max(1, total_tasks)
            # estimated total runtime based on current average
            est_total = avg * total_tasks

            print(
                f"[{done_tasks}/{total_tasks} | {pct:5.1f}%] "
                f"elapsed={format_hhmmss(elapsed)} ETA={format_hhmmss(eta)} "
                f"(est_total~{format_hhmmss(est_total)}) | "
                f"{img_path.name} -> {ch['name']} -> {rec['status']}"
            )

        gc.collect()
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

    print("\nDONE")
    print_summary_from_metadata(META_PATH, CHANGES)


if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


DEVICE: cuda




Found 455 total images.
Will process 455 images. (RUN_FIRST_N_IMAGES=None)
Metadata -> synthetic_dataset/metadata.jsonl (fresh) | run_id=20260113_143558

Rough ETA will stabilize after some tasks...

[1/4095 |   0.0%] elapsed=00:00:06 ETA=06:49:48 (est_total~06:49:54) | image_003.jpg -> remove_stove -> ok
[2/4095 |   0.0%] elapsed=00:00:11 ETA=06:23:42 (est_total~06:23:53) | image_003.jpg -> remove_sink -> ok
[3/4095 |   0.1%] elapsed=00:00:11 ETA=04:16:17 (est_total~04:16:28) | image_003.jpg -> remove_toilet -> fail
[4/4095 |   0.1%] elapsed=00:00:11 ETA=03:12:33 (est_total~03:12:44) | image_003.jpg -> replace_toilet2stove -> fail
[5/4095 |   0.1%] elapsed=00:00:11 ETA=02:34:24 (est_total~02:34:35) | image_003.jpg -> replace_sink2stove -> ok
[6/4095 |   0.1%] elapsed=00:00:15 ETA=03:01:12 (est_total~03:01:28) | image_003.jpg -> remove_1stdoor -> ok
[7/4095 |   0.2%] elapsed=00:00:20 ETA=03:19:46 (est_total~03:20:06) | image_003.jpg -> remove_2sdoor -> ok
[8/4095 |   0.2%] elapsed=00:0