#CONFIG

In [None]:
# === CONFIG ===

# Queued targets that REQUIRE the detailer (already pre-filtered by your generator)
DETAILER_QUEUE_FOLDER = "/content/drive/MyDrive/SS_OUTPUT_FOLDER/v1-5"  # @param {type:"string"}

# Where to save (1) the source garment and (2) the inpainted results
TARGET_DIR = "/content/drive/MyDrive/DETAILER_DONE/vtnon_v1_5"               # @param {type:"string"}

# Root that contains your SKU trees (used to locate the source)
WORKING_DIR = "/content/drive/MyDrive/SikSilk"                  # @param {type:"string"}

# Root for (subcategory-wide) garment masks to constrain the detector
MASKS_ROOT = "/content/drive/MyDrive/SKSLK_MODELS"              # @param {type:"string"}

# Model/runtime knobs
DEVICE_STR = "cuda"
INPAINT_GENEROUS_PAD = 150                                      # @param {type:"integer"}
INPAINT_TINY_PAD = 6                                            # @param {type:"integer"}
INPAINT_SEED = 2025                                             # @param {type:"integer"}
VISUALIZE = True                                                # @param {type:"boolean"}

# File patterns
VALID_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".PNG", ".JPG", ".JPEG", ".WEBP")

# Allowed detail tokens (normalized)
ALLOWED_DETAIL_TYPES = ["crest", "logo", "patch", "waist text", "sleeve text"]

# Allowed garment tokens (for prompting DINO)
ALLOWED_GARMENT_TYPES = [
    "hoodie","jeans","joggers","shorts","sweater","swimwear","t-shirt","shirts",
    "track top","trousers","twinset","polo","vests","shirts"
]
TOP_GARMENTS = ["t-shirt", "shirt", "sweater", "hoodie", "track top", "vest"]
BOTTOM_GARMENTS = ["shorts", "jogger-trousers", "trousers", "jeans", "swimwear"]
TWINSET_TYPES = ["twinset"]

# Angle parsing / source lookup helpers
BASE_NAMES = ["fr_rght", "fr_lft", "fr_cl",
              #"bc_rght", "bc_lft", "bc_cl", "bc",
              "fr", "lft", "rght"]
ACCEPTABLE_SUFFIXES = ["cut"]

# Skip if already have any inpainted output in TARGET_DIR for this SKU+angle
SKIP_IF_ALREADY_INPAINTED = True # @param {type:"boolean"}

USE_BF16_INFERENCE = True  # global toggle


# Create target dir if missing
import os, pathlib
pathlib.Path(TARGET_DIR).mkdir(parents=True, exist_ok=True)


#INSTALLS (restart & reinstall again after this)

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# SAM3 via Hugging Face transformers
!pip install -q "git+https://github.com/huggingface/transformers.git"


In [None]:
# GroundingDINO setup removed — SAM3 now handles detection + segmentation in one model.


In [None]:
%pip -q install open_clip_torch ninja wheel transformers accelerate \
                 sentencepiece protobuf huggingface_hub opencv-python
!pip install -U --no-deps --force-reinstall "git+https://github.com/huggingface/diffusers.git@main"
#%pip -q install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install --upgrade open_clip_torch

In [None]:
!pip -q install piexif

In [None]:
%cd /content/
!git clone --depth 1 https://github.com/song-wensong/insert-anything.git

In [None]:
!pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.2.0+torch2.6-cp312-cp312-linux_x86_64.whl
!pip install torch==2.6 torchvision==0.21 torchaudio==2.6
!pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
!git clone https://huggingface.co/aha2023/insert-anything-lora-for-nunchaku

#SETUP

In [None]:
import os, sys, torch, numpy as np, cv2, base64, gc, json
from pathlib import Path
from io import BytesIO
from PIL import Image, ImageOps
import piexif

device = torch.device(DEVICE_STR if torch.cuda.is_available() else "cpu")
print("✅ Torch device:", device)


In [None]:
import torch
from transformers import Sam3Processor, Sam3Model

HF_SAM3_ID = "facebook/sam3"
SAM3_CONFIDENCE = 0.05   # permissive to catch small logos; raise if predictions get noisy
SAM3_RESOLUTION = 1024
SAM3_DEVICE = device

sam3_processor = Sam3Processor.from_pretrained(HF_SAM3_ID)
sam3_model = Sam3Model.from_pretrained(HF_SAM3_ID).to(SAM3_DEVICE)
sam3_model.eval()
print("✅ HF SAM3 ready (text → boxes → masks)")


In [None]:
# SAM3 replaces the old SAM2 predictor — no extra setup needed.


In [None]:
# SAM3 initialized above.


In [None]:
#@title Insert_anything on nunchaku
%cd /content/insert-anything
from PIL import Image
import torch
import os
import numpy as np
import cv2
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, expand_image_mask
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from datetime import datetime

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

device = torch.device(f"cuda")
dtype = torch.bfloat16
size = (1024, 1024)



# Load the pre-trained model and LoRA-for-nunchaku weights
# Please replace the paths with your own paths
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev",
    transformer=transformer,
    torch_dtype=dtype
)

transformer.update_lora_params("/content/drive/MyDrive/insert-anything-lora/insert-anything_extracted_lora_rank_256-bf16.safetensors")
# Adjust the LoRA strength
transformer.set_lora_strength(1)

redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev").to(dtype=dtype)



# The purpose of this code is to reduce the GPU memory usage to 26GB, but it will increase the inference time accordingly.
pipe.to("cuda")
redux.to("cuda")
os.environ["NNCF_GROUP_SIZE"] = "-1"      # disable token merging


# UTILS

In [None]:
import matplotlib.pyplot as plt

In [None]:
def open_upright(path) -> Image.Image:
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im).convert("RGB")

def open_source_with_black_bg(path: str) -> Image.Image:
    im = Image.open(path)
    im = ImageOps.exif_transpose(im)
    name_low = os.path.basename(path).lower()
    if "_cut" in name_low and im.mode in ("RGBA","LA"):
        rgb = im.convert("RGB")
        alpha = im.getchannel("A")
        black = Image.new("RGB", im.size, (0,0,0))
        return Image.composite(rgb, black, alpha)
    return im.convert("RGB")


# NEW — root of subcategory-wide garment masks
MASKS_ROOT = '/content/drive/MyDrive/SKSLK_MODELS'
MASK_EXTS = ('.png', '.jpg', '.jpeg', '.webp', '.PNG', '.JPG', '.JPEG', '.WEBP')




import re
from pathlib import Path
from PIL import Image, ImageOps, ImageDraw

# --- Helper: get <Category>/<Subcategory> from the *source* path ------------
_SKU_DIR_RE = re.compile(r"SS-\d{3,7}", re.IGNORECASE)

def _category_subcategory_from_source(src_path: str) -> tuple[str, str] | None:
    """
    Resolve (Category, Subcategory) from the garment *source* path.
    Preferred: relative to WORKING_DIR → parts[0], parts[1].
    Fallback: find the SKU folder in the path and take the two parents.
    Returns None if not resolvable.
    """
    p = Path(src_path).resolve()
    wr = Path(WORKING_DIR).resolve()

    # Preferred: relative to WORKING_DIR
    try:
        rel = p.relative_to(wr)
        parts = rel.parts
        # Expect: Category/Subcategory/SKU/<file>
        if len(parts) >= 3:
            return parts[0], parts[1]
    except Exception:
        pass

    # Fallback: locate the SKU dir and take its two parents as Cat/Subcat
    parts = p.parts
    sku_idx = None
    for i, part in enumerate(parts):
        if _SKU_DIR_RE.fullmatch(part or ""):
            sku_idx = i
            break
    if sku_idx is not None and sku_idx >= 2:
        return parts[sku_idx - 2], parts[sku_idx - 1]

    # Last resort: try after an explicit 'SikSilk' anchor
    if "SikSilk" in parts:
        j = parts.index("SikSilk")
        if len(parts) >= j + 3:
            return parts[j + 1], parts[j + 2]

    return None

# --- New: derive mask basename from *angle*, not from filename heuristics ----
def _mask_basename_from_angle(angle_code: str | None) -> str | None:
    """
    Map 'fr' -> 'fr_mask', 'fr_lft' -> 'fr_lft_mask', 'bc_cl' -> 'bc_cl_mask', etc.
    If angle_code is missing, return None (→ no mask).
    """
    if not angle_code:
        return None
    angle = angle_code.strip().lower()
    return f"{angle}_mask"

# --- Exact-only mask finder ---------------------------------------------------

def find_mask_for_generated_exact(gen_path: str, source_path: str) -> Path | None:
    """
    EXACT lookup (no fuzzy fallbacks):
      angle  = parsed from queued filename/path (e.g., SS-12345_fr_cl.* -> 'fr_cl')
      (cat, subcat) = derived from source_path
      priority: <angle>_mask_agnostic.<ext>  →  <angle>_mask.<ext>
      searched in: MASKS_ROOT / cat / subcat
    """
    # 1) angle from queued filename
    _, angle = extract_sku_and_angle_from_path(gen_path)
    if not angle:
        print("⚠️  No angle parsed — proceeding without a mask.")
        return None
    angle = angle.strip().lower()

    # 2) category/subcategory from source path
    cat_sub = _category_subcategory_from_source(source_path)
    if not cat_sub:
        print("⚠️  Could not resolve Category/Subcategory from source path — no mask.")
        return None
    category, subcategory = cat_sub

    mask_dir = Path(MASKS_ROOT) / category / subcategory

    # 3) Try agnostic first, then regular; exact names only
    candidates = [f"{angle}_mask_agnostic", f"{angle}_mask"]

    for name in candidates:
        for ext in MASK_EXTS:
            cand = mask_dir / f"{name}{ext}"
            if cand.exists():
                which = "agnostic" if name.endswith("_agnostic") else "regular"
                print(f"✅ Found {which} mask: {cand}")
                return cand

    print(f"⚠️  No exact mask found in {mask_dir} for angle '{angle}' "
          f"(tried {candidates} with MASK_EXTS). Proceeding without mask.")
    return None

def load_binary_mask_for_generated(gen_path: str, source_path: str, gen_img: Image.Image) -> np.ndarray | None:
    mp = find_mask_for_generated_exact(gen_path, source_path)
    if mp is None:
        return None
    with Image.open(mp) as m:
        m = ImageOps.exif_transpose(m)
        return align_mask_to_image(m, gen_img)

# --- Align (unchanged) --------------------------------------------------------
def align_mask_to_image(mask_img: Image.Image, target_img: Image.Image) -> np.ndarray:
    mw, mh = mask_img.size
    tw, th = target_img.size
    if mh == th and mw > 0 and (tw % mw) == 0 and 1 < (tw // mw) <= 3:
        k = tw // mw
        tiled = Image.new('L', (tw, th), 0)
        src = mask_img.convert('L')
        for i in range(k):
            tiled.paste(src, (i * mw, 0))
        M = np.array(tiled, dtype=np.uint8)
    else:
        if mw == 0 or mh == 0:
            return np.zeros((th, tw), np.uint8)
        scale = max(mw / tw, mh / th)
        new_w = int(round(mw / scale)); new_h = int(round(mh / scale))
        m_resized = mask_img.convert('L').resize((new_w, new_h), Image.NEAREST)
        M = np.zeros((th, tw), np.uint8)
        x0 = (tw - new_w) // 2; y0 = (th - new_h) // 2
        M[y0:y0+new_h, x0:x0+new_w] = np.array(m_resized, dtype=np.uint8)
    return ((M > 127).astype(np.uint8) * 255)






# ---- visuals
def _draw_bbox(img: Image.Image, bb_xyxy, color="lime", width=4):
    out = img.copy()
    if bb_xyxy is None: return out
    draw = ImageDraw.Draw(out)
    draw.rectangle(bb_xyxy, outline=color, width=width)
    return out

def _show_images(pairs, cols=3, figsize=(16,12)):
    rows = int(np.ceil(len(pairs) / cols))
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten() if rows*cols>1 else [axes]
    for ax,(title,img) in zip(axes, pairs):
        ax.imshow(img); ax.set_title(title, fontsize=10); ax.axis("off")
    for ax in axes[len(pairs):]: ax.axis("off")
    plt.tight_layout(); plt.show()

def resize_and_pad(image, target_size=1024):
    w, h = image.size
    scale = target_size / max(1, max(w, h))
    new_w, new_h = int(round(w * scale)), int(round(h * scale))
    image_resized = image.resize((new_w, new_h), Image.LANCZOS)

    pad_w = (target_size - new_w) // 2
    pad_h = (target_size - new_h) // 2
    padding = (pad_w, pad_h, target_size - new_w - pad_w, target_size - new_h - pad_h)

    # ✅ Match fill type to mode
    mode = image_resized.mode
    if mode in ("L", "1", "I", "F"):
        fill_color = 0                      # int for single-channel
    elif mode == "RGBA":
        fill_color = (0, 0, 0, 0)           # transparent for RGBA
    else:
        fill_color = (0, 0, 0)              # RGB tuple for RGB/others

    return ImageOps.expand(image_resized, padding, fill=fill_color)

def box_1024_to_original(box_xyxy_1024, original_w, original_h):
    x1_1024, y1_1024, x2_1024, y2_1024 = [float(v) for v in box_xyxy_1024]
    target_size = 1024
    w, h = original_w, original_h
    scale = target_size / max(w, h)
    new_w, new_h = int(round(w*scale)), int(round(h*scale))
    pad_w = (target_size - new_w)//2
    pad_h = (target_size - new_h)//2
    x1 = (x1_1024 - pad_w) / scale; x2 = (x2_1024 - pad_w) / scale
    y1 = (y1_1024 - pad_h) / scale; y2 = (y2_1024 - pad_h) / scale
    x1 = min(max(int(round(x1)),0), w); x2 = min(max(int(round(x2)),0), w)
    y1 = min(max(int(round(y1)),0), h); y2 = min(max(int(round(y2)),0), h)
    return [x1,y1,x2,y2]


In [None]:
def apply_binary_mask(img_rgb: Image.Image, mask_np: np.ndarray | None, outside_color=(5,5,5)) -> Image.Image:
    if mask_np is None:
        return img_rgb
    mask_L = Image.fromarray(mask_np.astype(np.uint8))
    mode = img_rgb.mode
    if mode not in ("RGB", "RGBA", "L"):
        img_rgb = img_rgb.convert("RGB")
        mode = "RGB"
    if mode == "RGB":
        if isinstance(outside_color, int):
            outside_color = (outside_color,) * 3
        bg = Image.new("RGB", img_rgb.size, outside_color)
    elif mode == "RGBA":
        if isinstance(outside_color, int):
            outside_color = (outside_color,) * 3 + (255,)
        elif len(outside_color) == 3:
            outside_color = (*outside_color, 255)
        bg = Image.new("RGBA", img_rgb.size, outside_color)
    else:
        if isinstance(outside_color, tuple):
            outside_color = int(np.mean(outside_color))
        bg = Image.new("L", img_rgb.size, int(outside_color))
    return Image.composite(img_rgb, bg, mask_L)


# Dynamic, perimeter-based mask gating for detect_detail (logo-friendly)
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw


# --- SAM3 helpers -----------------------------------------------------------
def _clip_box_to_image(box_xyxy, w: int, h: int):
    x1, y1, x2, y2 = box_xyxy
    x1 = max(0, min(w, int(round(x1))))
    y1 = max(0, min(h, int(round(y1))))
    x2 = max(0, min(w, int(round(x2))))
    y2 = max(0, min(h, int(round(y2))))
    return [x1, y1, x2, y2]


def _sam3_predict_text(image_pil: Image.Image, prompt: str, *, max_dets: int = 12, score_threshold: float = SAM3_CONFIDENCE):
    """Run SAM3 (HF) with a text prompt and return sorted predictions."""
    if not prompt:
        return []
    if "sam3_processor" not in globals() or "sam3_model" not in globals():
        raise RuntimeError("SAM3 is not initialized. Run the SAM3 setup cell first.")

    inputs = sam3_processor(images=image_pil, text=prompt, return_tensors="pt")
    inputs = {k: v.to(SAM3_DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = sam3_model(**inputs)

    processed = sam3_processor.post_process_instance_segmentation(
        outputs,
        threshold=score_threshold,
        target_sizes=[image_pil.size[::-1]],
    )[0]

    boxes = processed.get("boxes")
    scores = processed.get("scores")
    masks = processed.get("masks")
    if boxes is None or scores is None or boxes.numel() == 0:
        return []

    boxes_np = boxes.detach().cpu().numpy()
    scores_np = scores.detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy() if masks is not None else None

    order = scores_np.argsort()[::-1]
    preds = []
    for idx in order[:max_dets]:
        mask_np = None
        if masks_np is not None:
            mask_np = (masks_np[idx] > 0.5).astype(np.uint8)
        preds.append({
            "box": boxes_np[idx].tolist(),
            "score": float(scores_np[idx]),
            "mask": mask_np,
        })
    return preds

def _mask_crop_to_full(mask_crop: np.ndarray | None, crop_box_on_full, full_size):
    """
    Place a mask (aligned to a crop image) back onto the full canvas.
    crop_box_on_full = (lx, ty, rx, by) used to produce the crop.
    full_size = (W, H) of the destination image.
    """
    if mask_crop is None or crop_box_on_full is None:
        return mask_crop

    full_w, full_h = full_size
    lx, ty, rx, by = [int(round(v)) for v in crop_box_on_full]
    x0, y0 = max(0, lx), max(0, ty)
    x1, y1 = min(rx, full_w), min(by, full_h)
    if x1 <= x0 or y1 <= y0:
        return np.zeros((full_h, full_w), np.uint8)

    mx0, my0 = max(0, -lx), max(0, -ty)
    mx1, my1 = mx0 + (x1 - x0), my0 + (y1 - y0)

    patch = mask_crop[my0:my1, mx0:mx1]
    if patch.shape[1] != (x1 - x0) or patch.shape[0] != (y1 - y0):
        patch = cv2.resize(patch, (x1 - x0, y1 - y0), interpolation=cv2.INTER_NEAREST)

    full_mask = np.zeros((full_h, full_w), np.uint8)
    full_mask[y0:y1, x0:x1] = (patch > 0).astype(np.uint8) * 255
    return full_mask

def enlarge_mask(mask_np: np.ndarray, scale: float = 1.05) -> np.ndarray:
    """Dilate mask outward based on object size (≈scale of the foreground)."""
    if mask_np is None:
        return None
    mask = (mask_np > 0).astype(np.uint8)
    ys, xs = np.where(mask)
    if ys.size == 0 or scale <= 1.0:
        return (mask_np > 0).astype(np.uint8) * 255
    h_obj = ys.max() - ys.min() + 1
    w_obj = xs.max() - xs.min() + 1
    grow = max(1, int(round(max(h_obj, w_obj) * (scale - 1.0))))
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (grow * 2 + 1, grow * 2 + 1))
    out = cv2.dilate(mask, kernel, iterations=1)
    return (out > 0).astype(np.uint8) * 255






# Spatially guided detect_detail:
# - perimeter-based mask polarity (as before)
# - strict+tolerant mask gates
# - spatial prior from source garment (normalized center+area)
# - combined score = w_score * norm_score + w_spatial * spatial_affinity


# ---- helpers ---------------------------------------------------------------
# Top-7 spatial re-ranking for detail detection
# - Keeps perimeter-based mask polarity & light gates
# - Takes SAM3's highest-scoring proposals, filters with mask gates, then
#   re-ranks TOP_K using spatial prior from the SOURCE garment
# - Normalizes SAM3 scores within those K and slightly down-weights rank-1

import matplotlib.pyplot as plt

# ---------- helpers you already use elsewhere ----------
def make_spatial_prior_from_box(bb_xyxy, img_size):
    """Build prior from SOURCE detail box on its garment crop. Normalized to [0,1]."""
    if bb_xyxy is None:
        return None
    W, H = img_size
    x1, y1, x2, y2 = [float(v) for v in bb_xyxy]
    x1 = max(0.0, min(W, x1)); y1 = max(0.0, min(H, y1))
    x2 = max(0.0, min(W, x2)); y2 = max(0.0, min(H, y2))
    if x2 <= x1 or y2 <= y1:
        return None
    cx = ((x1 + x2) / 2.0) / max(1.0, W)
    cy = ((y1 + y2) / 2.0) / max(1.0, H)
    area = ((x2 - x1) * (y2 - y1)) / max(1.0, (W * H))
    return {"cx": float(cx), "cy": float(cy), "area": float(area)}

def _spatial_affinity(cx_n, cy_n, area_n, prior, mirror_ok=True,
                      sigma_center=0.16, sigma_area=0.50):
    """Gaussian affinity in [0,1] for center & (log)area; mirror-aware."""
    def _aff(cx_p):
        dc2 = (cx_n - cx_p)**2 + (cy_n - prior["cy"])**2
        s_center = np.exp(- dc2 / (2.0 * (sigma_center**2)))
        a = max(1e-6, area_n); ap = max(1e-6, prior["area"])
        dlog = np.log(a / ap)
        s_area = np.exp(- (dlog**2) / (2.0 * (sigma_area**2)))
        return float(s_center * s_area)
    base = _aff(prior["cx"])
    if mirror_ok:
        return max(base, _aff(1.0 - prior["cx"]))
    return base

# ---------- main: top-7 re-ranking ----------


def _ensure_mask_for_image(mask_input, image_pil, *, crop_box_on_full=None):
    """
    Align a mask to image_pil.

    mask_input:
      • np.ndarray aligned to image_pil (H×W)  OR
      • (mask_full_np, "FULL") + crop_box_on_full=(lx,ty,rx,by) from crop_to_square

    Returns: uint8 mask (0/255) aligned to image_pil.size, with the same padding
    behavior as crop_to_square (i.e., if the crop went outside, we pad zeros).
    """
    if mask_input is None:
        return None

    # Case 1: already aligned to this image
    if not (isinstance(mask_input, tuple) and len(mask_input) == 2 and isinstance(mask_input[0], np.ndarray) and mask_input[1] == "FULL"):
        m = mask_input
        if m.ndim == 3:
            m = m[...,0] if m.shape[2] > 1 else m.squeeze(-1)
        if m.dtype != np.uint8:
            m = (m > 0).astype(np.uint8) * 255
        if (m.shape[1], m.shape[0]) != image_pil.size:
            m = cv2.resize(m, image_pil.size, interpolation=cv2.INTER_NEAREST)
        return m

    # Case 2: FULL mask + crop box from crop_to_square
    mask_full, _ = mask_input
    assert crop_box_on_full is not None, "crop_box_on_full is required for FULL mask."

    Hf, Wf = mask_full.shape[:2]
    lx, ty, rx, by = crop_box_on_full  # exactly what crop_to_square returned

    # Target canvas (the square side used by crop_to_square)
    tgt_w = int(round(rx - lx))
    tgt_h = int(round(by - ty))

    # Source window (clamped to the full image bounds)
    sx1 = int(np.floor(max(0, lx)))
    sy1 = int(np.floor(max(0, ty)))
    sx2 = int(np.ceil(min(Wf, rx)))
    sy2 = int(np.ceil(min(Hf, by)))

    # Offsets where the source window lands on the target canvas
    dx = int(np.floor(max(0, -lx)))   # same as crop_to_square's dx
    dy = int(np.floor(max(0, -ty)))   # same as crop_to_square's dy

    # Build canvas and paste the clipped region at (dx,dy)
    canvas = np.zeros((tgt_h, tgt_w), dtype=np.uint8)
    if sx2 > sx1 and sy2 > sy1:
        patch = mask_full[sy1:sy2, sx1:sx2]
        if patch.ndim == 3:
            patch = patch[...,0] if patch.shape[2] > 1 else patch.squeeze(-1)
        ph, pw = patch.shape[:2]
        canvas[dy:dy+ph, dx:dx+pw] = (patch > 0).astype(np.uint8) * 255

    # If image_pil size differs by a pixel due to rounding, align by resize
    if (canvas.shape[1], canvas.shape[0]) != image_pil.size:
        canvas = cv2.resize(canvas, image_pil.size, interpolation=cv2.INTER_NEAREST)

    return canvas


def _build_inside_mask_1024(mask_aligned_np, image_pil, *,
                            border_sample_px=2, erode_px=1, dilate_px=2,
                            debug=False):
    """
    Build 1024×1024 INSIDE mask with perimeter-based polarity, using the
    SAME resize_and_pad as the image to guarantee geometric alignment.
    """
    if mask_aligned_np is None:
        return None

    # 1) pad the mask to 1024 with the SAME routine as the image
    mL = Image.fromarray(mask_aligned_np, mode="L")
    m1024L = resize_and_pad(mL, target_size=1024).convert("L")
    m1024 = (np.array(m1024L) > 0)

    # 2) Perimeter-majority: which value dominates the border?
    h, w = m1024.shape
    b = max(1, int(border_sample_px))
    perim = np.concatenate([m1024[0:b,:].ravel(), m1024[h-b:h,:].ravel(),
                            m1024[:,0:b].ravel(), m1024[:,w-b:w].ravel()])
    ones = int(perim.sum()); zeros = int(perim.size - perim.sum())
    background_is_true = (ones >= zeros)   # majority on border = background
    inside = (~m1024) if background_is_true else m1024

    # 3) Moprhology for robust gating
    if erode_px > 0:
        k_e = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px*2+1, erode_px*2+1))
        inside = cv2.erode(inside.astype(np.uint8), k_e, 1).astype(bool)
    if dilate_px > 0:
        k_d = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_px*2+1, dilate_px*2+1))
        inside = cv2.dilate(inside.astype(np.uint8), k_d, 1).astype(bool)

    if debug:
        bg_txt = "white/True" if background_is_true else "black/False"
        cov = float(inside.mean())
        print(f"[mask1024] perimeter True={ones} False={zeros} → background={bg_txt}; inside_cov={cov:.3f}")

    return inside


def detect_detail_topk7(image_pil: Image.Image,
                        detail_type: str,
                        *,
                        source_prior: dict | None,
                        restrict_mask,                 # EITHER aligned np.ndarray OR (mask_full_np, "FULL")
                        crop_box_on_full=None,         # required if restrict_mask is ("FULL")
                        threshold: float = 0.05,
                        TOP_K: int = 7,
                        mirror_ok: bool = True,
                        # light mask gates
                        min_inside_frac: float = 0.30,
                        center_must_be_inside: bool = True,
                        erode_px: int = 1,
                        dilate_px: int = 2,
                        border_sample_px: int = 2,
                        # scoring weights
                        w_spatial: float = 0.65,
                        w_score: float = 0.35,
                        rank_weights: list[float] = None,
                        debug: bool = False,
                        viz: bool = False,
                        viz_overlay_mask: bool = True):
    """
    Robust top-7 re-ranking using SAM3 detections and optional mask gates.
    Returns (xyxy_on_image, raw_score, mask_on_image or None)
    """
    if rank_weights is None:
        rank_weights = [0.92, 1.00, 0.98, 0.97, 0.96, 0.955, 0.95]

    W0, H0 = image_pil.size
    prompt = (detail_type or "").strip() + "."

    mask_aligned = _ensure_mask_for_image(restrict_mask, image_pil, crop_box_on_full=crop_box_on_full)
    mask_bool = (mask_aligned > 0) if mask_aligned is not None else None

    preds = _sam3_predict_text(image_pil, prompt, max_dets=max(TOP_K * 3, 12), score_threshold=threshold)
    if not preds:
        return (None, None, None)

    picked = []
    for rank, p in enumerate(preds):
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], W0, H0)
        x1, y1, x2, y2 = box
        if x2 <= x1 or y2 <= y1:
            continue

        inside_frac = 1.0
        center_ok = True
        if mask_bool is not None:
            crop = mask_bool[y1:y2, x1:x2]
            area = max(1, (x2 - x1) * (y2 - y1))
            inside_frac = float(crop.sum()) / float(area)
            cxp, cyp = (x1 + x2) // 2, (y1 + y2) // 2
            center_ok = (0 <= cxp < W0 and 0 <= cyp < H0 and bool(mask_bool[cyp, cxp]))
            if inside_frac < min_inside_frac or (center_must_be_inside and not center_ok):
                continue

        picked.append({
            "box": box,
            "score": float(p["score"]),
            "rank": rank,
            "mask": p["mask"],
            "inside_frac": inside_frac,
        })
        if len(picked) >= TOP_K:
            break

    if not picked:
        base = preds[0]
        picked = [{"box": _clip_box_to_image(base["box"], W0, H0),
                   "score": float(base["score"]),
                   "rank": 0,
                   "mask": base["mask"],
                   "inside_frac": 0.0}]

    s = np.array([p["score"] for p in picked], dtype=np.float32)
    s_min, s_max = float(s.min()), float(s.max())
    s_norm = np.ones_like(s) * 0.5 if s_max == s_min else (s - s_min) / (s_max - s_min)

    best = None
    for j, p in enumerate(picked):
        rw = rank_weights[p["rank"]] if p["rank"] < len(rank_weights) else rank_weights[-1]
        score_normed = float(s_norm[j] * rw)
        x1, y1, x2, y2 = p["box"]
        area = max(1, (x2 - x1) * (y2 - y1))
        cx_n = ((x1 + x2) / 2.0) / max(1.0, W0)
        cy_n = ((y1 + y2) / 2.0) / max(1.0, H0)
        area_n = area / float(max(1, W0 * H0))
        spatial = _spatial_affinity(cx_n, cy_n, area_n, source_prior, mirror_ok=mirror_ok) if source_prior else 0.0
        combo = w_spatial * spatial + w_score * score_normed
        if best is None or combo > best["combo"]:
            best = {**p, "combo": float(combo), "spatial": float(spatial), "score_norm": score_normed}

    return best["box"], best["score"], best["mask"]


def detect_detail(image_pil: Image.Image,
                  detail_type: str,
                  threshold: float = 0.05,
                  used_boxes=None,
                  keep_best: bool = False,
                  iou_thr: float = 0.35,
                  restrict_mask: np.ndarray | None = None,
                  min_inside_frac: float = 0.40,
                  max_outside_frac: float = 0.70,
                  center_must_be_inside: bool = True,
                  erode_px: int = 1,
                  dilate_px: int = 2,
                  border_sample_px: int = 2,
                  debug: bool = False,
                  debug_topk: int = 5,
                  crop_box_on_full=None):
    """
    Simplified detail locator using SAM3 text grounding.
    Returns: (xyxy_on_image, score, mask_on_image)
    """
    used_boxes = used_boxes or []
    prompt = (detail_type or "").strip() + "."
    W, H = image_pil.size

    mask_aligned = _ensure_mask_for_image(restrict_mask, image_pil, crop_box_on_full=crop_box_on_full)
    mask_bool = (mask_aligned > 0) if mask_aligned is not None else None

    preds = _sam3_predict_text(image_pil, prompt, max_dets=10, score_threshold=threshold)
    if not preds:
        return (None, None, None)

    def _iou(a, b):
        ax1, ay1, ax2, ay2 = a; bx1, by1, bx2, by2 = b
        xi1, yi1 = max(ax1, bx1), max(ay1, by1)
        xi2, yi2 = min(ax2, bx2), min(ay2, by2)
        iw, ih = max(0, xi2 - xi1), max(0, yi2 - yi1)
        inter = iw * ih
        if inter == 0:
            return 0.0
        area_a = max(1, (ax2 - ax1) * (ay2 - ay1))
        area_b = max(1, (bx2 - bx1) * (by2 - by1))
        union = area_a + area_b - inter
        return inter / union

    best = None
    debug_rows = []
    for p in preds:
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], W, H)
        x1, y1, x2, y2 = box
        if x2 <= x1 or y2 <= y1:
            continue

        if any(_iou(box, ub) > iou_thr for ub in used_boxes):
            continue

        inside_frac = 1.0
        outside_frac = 0.0
        center_ok = True
        if mask_bool is not None:
            crop = mask_bool[y1:y2, x1:x2]
            area = max(1, (x2 - x1) * (y2 - y1))
            inside_frac = float(crop.sum()) / float(area)
            outside_frac = 1.0 - inside_frac
            cx_i, cy_i = (x1 + x2) // 2, (y1 + y2) // 2
            center_ok = (0 <= cx_i < W and 0 <= cy_i < H and bool(mask_bool[cy_i, cx_i])) if center_must_be_inside else True
            if not center_ok or inside_frac < min_inside_frac or outside_frac > max_outside_frac:
                continue

        if best is None or p["score"] > best["score"]:
            best = {"box": box, "score": float(p["score"]), "mask": p["mask"], "inside_frac": inside_frac}

        if debug and len(debug_rows) < debug_topk:
            debug_rows.append({
                "score": float(p["score"]),
                "box": box,
                "inside": inside_frac,
                "outside": outside_frac,
                "center": center_ok,
            })

    if best is None:
        if keep_best:
            base = preds[0]
            best = {"box": _clip_box_to_image(base["box"], W, H), "score": float(base["score"]), "mask": base["mask"], "inside_frac": 0.0}
        else:
            if debug:
                print("[detect_detail] No candidate satisfied mask gates.")
            return (None, None, None)

    if debug and debug_rows:
        print("[detect_detail/debug] top candidates (after score>thr):")
        for row in sorted(debug_rows, key=lambda r: r["score"], reverse=True):
            print(f"  score={row['score']:.3f} inside={row['inside']:.2f} outside={row['outside']:.2f} center={row['center']} box={row['box']}")

    return best["box"], best["score"], best["mask"]


def detect_garment_box(img: Image.Image, garment_tag: str, threshold=0.25, restrict_mask: np.ndarray | None = None):
    O_W, O_H = img.size
    if restrict_mask is not None:
        m1024 = resize_and_pad(Image.fromarray(restrict_mask, 'L'), 1024).convert('L')
        mask_1024_np = (np.array(m1024) > 127)
        ys, xs = np.where(mask_1024_np > 0)
        if xs.size == 0 or ys.size == 0:
            return None
        x1, y1, x2, y2 = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
        return box_1024_to_original([x1, y1, x2, y2], O_W, O_H)

    preds = _sam3_predict_text(img, f"{garment_tag.strip()} .", max_dets=6, score_threshold=threshold)
    if not preds:
        return None

    mask_bool = (restrict_mask > 0) if restrict_mask is not None else None
    best = None
    for p in preds:
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], O_W, O_H)
        if box[2] <= box[0] or box[3] <= box[1]:
            continue
        if mask_bool is not None:
            crop = mask_bool[box[1]:box[3], box[0]:box[2]]
            if crop.size == 0 or float(crop.mean()) < 0.05:
                continue
        if best is None or p["score"] > best["score"]:
            best = {"box": box, "score": float(p["score"])}

    return best["box"] if best else None


def bbox_to_mask(bb, img_size, pad_px=10):
    W, H = img_size
    x1, y1, x2, y2 = bb
    x1 = max(0, x1 - pad_px); y1 = max(0, y1 - pad_px)
    x2 = min(W - 1, x2 + pad_px); y2 = min(H - 1, y2 + pad_px)
    m = np.zeros((H, W), np.uint8)
    m[y1:y2, x1:x2] = 255
    return m


def crop_detail(image_pil, mask_np, bb_xyxy, out_size=1024, pad_px=20):
    W, H = image_pil.size
    x1, y1, x2, y2 = bb_xyxy
    x1 = max(0, x1 - pad_px); y1 = max(0, y1 - pad_px)
    x2 = min(W, x2 + pad_px); y2 = min(H, y2 + pad_px)
    side = max(x2 - x1, y2 - y1)
    cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
    lx = max(0, cx - side // 2); rx = lx + side
    ty = max(0, cy - side // 2); by = ty + side
    if rx > W:
        lx -= (rx - W); rx = W
    if by > H:
        ty -= (by - H); by = H
    crop_box = (lx, ty, rx, by)
    img_c = image_pil.crop(crop_box).resize((out_size, out_size), Image.Resampling.LANCZOS)
    m_c = mask_np[ty:by, lx:rx]
    m_c = cv2.resize(m_c, (out_size, out_size), interpolation=cv2.INTER_NEAREST)
    return img_c, m_c, crop_box


def adaptive_brightness(img, strength_dark=0.15, strength_light=0.03, clip=(0, 245)):
    a = np.asarray(img).astype(np.float32)
    lum = 0.2126 * a[..., 0] + 0.7152 * a[..., 1] + 0.0722 * a[..., 2]
    mean_lum = float(lum.mean() / 255.0)
    if mean_lum < 0.5:
        factor = 1 + (-strength_dark) * (0.5 - mean_lum) * 2
    else:
        factor = 1 + (strength_light) * (mean_lum - 0.5) * 2
    out = np.clip(a * factor, *clip).astype(np.uint8)
    return Image.fromarray(out)


def paste_crop_back(full_img: Image.Image, edited_crop: Image.Image, crop_box, crop_mask: np.ndarray,
                    expand_px=20, feather_px=10) -> Image.Image:
    edited_crop = adaptive_brightness(edited_crop, strength_dark=0.15, strength_light=0.03)
    x1, y1, x2, y2 = crop_box
    tgt_w, tgt_h = x2 - x1, y2 - y1
    edit_rs = edited_crop.resize((tgt_w, tgt_h), Image.Resampling.LANCZOS)
    mask_np = cv2.resize(crop_mask, (tgt_w, tgt_h), interpolation=cv2.INTER_NEAREST)
    bin_mask = (mask_np > 0).astype(np.uint8)
    if expand_px > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_px * 2 + 1, expand_px * 2 + 1))
        bin_mask = cv2.dilate(bin_mask, k, iterations=1)
    alpha = cv2.GaussianBlur(bin_mask.astype(np.float32) * 255, (0, 0), sigmaX=feather_px, sigmaY=feather_px)
    alpha[mask_np > 0] = 255
    alpha = alpha.clip(0, 255).astype(np.uint8)
    mask_img = Image.fromarray(alpha)

    region = full_img.crop((x1, y1, x2, y2))
    comp = Image.composite(edit_rs, region, mask_img)
    full_img.paste(comp, (x1, y1))
    return full_img


In [None]:
# Normalization for detail types from legacy names
def _normalize_detail_type(t: str) -> str:
    t = (t or "").strip().lower()
    mapping = {
        "waistband lettering": "waist text",
        "sleeve lettering": "sleeve text",
        "sleeve_text": "sleeve text",
        "waist_text": "waist text",
    }
    return mapping.get(t, t)

def _postprocess_details(payload: dict) -> dict:
    details = payload.get("details", [])
    fixed = []
    for d in details:
        typ = _normalize_detail_type(d.get("type"))
        col = d.get("color")
        if typ in ALLOWED_DETAIL_TYPES:
            ent = {"type": typ}
            if typ != "sleeve text" and isinstance(col, str) and col.strip():
                ent["color"] = col.strip()
            fixed.append(ent)
    return {"details": fixed}

def _try_parse_json(s: str) -> dict | None:
    try:
        obj = json.loads(s)
        if isinstance(obj, dict) and "details" in obj:
            return obj
    except Exception:
        pass
    # try to extract first {...}
    m = re.search(r"\{[\s\S]*\}", s)
    if m:
        try:
            obj = json.loads(m.group(0))
            if isinstance(obj, dict) and "details" in obj:
                return obj
        except Exception:
            pass
    return None

def read_details_from_metadata(img_path: str) -> dict:
    """Return {'details':[...]} or {'details':[{'type':'logo'}]} if metadata not found."""
    try:
        im = Image.open(img_path)
        # 1) PNG/JPEG info dict
        for k, v in (im.info or {}).items():
            if isinstance(v, str):
                obj = _try_parse_json(v)
                if obj:
                    return _postprocess_details(obj)
        # 2) EXIF: UserComment / XPComment
        try:
            exif_dict = piexif.load(im.info.get("exif", b"") or im.tobytes())
        except Exception:
            exif_dict = None

        def _decode_uc(x):
            if isinstance(x, bytes):
                for head in [b"ASCII\0\0\0", b"UNICODE\0", b"JIS\0\0\0"]:
                    if x.startswith(head):
                        x = x[len(head):]
                try:
                    return x.decode("utf-8", "ignore")
                except Exception:
                    return x.decode("latin-1", "ignore")
            if isinstance(x, str):
                return x
            return None

        if exif_dict:
            uc = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment, None)
            s = _decode_uc(uc)
            if s:
                obj = _try_parse_json(s)
                if obj:
                    return _postprocess_details(obj)
            xp = exif_dict.get("0th", {}).get(0x9C9C, None)
            if xp:
                try:
                    s = bytes(xp).decode("utf-16le", "ignore").rstrip("\x00")
                    obj = _try_parse_json(s)
                    if obj:
                        return _postprocess_details(obj)
                except Exception:
                    pass
        # 3) XMP sidecar embedded?
        if "XML:com.adobe.xmp" in (im.info or {}):
            obj = _try_parse_json(im.info["XML:com.adobe.xmp"])
            if obj:
                return _postprocess_details(obj)
        # 4) Optional sidecar .json next to image
        side = Path(img_path).with_suffix(".json")
        if side.exists():
            try:
                obj = json.loads(side.read_text())
                if "details" in obj:
                    return _postprocess_details(obj)
            except Exception:
                pass
    except Exception as e:
        print(f"⚠️ metadata read failed for {img_path}: {e}")

    # 👇 Fallback when nothing found
    return {"details": [{"type": "logo"}]}


# Garment type inference (from path)
def extract_garment_type_from_path(image_path: str, allowed_types=ALLOWED_GARMENT_TYPES) -> str:
    from pathlib import Path as _P
    import re
    def singularize(s):
        if len(s)>4:
            if s.endswith("es"): return s[:-2]
            if s.endswith("s"):  return s[:-1]
        return s
    def normalize_key(s): return singularize(s.replace("-","").replace("_","").lower().strip())
    norm_map = {}
    for t in allowed_types:
        base = normalize_key(t)
        norm_map[base]=t
        if not base.endswith("s"): norm_map[base+"s"]=t
        else:
            if base.endswith("es"): norm_map[base[:-2]]=t
            else: norm_map[base[:-1]]=t
    p = _P(image_path)
    file_compact = re.sub(r"[^a-z]+","", p.stem.lower())
    for k,v in norm_map.items():
        if k and k in file_compact: return v
    # parent folders
    for part in reversed(p.parts[:-1]):
        if part.startswith("."): continue
        toks = [singularize(t) for t in re.split(r"[^a-z]+", part.lower()) if t]
        for tok in toks:
            if tok in norm_map: return norm_map[tok]
    return ""


In [None]:
# === Robust SKU+angle parsing (handles: "SS-28623_fr (1).png", "Copy of SS-12345_bc_lft_v2.png", "SS-55555_fr_cl.png") ===
import os, re
from pathlib import Path
from functools import lru_cache

# Reuse your global config: BASE_NAMES, ACCEPTABLE_SUFFIXES, VALID_EXTS, WORKING_DIR

_SKU_RE = re.compile(r"(SS-\d{3,7})", re.IGNORECASE)
_COPY_RE = re.compile(r"^(?:copy of\s+)+", re.IGNORECASE)

def _strip_copy_prefix(s: str) -> str:
    return _COPY_RE.sub("", s).strip()

def _angle_tokens_desc() -> list[str]:
    # Longest-first to prefer 'fr_rght' over 'fr'
    return sorted(list(set(BASE_NAMES)), key=len, reverse=True)

def _token_delim_search(token: str, text: str) -> re.Match | None:
    """
    Find token delimited by non-alphanumerics (underscore is allowed as a delimiter).
    We treat [A-Za-z0-9] as 'wordy'; underscores/spaces/()/- etc. are delimiters.
    """
    # Escape underscores in token for regex
    tok = re.escape(token)
    pattern = rf"(?<![A-Za-z0-9]){tok}(?![A-Za-z0-9])"
    return re.search(pattern, text, flags=re.IGNORECASE)

def extract_sku_and_angle_from_path(path_like: str) -> tuple[str | None, str | None]:
    """
    Returns (SKU like 'SS-12345', angle_base like 'fr_lft'/'fr').
    Strategy:
      1) Extract SKU from filename; if not found, try parent dirs.
      2) After SKU in the filename, scan the suffix for the LONGEST valid angle token.
      3) Fallback to whole filename scan, then parent dirs.
    """
    p = Path(path_like)
    name = _strip_copy_prefix(p.name)

    # --- 1) SKU from filename, else parents
    m = _SKU_RE.search(name)
    sku = m.group(1).upper() if m else None
    if sku is None:
        for part in reversed(p.parts):
            mm = _SKU_RE.search(part)
            if mm:
                sku = mm.group(1).upper()
                break

    # --- 2) Angle after SKU region
    angle = None
    tokens = _angle_tokens_desc()
    if sku:
        mname = _SKU_RE.search(name)
        if mname:
            suffix = name[mname.end():]  # everything after the SKU
            for tok in tokens:
                if _token_delim_search(tok, suffix):
                    angle = tok
                    break

    # --- 3) Fallback: whole filename, then parents
    if angle is None:
        for tok in tokens:
            if _token_delim_search(tok, name):
                angle = tok
                break
    if angle is None:
        # Look in parent folders
        for part in reversed(p.parts[:-1]):
            part_clean = _strip_copy_prefix(part)
            for tok in tokens:
                if _token_delim_search(tok, part_clean):
                    angle = tok
                    break
            if angle:
                break

    return sku, angle

# ===================== Source finding via SKU folder anywhere =====================
@lru_cache(maxsize=1024)
def _find_sku_folder_anywhere(working_root: str, sku_name: str) -> Path | None:
    wr = Path(working_root)
    if not wr.exists():
        return None
    sku_low = sku_name.lower()
    best: tuple[int, Path] | None = None
    for dirpath, dirnames, _ in os.walk(wr):
        leaf = os.path.basename(dirpath)
        if leaf.lower() == sku_low:
            depth = len(Path(dirpath).parts)
            cand = Path(dirpath)
            if best is None or depth < best[0]:
                best = (depth, cand)
    return best[1] if best else None


def _list_valid_images(folder: Path) -> list[Path]:
    """
    Return candidate source images in `folder`, excluding:
      - any with 'generated', 'inpainted', '_nd', '_no_details', '_processed_by_detailer_'
      - any with '_sec' anywhere in the filename (case-insensitive)
    """
    deny_substrings = (
        "generated",
        "inpainted",
        "_nd",
        "_no_details",
        "_processed_by_detailer_",
        "_sec",   # ← NEW: ignore secondary variants
    )
    out = []
    for p in folder.iterdir():
        if not (p.is_file() and p.suffix in VALID_EXTS):
            continue
        name_low = p.name.lower()
        if any(s in name_low for s in deny_substrings):
            continue
        out.append(p)
    return out


def _rank_exact_angle(norm_stem: str, base: str, acceptable_suffixes: set[str]) -> int | None:
    if norm_stem == f"{base}_cut": return 1
    if norm_stem.startswith(base + "_") and norm_stem.endswith("_cut"): return 2
    if norm_stem == base: return 3
    if norm_stem.startswith(base + "_"):
        suf = norm_stem[len(base)+1:]
        if suf in acceptable_suffixes: return 4
    return None

def _is_fr_family(base: str | None) -> bool:
    if not base: return False
    return base in ("fr","fr_cl","fr_lft","fr_rght") or base.startswith("fr")

def _pick_source_in_dir(angle_base: str, directory: Path) -> Path | None:
    entries = _list_valid_images(directory)
    if not entries: return None
    acceptable = set(ACCEPTABLE_SUFFIXES)

    def _norm(p: Path) -> str:
        return _strip_copy_prefix(p.stem).lower()

    ranked: list[tuple[int,int,Path]] = []
    if _is_fr_family(angle_base):
        for p in entries:
            n = _norm(p)
            if n == "fr_cut": ranked.append((1,len(p.name),p)); continue
            if n.startswith("fr_") and n.endswith("_cut"): ranked.append((2,len(p.name),p)); continue
            if n == "fr": ranked.append((3,len(p.name),p)); continue
            if n.startswith("fr_"):
                suf = n[len("fr_"):]
                if suf in acceptable: ranked.append((4,len(p.name),p)); continue
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        ranked=[]
        for p in entries:
            n=_norm(p)
            r=_rank_exact_angle(n, angle_base, acceptable)
            if r is not None: ranked.append((r,len(p.name),p))
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        return None
    else:
        for p in entries:
            n=_norm(p)
            r=_rank_exact_angle(n, angle_base, acceptable)
            if r is not None: ranked.append((r,len(p.name),p))
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        return None

def find_source_via_sku(gen_path: Path | str, working_root: Path | str) -> Path | None:
    gen_path = Path(gen_path)
    sku, angle_base = extract_sku_and_angle_from_path(str(gen_path))

    if not sku:
        print(f"❌ Could not extract SKU from: {gen_path.name}")
        return None

    if not angle_base:
        # No noisy warning anymore; we’ll gracefully default.
        angle_base = "fr"

    sku_dir = _find_sku_folder_anywhere(str(working_root), sku)
    if not sku_dir:
        print(f"❌ SKU folder '{sku}' not found anywhere under {working_root}")
        return None

    ricardo = sku_dir / "Ricardo"
    for d in (ricardo, sku_dir):
        if d.exists() and d.is_dir():
            hit = _pick_source_in_dir(angle_base, d)
            if hit: return hit

    print(f"⚠️ No suitable source found in '{sku_dir}' (Ricardo or root) for angle '{angle_base}'")
    return None

def build_inpaint_suffix(details: list[dict]) -> str:
    def slug(s: str) -> str:
        s = s.lower().replace(" ", "-")
        return re.sub(r"[^a-z0-9\-]+", "", s).strip("-")
    parts=[]
    for d in details:
        t = d["type"]
        c = d.get("color","")
        if t != "sleeve text" and c:
            parts.append(slug(f"{c}-{t}"))
        else:
            parts.append(slug(t))
    return "_".join(parts) if parts else "none"

# --- Build the required base "SS-12345-bc_lft" from the queued filename ---
def build_out_base_from_gen(gen_path: str) -> tuple[str, str, str]:
    """
    Returns (sku_upper, angle_lower, out_base).
    out_base is 'SS-12345-bc_lft' (SKU + '-' + angle).
    """
    sku, angle = extract_sku_and_angle_from_path(gen_path)
    if not sku:
        raise ValueError(f"Cannot derive SKU from: {gen_path}")
    if not angle:
        angle = "fr"
    sku_up = sku.upper()
    angle_lo = angle.lower()
    return sku_up, angle_lo, f"{sku_up}-{angle_lo}"

def target_already_has_inpainted(target_dir: str, sku: str, angle: str) -> bool:
    """
    Check TARGET_DIR for any file starting with 'SS-12345-bc_lft_inpainted'.
    Case-insensitive; extension-agnostic.
    """
    td = Path(target_dir)
    if not td.exists():
        return False
    prefix = f"{sku.upper()}-{angle.lower()}_inpainted"
    prefix_low = prefix.lower()
    for p in td.iterdir():
        if p.is_file() and p.suffix in VALID_EXTS:
            if p.stem.lower().startswith(prefix_low):
                return True
    return False

In [None]:
def _inpaint_one_detail(gen_full: Image.Image,
                        src_full: Image.Image,
                        detail_prompt: str,
                        *,
                        garment_tag: str,
                        restrict_mask_full: np.ndarray | None,
                        generous_pad_px: int,
                        tiny_pad_px: int,
                        seed: int,
                        visualize: bool) -> Image.Image:

    gen_view_for_sam3 = apply_binary_mask(gen_full, restrict_mask_full) if restrict_mask_full is not None else gen_full

    gar_src_bb = detect_garment_box(src_full, garment_tag)
    gar_gen_bb = detect_garment_box(gen_view_for_sam3, garment_tag, restrict_mask=restrict_mask_full)
    if gar_src_bb is None or gar_gen_bb is None:
        print("❌ garment detection failed"); return gen_full

    # square garment crops
    def crop_to_square(image: Image.Image, bbox, pad_px=0):
        x1,y1,x2,y2 = bbox
        w,h = x2-x1, y2-y1
        side = max(w,h) + 2*pad_px
        cx,cy = (x1+x2)//2, (y1+y2)//2
        lx=max(0,cx-side//2); ty=max(0,cy-side//2)
        rx=lx+side; by=ty+side
        W,H=image.size
        if rx>W: lx -= (rx-W); rx=W
        if by>H: ty -= (by-H); by=H
        crop = image.crop((max(lx,0),max(ty,0),min(rx,W),min(by,H)))
        out  = Image.new("RGB",(side,side),(255,255,255))
        dx=max(0,-lx); dy=max(0,-ty)
        out.paste(crop,(dx,dy))
        return out, (lx,ty,rx,by)

    src_sq, sq_src = crop_to_square(src_full, gar_src_bb)
    gen_sq, sq_gen = crop_to_square(gen_view_for_sam3, gar_gen_bb)

    src_garm_sq, sq_coords_src = crop_to_square(src_full, gar_src_bb, pad_px=0)
    gen_garm_sq, sq_coords_gen = crop_to_square(gen_view_for_sam3, gar_gen_bb, pad_px=0)

    det_src_bb, _, det_src_mask_crop = detect_detail(src_sq, detail_prompt, crop_box_on_full=sq_src)
    prior = make_spatial_prior_from_box(det_src_bb, src_garm_sq.size)

    det_gen_bb, _, det_gen_mask_crop = detect_detail_topk7(
        gen_garm_sq,
        detail_prompt,
        source_prior=prior,
        restrict_mask=(restrict_mask_full, "FULL"),  # pass FULL mask
        crop_box_on_full=sq_coords_gen,              # the (x1,y1,x2,y2) used to make gen_garm_sq
        viz=False, debug=False
    )
    if det_src_bb is None or det_gen_bb is None:
        print(f"❌ detail not found: {detail_prompt}"); return gen_full

    # back to full coords
    lx_s, ty_s, _, _ = sq_src
    lx_g, ty_g, _, _ = sq_gen
    src_det_bb = [det_src_bb[0]+lx_s, det_src_bb[1]+ty_s, det_src_bb[2]+lx_s, det_src_bb[3]+ty_s]
    gen_det_bb = [det_gen_bb[0]+lx_g, det_gen_bb[1]+ty_g, det_gen_bb[2]+lx_g, det_gen_bb[3]+ty_g]

    src_mask_full = _mask_crop_to_full(det_src_mask_crop, sq_src, src_full.size) if det_src_mask_crop is not None else None
    gen_mask_full = _mask_crop_to_full(det_gen_mask_crop, sq_coords_gen, gen_full.size) if det_gen_mask_crop is not None else None
    if src_mask_full is None:
        src_mask_full = bbox_to_mask(src_det_bb, src_full.size, INPAINT_TINY_PAD)
    if gen_mask_full is None:
        gen_mask_full = bbox_to_mask(gen_det_bb, gen_full.size, INPAINT_TINY_PAD)
    else:
        gen_mask_full = enlarge_mask(gen_mask_full, scale=1.05)

    if visualize:
        _show_images([
            ("detail on source", _draw_bbox(src_full, src_det_bb)),
            ("detail on generated (masked)", _draw_bbox(gen_view_for_sam3, gen_det_bb))
        ], cols=2, figsize=(12,8))

    # crops for IA
    src_crop, src_mask, _   = crop_detail(src_full, src_mask_full, src_det_bb, 1024, 20)
    gen_crop, gen_mask, box = crop_detail(gen_full, gen_mask_full, gen_det_bb, 1024, INPAINT_GENEROUS_PAD)

    # diptych
    src_arr = np.array(src_crop)
    masked_src = src_arr  # keep source unmasked for IA

    gen_arr = np.array(gen_crop)
    gen_msk3 = np.stack([gen_mask]*3, -1)
    zeros = np.zeros_like(masked_src)

    diptych = np.concatenate([masked_src, gen_arr], axis=1).astype(np.uint8)
    dip_mask = np.concatenate([zeros, gen_msk3], axis=1).astype(np.uint8)
    dip_mask[dip_mask>0]=255

    if visualize:
        _show_images([
            ("diptych", Image.fromarray(diptych)),
            ("diptych mask", Image.fromarray(dip_mask).convert("RGB"))
        ], cols=2, figsize=(12,8))

    prior = redux(Image.fromarray(masked_src))
    gen_obj = torch.Generator(device).manual_seed(seed)
    ia_out = pipe(
        image=Image.fromarray(diptych),
        mask_image=Image.fromarray(dip_mask),
        height=1024,
        width=2048,
        max_sequence_length=512,
        num_inference_steps=60,
        guidance_scale=30,
        generator=gen_obj,
        **prior
    ).images[0]

    right_crop = ia_out.crop((1024,0,2048,1024))
    gen_full = paste_crop_back(gen_full, right_crop, box, gen_mask)

    if visualize:
        _show_images([
            ("IA result (2048×1024)", ia_out),
            ("after this detail", gen_full)
        ], cols=2, figsize=(14,8))

    return gen_full

def inpaint_with_details_list(generated_path: str,
                              source_path: str,
                              details: list[dict],
                              garment_type: str | None,
                              visualize: bool = True) -> Image.Image:

    gen_full = open_upright(generated_path)
    src_full = open_source_with_black_bg(source_path)

    restrict_mask_full = load_binary_mask_for_generated(generated_path, source_path, gen_full)

    if restrict_mask_full is None:
        print("⚠️  No garment mask found — proceeding without restriction.")
    else:
        print("✅ Garment mask loaded & aligned for", os.path.basename(generated_path))

    if garment_type is None or not garment_type.strip():
        garment_type = extract_garment_type_from_path(source_path)
    if not garment_type:
        garment_type = "t-shirt"  # conservative default prompt

    # twinset (optional, keep simple)
    garment_tags = [garment_type.lower()]
    if garment_type.lower() in TWINSET_TYPES:
        garment_tags = [TOP_GARMENTS[0], BOTTOM_GARMENTS[0]]

    out_img = gen_full.copy()
    for gtag in garment_tags:
        for d in details:
            d_type = d["type"]
            prompt_str = f"{d_type}".strip()
            print(f"🔄 Inpainting detail: {prompt_str}  (garment={gtag})")
            out_img = _inpaint_one_detail(
                out_img, src_full, prompt_str,
                garment_tag=gtag,
                restrict_mask_full=restrict_mask_full,
                generous_pad_px=INPAINT_GENEROUS_PAD,
                tiny_pad_px=INPAINT_TINY_PAD,
                seed=INPAINT_SEED,
                visualize=visualize
            )

    torch.cuda.empty_cache(); gc.collect()
    return out_img


In [None]:
import shutil
def process_detailer_queue():
    queue_root = Path(DETAILER_QUEUE_FOLDER)
    if not queue_root.exists():
        print(f"❌ Queue folder does not exist: {queue_root}")
        return

    gen_files = [p for p in queue_root.rglob("*") if p.is_file() and p.suffix in VALID_EXTS]
    if not gen_files:
        print(f"ℹ️ No images found in {queue_root}")
        return

    processed = skipped = failed = 0

    for gen_path in sorted(gen_files, key=lambda p: (str(p.parent), p.name)):
        try:
            print("\n" + "_"*80)
            print(f"🎯 Queue item: {gen_path}")

            # 1) Read details from metadata
            meta = read_details_from_metadata(str(gen_path))
            if not meta or not meta.get("details"):
                print("⏭️  No details found in metadata — skipping")
                skipped += 1
                continue

            details = [d for d in meta["details"] if d["type"] in ALLOWED_DETAIL_TYPES]
            if not details:
                print("⏭️  Details list empty after normalization — skipping")
                skipped += 1
                continue

            # 2) Find source garment near this item
            src_p = find_source_via_sku(gen_path, Path(WORKING_DIR))
            if not src_p:
                print("⏭️  Source garment not found — skipping")
                skipped += 1
                continue

            source_base = Path(src_p).stem  # e.g. SS-12345_fr
            sku_up, angle_lo, out_base = build_out_base_from_gen(str(gen_path))
            out_ext = ".png"

            # 3) Skip guard: any prior inpainted for this SKU+angle?
            if SKIP_IF_ALREADY_INPAINTED and target_already_has_inpainted(TARGET_DIR, sku_up, angle_lo):
              print(f"⏭️  Already have inpainted for {out_base} in TARGET_DIR — skipping")
              skipped += 1
              continue

            # 4) Inpaint
            garment_type = extract_garment_type_from_path(str(src_p))
            out_img = inpaint_with_details_list(
                str(gen_path),
                str(src_p),
                details=details,
                garment_type=garment_type,
                visualize=VISUALIZE
            )

            # 5) Build output names (keep source naming; append detail suffixes)
            suffix   = build_inpaint_suffix(details)   # unchanged
            dst_src  = Path(TARGET_DIR) / f"{out_base}{out_ext}"                         # e.g., SS-12345-bc_lft.jpg
            dst_ia   = Path(TARGET_DIR) / f"{out_base}_inpainted_{suffix}{out_ext}"      # e.g., SS-12345-bc_lft_inpainted_red_logo.jpg

            # 6) Save outputs: copy source + save inpainted
            #if not dst_src.exists():
            #    shutil.copy2(str(src_p), str(dst_src))
            #    print(f"📎 Saved source → {dst_src.name}")
            #else:
            #    print(f"📎 Source already present in TARGET_DIR → {dst_src.name}")

            # --- Save inpainted result ---
            out_img.save(str(dst_ia))
            print(f"✅ Saved inpainted → {dst_ia.name}")
            processed += 1

        except Exception as e:
            print(f"❌ Failed on {gen_path.name}: {e}")
            failed += 1

    print("\n==== SUMMARY ====")
    print(f"Processed: {processed}  |  Skipped: {skipped}  |  Failed: {failed}")




#RUN

In [None]:
# Run
process_detailer_queue()

#UNASSIGN

In [None]:
from google.colab import runtime
runtime.unassign()