# INSTALL

In [None]:
# -*- coding: utf-8 -*-
# --- Setup: Google auth + Drive + NanoBanana Pro deps ---

import sys, os, subprocess, textwrap, importlib

from google.colab import auth, drive
auth.authenticate_user()
drive.mount('/content/drive')

!pip install fal-client requests
!pip -q install --upgrade pip
!pip install "google-genai>=1.40.0" pillow numpy opencv-python-headless matplotlib gspread google-auth google-auth-oauthlib google-api-python-client piexif tqdm

print("✅ Setup done for NanoBanana Pro.")


# DETAILER INSTALLS

#INSTALLS (restart & reinstall again after this)

In [None]:
# SAM3 via Hugging Face transformers
!pip install -q "git+https://github.com/huggingface/transformers.git"


In [None]:
# GroundingDINO setup removed — SAM3 now handles detection + segmentation in one model.


In [None]:
%pip -q install open_clip_torch ninja wheel transformers accelerate \
                 sentencepiece protobuf huggingface_hub opencv-python
!pip install -U --no-deps --force-reinstall "git+https://github.com/huggingface/diffusers.git@main"
#%pip -q install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install --upgrade open_clip_torch

In [None]:
%cd /content/
!git clone --depth 1 https://github.com/song-wensong/insert-anything.git

In [None]:
!pip install https://huggingface.co/mit-han-lab/nunchaku/resolve/main/nunchaku-0.2.0+torch2.6-cp312-cp312-linux_x86_64.whl
!pip install torch==2.6 torchvision==0.21 torchaudio==2.6
!pip install ninja wheel diffusers transformers accelerate sentencepiece protobuf huggingface_hub
!git clone https://huggingface.co/aha2023/insert-anything-lora-for-nunchaku

In [None]:

import sys, subprocess, textwrap, importlib
import os, re, fnmatch, math, uuid, pytz, random, gc, tempfile, traceback
from datetime import datetime
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter, ImageDraw, ImageFont


# Select angles

In [None]:
fr_lft = True #@param {type:"boolean"}
fr_rght = True #@param {type:"boolean"}
fr_cl = True #@param {type:"boolean"}
bc_lft = True #@param {type:"boolean"}
bc_rght = True #@param {type:"boolean"}
lft = True #@param {type:"boolean"}
rght = True #@param {type:"boolean"}
bc_ = True #@param {type:"boolean"}
fr_ = True #@param {type:"boolean"}
fr_cl_btm = False #@param {type:"boolean"}
fr_cl_tp = False #@param {type:"boolean"}


names = ["fr_lft","fr_rght","fr_cl","bc_lft","bc_rght","lft","rght","bc_","fr_","fr_cl_btm","fr_cl_tp"]
ALLOWED_BASES = [n for n in names if locals()[n]]

# CONFIG

In [None]:
import re
# --- Unified CONFIG ---
from google.colab import userdata
FAL_KEY = userdata.get('FAL_KEY')

# Selection mode: list only
RUN_MODE = "sku_list"     #@param ["sku_list"]

# For RUN_MODE == "sku_list"
SKU_CSV = "28920"  #@param {type:"string"}

# Paths
BASE_PHOTOS_ROOT  = "/content/drive/MyDrive/Dazzl/SikSilk/SKSLK_MODELS/"
GARMENTS_ROOT     = "/content/drive/MyDrive/Dazzl/SikSilk/AlexGens/SikSilk/"


# Filename/dir policy
VALID_EXTENSIONS  = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")
IGNORE_DIRS       = {"old", "__MACOSX", ".ds_store", "Ricardo", "toweling"}
SKIP_FILENAME_TOKENS_CSV   = "mask, generated, _close, _open, freelance, _backup"   # substrings to skip
SKIP_BASENAME_SUFFIXES_CSV = "_sec"                             # stem endings to skip
REQUIRE_CUT_IN_FILENAME    = False   #@param {type:"boolean"}
PREFER_AGNOSTIC_MASKS = True #@param {type:"boolean"}
secondary_garment = True #@param {type:"boolean"}
SECONDARY_GARMENT = secondary_garment

# Cropping / paste-back (square 1:1, generous garment margin)
CROP_PADDING      = 300        # px padding around garment when building crop
UPPER_PADDING     = 200        # extra padding above garment
HORIZ_PADDING     = 150        # horizontal padding
MASK_EXPAND_PX    = 100        # outward growth before feather
MASK_FEATHER_PX   = 30         # Gaussian sigma for feathering
CROP_MIN_MARGIN   = 20        # minimum margin even if mask touches edge

TARGET_ASPECT = (1, 1)         # enforce square crops for 1:1 generations

# NanoBanana Pro (Google GenAI)
NANOBANANA_MODEL_ID = "gemini-3-pro-image-preview"
GEN_ASPECT_RATIO    = "1:1"
GEN_IMAGE_SIZE      = "4K" #@param ["1K", "2K", "4K"]


MAIN_PROMPT_INTRO = """You are an expert virtual try-on AI. You will be given a 'model image' and a 'garment image'. Your task is to create a new photorealistic image where the person from the 'model image' is wearing the clothing from the 'garment image'."""

SECONDARY_PROMPT_INTRO = """You are an expert virtual try-on AI. You will be given a 'model image' and a 'garment image'. Your task is to create a new photorealistic image where the person from the 'model image' is wearing the clothing from the 'garment image' as a complementary garment to their main."""

MAIN_PROMPT_RULES = [
    "**Complete Garment Replacement:** You MUST completely REMOVE and REPLACE the clothing item worn by the person in the 'model image' with the new garment. No part of the original clothing (e.g., collars, sleeves, patterns) should be visible in the final image.",
    "**Preserve the Model:** The person's face, hair, body shape, and pose from the 'model image' MUST remain unchanged, pixel-for-pixel.",
    "**Preserve the Background:** The entire background from the 'model image' MUST be preserved perfectly, pixel-for-pixel.",
    "**Apply the Garment:** Realistically fit the new garment onto the person. It should adapt to their pose with natural folds, shadows, and lighting consistent with the original scene.",
    "**Output:** Return ONLY the final, edited image. Do not include any text.",
    "**Bespoke quality:** The garment should be ironed (if applicable), pretty and sit perfectly well — this is a professional fashion product photoshoot.",
]

MAIN_TEXTURE_RULE = "**Fabric Texture:** Use the provided 'Main texture reference' image to match the fabric texture, print, and sheen perfectly on the garment."

SECONDARY_PROMPT_RULES = [
    "**Complete {sec_type} Garment Replacement:** You MUST completely REMOVE and REPLACE the {sec_type} clothing item worn by the person in the 'model image' with the new {sec_type} garment. No part of the original cloth (e.g., collars, sleeves, patterns) should be visible in the final image.",
    "**Preserve the Model:** The person's face, hair, body shape, and pose from the 'model image' MUST remain unchanged, pixel-for-pixel.",
    "**Preserve the Background:** The entire background from the 'model image' MUST be preserved perfectly, pixel-for-pixel.",
    "**Apply the {sec_type} Garment:** Realistically fit the new {sec_type} garment onto the person. It should adapt to their pose with natural folds, shadows, and lighting consistent with the original scene.",
    "**Bespoke quality:** the {sec_type} garment should be ironed (if applicable), pretty and sit perfectly well — this is a professional fashion product photoshoot.",
    "**Preserve EXACT composition:** the source composition shouldn't change even slightly.",
    "**Inpaint only what is needed:** Carefully consider, which part of the provided garment would be visible in the given angle. If target garment is seen only partly — replace that part with corresponding part of the new garment, but nothing else.",
    "**NEVER ZOOM IN or OUT**",
    "**Output:** Return ONLY the final, edited image. Do not include any text.",
]

SECONDARY_TEXTURE_RULE = "**Fabric Texture:** Use the provided texture reference image to match the {sec_type} garment's fabric texture, print, and sheen perfectly."

def _numbered_rules(rules):
    return "\n".join(f"{i+1}. {rule}" for i, rule in enumerate(rules))

def build_main_prompt(include_texture: bool = False):
    rules = list(MAIN_PROMPT_RULES)
    if include_texture:
        rules.insert(5, MAIN_TEXTURE_RULE)  # ensure texture guidance is #6
    return f"{MAIN_PROMPT_INTRO}\n\n**Crucial Rules:**\n" + _numbered_rules(rules)

def build_secondary_prompt(sec_type: str, include_texture: bool = False):
    sec = sec_type or "secondary"
    rules = [r.format(sec_type=sec) for r in SECONDARY_PROMPT_RULES]
    if include_texture:
        rules.insert(5, SECONDARY_TEXTURE_RULE.format(sec_type=sec))
    intro = SECONDARY_PROMPT_INTRO.format(sec_type=sec)
    return f"{intro}\n\n**Crucial Rules:**\n" + _numbered_rules(rules)

TRYON_PROMPT = build_main_prompt(include_texture=False)
SECONDARY_TRYON_PROMPT = build_secondary_prompt(sec_type="secondary", include_texture=False)

# SAM3 segmentation (fal.ai)
FAL_SAM_MODEL_ID = "fal-ai/sam-3/image"
MASK_MAX_SIZE     = 1024   # px max side sent for segmentation
MASK_PROMPT_TEMPLATE = "{category}"


# Sheet-related (Ops removed) kept only for Gen Log appends
SPREADSHEET_ID = "1Kbq9__sEUQiuDPuza5Xy_hRyIn8pUvmfFj6vhPBrp8Y"
GEN_LOG_SHEET  = "Gen Log"

# Misc
SHOW_VISUALS = True
TIMEZONE     = "Europe/Lisbon"
OPERATOR     = "Ivan"
OUTPUT_DIR   = "/content/drive/MyDrive/Dazzl/SikSilk/SS_OUTPUT_FOLDER/v1-5/" #@param {type:"string"}


# === Garment/type taxonomy (kept) ===
ALLOWED_GARMENT_TYPES = [
    "hoodie","jeans","joggers","shorts","sweater","swimwear",
    "t-shirt","shirts","track top","trousers","twinset","polo","vests","shirts"
]
TOP_GARMENTS    = ["t-shirts","shirts","sweaters","hoodies","polos","vests"]
BOTTOM_GARMENTS = ["shorts","joggers","trousers","jeans", "pants", "swimwear"]
TWINSET_TYPES   = ["twinset"]

# === Details tokens ===
ALLOWED_DETAIL_TYPES = ["crest","logo","patch"]

# Angle sheet tokens (kept for compatibility, not used in v1.3)
ANGLE_NEEDS_REGENERATE_TOKEN = "Regenerate"
ENFORCE_BAN_SUBSTRINGS     = True
BANNED_SUBSTRINGS_CSV      = "wrong, pair, combo"
ENFORCE_REQUIRE_SUBSTRINGS = False
REQUIRED_SUBSTRINGS_CSV    = ""
REQUIRED_SUBSTRINGS_MODE   = "ANY"   # "ANY" | "ALL"

def normalize_sku_list(sku_csv: str) -> str:
    skus = []
    for raw in sku_csv.split(','):
        sku = raw.strip().upper()
        match = re.search(r'(\d+)', sku)
        if match:
            sku_number = match.group(1)
            skus.append(f"SS-{sku_number}")
    # Return as CSV string
    return ", ".join(skus)

SKU_CSV = normalize_sku_list(SKU_CSV)

print("✅ Config ready for NanoBanana Pro v1.3")


# DETAILER CONFIG (shared)

In [None]:

# --- Detailer (logo cleanup) CONFIG ---
RUN_DETAILER_AFTER_TRYON = True  # disable if you only want try-on
DETAILER_VISUALIZE = False       # set True for debug plots
DETAILER_ONLY_TYPES = ["logo"]  # fallback prompts when metadata has no details

DETAILER_QUEUE_FOLDER = "/content/drive/MyDrive/DETAILER_TODO"
TARGET_DIR = OUTPUT_DIR          # reuse try-on outputs
WORKING_DIR = GARMENTS_ROOT      # source search root for detailer
MASKS_ROOT = BASE_PHOTOS_ROOT    # mask lookup root (category/subcategory)

# Detailer model/runtime knobs
DEVICE_STR = "cuda"
INPAINT_GENEROUS_PAD = 150
INPAINT_TINY_PAD = 6
INPAINT_SEED = 2025
SKIP_IF_ALREADY_INPAINTED = False
USE_BF16_INFERENCE = True
VISUALIZE = False

# Shared constants
VALID_EXTS = VALID_EXTENSIONS
BASE_NAMES = ["fr_rght", "fr_lft", "fr_cl", "fr", "lft", "rght", "bc_lft", "bc_rght", "bc", "bc_cl"]
ACCEPTABLE_SUFFIXES = ["cut"]

# Expand garment/detail taxonomies so both flows agree
TOP_GARMENTS = sorted(set(list(TOP_GARMENTS) + ["t-shirt", "shirt", "sweater", "hoodie", "track top", "vest"]))
BOTTOM_GARMENTS = sorted(set(list(BOTTOM_GARMENTS) + ["shorts", "jogger-trousers", "trousers", "jeans", "swimwear"]))
TWINSET_TYPES = sorted(set(list(TWINSET_TYPES) + ["twinset"]))
ALLOWED_DETAIL_TYPES = sorted(set(list(ALLOWED_DETAIL_TYPES) + ["waist text", "sleeve text"]))



# UTILS

In [None]:
# --- Core utilities: normalization, angles, walking, masks (agnostic-first) ---

import os, re, fnmatch, math, uuid, pytz, random, gc, tempfile, traceback
from datetime import datetime
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter



# Parsers
def _parse_csv_list(s):  return [x.strip().casefold() for x in (s or "").split(",") if x.strip()]
BANNED_SUBSTRINGS       = _parse_csv_list(BANNED_SUBSTRINGS_CSV)
REQUIRED_SUBSTRINGS     = _parse_csv_list(REQUIRED_SUBSTRINGS_CSV)
SKIP_FILENAME_TOKENS    = set(_parse_csv_list(SKIP_FILENAME_TOKENS_CSV))
SKIP_BASENAME_SUFFIXES  = tuple(_parse_csv_list(SKIP_BASENAME_SUFFIXES_CSV))

# Normalizers
def _norm_sku(s):
    if s is None: return ""
    s = str(s).replace("\u00A0"," ")
    s = " ".join(s.split())
    return s.casefold()

def _norm_angle(s):
    s = (s or "").strip().lower()
    return s.strip("_ ").replace("-", "_")

# Angle aliases
ANGLE_ALIASES = {
    "fr_cl":   ["fr", "fr_"],
    "fr":      ["fr_cl"],
    "bc_lft":  ["bc", "bc_"],
    "bc_rght": ["bc", "bc_"],
}


# --- Helpers to keep outputs strict, sources flexible ---
def expand_as_list(angles):
    exp = list(expand_allowed_angles(angles))
    exp = [_norm_angle(a) for a in exp]
    exp.sort(key=len, reverse=True)  # prefer 'fr_cl' over 'fr'
    return exp

def pick_target_angle(source_angle: str, allowed_outputs: set) -> str | None:
    s = _norm_angle(source_angle)
    for target in allowed_outputs:
        fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
        if s in fam:
            return _norm_angle(target)
    return None


def expand_allowed_angles(angles):
    expanded = set()
    for a in (angles or []):
        a_norm = _norm_angle(a)
        expanded.add(a_norm)
        for alt in ANGLE_ALIASES.get(a_norm, []):
            expanded.add(_norm_angle(alt))
    return expanded

# Ignore set
IGNORE_DIRS = {d.lower() for d in IGNORE_DIRS}

# Walkers
def _is_sku_folder(path: str) -> bool:
    if os.path.basename(os.path.normpath(path)).lower() in IGNORE_DIRS:
        return False
    try:
        for f in os.listdir(path):
            if os.path.isfile(os.path.join(path, f)) and f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)):
                return True
    except Exception:
        return False
    return False

def iter_sku_folders(root: str):
    for dirpath, dirnames, filenames in os.walk(root):
        dirnames[:] = [d for d in dirnames if d.lower() not in IGNORE_DIRS]
        if any(f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)) for f in filenames):
            yield dirpath

def resolve_targets(idents_csv: str, garments_root: str):
    """
    Accepts:
      • Plain SKU names, relative paths (Category/Subcategory/SKU), or absolute dirs
      • Glob patterns (e.g., 'Hoodies/*' or 'SKSLK_12*')
      • Directories that are NOT SKU leaves → expand to all descendant SKU leaves
    """
    idents = [s.strip() for s in idents_csv.replace("\n", ",").split(",") if s.strip()]
    if not idents: return [], []

    all_sku_dirs = list(iter_sku_folders(garments_root))
    rel_map = {p: os.path.relpath(p, garments_root) for p in all_sku_dirs}
    base_map = {p: os.path.basename(p) for p in all_sku_dirs}

    seen, out, unmatched = set(), [], []
    def add_path(p):
        ap = os.path.abspath(p)
        if os.path.isdir(ap):
            if _is_sku_folder(ap):
                if ap not in seen:
                    seen.add(ap); out.append(ap)
            else:
                # Expand directory to all descendant SKU leaves
                for leaf in iter_sku_folders(ap):
                    a = os.path.abspath(leaf)
                    if a not in seen:
                        seen.add(a); out.append(a)

    for ident in idents:
        before = len(out)
        # Absolute directory or SKU path
        if os.path.isabs(ident) and os.path.isdir(ident):
            add_path(ident)

        # Relative under garments root (dir or SKU)
        rel_candidate = os.path.join(garments_root, ident)
        if os.path.exists(rel_candidate):
            add_path(rel_candidate)

        # Glob/pattern over known SKU leaves (by basename or relative path)
        for p in all_sku_dirs:
            if fnmatch.fnmatch(base_map[p], ident) or fnmatch.fnmatch(rel_map[p], ident):
                add_path(p)

        if len(out) == before:
            unmatched.append(ident)

    out.sort()
    return out, unmatched

# Base/mask location resolution
def resolve_base_mask_dir(sku_folder: str,
                          garments_root: str = GARMENTS_ROOT,
                          base_root: str = BASE_PHOTOS_ROOT):
    """
    Map .../GARMENTS_ROOT/Category/Subcategory/SKU → .../BASE_ROOT/Category/Subcategory
    With robust fallbacks.
    """
    abs_sku = os.path.abspath(sku_folder)
    abs_gar = os.path.abspath(garments_root)
    try:
        rel = os.path.relpath(abs_sku, abs_gar)
    except Exception:
        rel = None

    if rel and not rel.startswith(".."):
        rel_parent = os.path.dirname(rel)
        cand = os.path.join(base_root, rel_parent)
        if os.path.isdir(cand): return cand

    subcat = os.path.basename(os.path.dirname(abs_sku))
    cat    = os.path.basename(os.path.dirname(os.path.dirname(abs_sku)))
    cand2  = os.path.join(base_root, cat, subcat)
    if os.path.isdir(cand2): return cand2

    cand3  = os.path.join(base_root, subcat)
    if os.path.isdir(cand3): return cand3
    return None

def _valid_ext(fname): return fname.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS))

def _file_prefix_or_none(filename: str):
    low = filename.lower()
    for base in ALLOWED_BASES:
        if low.startswith(base): return base
    return None

def _find_image_with_stem_and_suffix(directory, stem, suffix=""):
    if not directory or not os.path.isdir(directory):
        return None
    stem = stem.lower()
    for file in os.listdir(directory):
        fname, fext = os.path.splitext(file)
        if fext.lower() in (".png",".jpg",".jpeg") and fname.lower() == f"{stem}{suffix}":
            return os.path.join(directory, file)
    return None

# --- Existence check in Google Drive by Colab-style path ---
def drive_file_exists_any_ext_at_colab_path(target_colab_path: str,
                                            exts=(".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")) -> bool:
    """
    Given a Colab-style *file* path (incl. a filename with any extension),
    checks if a file with the SAME stem exists in the same folder with any of the allowed extensions.
    """
    try:
        parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(target_colab_path)
        stem, _ = os.path.splitext(desired_name)
        files = _list_children(parent_id, q_extra="")  # list once; filter locally
        allowed = {e.lower() for e in exts}
        for f in files:
            fname = f.get("name", "")
            s, e = os.path.splitext(fname)
            if s == stem and e.lower() in allowed:
                return True
        return False
    except Exception as e:
        print(f"⚠️ Ext-agnostic existence check failed for {target_colab_path}: {e}")
        return False


# === NEW: mask finding with AGNOSTIC priority ===
def find_mask_path(base_subcat_dir: str, stem_no_cut: str):
    """
    Priority:
      1) {stem}_mask_agnostic.(png|jpg|jpeg)
      2) {stem}_mask.(png|jpg|jpeg)
    """
    if not base_subcat_dir or not os.path.isdir(base_subcat_dir):
        return None

    candidates = []
    if PREFER_AGNOSTIC_MASKS:
      for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
          candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask_agnostic{ext}"))
    for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
        candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask{ext}"))

    for p in candidates:
        if os.path.exists(p):
            return p
    return None


def find_secondary_garment_path(folder_path: str, main_filename: str):
    """
    Locate the secondary garment paired with a primary garment file.
    Example: main 'bc_lft_cut.png' -> looks for 'bc_lft_sec_cut.(png|jpg)'.
    Falls back to a non-cut variant when REQUIRE_CUT_IN_FILENAME is False.
    """
    stem, _ = os.path.splitext(main_filename)
    has_cut = stem.endswith("_cut")
    core = stem[:-4] if has_cut else stem

    candidates = [f"{core}_sec_cut"]
    if not REQUIRE_CUT_IN_FILENAME:
        candidates.append(f"{core}_sec")
    if not has_cut:
        candidates.append(f"{stem}_sec_cut")

    seen = set()
    for cand in candidates:
        if cand in seen:
            continue
        seen.add(cand)
        for ext in VALID_EXTENSIONS:
            path = os.path.join(folder_path, f"{cand}{ext}")
            if os.path.exists(path):
                return path
    return None

# ──────────────────────────────────────────
# --- Aspect-ratio bbox (replaces square bbox usage) ---



def find_aspect_bbox(
    mask: Image.Image,
    aspect: tuple[int,int] = (1,1),   # width:height, e.g. (1280,1600)
    padding: int = 40,
    upper_padding: int | None = None,
    horiz_padding: int = 0,
    min_margin: int | None = None,
    allow_padding: bool = True,
):
    """
    Return a rectangular bbox [x0, y0, x1, y1] that fully contains the mask + padding
    and matches the requested aspect ratio. When allow_padding is True the box may
    extend outside the image; callers should pad when cropping to preserve aspect.
    When allow_padding is False the bbox is kept inside the image while expanding
    other directions to honor the aspect ratio (minimal in-frame crop).
    """
    if min_margin is None:
        try:
            min_margin = int(MASK_EXPAND_PX + 3 * MASK_FEATHER_PX + 5)
        except Exception:
            min_margin = 40

    m = np.array(mask.convert("L"))
    h, w = m.shape
    ys, xs = np.where(m > 128)
    if xs.size == 0:
        raise ValueError("Mask has no white pixels!")

    x_min, x_max = int(xs.min()), int(xs.max())
    y_min, y_max = int(ys.min()), int(ys.max())

    if upper_padding is None:
        upper_padding = padding

    # Initial padded bbox (can go outside image bounds; padding applied later)
    x0 = x_min - horiz_padding - min_margin
    x1 = x_max + horiz_padding + min_margin
    y0 = y_min - upper_padding - min_margin
    y1 = y_max + padding + min_margin

    bw, bh = (x1 - x0), (y1 - y0)
    aw, ah = aspect
    target_ar = float(aw) / float(max(1, ah))

    def expand_to_aspect(x0, y0, x1, y1):
        bw = x1 - x0; bh = y1 - y0
        cur_ar = bw / float(max(1, bh))
        if cur_ar < target_ar:
            need_w = int(np.ceil(target_ar * bh))
            grow = max(0, need_w - bw)
            x0 -= grow // 2
            x1 += grow - grow // 2
        elif cur_ar > target_ar:
            need_h = int(np.ceil(bw / target_ar))
            grow = max(0, need_h - bh)
            y0 -= grow // 2
            y1 += grow - grow // 2
        return x0, y0, x1, y1

    x0, y0, x1, y1 = expand_to_aspect(x0, y0, x1, y1)

    if allow_padding:
        x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
        return [x0, y0, x1, y1]

    # Constrain inside the frame without padding (used for secondary flow)
    x0 = max(0.0, float(x0))
    y0 = max(0.0, float(y0))
    x1 = min(w, float(x1))
    y1 = min(h, float(y1))

    bw, bh = x1 - x0, y1 - y0
    if bw <= 0 or bh <= 0:
        return [0, 0, w, h]

    def place_span(b0, b1, size, limit):
        size = min(size, limit)
        span = b1 - b0
        min_start = max(0.0, b1 - size)
        max_start = min(limit - size, b0)
        desired = b0 - (size - span) / 2.0
        start = min(max(desired, min_start), max_start)
        end = start + size
        if end > limit:
            start -= (end - limit)
            end = limit
        if start < 0:
            end -= start
            start = 0
        return start, end

    if aw == ah:
        target_size = max(bw, bh)
        target_size = min(target_size, w, h)
        x0, x1 = place_span(x0, x1, target_size, float(w))
        y0, y1 = place_span(y0, y1, target_size, float(h))
    else:
        target_w = max(bw, float(np.ceil(bh * target_ar)))
        target_h = max(bh, float(np.ceil(target_w / target_ar)))
        target_w = min(target_w, w)
        target_h = min(target_h, h)
        x0, x1 = place_span(x0, x1, target_w, float(w))
        y0, y1 = place_span(y0, y1, target_h, float(h))

    return [int(round(x0)), int(round(y0)), int(round(x1)), int(round(y1))]




def crop_with_padding(img: Image.Image, bbox, fill):
    """Crop using bbox (which may extend outside the image) and pad missing areas with fill."""
    x0, y0, x1, y1 = map(int, bbox)
    w, h = img.size
    tgt_w, tgt_h = x1 - x0, y1 - y0
    out = Image.new(img.mode, (tgt_w, tgt_h), fill)

    src_box = (
        max(0, x0),
        max(0, y0),
        min(w, x1),
        min(h, y1),
    )
    dst_xy = (max(0, -x0), max(0, -y0))

    if src_box[2] > src_box[0] and src_box[3] > src_box[1]:
        region = img.crop(src_box)
        out.paste(region, dst_xy)
    return out

WHITE_RGB = (255,255,255)

def flatten_alpha_to_white(img: Image.Image) -> Image.Image:
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        bg = Image.new("RGB", img.size, WHITE_RGB)
        bg.paste(img, mask=img.split()[-1])
        return bg
    return img.convert("RGB")

def _tight_bbox_nonwhite_or_opaque(img: Image.Image):
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        arr = np.asarray(img.convert("RGBA"))
        alpha = arr[...,3]
        fg = alpha > 0
    else:
        arr = np.asarray(img.convert("RGB"))
        fg = ~((arr[...,0]==255)&(arr[...,1]==255)&(arr[...,2]==255))
    if not np.any(fg): return None
    ys, xs = np.where(fg)
    x0, x1 = int(xs.min()), int(xs.max())+1
    y0, y1 = int(ys.min()), int(ys.max())+1
    return (x0,y0,x1,y1)

def crop_garment_keep_aspect(img: Image.Image) -> Image.Image:
    bbox = _tight_bbox_nonwhite_or_opaque(img)
    base = flatten_alpha_to_white(img)
    if bbox is None: return base
    full_bbox = (0,0,base.width,base.height)
    if bbox == full_bbox: return base
    return base.crop(bbox)

def to_centered_square(gar: Image.Image, fill=WHITE_RGB) -> Image.Image:
    w,h = gar.size; side = max(w,h)
    sq = Image.new("RGB", (side, side), fill)
    ox, oy = (side-w)//2, (side-h)//2
    sq.paste(gar, (ox,oy)); return sq




def _add_caption_above_square(square_img: Image.Image, heading: str) -> Image.Image:
    pad_top = max(60, square_img.height // 5)
    canvas = Image.new("RGB", (square_img.width, square_img.height + pad_top), WHITE_RGB)
    draw = ImageDraw.Draw(canvas)
    font_size = max(64, square_img.width // 9)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", font_size)
    except Exception:
        font = ImageFont.load_default()
    bbox = draw.textbbox((0, 0), heading, font=font)
    text_w = bbox[2] - bbox[0]
    text_h = bbox[3] - bbox[1]
    text_x = max(0, (canvas.width - text_w) // 2)
    text_y = max(0, (pad_top - text_h) // 2)
    draw.text((text_x, text_y), heading, fill=(0, 0, 0), font=font)
    canvas.paste(square_img, (0, pad_top))
    return canvas

def build_no_texture_card(size):
    w, h = size if size and len(size) == 2 else (512, 512)
    w = max(128, int(w))
    h = max(128, int(h))
    canvas = Image.new("RGB", (w, h), WHITE_RGB)
    draw = ImageDraw.Draw(canvas)
    text = "No texture file found"
    font_size = max(64, min(w, h) // 9)
    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", font_size)
    except Exception:
        font = ImageFont.load_default()
    bbox = draw.textbbox((0, 0), text, font=font)
    tw, th = bbox[2] - bbox[0], bbox[3] - bbox[1]
    draw.text(((w - tw) // 2, (h - th) // 2), text, fill=(0, 0, 0), font=font)
    return canvas


def load_texture_reference(folder_path: str, *, secondary: bool = False, heading: str = "Main texture reference") -> Image.Image | None:
    """
    Load a texture reference if present:
      - texture.(png|jpg|jpeg) for primary
      - texture_sec.(png|jpg|jpeg) for secondary
    Crop to a square and add a caption banner above on white.
    """
    stem = "texture_sec" if secondary else "texture"
    for ext in VALID_EXTENSIONS:
        candidate = os.path.join(folder_path, f"{stem}{ext}")
        if os.path.exists(candidate):
            try:
                raw = open_upright(candidate).convert("RGB")
                side = min(raw.size)
                if side <= 0:
                    continue
                square = ImageOps.fit(raw, (side, side), method=Image.Resampling.LANCZOS, centering=(0.5, 0.5))
                return _add_caption_above_square(square, heading)
            except Exception as tex_err:
                print(f"      ⚠️ Unable to use texture reference '{candidate}': {tex_err}")
                return None
    return None


In [None]:
# --- Visualisation helpers (restored) ---

import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageOps

def open_upright(path) -> Image.Image:
    # EXIF-aware loader (same as before)
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im)

def show_gallery(img_list, titles=None, cols=3, w=4):
    """
    Display PIL images in a flexible grid (identical behaviour to your original).
    Only renders if SHOW_VISUALS is True.
    """
    if not globals().get("SHOW_VISUALS", False):
        return

    n = len(img_list)
    rows = math.ceil(n / cols)
    plt.figure(figsize=(cols * w, rows * w))

    for i, img in enumerate(img_list):
        plt.subplot(rows, cols, i + 1)
        # Accept PIL, torch tensors or numpy arrays (4-D batch ⇒ pick first)
        if isinstance(img, np.ndarray) and img.ndim == 4:
            img = img[0]  # (B,H,W,C) → (H,W,C)
        # Torch tensors are printed via duck-typing check to avoid hard import
        if "Tensor" in str(type(img)):
            img = img.detach().cpu().permute(1, 2, 0).numpy()
        plt.imshow(img)
        if titles and i < len(titles):
            plt.title(titles[i])
        plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import cv2
from PIL import Image

def paste_crop_back_debug(
    full_img: Image.Image,
    edited_crop: Image.Image,
    crop_box,               # (x0, y0, x1, y1) in full_img coords
    crop_mask,              # H×W uint8/bool, garment=white in crop coords
    solid_expand_px: int = 8,   # grow the 100% opaque region
    halo_px: int = 40,          # thickness of the soft halo OUTSIDE solid
    feather_px: int = 20,       # Gaussian sigma for halo
    *,
    bin_thresh: int = 127,
    edge_feather_px: int = 15,  # clamp width at crop borders
):
    x0, y0, x1, y1 = map(int, crop_box)
    tgt_w, tgt_h   = (x1 - x0), (y1 - y0)

    # --- resize edited crop ---
    edit_rs = edited_crop.resize((tgt_w, tgt_h), Image.Resampling.LANCZOS)

    # --- 1) binary silhouette mask in crop coords ---
    mask_np = crop_mask
    if isinstance(mask_np, Image.Image):
        mask_np = np.array(mask_np.convert("L"))
    if mask_np.ndim == 3:
        mask_np = mask_np[..., 0]

    mask_np = cv2.resize(mask_np, (tgt_w, tgt_h), interpolation=cv2.INTER_NEAREST)
    mask_bin = (mask_np > bin_thresh).astype(np.uint8)

    # --- 2) solid = expanded garment, outer = solid + halo -------------------
    def dilate(mask, r):
        if r <= 0:
            return mask.copy()
        ksize = max(1, r * 2 + 1)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
        return cv2.dilate(mask, kernel, iterations=1)

    solid = dilate(mask_bin, solid_expand_px)               # fully opaque region
    outer = dilate(solid, halo_px)                          # solid + halo shell

    # band where we want partial alpha
    band = outer.clip(0, 1).astype(np.float32) * 255.0

    # --- 3) blur the band to get a smooth halo --------------------------------
    if feather_px > 0:
        band = cv2.GaussianBlur(
            band, (0, 0),
            sigmaX=feather_px,
            sigmaY=feather_px,
        )

    # --- 4) clamp halo near crop borders (no recursion) -----------------------
    H, W = band.shape
    ef = max(1, int(edge_feather_px))
    ef = min(ef, H // 2, W // 2)

    leaks_top    = band[0, :].max() > 0
    leaks_bottom = band[-1, :].max() > 0
    leaks_left   = band[:, 0].max() > 0
    leaks_right  = band[:, -1].max() > 0

    if leaks_top and ef > 0:
        ramp = np.linspace(0.0, 1.0, ef, endpoint=True).reshape(-1, 1)
        band[:ef, :] *= ramp
    if leaks_bottom and ef > 0:
        ramp = np.linspace(1.0, 0.0, ef, endpoint=True).reshape(-1, 1)
        band[-ef:, :] *= ramp
    if leaks_left and ef > 0:
        ramp = np.linspace(0.0, 1.0, ef, endpoint=True).reshape(1, -1)
        band[:, :ef] *= ramp
    if leaks_right and ef > 0:
        ramp = np.linspace(1.0, 0.0, ef, endpoint=True).reshape(1, -1)
        band[:, -ef:] *= ramp

    # --- 5) final alpha: 255 inside solid, halo in the band only --------------
    alpha = band.copy()
    alpha[solid > 0] = 255.0
    alpha = np.clip(alpha, 0, 255).astype(np.uint8)

    # --- 6) composite back into full image ------------------------------------
    mask_img = Image.fromarray(alpha, mode="L")
    region   = full_img.crop((x0, y0, x1, y1))
    comp     = Image.composite(edit_rs, region, mask_img)
    out_img  = full_img.copy()
    out_img.paste(comp, (x0, y0))

    return out_img, alpha


In [None]:

# --- NanoBanana Pro try-on (Google Cloud GenAI) ---

import os
from google import genai
from google.genai import types
from PIL import Image

def _load_gemini_api_key():
    key = os.getenv("GEMINI_API_KEY") or os.getenv("GEMINI_APIKEY")
    try:
        from google.colab import userdata
        key = key or userdata.get("GEMINI_API_KEY")
    except Exception:
        pass
    if not key:
        raise ValueError("Set GEMINI_API_KEY in environment or Colab userdata.")
    return key

genai_client = genai.Client(api_key=_load_gemini_api_key())

def _extract_first_image(resp):
    import io
    parts = []
    if hasattr(resp, "parts"):
        parts.extend(resp.parts)
    for cand in getattr(resp, "candidates", []):
        parts.extend(getattr(getattr(cand, "content", None), "parts", []) or [])

    for part in parts:
        if isinstance(part, Image.Image):
            return part
        as_image = getattr(part, "as_image", None)
        if callable(as_image):
            img = as_image()
            if isinstance(img, Image.Image):
                return img
        inline = getattr(part, "inline_data", None)
        if inline and getattr(inline, "data", None):
            return Image.open(io.BytesIO(inline.data)).convert("RGB")
    raise ValueError("No image returned from NanoBanana Pro response.")


def run_nanobanana_tryon(model_image: Image.Image, garment_image: Image.Image,
                         *, aspect_ratio: str = GEN_ASPECT_RATIO,
                         image_size: str = GEN_IMAGE_SIZE,
                         prompt: str | None = None,
                         extra_images=None):
    prompt_text = prompt or TRYON_PROMPT
    contents = [
        prompt_text,
        "Model image:",
        model_image.convert("RGB"),
        "Garment image:",
        garment_image.convert("RGB"),
    ]
    if extra_images:
        contents.extend(extra_images)
    resp = genai_client.models.generate_content(
        model=NANOBANANA_MODEL_ID,
        contents=contents,
        config=types.GenerateContentConfig(
            response_modalities=["IMAGE"],
            image_config=types.ImageConfig(
                aspect_ratio=aspect_ratio,
                image_size=image_size,
            ),
        ),
    )
    img = _extract_first_image(resp)
    return img.convert("RGB")

print("✅ NanoBanana Pro client ready.")


In [None]:

# --- SAM3 segmentation via fal.ai for on-the-fly garment masks ---

import base64, io, json
import fal_client
import requests

MASK_CACHE = {}
TOP_GARMENTS_SET = {g.casefold() for g in TOP_GARMENTS}
BOTTOM_GARMENTS_SET = {g.casefold() for g in BOTTOM_GARMENTS}

def _classify_garment_category(category: str) -> str | None:
    cat_norm = re.sub(r"\s+", " ", str(category or "").strip()).casefold()
    if cat_norm in TOP_GARMENTS_SET:
        return "top"
    if cat_norm in BOTTOM_GARMENTS_SET:
        return "bottom"
    return None


def _garment_category_from_path(path: str) -> str:
    parts = os.path.normpath(path).split(os.sep)
    lowers = [p.casefold() for p in parts]
    cat = None
    if "siksilk" in lowers:
        last_idx = len(lowers) - 1 - lowers[::-1].index("siksilk")
        if last_idx + 1 < len(parts):
            cat = parts[last_idx + 1]
    if not cat:
        try:
            rel = os.path.relpath(path, GARMENTS_ROOT)
            if not rel.startswith(".."):  # path is under garments root
                rel_parts = rel.split(os.sep)
                if rel_parts:
                    cat = rel_parts[0]
        except Exception:
            pass
    if not cat:
        parent = os.path.basename(os.path.dirname(path))
        cat = parent or "garment"
    cat_clean = re.sub(r"[_]+", " ", str(cat)).strip()
    cat_clean = re.sub(r"\s+", " ", cat_clean)
    return cat_clean or "garment"

def _build_mask_prompt(category: str, variant: str | None = None) -> str:
    clean_category = str(category).strip()
    clean_category = clean_category.rstrip("s") if clean_category else clean_category
    prompt = MASK_PROMPT_TEMPLATE.format(category=clean_category, variant=(variant or "").strip())
    prompt = prompt.strip()
    return prompt or clean_category

def _strip_json_block(text: str) -> str:
    if "```json" in text:
        return text.split("```json", 1)[1].split("```", 1)[0].strip()
    if "```" in text:
        return text.split("```", 1)[1].split("```", 1)[0].strip()
    return text.strip()

def _mask_data_to_canvas(data_url: str, target_size):
    w, h = target_size
    if data_url.startswith("http"):
        resp = requests.get(data_url)
        resp.raise_for_status()
        raw = resp.content
    else:
        encoded = data_url.split(",", 1)[1] if "," in data_url else data_url
        encoded = encoded.strip()
        if not encoded:
            raise ValueError("Mask data missing from SAM response.")
        pad_len = (-len(encoded)) % 4
        if pad_len:
            encoded = encoded + "=" * pad_len
        try:
            raw = base64.b64decode(encoded)
        except Exception as e:
            raise ValueError(f"Invalid mask image data: {e}")
    mask_img = Image.open(io.BytesIO(raw)).convert("L")
    if mask_img.size != (w, h):
        mask_img = mask_img.resize((w, h), Image.Resampling.NEAREST)
    return mask_img.point(lambda v: 255 if v > 128 else 0)

def _image_to_data_url(img: Image.Image) -> str:
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()

def _extract_masks_from_result(result):
    masks = []
    for m in result.get("masks") or []:
        if isinstance(m, dict):
            data = m.get("file_data") or m.get("data")
            ct = m.get("content_type") or "image/png"
            if data:
                if not str(data).startswith("data:"):
                    data = f"data:{ct};base64,{data}"
                masks.append(data)
            elif m.get("url"):
                masks.append(m["url"])
        elif isinstance(m, str):
            masks.append(m)
    img = result.get("image")
    if img:
        if isinstance(img, dict):
            data = img.get("file_data") or img.get("data")
            ct = img.get("content_type") or "image/png"
            if data:
                if not str(data).startswith("data:"):
                    data = f"data:{ct};base64,{data}"
                masks.append(data)
            elif img.get("url"):
                masks.append(img["url"])
        elif isinstance(img, str):
            masks.append(img)
    return masks

def _fetch_sam_masks(image_data_url: str, category: str):
    logs = []
    def _on_queue(update):
        if isinstance(update, fal_client.InProgress):
            for log in update.logs:
                msg = log.get("message")
                if msg:
                    logs.append(msg)
                    print(msg)
    result = fal_client.subscribe(
        FAL_SAM_MODEL_ID,
        arguments={
            "image_url": image_data_url,
            "text_prompt": category,
            "apply_mask": False,
            "return_multiple_masks": False,
            "max_masks": 1,
            "output_format": "png",
            "sync_mode": True,
        },
        with_logs=True,
        on_queue_update=_on_queue,
    )
    return _extract_masks_from_result(result)

def generate_mask_with_gemini(base_img_path: str, garment_folder: str, *, mask_variant: str | None = None):
    cache_key = (base_img_path, garment_folder, mask_variant or "primary")
    if cache_key in MASK_CACHE:
        return MASK_CACHE[cache_key]
    if not FAL_KEY:
        raise ValueError("Set FAL_KEY variable for fal.ai SAM3 segmentation.")
    category = _garment_category_from_path(garment_folder)
    prompt_category = str(category).strip()
    main_class = _classify_garment_category(category)
    if mask_variant:
        prompt_category = f"{prompt_category} ({mask_variant})"
        if "sec" in str(mask_variant).lower() and main_class:
            if main_class == "top":
                prompt_category = "bottom garment"
            elif main_class == "bottom":
                prompt_category = "upper clothes"
    prompt_category = prompt_category.rstrip("s") if prompt_category else prompt_category
    base_img = Image.open(base_img_path).convert("RGB")
    orig_w, orig_h = base_img.size
    inf_img = base_img
    scale = 1.0
    max_side = max(orig_w, orig_h)
    if max_side > MASK_MAX_SIZE:
        scale = MASK_MAX_SIZE / float(max_side)
        inf_img = base_img.resize((max(1, int(orig_w * scale)), max(1, int(orig_h * scale))), Image.Resampling.LANCZOS)

    print(f"Segmentation started for {os.path.basename(base_img_path)} [{prompt_category}] via fal.ai SAM3")
    prompt = _build_mask_prompt(prompt_category, mask_variant)
    masks = _fetch_sam_masks(_image_to_data_url(inf_img), prompt)
    if not masks:
        raise ValueError("SAM3 did not return any masks.")

    mask_canvas = None
    errors = []
    for m in masks:
        if not isinstance(m, str):
            continue
        try:
            mask_canvas = _mask_data_to_canvas(m, inf_img.size)
            break
        except Exception as e:
            errors.append(str(e))
            continue

    if mask_canvas is None:
        raise ValueError(f"SAM3 returned masks but none were usable: {errors}")

    if scale != 1.0:
        mask_canvas = mask_canvas.resize((orig_w, orig_h), Image.Resampling.NEAREST)
    result = (mask_canvas, f"sam3:{prompt_category}")
    MASK_CACHE[cache_key] = result
    return result

print("✅ SAM3 mask generation ready (fal.ai)")


In [None]:
# Colab cell — output routing + metadata helpers
import os, json
from PIL import PngImagePlugin

def ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

ensure_dir(OUTPUT_DIR)

def build_output_filename(sku_name: str, angle_code: str, ext=".png", suffix="") -> str:
    # Examples: SS-12345-fr_rght or SS-12345-bc_lft
    angle_clean = _norm_angle(angle_code)
    suffix = suffix or ""
    return f"{sku_name}-{angle_clean}{suffix}{ext}"


import json, piexif
from PIL import Image

def save_png_with_metadata(img, out_path, details_payload=None, quality=95):
    if details_payload:
        # Encode JSON as UTF-8 with an ASCII prefix per EXIF spec for UserComment
        payload = json.dumps(details_payload, ensure_ascii=False).encode("utf-8")
        user_comment = b"ASCII\x00\x00\x00" + payload  # indicates undefined/UTF-8
        exif_dict = {"0th": {}, "Exif": {piexif.ExifIFD.UserComment: user_comment}, "1st": {}, "GPS": {}, "Interop": {}}
        exif_bytes = piexif.dump(exif_dict)
        img.save(out_path, format="PNG", exif=exif_bytes)
    else:
        img.save(out_path, format="PNG")

import json, piexif
from PIL import Image, ImageOps


In [None]:
# --- Google APIs: gspread + Drive upload (Operations sync removed) ---

import google.auth
SCOPES = ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/spreadsheets"]
creds, _ = google.auth.default(scopes=SCOPES)

import gspread
gs = gspread.authorize(creds)

from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
drive_svc = build("drive", "v3", credentials=creds)

FOLDER_MIME   = "application/vnd.google-apps.folder"
SHORTCUT_MIME = "application/vnd.google-apps.shortcut"
PATH_PREFIX   = "/content/drive/MyDrive/"

def _escape_name(name: str) -> str: return name.replace("'", r"'")

def _maybe_follow_shortcut(file_obj):
    if file_obj.get("mimeType") == SHORTCUT_MIME:
        sd = file_obj.get("shortcutDetails", {}) or {}
        return sd.get("targetId"), sd.get("targetMimeType")
    return file_obj.get("id"), file_obj.get("mimeType")

def _list_children(parent_id: str, q_extra: str, page_size: int = 1000):
    q = f"'{parent_id}' in parents and trashed = false"
    if q_extra: q += f" and ({q_extra})"
    resp = drive_svc.files().list(
        q=q, spaces="drive", pageSize=page_size,
        fields="files(id,name,mimeType,shortcutDetails)",
        includeItemsFromAllDrives=True, supportsAllDrives=True,
    ).execute()
    return resp.get("files", [])

def _find_folder_id(parent_id: str, name: str):
    files = _list_children(
        parent_id,
        q_extra=(
            f"name = '{_escape_name(name)}' and "
            f"(mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')"
        ),
    )

    # Direct folder match
    for f in files:
        if f["mimeType"] == FOLDER_MIME:
            return f["id"]

    # Shortcut to folder
    for f in files:
        if f["mimeType"] == SHORTCUT_MIME:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME:
                return tid

    # Fallback: match by case-insensitive name
    files = _list_children(
        parent_id,
        q_extra=(
            f"(mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')"
        ),
    )
    needle = name.strip().casefold()

    for f in files:
        if f.get("name", "").strip().casefold() == needle:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME:
                return tid

    return None


def _resolve_parent_id_and_filename_from_colab_path(colab_path: str):
    if not colab_path.startswith(PATH_PREFIX):
        raise ValueError(
            f"This helper supports only '{PATH_PREFIX}...'. Got: {colab_path}"
        )

    parts = colab_path[len(PATH_PREFIX):].strip("/").split("/")
    if not parts:
        raise ValueError("Path must include a file name.")

    parent_id = "root"

    for part in parts[:-1]:
        next_id = _find_folder_id(parent_id, part)
        if not next_id:
            raise FileNotFoundError(f"Folder not found in path: '{part}'")
        parent_id = next_id

    desired_name = parts[-1]
    return parent_id, desired_name


def upload_to_drive_folder(
    local_path: str,
    parent_folder_id: str,
    desired_name: str | None = None
):
    media = MediaFileUpload(local_path, resumable=True)
    body = {
        "name": desired_name or os.path.basename(local_path),
        "parents": [parent_folder_id],
    }

    file = (
        drive_svc.files()
        .create(
            body=body,
            media_body=media,
            fields="id, webViewLink, name, parents",
            supportsAllDrives=True,
        )
        .execute()
    )

    drive_svc.permissions().create(
        fileId=file["id"],
        body={"type": "anyone", "role": "reader"},
        fields="id",
        supportsAllDrives=True,
    ).execute()

    return file


def upload_file_and_append_to_sheet(
    local_path: str,
    target_colab_path: str,
    sku_name: str,
    angle: str,
    spreadsheet_id: str | None,
    worksheet_name: str | None,
):
    parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(
        target_colab_path
    )

    uploaded = upload_to_drive_folder(local_path, parent_id, desired_name)
    file_id = uploaded["id"]

    file_url = (
        uploaded.get("webViewLink")
        or f"https://drive.google.com/file/d/{file_id}/view?usp=sharing"
    )
    folder_url = f"https://drive.google.com/drive/folders/{parent_id}"

    if spreadsheet_id and worksheet_name:
        ts = datetime.now(pytz.timezone(TIMEZONE)).strftime("%m-%d %H:%M:%S")
        uid = str(uuid.uuid4())

        sh = gs.open_by_key(spreadsheet_id)
        ws = sh.worksheet(worksheet_name)

        sku_cell = f'=HYPERLINK("{folder_url}"; "{sku_name}")'

        ws.append_row(
            [
                sku_cell,
                angle,
                ts,
                file_url,
                uid,
                "Girls need to check",
                OPERATOR,
            ],
            value_input_option="USER_ENTERED",
        )

    return {"file_url": file_url}


# DETAILER SETUP + UTILS

#SETUP

In [None]:
import os, sys, torch, numpy as np, cv2, base64, gc, json
from pathlib import Path
from io import BytesIO
from PIL import Image, ImageOps
import piexif

CPU_DEVICE = torch.device("cpu")
GPU_DEVICE = torch.device("cuda") if torch.cuda.is_available() else CPU_DEVICE

device = torch.device(DEVICE_STR if torch.cuda.is_available() else "cpu")
print("✅ Torch device:", device)


In [None]:
import torch
from transformers import Sam3Processor, Sam3Model

HF_SAM3_ID = "facebook/sam3"
SAM3_CONFIDENCE = 0.05   # permissive to catch small logos; raise if predictions get noisy
SAM3_RESOLUTION = 1024
SAM3_PREFERRED_DEVICE = GPU_DEVICE  # pin SAM3 to CPU to avoid VRAM pressure
SAM3_DEVICE = CPU_DEVICE  # keep SAM3 on CPU until needed to leave VRAM for insert-anything

sam3_processor = Sam3Processor.from_pretrained(HF_SAM3_ID)
sam3_model = Sam3Model.from_pretrained(HF_SAM3_ID).to(SAM3_DEVICE)
sam3_model.eval()
print(f"✅ HF SAM3 ready (current device: {SAM3_DEVICE}, preferred: {SAM3_PREFERRED_DEVICE})")



In [None]:
#@title Insert_anything on nunchaku
%cd /content/insert-anything
from PIL import Image
import torch
import os
import numpy as np
import cv2
from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, expand_image_mask
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from datetime import datetime

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

IA_DEVICE = GPU_DEVICE
IA_CPU_OFFLOAD = True  # sequential/model CPU offload for ~10GB GPUs; set False to keep models on GPU
IA_OFFLOAD_FOLDER = "/content/ia_offload"
dtype = torch.bfloat16
size = (1024, 1024)



# Load the pre-trained model and LoRA-for-nunchaku weights
# Please replace the paths with your own paths
transformer = NunchakuFluxTransformer2dModel.from_pretrained("mit-han-lab/svdq-int4-flux.1-fill-dev")

pipe = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev",
    transformer=transformer,
    torch_dtype=dtype
)


transformer.update_lora_params(
    path_or_state_dict="/content/drive/MyDrive/insert-anything-lora/insert-anything_extracted_lora_rank_256-bf16.safetensors"
)


# Adjust the LoRA strength
transformer.set_lora_strength(1)

redux = FluxPriorReduxPipeline.from_pretrained("black-forest-labs/FLUX.1-Redux-dev").to(dtype=dtype)



# The purpose of this code is to reduce the GPU memory usage to 26GB, but it will increase the inference time accordingly.
os.environ["NNCF_GROUP_SIZE"] = "-1"      # disable token merging

if IA_CPU_OFFLOAD:
    try:
        pipe.enable_model_cpu_offload(gpu_id=0, offload_folder=IA_OFFLOAD_FOLDER)
        redux.enable_model_cpu_offload(gpu_id=0, offload_folder=IA_OFFLOAD_FOLDER)
    except TypeError:
        pipe.enable_model_cpu_offload(gpu_id=0)
        redux.enable_model_cpu_offload(gpu_id=0)
    IA_DEVICE = GPU_DEVICE
    print("✅ Insert-anything pipelines using model CPU offload (10GB-friendly)")
else:
    pipe.to(IA_DEVICE)
    redux.to(IA_DEVICE)
    print(f"✅ Insert-anything pipelines ready on {IA_DEVICE}")





# UTILS

In [None]:
import matplotlib.pyplot as plt

In [None]:
def open_upright(path) -> Image.Image:
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im).convert("RGB")

def open_source_with_black_bg(path: str) -> Image.Image:
    im = Image.open(path)
    im = ImageOps.exif_transpose(im)
    name_low = os.path.basename(path).lower()
    if "_cut" in name_low and im.mode in ("RGBA","LA"):
        rgb = im.convert("RGB")
        alpha = im.getchannel("A")
        black = Image.new("RGB", im.size, (0,0,0))
        return Image.composite(rgb, black, alpha)
    return im.convert("RGB")


# NEW — root of subcategory-wide garment masks
MASKS_ROOT = MASKS_ROOT  # reuse configured masks root
MASK_EXTS = ('.png', '.jpg', '.jpeg', '.webp', '.PNG', '.JPG', '.JPEG', '.WEBP')




import re
from pathlib import Path
from PIL import Image, ImageOps, ImageDraw

# --- Helper: get <Category>/<Subcategory> from the *source* path ------------
_SKU_DIR_RE = re.compile(r"SS-\d{3,7}", re.IGNORECASE)

def _category_subcategory_from_source(src_path: str) -> tuple[str, str] | None:
    """
    Resolve (Category, Subcategory) from the garment *source* path.
    Preferred: relative to WORKING_DIR → parts[0], parts[1].
    Fallback: find the SKU folder in the path and take the two parents.
    Returns None if not resolvable.
    """
    p = Path(src_path).resolve()
    wr = Path(WORKING_DIR).resolve()

    # Preferred: relative to WORKING_DIR
    try:
        rel = p.relative_to(wr)
        parts = rel.parts
        # Expect: Category/Subcategory/SKU/<file>
        if len(parts) >= 3:
            return parts[0], parts[1]
    except Exception:
        pass

    # Fallback: locate the SKU dir and take its two parents as Cat/Subcat
    parts = p.parts
    sku_idx = None
    for i, part in enumerate(parts):
        if _SKU_DIR_RE.fullmatch(part or ""):
            sku_idx = i
            break
    if sku_idx is not None and sku_idx >= 2:
        return parts[sku_idx - 2], parts[sku_idx - 1]

    # Last resort: try after an explicit 'SikSilk' anchor
    if "SikSilk" in parts:
        j = parts.index("SikSilk")
        if len(parts) >= j + 3:
            return parts[j + 1], parts[j + 2]

    return None

# --- New: derive mask basename from *angle*, not from filename heuristics ----
def _mask_basename_from_angle(angle_code: str | None) -> str | None:
    """
    Map 'fr' -> 'fr_mask', 'fr_lft' -> 'fr_lft_mask', 'bc_cl' -> 'bc_cl_mask', etc.
    If angle_code is missing, return None (→ no mask).
    """
    if not angle_code:
        return None
    angle = angle_code.strip().lower()
    return f"{angle}_mask"

# --- Exact-only mask finder ---------------------------------------------------

def find_mask_for_generated_exact(gen_path: str, source_path: str) -> Path | None:
    """
    EXACT lookup (no fuzzy fallbacks):
      angle  = parsed from queued filename/path (e.g., SS-12345_fr_cl.* -> 'fr_cl')
      (cat, subcat) = derived from source_path
      priority: <angle>_mask_agnostic.<ext>  →  <angle>_mask.<ext>
      searched in: MASKS_ROOT / cat / subcat
    """
    # 1) angle from queued filename
    _, angle = extract_sku_and_angle_from_path(gen_path)
    if not angle:
        print("⚠️  No angle parsed — proceeding without a mask.")
        return None
    angle = angle.strip().lower()

    # 2) category/subcategory from source path
    cat_sub = _category_subcategory_from_source(source_path)
    if not cat_sub:
        print("⚠️  Could not resolve Category/Subcategory from source path — no mask.")
        return None
    category, subcategory = cat_sub

    mask_dir = Path(MASKS_ROOT) / category / subcategory

    # 3) Try agnostic first, then regular; exact names only
    candidates = [f"{angle}_mask_agnostic", f"{angle}_mask"]

    for name in candidates:
        for ext in MASK_EXTS:
            cand = mask_dir / f"{name}{ext}"
            if cand.exists():
                which = "agnostic" if name.endswith("_agnostic") else "regular"
                print(f"✅ Found {which} mask: {cand}")
                return cand

    print(f"⚠️  No exact mask found in {mask_dir} for angle '{angle}' "
          f"(tried {candidates} with MASK_EXTS). Proceeding without mask.")
    return None

def load_binary_mask_for_generated(gen_path: str, source_path: str, gen_img: Image.Image) -> np.ndarray | None:
    mp = find_mask_for_generated_exact(gen_path, source_path)
    if mp is None:
        return None
    with Image.open(mp) as m:
        m = ImageOps.exif_transpose(m)
        return align_mask_to_image(m, gen_img)

# --- Align (unchanged) --------------------------------------------------------
def align_mask_to_image(mask_img: Image.Image, target_img: Image.Image) -> np.ndarray:
    mw, mh = mask_img.size
    tw, th = target_img.size
    if mh == th and mw > 0 and (tw % mw) == 0 and 1 < (tw // mw) <= 3:
        k = tw // mw
        tiled = Image.new('L', (tw, th), 0)
        src = mask_img.convert('L')
        for i in range(k):
            tiled.paste(src, (i * mw, 0))
        M = np.array(tiled, dtype=np.uint8)
    else:
        if mw == 0 or mh == 0:
            return np.zeros((th, tw), np.uint8)
        scale = max(mw / tw, mh / th)
        new_w = int(round(mw / scale)); new_h = int(round(mh / scale))
        m_resized = mask_img.convert('L').resize((new_w, new_h), Image.NEAREST)
        M = np.zeros((th, tw), np.uint8)
        x0 = (tw - new_w) // 2; y0 = (th - new_h) // 2
        M[y0:y0+new_h, x0:x0+new_w] = np.array(m_resized, dtype=np.uint8)
    return ((M > 127).astype(np.uint8) * 255)


def enlarge_mask(mask_np: np.ndarray, scale: float = 1.05) -> np.ndarray:
    """Dilate mask outward based on object size (≈scale of the foreground)."""
    if mask_np is None:
        return None
    mask = (mask_np > 0).astype(np.uint8)
    ys, xs = np.where(mask)
    if ys.size == 0 or scale <= 1.0:
        return (mask_np > 0).astype(np.uint8) * 255
    h_obj = ys.max() - ys.min() + 1
    w_obj = xs.max() - xs.min() + 1
    grow = max(1, int(round(max(h_obj, w_obj) * (scale - 1.0))))
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (grow * 2 + 1, grow * 2 + 1))
    out = cv2.dilate(mask, kernel, iterations=1)
    return (out > 0).astype(np.uint8) * 255



# ---- visuals
def _draw_bbox(img: Image.Image, bb_xyxy, color="lime", width=4):
    out = img.copy()
    if bb_xyxy is None: return out
    draw = ImageDraw.Draw(out)
    draw.rectangle(bb_xyxy, outline=color, width=width)
    return out

def _show_images(pairs, cols=3, figsize=(16,12)):
    rows = int(np.ceil(len(pairs) / cols))
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten() if rows*cols>1 else [axes]
    for ax,(title,img) in zip(axes, pairs):
        ax.imshow(img); ax.set_title(title, fontsize=10); ax.axis("off")
    for ax in axes[len(pairs):]: ax.axis("off")
    plt.tight_layout(); plt.show()

def resize_and_pad(image, target_size=1024):
    w, h = image.size
    scale = target_size / max(1, max(w, h))
    new_w, new_h = int(round(w * scale)), int(round(h * scale))
    image_resized = image.resize((new_w, new_h), Image.LANCZOS)

    pad_w = (target_size - new_w) // 2
    pad_h = (target_size - new_h) // 2
    padding = (pad_w, pad_h, target_size - new_w - pad_w, target_size - new_h - pad_h)

    # ✅ Match fill type to mode
    mode = image_resized.mode
    if mode in ("L", "1", "I", "F"):
        fill_color = 0                      # int for single-channel
    elif mode == "RGBA":
        fill_color = (0, 0, 0, 0)           # transparent for RGBA
    else:
        fill_color = (0, 0, 0)              # RGB tuple for RGB/others

    return ImageOps.expand(image_resized, padding, fill=fill_color)

def box_1024_to_original(box_xyxy_1024, original_w, original_h):
    x1_1024, y1_1024, x2_1024, y2_1024 = [float(v) for v in box_xyxy_1024]
    target_size = 1024
    w, h = original_w, original_h
    scale = target_size / max(w, h)
    new_w, new_h = int(round(w*scale)), int(round(h*scale))
    pad_w = (target_size - new_w)//2
    pad_h = (target_size - new_h)//2
    x1 = (x1_1024 - pad_w) / scale; x2 = (x2_1024 - pad_w) / scale
    y1 = (y1_1024 - pad_h) / scale; y2 = (y2_1024 - pad_h) / scale
    x1 = min(max(int(round(x1)),0), w); x2 = min(max(int(round(x2)),0), w)
    y1 = min(max(int(round(y1)),0), h); y2 = min(max(int(round(y2)),0), h)
    return [x1,y1,x2,y2]



In [None]:
def apply_binary_mask(img_rgb: Image.Image, mask_np: np.ndarray | None, outside_color=(5,5,5)) -> Image.Image:
    if mask_np is None:
        return img_rgb
    mask_L = Image.fromarray(mask_np.astype(np.uint8))
    mode = img_rgb.mode
    if mode not in ("RGB", "RGBA", "L"):
        img_rgb = img_rgb.convert("RGB")
        mode = "RGB"
    if mode == "RGB":
        if isinstance(outside_color, int):
            outside_color = (outside_color,) * 3
        bg = Image.new("RGB", img_rgb.size, outside_color)
    elif mode == "RGBA":
        if isinstance(outside_color, int):
            outside_color = (outside_color,) * 3 + (255,)
        elif len(outside_color) == 3:
            outside_color = (*outside_color, 255)
        bg = Image.new("RGBA", img_rgb.size, outside_color)
    else:
        if isinstance(outside_color, tuple):
            outside_color = int(np.mean(outside_color))
        bg = Image.new("L", img_rgb.size, int(outside_color))
    return Image.composite(img_rgb, bg, mask_L)


# Dynamic, perimeter-based mask gating for detect_detail (logo-friendly)
import torch
import numpy as np
import cv2
from PIL import Image, ImageDraw



# --- VRAM juggling to avoid SAM3 + IA sharing GPU at the same time ----------------
if "IA_DEVICE" not in globals():
    IA_DEVICE = GPU_DEVICE

def _move_sam3_to(target: torch.device):
    global SAM3_DEVICE
    target = torch.device(target)
    if target.type == "cuda" and not torch.cuda.is_available():
        target = CPU_DEVICE
    if SAM3_DEVICE == target:
        return
    sam3_model.to(target)
    SAM3_DEVICE = target
    if target.type == "cpu":
        torch.cuda.empty_cache(); gc.collect()


def _move_insert_anything_to(target: torch.device):
    global IA_DEVICE
    target = torch.device(target)
    if target.type == "cuda" and not torch.cuda.is_available():
        target = CPU_DEVICE
    if IA_CPU_OFFLOAD:
        IA_DEVICE = GPU_DEVICE if target.type == "cuda" else CPU_DEVICE
        return
    if IA_DEVICE == target:
        return
    pipe.to(target)
    redux.to(target)
    IA_DEVICE = target
    if target.type == "cpu":
        torch.cuda.empty_cache(); gc.collect()


def _prepare_for_sam3():
    global IA_DEVICE
    if IA_CPU_OFFLOAD:
        _move_sam3_to(SAM3_PREFERRED_DEVICE)
        IA_DEVICE = GPU_DEVICE
        return
    _move_insert_anything_to(CPU_DEVICE)
    _move_sam3_to(SAM3_PREFERRED_DEVICE)


def _prepare_for_insert_anything():
    global IA_DEVICE
    if IA_CPU_OFFLOAD:
        _move_sam3_to(CPU_DEVICE)
        IA_DEVICE = GPU_DEVICE
        return
    _move_sam3_to(CPU_DEVICE)
    _move_insert_anything_to(GPU_DEVICE)


# --- SAM3 helpers -----------------------------------------------------------
def _clip_box_to_image(box_xyxy, w: int, h: int):
    x1, y1, x2, y2 = box_xyxy
    x1 = max(0, min(w, int(round(x1))))
    y1 = max(0, min(h, int(round(y1))))
    x2 = max(0, min(w, int(round(x2))))
    y2 = max(0, min(h, int(round(y2))))
    return [x1, y1, x2, y2]


def _sam3_predict_text(image_pil: Image.Image, prompt: str, *, max_dets: int = 12, score_threshold: float = SAM3_CONFIDENCE):
    """Run SAM3 (HF) with a text prompt and return sorted predictions."""
    if not prompt:
        return []
    if "sam3_processor" not in globals() or "sam3_model" not in globals():
        raise RuntimeError("SAM3 is not initialized. Run the SAM3 setup cell first.")

    inputs = sam3_processor(images=image_pil, text=prompt, return_tensors="pt")
    inputs = {k: v.to(SAM3_DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

    with torch.inference_mode():
        outputs = sam3_model(**inputs)

    processed = sam3_processor.post_process_instance_segmentation(
        outputs,
        threshold=score_threshold,
        target_sizes=[image_pil.size[::-1]],
    )[0]

    boxes = processed.get("boxes")
    scores = processed.get("scores")
    masks = processed.get("masks")
    if boxes is None or scores is None or boxes.numel() == 0:
        return []

    boxes_np = boxes.detach().cpu().numpy()
    scores_np = scores.detach().cpu().numpy()
    masks_np = masks.detach().cpu().numpy() if masks is not None else None

    order = scores_np.argsort()[::-1]
    preds = []
    for idx in order[:max_dets]:
        mask_np = None
        if masks_np is not None:
            mask_np = (masks_np[idx] > 0.5).astype(np.uint8)
        preds.append({
            "box": boxes_np[idx].tolist(),
            "score": float(scores_np[idx]),
            "mask": mask_np,
        })
    return preds

def _mask_crop_to_full(mask_crop: np.ndarray | None, crop_box_on_full, full_size):
    """
    Place a mask (aligned to a crop image) back onto the full canvas.
    crop_box_on_full = (lx, ty, rx, by) used to produce the crop.
    full_size = (W, H) of the destination image.
    """
    if mask_crop is None or crop_box_on_full is None:
        return mask_crop

    full_w, full_h = full_size
    lx, ty, rx, by = [int(round(v)) for v in crop_box_on_full]
    x0, y0 = max(0, lx), max(0, ty)
    x1, y1 = min(rx, full_w), min(by, full_h)
    if x1 <= x0 or y1 <= y0:
        return np.zeros((full_h, full_w), np.uint8)

    mx0, my0 = max(0, -lx), max(0, -ty)
    mx1, my1 = mx0 + (x1 - x0), my0 + (y1 - y0)

    patch = mask_crop[my0:my1, mx0:mx1]
    if patch.shape[1] != (x1 - x0) or patch.shape[0] != (y1 - y0):
        patch = cv2.resize(patch, (x1 - x0, y1 - y0), interpolation=cv2.INTER_NEAREST)

    full_mask = np.zeros((full_h, full_w), np.uint8)
    full_mask[y0:y1, x0:x1] = (patch > 0).astype(np.uint8) * 255
    return full_mask


# Spatially guided detect_detail:
# - perimeter-based mask polarity (as before)
# - strict+tolerant mask gates
# - spatial prior from source garment (normalized center+area)
# - combined score = w_score * norm_score + w_spatial * spatial_affinity


# ---- helpers ---------------------------------------------------------------
# Top-7 spatial re-ranking for detail detection
# - Keeps perimeter-based mask polarity & light gates
# - Takes SAM3's highest-scoring proposals, filters with mask gates, then
#   re-ranks TOP_K using spatial prior from the SOURCE garment
# - Normalizes SAM3 scores within those K and slightly down-weights rank-1

import matplotlib.pyplot as plt

# ---------- helpers you already use elsewhere ----------
def make_spatial_prior_from_box(bb_xyxy, img_size):
    """Build prior from SOURCE detail box on its garment crop. Normalized to [0,1]."""
    if bb_xyxy is None:
        return None
    W, H = img_size
    x1, y1, x2, y2 = [float(v) for v in bb_xyxy]
    x1 = max(0.0, min(W, x1)); y1 = max(0.0, min(H, y1))
    x2 = max(0.0, min(W, x2)); y2 = max(0.0, min(H, y2))
    if x2 <= x1 or y2 <= y1:
        return None
    cx = ((x1 + x2) / 2.0) / max(1.0, W)
    cy = ((y1 + y2) / 2.0) / max(1.0, H)
    area = ((x2 - x1) * (y2 - y1)) / max(1.0, (W * H))
    return {"cx": float(cx), "cy": float(cy), "area": float(area)}

def _spatial_affinity(cx_n, cy_n, area_n, prior, mirror_ok=True,
                      sigma_center=0.16, sigma_area=0.50):
    """Gaussian affinity in [0,1] for center & (log)area; mirror-aware."""
    def _aff(cx_p):
        dc2 = (cx_n - cx_p)**2 + (cy_n - prior["cy"])**2
        s_center = np.exp(- dc2 / (2.0 * (sigma_center**2)))
        a = max(1e-6, area_n); ap = max(1e-6, prior["area"])
        dlog = np.log(a / ap)
        s_area = np.exp(- (dlog**2) / (2.0 * (sigma_area**2)))
        return float(s_center * s_area)
    base = _aff(prior["cx"])
    if mirror_ok:
        return max(base, _aff(1.0 - prior["cx"]))
    return base

# ---------- main: top-7 re-ranking ----------


def _ensure_mask_for_image(mask_input, image_pil, *, crop_box_on_full=None):
    """
    Align a mask to image_pil.

    mask_input:
      • np.ndarray aligned to image_pil (H×W)  OR
      • (mask_full_np, "FULL") + crop_box_on_full=(lx,ty,rx,by) from crop_to_square

    Returns: uint8 mask (0/255) aligned to image_pil.size, with the same padding
    behavior as crop_to_square (i.e., if the crop went outside, we pad zeros).
    """
    if mask_input is None or (isinstance(mask_input, tuple) and len(mask_input)==2 and mask_input[0] is None):
        return None

    # Case 1: already aligned to this image
    if not (isinstance(mask_input, tuple) and len(mask_input) == 2 and isinstance(mask_input[0], np.ndarray) and mask_input[1] == "FULL"):
        m = mask_input
        if m.ndim == 3:
            m = m[...,0] if m.shape[2] > 1 else m.squeeze(-1)
        if m.dtype != np.uint8:
            m = (m > 0).astype(np.uint8) * 255
        if (m.shape[1], m.shape[0]) != image_pil.size:
            m = cv2.resize(m, image_pil.size, interpolation=cv2.INTER_NEAREST)
        return m

    # Case 2: FULL mask + crop box from crop_to_square
    mask_full, _ = mask_input
    assert crop_box_on_full is not None, "crop_box_on_full is required for FULL mask."

    Hf, Wf = mask_full.shape[:2]
    lx, ty, rx, by = crop_box_on_full  # exactly what crop_to_square returned

    # Target canvas (the square side used by crop_to_square)
    tgt_w = int(round(rx - lx))
    tgt_h = int(round(by - ty))

    # Source window (clamped to the full image bounds)
    sx1 = int(np.floor(max(0, lx)))
    sy1 = int(np.floor(max(0, ty)))
    sx2 = int(np.ceil(min(Wf, rx)))
    sy2 = int(np.ceil(min(Hf, by)))

    # Offsets where the source window lands on the target canvas
    dx = int(np.floor(max(0, -lx)))   # same as crop_to_square's dx
    dy = int(np.floor(max(0, -ty)))   # same as crop_to_square's dy

    # Build canvas and paste the clipped region at (dx,dy)
    canvas = np.zeros((tgt_h, tgt_w), dtype=np.uint8)
    if sx2 > sx1 and sy2 > sy1:
        patch = mask_full[sy1:sy2, sx1:sx2]
        if patch.ndim == 3:
            patch = patch[...,0] if patch.shape[2] > 1 else patch.squeeze(-1)
        ph, pw = patch.shape[:2]
        canvas[dy:dy+ph, dx:dx+pw] = (patch > 0).astype(np.uint8) * 255

    # If image_pil size differs by a pixel due to rounding, align by resize
    if (canvas.shape[1], canvas.shape[0]) != image_pil.size:
        canvas = cv2.resize(canvas, image_pil.size, interpolation=cv2.INTER_NEAREST)

    return canvas


def _build_inside_mask_1024(mask_aligned_np, image_pil, *,
                            border_sample_px=2, erode_px=1, dilate_px=2,
                            debug=False):
    """
    Build 1024×1024 INSIDE mask with perimeter-based polarity, using the
    SAME resize_and_pad as the image to guarantee geometric alignment.
    """
    if mask_aligned_np is None:
        return None

    # 1) pad the mask to 1024 with the SAME routine as the image
    mL = Image.fromarray(mask_aligned_np, mode="L")
    m1024L = resize_and_pad(mL, target_size=1024).convert("L")
    m1024 = (np.array(m1024L) > 0)

    # 2) Perimeter-majority: which value dominates the border?
    h, w = m1024.shape
    b = max(1, int(border_sample_px))
    perim = np.concatenate([m1024[0:b,:].ravel(), m1024[h-b:h,:].ravel(),
                            m1024[:,0:b].ravel(), m1024[:,w-b:w].ravel()])
    ones = int(perim.sum()); zeros = int(perim.size - perim.sum())
    background_is_true = (ones >= zeros)   # majority on border = background
    inside = (~m1024) if background_is_true else m1024

    # 3) Moprhology for robust gating
    if erode_px > 0:
        k_e = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px*2+1, erode_px*2+1))
        inside = cv2.erode(inside.astype(np.uint8), k_e, 1).astype(bool)
    if dilate_px > 0:
        k_d = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_px*2+1, dilate_px*2+1))
        inside = cv2.dilate(inside.astype(np.uint8), k_d, 1).astype(bool)

    if debug:
        bg_txt = "white/True" if background_is_true else "black/False"
        cov = float(inside.mean())
        print(f"[mask1024] perimeter True={ones} False={zeros} → background={bg_txt}; inside_cov={cov:.3f}")

    return inside


def detect_detail_topk7(image_pil: Image.Image,
                        detail_type: str,
                        *,
                        source_prior: dict | None,
                        restrict_mask,                 # EITHER aligned np.ndarray OR (mask_full_np, "FULL")
                        crop_box_on_full=None,         # required if restrict_mask is ("FULL")
                        threshold: float = 0.05,
                        TOP_K: int = 7,
                        mirror_ok: bool = True,
                        # light mask gates
                        min_inside_frac: float = 0.30,
                        center_must_be_inside: bool = True,
                        erode_px: int = 1,
                        dilate_px: int = 2,
                        border_sample_px: int = 2,
                        # scoring weights
                        w_spatial: float = 0.65,
                        w_score: float = 0.35,
                        rank_weights: list[float] = None,
                        debug: bool = False,
                        viz: bool = False,
                        viz_overlay_mask: bool = True):
    """
    Robust top-7 re-ranking using SAM3 detections and optional mask gates.
    Returns (xyxy_on_image, raw_score, mask_on_image or None)
    """
    if rank_weights is None:
        rank_weights = [0.92, 1.00, 0.98, 0.97, 0.96, 0.955, 0.95]

    W0, H0 = image_pil.size
    prompt = (detail_type or "").strip() + "."

    mask_aligned = _ensure_mask_for_image(restrict_mask, image_pil, crop_box_on_full=crop_box_on_full)
    mask_bool = (mask_aligned > 0) if mask_aligned is not None else None

    preds = _sam3_predict_text(image_pil, prompt, max_dets=max(TOP_K * 3, 12), score_threshold=threshold)
    if not preds:
        return (None, None, None)

    picked = []
    for rank, p in enumerate(preds):
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], W0, H0)
        x1, y1, x2, y2 = box
        if x2 <= x1 or y2 <= y1:
            continue

        inside_frac = 1.0
        center_ok = True
        if mask_bool is not None:
            crop = mask_bool[y1:y2, x1:x2]
            area = max(1, (x2 - x1) * (y2 - y1))
            inside_frac = float(crop.sum()) / float(area)
            cxp, cyp = (x1 + x2) // 2, (y1 + y2) // 2
            center_ok = (0 <= cxp < W0 and 0 <= cyp < H0 and bool(mask_bool[cyp, cxp]))
            if inside_frac < min_inside_frac or (center_must_be_inside and not center_ok):
                continue

        picked.append({
            "box": box,
            "score": float(p["score"]),
            "rank": rank,
            "mask": p["mask"],
            "inside_frac": inside_frac,
        })
        if len(picked) >= TOP_K:
            break

    if not picked:
        base = preds[0]
        picked = [{"box": _clip_box_to_image(base["box"], W0, H0),
                   "score": float(base["score"]),
                   "rank": 0,
                   "mask": base["mask"],
                   "inside_frac": 0.0}]

    s = np.array([p["score"] for p in picked], dtype=np.float32)
    s_min, s_max = float(s.min()), float(s.max())
    s_norm = np.ones_like(s) * 0.5 if s_max == s_min else (s - s_min) / (s_max - s_min)

    best = None
    for j, p in enumerate(picked):
        rw = rank_weights[p["rank"]] if p["rank"] < len(rank_weights) else rank_weights[-1]
        score_normed = float(s_norm[j] * rw)
        x1, y1, x2, y2 = p["box"]
        area = max(1, (x2 - x1) * (y2 - y1))
        cx_n = ((x1 + x2) / 2.0) / max(1.0, W0)
        cy_n = ((y1 + y2) / 2.0) / max(1.0, H0)
        area_n = area / float(max(1, W0 * H0))
        spatial = _spatial_affinity(cx_n, cy_n, area_n, source_prior, mirror_ok=mirror_ok) if source_prior else 0.0
        combo = w_spatial * spatial + w_score * score_normed
        if best is None or combo > best["combo"]:
            best = {**p, "combo": float(combo), "spatial": float(spatial), "score_norm": score_normed}

    return best["box"], best["score"], best["mask"]


def detect_detail(image_pil: Image.Image,
                  detail_type: str,
                  threshold: float = 0.05,
                  used_boxes=None,
                  keep_best: bool = False,
                  iou_thr: float = 0.35,
                  restrict_mask: np.ndarray | None = None,
                  min_inside_frac: float = 0.40,
                  max_outside_frac: float = 0.70,
                  center_must_be_inside: bool = True,
                  erode_px: int = 1,
                  dilate_px: int = 2,
                  border_sample_px: int = 2,
                  debug: bool = False,
                  debug_topk: int = 5,
                  crop_box_on_full=None):
    """
    Simplified detail locator using SAM3 text grounding.
    Returns: (xyxy_on_image, score, mask_on_image)
    """
    used_boxes = used_boxes or []
    prompt = (detail_type or "").strip() + "."
    W, H = image_pil.size

    mask_aligned = _ensure_mask_for_image(restrict_mask, image_pil, crop_box_on_full=crop_box_on_full)
    mask_bool = (mask_aligned > 0) if mask_aligned is not None else None

    preds = _sam3_predict_text(image_pil, prompt, max_dets=10, score_threshold=threshold)
    if not preds:
        return (None, None, None)

    def _iou(a, b):
        ax1, ay1, ax2, ay2 = a; bx1, by1, bx2, by2 = b
        xi1, yi1 = max(ax1, bx1), max(ay1, by1)
        xi2, yi2 = min(ax2, bx2), min(ay2, by2)
        iw, ih = max(0, xi2 - xi1), max(0, yi2 - yi1)
        inter = iw * ih
        if inter == 0:
            return 0.0
        area_a = max(1, (ax2 - ax1) * (ay2 - ay1))
        area_b = max(1, (bx2 - bx1) * (by2 - by1))
        union = area_a + area_b - inter
        return inter / union

    best = None
    debug_rows = []
    for p in preds:
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], W, H)
        x1, y1, x2, y2 = box
        if x2 <= x1 or y2 <= y1:
            continue

        if any(_iou(box, ub) > iou_thr for ub in used_boxes):
            continue

        inside_frac = 1.0
        outside_frac = 0.0
        center_ok = True
        if mask_bool is not None:
            crop = mask_bool[y1:y2, x1:x2]
            area = max(1, (x2 - x1) * (y2 - y1))
            inside_frac = float(crop.sum()) / float(area)
            outside_frac = 1.0 - inside_frac
            cx_i, cy_i = (x1 + x2) // 2, (y1 + y2) // 2
            center_ok = (0 <= cx_i < W and 0 <= cy_i < H and bool(mask_bool[cy_i, cx_i])) if center_must_be_inside else True
            if not center_ok or inside_frac < min_inside_frac or outside_frac > max_outside_frac:
                continue

        if best is None or p["score"] > best["score"]:
            best = {"box": box, "score": float(p["score"]), "mask": p["mask"], "inside_frac": inside_frac}

        if debug and len(debug_rows) < debug_topk:
            debug_rows.append({
                "score": float(p["score"]),
                "box": box,
                "inside": inside_frac,
                "outside": outside_frac,
                "center": center_ok,
            })

    if best is None:
        if keep_best:
            base = preds[0]
            best = {"box": _clip_box_to_image(base["box"], W, H), "score": float(base["score"]), "mask": base["mask"], "inside_frac": 0.0}
        else:
            if debug:
                print("[detect_detail] No candidate satisfied mask gates.")
            return (None, None, None)

    if debug and debug_rows:
        print("[detect_detail/debug] top candidates (after score>thr):")
        for row in sorted(debug_rows, key=lambda r: r["score"], reverse=True):
            print(f"  score={row['score']:.3f} inside={row['inside']:.2f} outside={row['outside']:.2f} center={row['center']} box={row['box']}")

    return best["box"], best["score"], best["mask"]


def detect_garment_box(img: Image.Image, garment_tag: str, threshold=0.25, restrict_mask: np.ndarray | None = None):
    O_W, O_H = img.size
    if restrict_mask is not None:
        m1024 = resize_and_pad(Image.fromarray(restrict_mask, 'L'), 1024).convert('L')
        mask_1024_np = (np.array(m1024) > 127)
        ys, xs = np.where(mask_1024_np > 0)
        if xs.size == 0 or ys.size == 0:
            return None
        x1, y1, x2, y2 = [float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())]
        return box_1024_to_original([x1, y1, x2, y2], O_W, O_H)

    preds = _sam3_predict_text(img, f"{garment_tag.strip()} .", max_dets=6, score_threshold=threshold)
    if not preds:
        return None

    mask_bool = (restrict_mask > 0) if restrict_mask is not None else None
    best = None
    for p in preds:
        if threshold is not None and float(p["score"]) < threshold:
            continue
        box = _clip_box_to_image(p["box"], O_W, O_H)
        if box[2] <= box[0] or box[3] <= box[1]:
            continue
        if mask_bool is not None:
            crop = mask_bool[box[1]:box[3], box[0]:box[2]]
            if crop.size == 0 or float(crop.mean()) < 0.05:
                continue
        if best is None or p["score"] > best["score"]:
            best = {"box": box, "score": float(p["score"])}

    return best["box"] if best else None


def bbox_to_mask(bb, img_size, pad_px=10):
    W, H = img_size
    x1, y1, x2, y2 = bb
    x1 = max(0, x1 - pad_px); y1 = max(0, y1 - pad_px)
    x2 = min(W - 1, x2 + pad_px); y2 = min(H - 1, y2 + pad_px)
    m = np.zeros((H, W), np.uint8)
    m[y1:y2, x1:x2] = 255
    return m


def crop_detail(image_pil, mask_np, bb_xyxy, out_size=1024, pad_px=20):
    W, H = image_pil.size
    x1, y1, x2, y2 = bb_xyxy
    x1 = max(0, x1 - pad_px); y1 = max(0, y1 - pad_px)
    x2 = min(W, x2 + pad_px); y2 = min(H, y2 + pad_px)
    side = max(x2 - x1, y2 - y1)
    cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
    lx = max(0, cx - side // 2); rx = lx + side
    ty = max(0, cy - side // 2); by = ty + side
    if rx > W:
        lx -= (rx - W); rx = W
    if by > H:
        ty -= (by - H); by = H
    crop_box = (lx, ty, rx, by)
    img_c = image_pil.crop(crop_box).resize((out_size, out_size), Image.Resampling.LANCZOS)
    m_c = mask_np[ty:by, lx:rx]
    m_c = cv2.resize(m_c, (out_size, out_size), interpolation=cv2.INTER_NEAREST)
    return img_c, m_c, crop_box


def adaptive_brightness(img, strength_dark=0.15, strength_light=0.03, clip=(0, 245)):
    a = np.asarray(img).astype(np.float32)
    lum = 0.2126 * a[..., 0] + 0.7152 * a[..., 1] + 0.0722 * a[..., 2]
    mean_lum = float(lum.mean() / 255.0)
    if mean_lum < 0.5:
        factor = 1 + (-strength_dark) * (0.5 - mean_lum) * 2
    else:
        factor = 1 + (strength_light) * (mean_lum - 0.5) * 2
    out = np.clip(a * factor, *clip).astype(np.uint8)
    return Image.fromarray(out)


def paste_crop_back(full_img: Image.Image, edited_crop: Image.Image, crop_box, crop_mask: np.ndarray,
                    expand_px=20, feather_px=10) -> Image.Image:
    edited_crop = adaptive_brightness(edited_crop, strength_dark=0.15, strength_light=0.03)
    x1, y1, x2, y2 = crop_box
    tgt_w, tgt_h = x2 - x1, y2 - y1
    edit_rs = edited_crop.resize((tgt_w, tgt_h), Image.Resampling.LANCZOS)
    mask_np = cv2.resize(crop_mask, (tgt_w, tgt_h), interpolation=cv2.INTER_NEAREST)
    bin_mask = (mask_np > 0).astype(np.uint8)
    if expand_px > 0:
        k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (expand_px * 2 + 1, expand_px * 2 + 1))
        bin_mask = cv2.dilate(bin_mask, k, iterations=1)
    alpha = cv2.GaussianBlur(bin_mask.astype(np.float32) * 255, (0, 0), sigmaX=feather_px, sigmaY=feather_px)
    alpha[mask_np > 0] = 255
    alpha = alpha.clip(0, 255).astype(np.uint8)
    mask_img = Image.fromarray(alpha)

    region = full_img.crop((x1, y1, x2, y2))
    comp = Image.composite(edit_rs, region, mask_img)
    full_img.paste(comp, (x1, y1))
    return full_img






In [None]:
# Normalization for detail types from legacy names
def _normalize_detail_type(t: str) -> str:
    t = (t or "").strip().lower()
    mapping = {
        "waistband lettering": "waist text",
        "sleeve lettering": "sleeve text",
        "sleeve_text": "sleeve text",
        "waist_text": "waist text",
    }
    return mapping.get(t, t)

def _postprocess_details(payload: dict) -> dict:
    details = payload.get("details", [])
    fixed = []
    for d in details:
        typ = _normalize_detail_type(d.get("type"))
        col = d.get("color")
        if typ in ALLOWED_DETAIL_TYPES:
            ent = {"type": typ}
            if typ != "sleeve text" and isinstance(col, str) and col.strip():
                ent["color"] = col.strip()
            fixed.append(ent)
    return {"details": fixed}

def _try_parse_json(s: str) -> dict | None:
    try:
        obj = json.loads(s)
        if isinstance(obj, dict) and "details" in obj:
            return obj
    except Exception:
        pass
    # try to extract first {...}
    m = re.search(r"\{[\s\S]*\}", s)
    if m:
        try:
            obj = json.loads(m.group(0))
            if isinstance(obj, dict) and "details" in obj:
                return obj
        except Exception:
            pass
    return None

def read_details_from_metadata(img_path: str) -> dict:
    """Return {'details':[...]} or {'details':[{'type':'logo'}]} if metadata not found."""
    try:
        im = Image.open(img_path)
        # 1) PNG/JPEG info dict
        for k, v in (im.info or {}).items():
            if isinstance(v, str):
                obj = _try_parse_json(v)
                if obj:
                    return _postprocess_details(obj)
        # 2) EXIF: UserComment / XPComment
        try:
            exif_dict = piexif.load(im.info.get("exif", b"") or im.tobytes())
        except Exception:
            exif_dict = None

        def _decode_uc(x):
            if isinstance(x, bytes):
                for head in [b"ASCII\0\0\0", b"UNICODE\0", b"JIS\0\0\0"]:
                    if x.startswith(head):
                        x = x[len(head):]
                try:
                    return x.decode("utf-8", "ignore")
                except Exception:
                    return x.decode("latin-1", "ignore")
            if isinstance(x, str):
                return x
            return None

        if exif_dict:
            uc = exif_dict.get("Exif", {}).get(piexif.ExifIFD.UserComment, None)
            s = _decode_uc(uc)
            if s:
                obj = _try_parse_json(s)
                if obj:
                    return _postprocess_details(obj)
            xp = exif_dict.get("0th", {}).get(0x9C9C, None)
            if xp:
                try:
                    s = bytes(xp).decode("utf-16le", "ignore").rstrip("\x00")
                    obj = _try_parse_json(s)
                    if obj:
                        return _postprocess_details(obj)
                except Exception:
                    pass
        # 3) XMP sidecar embedded?
        if "XML:com.adobe.xmp" in (im.info or {}):
            obj = _try_parse_json(im.info["XML:com.adobe.xmp"])
            if obj:
                return _postprocess_details(obj)
        # 4) Optional sidecar .json next to image
        side = Path(img_path).with_suffix(".json")
        if side.exists():
            try:
                obj = json.loads(side.read_text())
                if "details" in obj:
                    return _postprocess_details(obj)
            except Exception:
                pass
    except Exception as e:
        print(f"⚠️ metadata read failed for {img_path}: {e}")

    # 👇 Fallback when nothing found
    return {"details": [{"type": "logo"}]}


# Garment type inference (from path)
def extract_garment_type_from_path(image_path: str, allowed_types=ALLOWED_GARMENT_TYPES) -> str:
    from pathlib import Path as _P
    import re
    def singularize(s):
        if len(s)>4:
            if s.endswith("es"): return s[:-2]
            if s.endswith("s"):  return s[:-1]
        return s
    def normalize_key(s): return singularize(s.replace("-","").replace("_","").lower().strip())
    norm_map = {}
    for t in allowed_types:
        base = normalize_key(t)
        norm_map[base]=t
        if not base.endswith("s"): norm_map[base+"s"]=t
        else:
            if base.endswith("es"): norm_map[base[:-2]]=t
            else: norm_map[base[:-1]]=t
    p = _P(image_path)
    file_compact = re.sub(r"[^a-z]+","", p.stem.lower())
    for k,v in norm_map.items():
        if k and k in file_compact: return v
    # parent folders
    for part in reversed(p.parts[:-1]):
        if part.startswith("."): continue
        toks = [singularize(t) for t in re.split(r"[^a-z]+", part.lower()) if t]
        for tok in toks:
            if tok in norm_map: return norm_map[tok]
    return ""


In [None]:
# === Robust SKU+angle parsing (handles: "SS-28623_fr (1).png", "Copy of SS-12345_bc_lft_v2.png", "SS-55555_fr_cl.png") ===
import os, re
from pathlib import Path
from functools import lru_cache

# Reuse your global config: BASE_NAMES, ACCEPTABLE_SUFFIXES, VALID_EXTS, WORKING_DIR

_SKU_RE = re.compile(r"(SS-\d{3,7})", re.IGNORECASE)
_COPY_RE = re.compile(r"^(?:copy of\s+)+", re.IGNORECASE)

def _strip_copy_prefix(s: str) -> str:
    return _COPY_RE.sub("", s).strip()

def _angle_tokens_desc() -> list[str]:
    # Longest-first to prefer 'fr_rght' over 'fr'
    return sorted(list(set(BASE_NAMES)), key=len, reverse=True)

def _token_delim_search(token: str, text: str) -> re.Match | None:
    """
    Find token delimited by non-alphanumerics (underscore is allowed as a delimiter).
    We treat [A-Za-z0-9] as 'wordy'; underscores/spaces/()/- etc. are delimiters.
    """
    # Escape underscores in token for regex
    tok = re.escape(token)
    pattern = rf"(?<![A-Za-z0-9]){tok}(?![A-Za-z0-9])"
    return re.search(pattern, text, flags=re.IGNORECASE)

def extract_sku_and_angle_from_path(path_like: str) -> tuple[str | None, str | None]:
    """
    Returns (SKU like 'SS-12345', angle_base like 'fr_lft'/'fr').
    Strategy:
      1) Extract SKU from filename; if not found, try parent dirs.
      2) After SKU in the filename, scan the suffix for the LONGEST valid angle token.
      3) Fallback to whole filename scan, then parent dirs.
    """
    p = Path(path_like)
    name = _strip_copy_prefix(p.name)

    # --- 1) SKU from filename, else parents
    m = _SKU_RE.search(name)
    sku = m.group(1).upper() if m else None
    if sku is None:
        for part in reversed(p.parts):
            mm = _SKU_RE.search(part)
            if mm:
                sku = mm.group(1).upper()
                break

    # --- 2) Angle after SKU region
    angle = None
    tokens = _angle_tokens_desc()
    if sku:
        mname = _SKU_RE.search(name)
        if mname:
            suffix = name[mname.end():]  # everything after the SKU
            for tok in tokens:
                if _token_delim_search(tok, suffix):
                    angle = tok
                    break

    # --- 3) Fallback: whole filename, then parents
    if angle is None:
        for tok in tokens:
            if _token_delim_search(tok, name):
                angle = tok
                break
    if angle is None:
        # Look in parent folders
        for part in reversed(p.parts[:-1]):
            part_clean = _strip_copy_prefix(part)
            for tok in tokens:
                if _token_delim_search(tok, part_clean):
                    angle = tok
                    break
            if angle:
                break

    return sku, angle

# ===================== Source finding via SKU folder anywhere =====================
@lru_cache(maxsize=1024)
def _find_sku_folder_anywhere(working_root: str, sku_name: str) -> Path | None:
    wr = Path(working_root)
    if not wr.exists():
        return None
    sku_low = sku_name.lower()
    best: tuple[int, Path] | None = None
    for dirpath, dirnames, _ in os.walk(wr):
        leaf = os.path.basename(dirpath)
        if leaf.lower() == sku_low:
            depth = len(Path(dirpath).parts)
            cand = Path(dirpath)
            if best is None or depth < best[0]:
                best = (depth, cand)
    return best[1] if best else None


def _list_valid_images(folder: Path) -> list[Path]:
    """
    Return candidate source images in `folder`, excluding:
      - any with 'generated', 'inpainted', '_nd', '_no_details', '_processed_by_detailer_'
      - any with '_sec' anywhere in the filename (case-insensitive)
    """
    deny_substrings = (
        "generated",
        "inpainted",
        "_nd",
        "_no_details",
        "_processed_by_detailer_",
        "_sec",   # ← NEW: ignore secondary variants
    )
    out = []
    for p in folder.iterdir():
        if not (p.is_file() and p.suffix in VALID_EXTS):
            continue
        name_low = p.name.lower()
        if any(s in name_low for s in deny_substrings):
            continue
        out.append(p)
    return out


def _rank_exact_angle(norm_stem: str, base: str, acceptable_suffixes: set[str]) -> int | None:
    if norm_stem == f"{base}_cut": return 1
    if norm_stem.startswith(base + "_") and norm_stem.endswith("_cut"): return 2
    if norm_stem == base: return 3
    if norm_stem.startswith(base + "_"):
        suf = norm_stem[len(base)+1:]
        if suf in acceptable_suffixes: return 4
    return None

def _is_fr_family(base: str | None) -> bool:
    if not base: return False
    return base in ("fr","fr_cl","fr_lft","fr_rght") or base.startswith("fr")

def _pick_source_in_dir(angle_base: str, directory: Path) -> Path | None:
    entries = _list_valid_images(directory)
    if not entries: return None
    acceptable = set(ACCEPTABLE_SUFFIXES)

    def _norm(p: Path) -> str:
        return _strip_copy_prefix(p.stem).lower()

    ranked: list[tuple[int,int,Path]] = []
    if _is_fr_family(angle_base):
        for p in entries:
            n = _norm(p)
            if n == "fr_cut": ranked.append((1,len(p.name),p)); continue
            if n.startswith("fr_") and n.endswith("_cut"): ranked.append((2,len(p.name),p)); continue
            if n == "fr": ranked.append((3,len(p.name),p)); continue
            if n.startswith("fr_"):
                suf = n[len("fr_"):]
                if suf in acceptable: ranked.append((4,len(p.name),p)); continue
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        ranked=[]
        for p in entries:
            n=_norm(p)
            r=_rank_exact_angle(n, angle_base, acceptable)
            if r is not None: ranked.append((r,len(p.name),p))
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        return None
    else:
        for p in entries:
            n=_norm(p)
            r=_rank_exact_angle(n, angle_base, acceptable)
            if r is not None: ranked.append((r,len(p.name),p))
        if ranked:
            ranked.sort(key=lambda t: (t[0], t[1], t[2].name))
            return ranked[0][2]
        return None

def find_source_via_sku(gen_path: Path | str, working_root: Path | str) -> Path | None:
    gen_path = Path(gen_path)
    sku, angle_base = extract_sku_and_angle_from_path(str(gen_path))

    if not sku:
        print(f"❌ Could not extract SKU from: {gen_path.name}")
        return None

    if not angle_base:
        # No noisy warning anymore; we’ll gracefully default.
        angle_base = "fr"

    sku_dir = _find_sku_folder_anywhere(str(working_root), sku)
    if not sku_dir:
        print(f"❌ SKU folder '{sku}' not found anywhere under {working_root}")
        return None

    ricardo = sku_dir / "Ricardo"
    for d in (ricardo, sku_dir):
        if d.exists() and d.is_dir():
            hit = _pick_source_in_dir(angle_base, d)
            if hit: return hit

    print(f"⚠️ No suitable source found in '{sku_dir}' (Ricardo or root) for angle '{angle_base}'")
    return None

def build_inpaint_suffix(details: list[dict]) -> str:
    def slug(s: str) -> str:
        s = s.lower().replace(" ", "-")
        return re.sub(r"[^a-z0-9\-]+", "", s).strip("-")
    parts=[]
    for d in details:
        t = d["type"]
        c = d.get("color","")
        if t != "sleeve text" and c:
            parts.append(slug(f"{c}-{t}"))
        else:
            parts.append(slug(t))
    return "_".join(parts) if parts else "none"

# --- Build the required base "SS-12345-bc_lft" from the queued filename ---
def build_out_base_from_gen(gen_path: str) -> tuple[str, str, str]:
    """
    Returns (sku_upper, angle_lower, out_base).
    out_base is 'SS-12345-bc_lft' (SKU + '-' + angle).
    """
    sku, angle = extract_sku_and_angle_from_path(gen_path)
    if not sku:
        raise ValueError(f"Cannot derive SKU from: {gen_path}")
    if not angle:
        angle = "fr"
    sku_up = sku.upper()
    angle_lo = angle.lower()
    return sku_up, angle_lo, f"{sku_up}-{angle_lo}"

def target_already_has_inpainted(target_dir: str, sku: str, angle: str) -> bool:
    """
    Check TARGET_DIR for any file starting with 'SS-12345-bc_lft_inpainted'.
    Case-insensitive; extension-agnostic.
    """
    td = Path(target_dir)
    if not td.exists():
        return False
    prefix = f"{sku.upper()}-{angle.lower()}_inpainted"
    prefix_low = prefix.lower()
    for p in td.iterdir():
        if p.is_file() and p.suffix in VALID_EXTS:
            if p.stem.lower().startswith(prefix_low):
                return True
    return False

In [None]:
def _inpaint_one_detail(gen_full: Image.Image,
                        src_full: Image.Image,
                        detail_prompt: str,
                        *,
                        garment_tag: str,
                        restrict_mask_full: np.ndarray | None,
                        generous_pad_px: int,
                        tiny_pad_px: int,
                        seed: int,
                        visualize: bool) -> tuple[Image.Image, bool]:

    _prepare_for_sam3()
    modified = False

    gen_view_for_sam3 = apply_binary_mask(gen_full, restrict_mask_full) if restrict_mask_full is not None else gen_full

    gar_src_bb = detect_garment_box(src_full, garment_tag)
    gar_gen_bb = detect_garment_box(gen_view_for_sam3, garment_tag, restrict_mask=restrict_mask_full)
    if gar_src_bb is None or gar_gen_bb is None:
        print("❌ garment detection failed"); return gen_full, modified

    # square garment crops
    def crop_to_square(image: Image.Image, bbox, pad_px=0):
        x1,y1,x2,y2 = bbox
        w,h = x2-x1, y2-y1
        side = max(w,h) + 2*pad_px
        cx,cy = (x1+x2)//2, (y1+y2)//2
        lx=max(0,cx-side//2); ty=max(0,cy-side//2)
        rx=lx+side; by=ty+side
        W,H=image.size
        if rx>W: lx -= (rx-W); rx=W
        if by>H: ty -= (by-H); by=H
        crop = image.crop((max(lx,0),max(ty,0),min(rx,W),min(by,H)))
        out  = Image.new("RGB",(side,side),(255,255,255))
        dx=max(0,-lx); dy=max(0,-ty)
        out.paste(crop,(dx,dy))
        return out, (lx,ty,rx,by)

    src_sq, sq_src = crop_to_square(src_full, gar_src_bb)
    gen_sq, sq_gen = crop_to_square(gen_view_for_sam3, gar_gen_bb)

    src_garm_sq, sq_coords_src = crop_to_square(src_full, gar_src_bb, pad_px=0)
    gen_garm_sq, sq_coords_gen = crop_to_square(gen_view_for_sam3, gar_gen_bb, pad_px=0)

    det_src_bb, _, det_src_mask_crop = detect_detail(src_sq, detail_prompt, crop_box_on_full=sq_src)
    prior = make_spatial_prior_from_box(det_src_bb, src_garm_sq.size)

    det_gen_bb, _, det_gen_mask_crop = detect_detail_topk7(
        gen_garm_sq,
        detail_prompt,
        source_prior=prior,
        restrict_mask=(restrict_mask_full, "FULL"),  # pass FULL mask
        crop_box_on_full=sq_coords_gen,              # the (x1,y1,x2,y2) used to make gen_garm_sq
        viz=False, debug=False
    )
    if det_src_bb is None or det_gen_bb is None:
        print(f"❌ detail not found: {detail_prompt}"); return gen_full, modified

    # back to full coords
    lx_s, ty_s, _, _ = sq_src
    lx_g, ty_g, _, _ = sq_gen
    src_det_bb = [det_src_bb[0]+lx_s, det_src_bb[1]+ty_s, det_src_bb[2]+lx_s, det_src_bb[3]+ty_s]
    gen_det_bb = [det_gen_bb[0]+lx_g, det_gen_bb[1]+ty_g, det_gen_bb[2]+lx_g, det_gen_bb[3]+ty_g]

    src_mask_full = _mask_crop_to_full(det_src_mask_crop, sq_src, src_full.size) if det_src_mask_crop is not None else None
    gen_mask_full = _mask_crop_to_full(det_gen_mask_crop, sq_coords_gen, gen_full.size) if det_gen_mask_crop is not None else None
    if src_mask_full is None:
        src_mask_full = bbox_to_mask(src_det_bb, src_full.size, INPAINT_TINY_PAD)
    if gen_mask_full is None:
        gen_mask_full = bbox_to_mask(gen_det_bb, gen_full.size, INPAINT_TINY_PAD)
    else:
        gen_mask_full = enlarge_mask(gen_mask_full, scale=1.05)

    if visualize:
        _show_images([
            ("detail on source", _draw_bbox(src_full, src_det_bb)),
            ("detail on generated (masked)", _draw_bbox(gen_view_for_sam3, gen_det_bb))
        ], cols=2, figsize=(12,8))

    # crops for IA
    src_crop, src_mask, _   = crop_detail(src_full, src_mask_full, src_det_bb, 1024, 20)
    gen_crop, gen_mask, box = crop_detail(gen_full, gen_mask_full, gen_det_bb, 1024, INPAINT_GENEROUS_PAD)

    # diptych
    src_arr = np.array(src_crop)
    masked_src = src_arr  # keep source unmasked for IA

    gen_arr = np.array(gen_crop)
    gen_msk3 = np.stack([gen_mask]*3, -1)
    zeros = np.zeros_like(masked_src)

    diptych = np.concatenate([masked_src, gen_arr], axis=1).astype(np.uint8)
    dip_mask = np.concatenate([zeros, gen_msk3], axis=1).astype(np.uint8)
    dip_mask[dip_mask>0]=255

    if visualize:
        _show_images([
            ("diptych", Image.fromarray(diptych)),
            ("diptych mask", Image.fromarray(dip_mask).convert("RGB"))
        ], cols=2, figsize=(12,8))

    _prepare_for_insert_anything()

    prior = redux(Image.fromarray(masked_src))
    gen_obj = torch.Generator(IA_DEVICE).manual_seed(seed)
    ia_out = pipe(
        image=Image.fromarray(diptych),
        mask_image=Image.fromarray(dip_mask),
        height=1024,
        width=2048,
        max_sequence_length=512,
        num_inference_steps=60,
        guidance_scale=30,
        generator=gen_obj,
        **prior
    ).images[0]

    right_crop = ia_out.crop((1024,0,2048,1024))
    gen_full = paste_crop_back(gen_full, right_crop, box, gen_mask)
    modified = True

    if visualize:
        _show_images([
            ("IA result (2048×1024)", ia_out),
            ("after this detail", gen_full)
        ], cols=2, figsize=(14,8))

    return gen_full, modified

def inpaint_with_details_list(generated_path: str,
                              source_path: str,
                              details: list[dict],
                              garment_type: str | None,
                              visualize: bool = True) -> tuple[Image.Image, bool]:

    gen_full = open_upright(generated_path)
    src_full = open_source_with_black_bg(source_path)

    restrict_mask_full = load_binary_mask_for_generated(generated_path, source_path, gen_full)

    if restrict_mask_full is None:
        print("⚠️  No garment mask found — proceeding without restriction.")
    else:
        print("✅ Garment mask loaded & aligned for", os.path.basename(generated_path))

    if garment_type is None or not garment_type.strip():
        garment_type = extract_garment_type_from_path(source_path)
    if not garment_type:
        garment_type = "t-shirt"  # conservative default prompt

    # twinset (optional, keep simple)
    garment_tags = [garment_type.lower()]
    if garment_type.lower() in TWINSET_TYPES:
        garment_tags = [TOP_GARMENTS[0], BOTTOM_GARMENTS[0]]

    out_img = gen_full.copy()
    any_modified = False
    for gtag in garment_tags:
        for d in details:
            d_type = d["type"]
            prompt_str = f"{d_type}".strip()
            print(f"🔄 Inpainting detail: {prompt_str}  (garment={gtag})")
            out_img, did_modify = _inpaint_one_detail(
                out_img, src_full, prompt_str,
                garment_tag=gtag,
                restrict_mask_full=restrict_mask_full,
                generous_pad_px=INPAINT_GENEROUS_PAD,
                tiny_pad_px=INPAINT_TINY_PAD,
                seed=INPAINT_SEED,
                visualize=visualize
            )
            any_modified = any_modified or did_modify

    torch.cuda.empty_cache(); gc.collect()
    return out_img, any_modified
    






In [None]:
import shutil
def process_detailer_queue():
    queue_root = Path(DETAILER_QUEUE_FOLDER)
    if not queue_root.exists():
        print(f"❌ Queue folder does not exist: {queue_root}")
        return

    gen_files = [p for p in queue_root.rglob("*") if p.is_file() and p.suffix in VALID_EXTS]
    if not gen_files:
        print(f"ℹ️ No images found in {queue_root}")
        return

    processed = skipped = failed = 0

    for gen_path in sorted(gen_files, key=lambda p: (str(p.parent), p.name)):
        try:
            print("\n" + "_"*80)
            print(f"🎯 Queue item: {gen_path}")

            # 1) Read details from metadata
            meta = read_details_from_metadata(str(gen_path))
            if not meta or not meta.get("details"):
                print("⏭️  No details found in metadata — skipping")
                skipped += 1
                continue

            details = [d for d in meta["details"] if d["type"] in ALLOWED_DETAIL_TYPES]
            if not details:
                print("⏭️  Details list empty after normalization — skipping")
                skipped += 1
                continue

            # 2) Find source garment near this item
            src_p = find_source_via_sku(gen_path, Path(WORKING_DIR))
            if not src_p:
                print("⏭️  Source garment not found — skipping")
                skipped += 1
                continue

            source_base = Path(src_p).stem  # e.g. SS-12345_fr
            sku_up, angle_lo, out_base = build_out_base_from_gen(str(gen_path))
            out_ext = ".png"

            # 3) Skip guard: any prior inpainted for this SKU+angle?
            if SKIP_IF_ALREADY_INPAINTED and target_already_has_inpainted(TARGET_DIR, sku_up, angle_lo):
              print(f"⏭️  Already have inpainted for {out_base} in TARGET_DIR — skipping")
              skipped += 1
              continue

            # 4) Inpaint
            garment_type = extract_garment_type_from_path(str(src_p))
            out_img, modified = inpaint_with_details_list(
                str(gen_path),
                str(src_p),
                details=details,
                garment_type=garment_type,
                visualize=VISUALIZE
            )
            if not modified:
                print("⏭️  No details applied — skipping save")
                skipped += 1
                continue

            # 5) Build output names (keep source naming; append detail suffixes)
            suffix   = build_inpaint_suffix(details)   # unchanged
            dst_src  = Path(TARGET_DIR) / f"{out_base}{out_ext}"                         # e.g., SS-12345-bc_lft.jpg
            dst_ia   = Path(TARGET_DIR) / f"{out_base}_inpainted_{suffix}{out_ext}"      # e.g., SS-12345-bc_lft_inpainted_red_logo.jpg

            # 6) Save outputs: copy source + save inpainted
            #if not dst_src.exists():
            #    shutil.copy2(str(src_p), str(dst_src))
            #    print(f"📎 Saved source → {dst_src.name}")
            #else:
            #    print(f"📎 Source already present in TARGET_DIR → {dst_src.name}")

            # --- Save inpainted result ---
            out_img.save(str(dst_ia))
            print(f"✅ Saved inpainted → {dst_ia.name}")
            processed += 1

        except Exception as e:
            print(f"❌ Failed on {gen_path.name}: {e}")
            failed += 1

    print("\n==== SUMMARY ====")
    print(f"Processed: {processed}  |  Skipped: {skipped}  |  Failed: {failed}")





In [None]:

# --- Try-on → Detailer bridge ---
from PIL import Image

def _normalize_detail_payload(details_payload):
    if not details_payload:
        return [{"type": t} for t in (DETAILER_ONLY_TYPES or ["logo"])]
    dets = details_payload.get("details") if isinstance(details_payload, dict) else None
    if not dets:
        return [{"type": t} for t in (DETAILER_ONLY_TYPES or ["logo"])]
    cleaned = []
    for d in dets:
        if not isinstance(d, dict):
            continue
        typ = d.get("type") or "logo"
        entry = {"type": typ}
        if d.get("color"):
            entry["color"] = d["color"]
        cleaned.append(entry)
    return cleaned or [{"type": "logo"}]

def run_detailer_logo_cleanup(gen_path: str, source_path: str, angle_code: str, *, details_payload=None):
    if not RUN_DETAILER_AFTER_TRYON:
        return Image.open(gen_path).convert("RGB"), False
    try:
        details = _normalize_detail_payload(details_payload)
        garment_type = extract_garment_type_from_path(source_path)
        out_img, modified = inpaint_with_details_list(
            generated_path=gen_path,
            source_path=source_path,
            details=details,
            garment_type=garment_type,
            visualize=DETAILER_VISUALIZE,
        )
        if modified:
            print(f"      🧼 Detailer removed logo(s) on {gen_path}")
        else:
            print(f"      ℹ️ Detailer found no logo to inpaint on {gen_path}")
        return out_img, modified
    except Exception as e:
        print(f"      ⚠️ Detailer failed on {gen_path}: {e}")
        try:
            return Image.open(gen_path).convert("RGB"), False
        except Exception:
            return None, False



# BATCH HELPERS

In [None]:

# --- Batch processor  ---


def process_one_garment_folder(folder_path: str, allowed_angles=None):
    allowed_outputs_set = {_norm_angle(a) for a in (allowed_angles or [])}
    allowed_outputs = sorted(allowed_outputs_set)
    allowed_sources = expand_as_list(allowed_angles) if allowed_angles else None

    main_category = _garment_category_from_path(folder_path)
    main_class = _classify_garment_category(main_category)
    secondary_type = "bottom" if main_class == "top" else ("top" if main_class == "bottom" else "secondary")

    base_subcat_dir = resolve_base_mask_dir(folder_path)
    if not base_subcat_dir:
        print(f"⚠️ Cannot resolve base/mask dir for SKU: {folder_path}")
    files_sorted = sorted(os.listdir(folder_path))

    main_texture_card = load_texture_reference(folder_path, secondary=False, heading="Main texture reference")
    secondary_texture_card = load_texture_reference(folder_path, secondary=True, heading="Main texture reference") if SECONDARY_GARMENT else None

    main_prompt_text = build_main_prompt(include_texture=main_texture_card is not None)
    secondary_prompt_text = build_secondary_prompt(sec_type=secondary_type, include_texture=secondary_texture_card is not None)

    def source_priority(target_angle: str, filename: str) -> int:
        stem = os.path.splitext(filename.lower())[0]
        if target_angle == "fr_cl":
            if stem.startswith("fr_cl"):
                return 0
            if stem.startswith("fr_") and not (stem.startswith("fr_rght") or stem.startswith("fr_lft")):
                return 1
            if stem.startswith("fr_rght") or stem.startswith("fr_lft"):
                return 2
        if stem.startswith(target_angle):
            return 0
        return 3

    best_for_target = {}

    for file in files_sorted:
        low = file.lower()

        if "_sec" in low:
          # Ignore secondary garments in primary detection
          continue

        if allowed_sources and not any(low.startswith(src) for src in allowed_sources):
            continue
        if SKIP_FILENAME_TOKENS and any(tok in low for tok in SKIP_FILENAME_TOKENS):
            continue
        if REQUIRE_CUT_IN_FILENAME and ("cut" not in low):
            continue
        if not _valid_ext(file):
            continue

        matching_sources = [src for src in (allowed_sources or []) if low.startswith(src)] if allowed_sources else [_norm_angle(os.path.splitext(file)[0])]
        if allowed_sources and not matching_sources:
            continue

        target_candidates = set()
        for src in matching_sources:
            norm_src = _norm_angle(src)
            for target in (allowed_outputs or [norm_src]):
                fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
                fam.add(_norm_angle(target))
                if norm_src in fam:
                    target_candidates.add(_norm_angle(target))

        if allowed_outputs and not target_candidates:
            continue

        for target_angle in sorted(target_candidates):
            base_img_path = _find_image_with_stem_and_suffix(base_subcat_dir, target_angle)
            if not base_img_path:
                print(f"⚠️ Missing BASE for target '{target_angle}' → skipping that target for {file}")
                continue

            mask_path = find_mask_path(base_subcat_dir, target_angle)
            if not mask_path:
                print(f"ℹ️ No local mask for target '{target_angle}', will request SAM3 segmentation.")

            priority = source_priority(target_angle, low)
            current = best_for_target.get(target_angle)
            if current and priority >= current["priority"]:
                continue

            best_for_target[target_angle] = {
                "priority": priority,
                "file": file,
                "base_img_path": base_img_path,
                "mask_path": mask_path,
            }

    worklist = [
        (data["file"], target_angle, data["base_img_path"], data["mask_path"])
        for target_angle, data in best_for_target.items()
    ]
    worklist.sort(key=lambda x: x[1])

    def find_existing_output_local(target_colab_path: str):
        stem, _ = os.path.splitext(target_colab_path)
        for ext in (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"):
            cand = f"{stem}{ext}"
            if os.path.exists(cand):
                return cand
        return None

    sku_name = os.path.basename(folder_path)
    if allowed_sources:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate (outputs={sorted(list(allowed_outputs_set))}, sources={sorted(list(set(allowed_sources)))})")
    else:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate")

    if not worklist:
        return

    for idx, (file, target_angle, base_img_path, mask_path) in enumerate(worklist, start=1):
        print(f"   {idx:>3}/{len(worklist):<3}  {file}  | USING STRICT base/mask='{target_angle}'")
        garment_path = os.path.join(folder_path, file)

        sku_name = os.path.basename(folder_path)
        angle_code = _norm_angle(target_angle)
        main_suffix = "_onlymain" if SECONDARY_GARMENT else ""
        main_out_name = build_output_filename(sku_name, angle_code, ext=".png", suffix=main_suffix)
        final_out_name = build_output_filename(sku_name, angle_code, ext=".png", suffix="_both") if SECONDARY_GARMENT else None

        main_colab_path = os.path.join(OUTPUT_DIR, main_out_name)
        final_colab_path = os.path.join(OUTPUT_DIR, final_out_name) if final_out_name else None

        main_exists = drive_file_exists_any_ext_at_colab_path(main_colab_path)
        final_exists = final_colab_path and drive_file_exists_any_ext_at_colab_path(final_colab_path)

        reuse_main_img = None
        sec_garment_override = None

        if not SECONDARY_GARMENT:
            if main_exists:
                print(f"      ⏭️  Skip: {main_out_name} already exists in {OUTPUT_DIR} (main target)")
                continue
        else:
            if final_exists:
                print(f"      ⏭️  Skip: {final_out_name} already exists in {OUTPUT_DIR} (secondary target)")
                continue

            if main_exists:
                sec_garment_override = find_secondary_garment_path(folder_path, file)
                if not sec_garment_override:
                    print(f"      ⏭️  Skip: {main_out_name} exists but no secondary garment found in {folder_path}")
                    continue

                existing_main_path = find_existing_output_local(main_colab_path)
                if not existing_main_path:
                    print(f"      ⏭️  Skip: {main_out_name} exists in Drive but not on disk → cannot reuse without regenerating main")
                    continue

                try:
                    reuse_main_img = Image.open(existing_main_path).convert("RGB")
                    print(f"      ▶️  Reusing existing main output '{os.path.basename(existing_main_path)}' for secondary stage")
                except Exception as reuse_err:
                    print(f"      ❌ Unable to reuse existing main output '{existing_main_path}': {reuse_err}")
                    continue

        try:
            mask_dir = base_subcat_dir

            def perform_tryon_stage(stage_base_full, stage_mask_img, stage_garment_img, suffix, stage_label, mask_label, prompt_text=None, texture_card=None, texture_label=None, detailer_source_path=None):
                show_gallery(
                    [stage_garment_img, stage_base_full, stage_mask_img.convert("RGB")],
                    [f"Source garment (white BG) [{stage_label}]", f"Base photo [{angle_code}] [{stage_label}]", f"Mask [{mask_label}]"]
                )

                upper_padding = UPPER_PADDING if (stage_label == "secondary") else (UPPER_PADDING + 100)

                bbox = find_aspect_bbox(
                    stage_mask_img,
                    aspect=TARGET_ASPECT,
                    padding=CROP_PADDING,
                    upper_padding=upper_padding,
                    horiz_padding=HORIZ_PADDING,
                    min_margin=CROP_MIN_MARGIN,
                    allow_padding=(stage_label != "secondary"),
                )
                base_crop = crop_with_padding(stage_base_full, bbox, fill=WHITE_RGB)
                mask_crop = crop_with_padding(stage_mask_img.convert("L"), bbox, fill=0)
                if texture_card is not None:
                    crop_thumb = texture_card
                    crop_title = texture_label or "Texture reference"
                else:
                    crop_thumb = build_no_texture_card(stage_garment_img.size)
                    crop_title = "No texture file found"
                show_gallery(
                    [base_crop, mask_crop.convert("RGB"), crop_thumb],
                    [f"Cropped base (1:1) [{stage_label}]", "Cropped mask", crop_title],
                )

                extra_images = None
                if texture_card is not None:
                    label = texture_label or "Texture reference"
                    extra_images = [f"{label}:", texture_card.convert("RGB")]


                print("🍌 started...")
                tryon_gen = run_nanobanana_tryon(
                    model_image=base_crop,
                    garment_image=stage_garment_img,
                    aspect_ratio=GEN_ASPECT_RATIO,
                    image_size=GEN_IMAGE_SIZE,
                    prompt=prompt_text or TRYON_PROMPT,
                    extra_images=extra_images,
                )

                if tryon_gen.size != base_crop.size:
                    if tryon_gen.width != tryon_gen.height:
                        print(f"      ⚠️ Generator returned non-square image {tryon_gen.size}; resizing to {base_crop.size}.")
                    tryon_sq = tryon_gen.resize(base_crop.size, Image.Resampling.LANCZOS)
                else:
                    tryon_sq = tryon_gen

                final_img_local, alpha_dbg = paste_crop_back_debug(
                    full_img   = stage_base_full.copy(),
                    edited_crop= tryon_sq,
                    crop_box   = bbox,
                    crop_mask  = np.array(mask_crop),
                    solid_expand_px = max(5, MASK_EXPAND_PX // 4),
                    halo_px         = MASK_EXPAND_PX,
                    feather_px      = MASK_FEATHER_PX,
                    edge_feather_px = 15,   # tweak to taste
                )

                show_gallery(
                    [tryon_sq, alpha_dbg, final_img_local],
                    ["Try-on crop", "Alpha (debug)", "Final paste-back"],
                )

                out_name_local = build_output_filename(sku_name, angle_code, ext=".png", suffix=suffix)
                tmp_path_local = os.path.join("/tmp", out_name_local)

                detailer_source = detailer_source_path or garment_path
                detailer_payload = {"details": [{"type": t} for t in (DETAILER_ONLY_TYPES or ["logo"])]}
                save_png_with_metadata(final_img_local, tmp_path_local, details_payload=detailer_payload)

                if RUN_DETAILER_AFTER_TRYON:
                    cleaned_img, cleaned = run_detailer_logo_cleanup(
                        tmp_path_local,
                        detailer_source,
                        angle_code,
                        details_payload=detailer_payload,
                    )
                    if cleaned and cleaned_img is not None:
                        final_img_local = cleaned_img
                        save_png_with_metadata(final_img_local, tmp_path_local, details_payload=detailer_payload)

                target_path_for_drive = os.path.join(OUTPUT_DIR, out_name_local)

                info = upload_file_and_append_to_sheet(
                    local_path       = tmp_path_local,
                    target_colab_path= target_path_for_drive,
                    sku_name         = sku_name,
                    angle            = angle_code,
                    spreadsheet_id   = SPREADSHEET_ID,
                    worksheet_name   = GEN_LOG_SHEET,
                )
                print(f"      ✅ Uploaded [{stage_label}] → {info['file_url']}")
                return final_img_local

            main_result = reuse_main_img
            skipped_main = reuse_main_img is not None

            if not skipped_main:
                garment_img = flatten_alpha_to_white(open_upright(garment_path))
                base_full   = Image.open(base_img_path).convert("RGB")
                mask_label  = None

                try:
                    mask_full, mask_label = generate_mask_with_gemini(base_img_path, folder_path)
                except Exception as mask_err:
                    if mask_path:
                        print(f"      ⚠️ SAM3 mask failed for '{target_angle}', using disk mask instead: {mask_err}")
                        mask_full = Image.open(mask_path).convert("L")
                        mask_label = os.path.basename(mask_path)
                        mask_dir = os.path.dirname(mask_path)
                    else:
                        print(f"      ❌ SAM3 mask failed and no local mask for '{target_angle}': {mask_err}")
                        continue

                main_result = perform_tryon_stage(
                    stage_base_full=base_full,
                    stage_mask_img=mask_full,
                    stage_garment_img=garment_img,
                    suffix=main_suffix,
                    prompt_text=main_prompt_text,
                    stage_label="main",
                    mask_label=mask_label or "gemini-mask",
                    texture_card=main_texture_card,
                    texture_label="Main texture reference",
                    detailer_source_path=garment_path,
                )

            if SECONDARY_GARMENT:
                sec_mask_path = find_mask_path(mask_dir, f"{target_angle}_sec")

                sec_garment_path = sec_garment_override or find_secondary_garment_path(folder_path, file)
                if not sec_garment_path:
                    print(f"      ⚠️ Secondary garment missing for '{target_angle}' → kept main-only output.")
                    continue

                if main_result is None:
                    print(f"      ❌ Cannot run secondary stage for '{target_angle}' without a main result.")
                    continue

                sec_mask_full = None
                sec_mask_label = None
                try:
                    sec_mask_full, sec_mask_label = generate_mask_with_gemini(
                        base_img_path,
                        folder_path,
                        mask_variant=f"{target_angle}_sec",
                    )
                except Exception as sec_mask_err:
                    if sec_mask_path:
                        print(f"      ⚠️ Secondary SAM3 mask failed for '{target_angle}', using disk mask instead: {sec_mask_err}")
                        sec_mask_full = Image.open(sec_mask_path).convert("L")
                        sec_mask_label = os.path.basename(sec_mask_path)
                    else:
                        print(f"      ❌ Secondary SAM3 mask failed and no local mask for '{target_angle}': {sec_mask_err}")
                        continue

                if sec_mask_full is None:
                    if sec_mask_path:
                        sec_mask_full = Image.open(sec_mask_path).convert("L")
                        sec_mask_label = os.path.basename(sec_mask_path)
                    else:
                        print(f"      ⚠️ Secondary mask missing for '{target_angle}' → kept main-only output.")
                        continue

                sec_garment_img = flatten_alpha_to_white(open_upright(sec_garment_path))

                perform_tryon_stage(
                    stage_base_full=main_result,
                    prompt_text=secondary_prompt_text,
                    stage_mask_img=sec_mask_full,
                    stage_garment_img=sec_garment_img,
                    suffix="_both",
                    stage_label="secondary",
                    mask_label=sec_mask_label or (os.path.basename(sec_mask_path) if sec_mask_path else "gemini-mask"),
                    texture_card=secondary_texture_card,
                    texture_label="Main texture reference",
                    detailer_source_path=sec_garment_path,
                )

        except Exception as e:
            print(f"      ❌ Error: {e}")




In [None]:
# --- Sheet-driven angle selection + runners ---

def build_sku_folder_index(garments_root: str):
    return { _norm_sku(os.path.basename(p)) : p for p in iter_sku_folders(garments_root) }

def run_list():
    targets, unmatched = resolve_targets(SKU_CSV, GARMENTS_ROOT)
    if not targets:
        print("⚠️ No matching SKU folders found.")
        if unmatched: print("Unmatched:", ", ".join(unmatched))
        return
    print(f"➡️  Will process {len(targets)} SKU(s).")
    for i, p in enumerate(targets, start=1):
        name = os.path.basename(p)
        print(f"\nSKU {i}/{len(targets)} ▶️  {name}")
        try:
            process_one_garment_folder(p, allowed_angles=ALLOWED_BASES)
            print(f"✅ Finished: {name}")
        except Exception as e:
            print(f"❌ Error in {name}: {e}")
    if unmatched:
        print("\nℹ️  Unmatched identifiers:")
        for u in unmatched: print("   -", u)
    print("\n🏁 List run complete.")


# DISPATCH

In [None]:
# Dispatch
run_list()


#UNASSIGN

In [None]:
from google.colab import runtime
runtime.unassign()