# SETUP (restart after this)

In [None]:
!pip install piexif

Collecting piexif
  Downloading piexif-1.1.3-py2.py3-none-any.whl.metadata (3.7 kB)
Downloading piexif-1.1.3-py2.py3-none-any.whl (20 kB)
Installing collected packages: piexif
Successfully installed piexif-1.1.3


In [None]:
# -*- coding: utf-8 -*-
# --- Setup: Google auth + Drive + repos + deps ---

import sys, os, subprocess, textwrap, importlib

# Colab auth + Drive
from google.colab import auth, drive
auth.authenticate_user()
drive.mount('/content/drive')

# Repos (only what we actually need)
!git clone https://github.com/nftblackmagic/catvton-flux || true
!git clone https://github.com/XPixelGroup/HYPIR.git || true

# Deps (pin modern, stable)
!pip -q install --upgrade pip
!pip -q install "diffusers>=0.31.0" "accelerate>=0.33.0" peft==0.17.0 \
                pillow torchvision safetensors einops numpy==1.26.4 \
                gspread google-auth google-auth-oauthlib google-api-python-client \
                pytz
!pip -q install --force-reinstall "transformers==4.46.2"

# Optional: HYPIR weight (kept small & lazy-used)
!wget -q https://huggingface.co/lxq007/HYPIR/resolve/main/HYPIR_sd2.pth -O /content/HYPIR_sd2.pth

# Colab cell — install OpenAI SDK with Responses API
!pip -q install --upgrade "openai>=1.52.0"


print("✅ Setup done.")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Cloning into 'catvton-flux'...
remote: Enumerating objects: 331, done.[K
remote: Counting objects: 100% (82/82), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 331 (delta 61), reused 57 (delta 57), pack-reused 249 (from 1)[K
Receiving objects: 100% (331/331), 17.24 MiB | 12.46 MiB/s, done.
Resolving deltas: 100% (144/144), done.
Cloning into 'HYPIR'...
remote: Enumerating objects: 175, done.[K
remote: Counting objects: 100% (84/84), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 175 (delta 57), reused 57 (delta 43), pack-reused 91 (from 1)[K
Receiving objects: 100% (175/175), 7.26 MiB | 25.54 MiB/s, done.
Resolving deltas: 100% (63/63), done.
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m91.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver

# Select angles

In [None]:
fr_lft = True #@param {type:"boolean"}
fr_rght = True #@param {type:"boolean"}
fr_cl = True #@param {type:"boolean"}
bc_lft = True #@param {type:"boolean"}
bc_rght = True #@param {type:"boolean"}
lft = True #@param {type:"boolean"}
rght = True #@param {type:"boolean"}
bc_ = True #@param {type:"boolean"}
fr_ = True #@param {type:"boolean"}
fr_cl_btm = False #@param {type:"boolean"}
fr_cl_tp = False #@param {type:"boolean"}


names = ["fr_lft","fr_rght","fr_cl","bc_lft","bc_rght","lft","rght","bc_","fr_","fr_cl_btm","fr_cl_tp"]
ALLOWED_BASES = [n for n in names if locals()[n]]

# CONFIG

In [None]:
# --- Unified CONFIG ---

# Selection mode: "sheet" | "sku_list" | "dir"
RUN_MODE = "sku_list"     #@param ["sheet", "sku_list", "dir"]

# For RUN_MODE == "sku_list"
SKU_CSV = "28748, 28920, 28747, 29018, 29095, 29094"  #@param {type:"string"}

# For RUN_MODE == "dir": an absolute or relative (under GARMENTS_ROOT) directory
TARGET_DIR = "/content/drive/MyDrive/Dazzl/SikSilk/AlexGens/SikSilk"  #@param {type:"string"}

# Paths
BASE_PHOTOS_ROOT  = "/content/drive/MyDrive/Dazzl/SikSilk/SKSLK_MODELS/"
GARMENTS_ROOT     = "/content/drive/MyDrive/Dazzl/SikSilk/AlexGens/SikSilk/"


# Filename/dir policy
VALID_EXTENSIONS  = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")
IGNORE_DIRS       = {"old", "__MACOSX", ".ds_store", "Ricardo", "toweling"}
SKIP_FILENAME_TOKENS_CSV   = "mask, generated, freelance, _sec, _backup"   # substrings to skip
SKIP_BASENAME_SUFFIXES_CSV = "_sec"                             # stem endings to skip
REQUIRE_CUT_IN_FILENAME    = True   #@param {type:"boolean"}

# Cropping / paste-back
CROP_PADDING      = 100        # px above selection
UPPER_PADDING     = 200        # extra padding below bbox
MASK_EXPAND_PX    = 100        # outward growth before feather
MASK_FEATHER_PX   = 40         # Gaussian sigma for feathering
BRIGHTNESS_FACTOR = 0.95

# Inference (CatVTON+LoRA over Flux Fill)
INFERENCE_MODE    = "lora"
WIDTH, HEIGHT     = 1280, 1600
STEPS             = 75 #@param {type:"number"}
GUIDANCE          = 47 #@param {type:"number"}
FILL_MODEL_ID     = "black-forest-labs/FLUX.1-dev"
CATVTON_XFM       = "xiaozaa/catvton-flux-beta"
#LORA_PATH         = "/content/drive/MyDrive/Dazzl/SikSilk/Jeans_LORA/LORA_models/jeans_LORA_4_lowcfg_step400" #@param {type: "string"}
#LORA_PATH         = "/content/drive/MyDrive/Dazzl/SikSilk/their_dataset_LORA/1280_their_ds_LORA_10_w_jitter/last" #@param {type: "string"}
LORA_PATH         = "/content/drive/MyDrive/Dazzl/SikSilk/their_dataset_LORA/4x5_1280_their_ds_LORA_13_w_jitter/4x5_1280_their_ds_LORA_13_w_jitter_best" #@param {type: "string"}


PREFER_AGNOSTIC_MASKS = True #@param {type:"boolean"}

# LoRA schedule
LORA_START = 1 #@param {type:"number"}
LORA_MID   = 1 #@param {type:"number"}
LORA_END   = 1 #@param {type:"number"}
SPIKE_AT_STEP      = 1 #@param {type:"number"}
TAIL_START_AT_STEP = 5 #@param {type:"number"}

TARGET_ASPECT = (WIDTH, HEIGHT)

# Prompt
CATVTON_PROMPT = (
    "The pair of images highlights a clothing and its styling on a model, high resolution, 4K, 8K; "
    "[IMAGE1] Detailed product shot of a clothing"
    "[IMAGE2] The same cloth is worn by STVBLDMN in a studio setting."
)

# HYPIR enhancement (optional overlay)
ENABLE_HYPIR_ENHANCE   = False   #@param {type:"boolean"}
HYPIR_OVERLAY_OPACITY  = 0.15    #@param {type:"number"}
HYPIR_PROMPT           = "macro, defined fabric texture, 4K, professional fashion photography" #@param {type:"string"}
HYPIR_UPSCALE          = 1
HYPIR_WEIGHT_PATH      = "/content/HYPIR_sd2.pth"
HYPIR_BASE_MODEL       = "stabilityai/stable-diffusion-2-1-base"

GENERATED_SUFFIX  = "" #@param {type:"string"}
if ENABLE_HYPIR_ENHANCE:
    GENERATED_SUFFIX += "_enhanced"

# Spreadsheet (Gen Log + Operations)
SPREADSHEET_ID = "1Kbq9__sEUQiuDPuza5Xy_hRyIn8pUvmfFj6vhPBrp8Y"
GEN_LOG_SHEET  = "Gen Log"
OPS_SHEET_NAME = "Operations"

# Sheet-driven angle selection
USE_SHEET_SELECTION = (RUN_MODE == "sheet")
ANGLE_NEEDS_REGENERATE_TOKEN = "Regenerate"
ENFORCE_BAN_SUBSTRINGS     = True
BANNED_SUBSTRINGS_CSV      = "wrong, pair, combo"

ENFORCE_REQUIRE_SUBSTRINGS = False
REQUIRED_SUBSTRINGS_CSV    = ""
REQUIRED_SUBSTRINGS_MODE   = "ANY"   # "ANY" | "ALL"

# Misc
SHOW_VISUALS = True   # fewer inline plots in a batch notebook
TIMEZONE     = "Europe/Lisbon"
OPERATOR     = "Ivan"


# === Garment/type taxonomy (from you) ===
ALLOWED_GARMENT_TYPES = [
    "hoodie","jeans","joggers","shorts","sweater","swimwear",
    "t-shirt","shirts","track top","trousers","twinset","polo","vests","shirts"
]
TOP_GARMENTS    = ["t-shirt","shirt","sweater","hoodie","track top","vest"]
BOTTOM_GARMENTS = ["shorts","jogger-trousers","trousers","jeans","swimwear"]
TWINSET_TYPES   = ["twinset"]

# === Details tokens from you (canonicalized) ===
ALLOWED_DETAIL_TYPES = ["crest","logo","patch"]

# === GPT switches/models/threshold ===
GPT_EVAL_ENABLED   = False
GPT_EVAL_MODEL     = "gpt-4.1-mini"      # vision-capable, cost-effective
GPT_PASS_THRESHOLD = 7                   # stop retrying at ≥

# === Your DETAIL prompt (kept essence; only token names harmonized) ===
GPT_DETAIL_ANALYZE_PROMPT = f"""
You are a vision AI for fashion. Look at the main garment in the attached image.
Output STRICTLY the following JSON schema and nothing else:

{{
  "details": [
    {{"type": "logo"|"patch"|"crest", "color": "<color if appropriate>"}},
    ...
  ]
}}

RULES:
- "details": list EVERY decorative detail (logo, patch, crest).
- Allowed detail tokens are exactly: {ALLOWED_DETAIL_TYPES}
- For logo/patch: include a COLOR string (e.g., "white", "red").
- If there are NO details, return "details": [].
Return your answer STRICTLY as a JSON object; no commentary.
"""


# === LoRA retry candidates (first one mirrors your current global triplet) ===
LORA_SCHEDULE_CANDIDATES = [
    (LORA_START, LORA_MID, LORA_END),
    (2.7, 0.9, 1.3),
    (3.0, 0.5, 1.2),
    (2.2, 1.0, 1.8),
]

# === Output routing (single all-output folder + detailer queue) ===
OUTPUT_DIR          = "/content/drive/MyDrive/Dazzl/SikSilk/SS_OUTPUT_FOLDER/19oct" #@param {type:"string"}
#DETAILER_QUEUE_DIR  = "/content/drive/MyDrive/Dazzl/SikSilk/DETAILER_QUEUE_FOLDER/14oct" #@param {type:"string"}
print("✅ GPT config & folders ready.")

# ==== CONFIG: Few-shot eval messages (from your "Eval") ====


EVAL_DEVELOPER_PROMPT = "You are a professional stylist and fashion construction expert with 15 years at Vogue and multiple industry awards. Your task is to rate the structural faithfulness of an AI-generated fashion photo of a garment compared to the real flat-lay/ghost-mannequin image of that same garment. You are given a side-by-side composite (reference garment + AI generation). Your output MUST be an integer from 1–10, where: 1 = the garment generated is not the source garment (e.g. a dress instead of a hoodie); 2 = The provided garment is effectively not the same: major structural differences (missing/extra panels, different neckline/collar, wrong closure type) or obviously different pattern/structure.; 4 = The generated garment is clearly the intended item, but has noticeable structural differences (e.g., added chest pocket, missing coin pocket, different fly/closure construction, visibly different seam placements, added drawstrings where none exist, extra/missing seams/panels); 7 = A very good structural clone; the pattern and construction read correctly, with only small inconsistencies (e.g., slightly different sleeve length, minor hood volume difference, subtle seam placement drift, extra bar-tack impression, small hem shape change).; 8 = Near-perfect structure; a trained expert can find tiny off-details in construction or proportions, but overall pattern and elements match.; 9 = Essentially perfect structure even for a trained eye. at most a minute proportional quirk.; 10 = Impeccable structural cloning; every element (seam placements, paneling, pockets, closures, hems, cuffs, collar/neckline, waistband shape, hood size, sleeve/leg lengths, pleat counts, slit positions) is exactly right. Critically focus ONLY on structure and construction. Evaluate: Presence/absence and placement of elements (pockets, zippers, buttons/snaps, drawstrings, pleats, slits, waistbands, yokes, panels, darts, gussets). Seam placement and count, stitch-line paths, hems/cuffs, collar/neckline shape and depth, hood size/shape. Proportions and lengths (sleeve/leg/hem), vent/slit positions, panel geometry. Overall fit category on the body: baggy/oversize vs. regular vs. slim (coarse classification only). Artifacts like random fabric pieces, phantom seams, warped or duplicated elements. Explicitly exclude any judgment about color, shade, fabric texture/hand, prints or logos/graphics. Ignore materials and surface finishes entirely. Provide a brief, technical justification broken into segments (e.g., “Neckline & Collar”, “Seams & Panels”, “Pockets & Closures”, “Sleeves & Lengths”, “Hood & Fit”, etc.). Do not give recommendations. End with the numeric score in curly braces (e.g., {7}) and nothing after that. Be precise and never hallucinate."



EVAL_FEWSHOT_MESSAGES = [
    {
      "role": "developer",
      "content": [
        {
          "type": "input_text",
          "text": EVAL_DEVELOPER_PROMPT
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_hoodie_neutral.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nHood & Opening: Reference hood has no drawstrings and a compact three-panel build; AI image adds drawstrings and shows a deeper hood. Added element and altered hood volume.\nNeckline & Collar: Reference shows a standard hood-to-body seam without placket; AI is consistent here.\nPockets & Closures: Single kangaroo pocket in both; pocket shape and entry placements align.\nSeams & Panels: Shoulder and side seams read correctly; cuff/hem bands present.\nSleeves & Lengths: Sleeve length equivalent; cuff proportion consistent.\nFit Category: Reference regular/relaxed; AI reads regular/relaxed.\nStructural Verdict: Added drawstrings and increased hood volume create meaningful structural deviation.\n{4}"
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_jeans_front.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nWaistband & Fly: Reference shows classic waistband with button and zipper fly; AI reproduces waistband height and fly construction.\nPockets: Five-pocket layout present in both (including coin pocket); pocket mouth shapes differ slightly, placements align.\nSeams & Panels: Side seams, inseams, and front rise seam match in path and count; no phantom seams.\nHem & Length: Full-length with straight hem; AI length and hem shape align.\nFit Category: Reference slim; AI on-body reads slim.\nStructural Verdict: Strong structural match with tiny pocket-mouth curvature drift.\n{8}"
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_dress.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nNeckline & Collar: Reference is a high crew neckline; AI shows a V-neck—major structural change.\nSeams & Panels: Reference bodice has princess seams; AI omits them, simplifying paneling.\nSleeves & Lengths: Reference is sleeveless; AI adds capped sleeves—added elements.\nSkirt & Slits: Reference A-line with no slit; AI suggests a front slit—extra element.\nFit Category: Reference regular; AI reads regular.\nStructural Verdict: Different neckline, added sleeves, altered paneling/slit—structurally a different garment.\n{2}"
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_denim_jacket.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nChest Pockets: Reference has two flap chest pockets with pointed flaps; AI shows both flap pockets with matching flap geometry and placements.\nBody Pockets: Reference includes vertical welt hand pockets; AI renders both welts in correct positions and angles.\nYoke & Panels: Reference features classic front yoke seam; AI preserves the yoke line and height precisely; panel divisions match.\nPlacket & Closures: Center-front button placket structure, spacing impression, and cuff plackets read correctly in AI; no extra placket or missing closure structure.\nSeams & Toplines: Side seams, armhole seams, and sleeve paneling align with reference; no phantom seams or missing panels.\nSleeves & Lengths: Sleeve length and cuff proportion match; body length and hem band shape are consistent with reference.\nFit Category: Reference regular; AI on-body reads regular.\nStructural Verdict: Perfect structural clone. All structural elements—including pocket count/type, yoke, placket, seams, and proportions—match exactly. No extra/missing elements or artifacts detected.\n{10}"
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_pleated_skirt.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nWaistband: Reference shows a straight waistband; AI reproduces same waistband height and structure.\nPleats: Reference uses dense micro-pleats; AI shows slightly wider spacing but consistent count across panels.\nHem & Length: Straight hem in both; length reads equivalent.\nClosures: Side-zip impression not visible in either—no conflict.\nFit Category: Reference regular; AI reads regular.\nStructural Verdict: Decent structural alignment with small pleat-spacing variance.\n{7}"
        }
      ]
    },

    {
      "role": "user",
      "content": [
        {"type": "input_text", "text": "Compare: https://example.com/flatlay_shirt.png"}
      ]
    },
    {
      "role": "assistant",
      "content": [
        {
          "type": "output_text",
          "text": "Justification and Structural Analysis:\nCollar & Stand: Reference has a point collar with stand; AI matches collar shape and stand height.\nPlacket: Reference features a standard front placket; AI placket present and aligned.\nChest Pocket: Reference has no chest pocket; AI adds a left chest patch pocket—added element.\nYoke & Seams: Back yoke implied in reference; AI suggests same; side seams align.\nSleeves & Cuffs: Long sleeves with single-button cuff in reference; AI keeps long sleeves with comparable cuff shape.\nFit Category: Reference regular; AI reads slim-regular.\nStructural Verdict: Added chest pocket changes structure despite other accurate elements.\n{5}"
        }
      ]
    },
]
print("✅ Few-shot eval messages loaded.")



✅ GPT config & folders ready.
✅ Few-shot eval messages loaded.


# UTILS

In [None]:
# --- Core utilities: normalization, angles, walking, masks (agnostic-first) ---

import os, re, fnmatch, math, uuid, pytz, random, gc, tempfile, traceback
from datetime import datetime
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter

def normalize_sku_list(sku_csv: str) -> str:
    skus = []
    for raw in sku_csv.split(','):
        sku = raw.strip().upper()
        match = re.search(r'(\d+)', sku)
        if match:
            sku_number = match.group(1)
            skus.append(f"SS-{sku_number}")
    # Return as CSV string
    return ", ".join(skus)

SKU_CSV = normalize_sku_list(SKU_CSV)

# Parsers
def _parse_csv_list(s):  return [x.strip().casefold() for x in (s or "").split(",") if x.strip()]
BANNED_SUBSTRINGS       = _parse_csv_list(BANNED_SUBSTRINGS_CSV)
REQUIRED_SUBSTRINGS     = _parse_csv_list(REQUIRED_SUBSTRINGS_CSV)
SKIP_FILENAME_TOKENS    = set(_parse_csv_list(SKIP_FILENAME_TOKENS_CSV))
SKIP_BASENAME_SUFFIXES  = tuple(_parse_csv_list(SKIP_BASENAME_SUFFIXES_CSV))

# Normalizers
def _norm_sku(s):
    if s is None: return ""
    s = str(s).replace("\u00A0"," ")
    s = " ".join(s.split())
    return s.casefold()

def _norm_angle(s):
    s = (s or "").strip().lower()
    return s.strip("_ ").replace("-", "_")

# Angle aliases
ANGLE_ALIASES = {
    "fr_cl": ["fr", "fr_"],
    #"lft":   ["fr_lft", "bc_lft"],
    #"rght":  ["fr_rght", "bc_rght"],
}

# --- Helpers to keep outputs strict, sources flexible ---
def expand_as_list(angles):
    exp = list(expand_allowed_angles(angles))
    exp = [_norm_angle(a) for a in exp]
    exp.sort(key=len, reverse=True)  # prefer 'fr_cl' over 'fr'
    return exp

def pick_target_angle(source_angle: str, allowed_outputs: set) -> str | None:
    s = _norm_angle(source_angle)
    for target in allowed_outputs:
        fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
        if s in fam:
            return _norm_angle(target)
    return None


def expand_allowed_angles(angles):
    expanded = set()
    for a in (angles or []):
        a_norm = _norm_angle(a)
        expanded.add(a_norm)
        for alt in ANGLE_ALIASES.get(a_norm, []):
            expanded.add(_norm_angle(alt))
    return expanded

# Ignore set
IGNORE_DIRS = {d.lower() for d in IGNORE_DIRS}

# Walkers
def _is_sku_folder(path: str) -> bool:
    if os.path.basename(os.path.normpath(path)).lower() in IGNORE_DIRS:
        return False
    try:
        for f in os.listdir(path):
            if os.path.isfile(os.path.join(path, f)) and f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)):
                return True
    except Exception:
        return False
    return False

def iter_sku_folders(root: str):
    for dirpath, dirnames, filenames in os.walk(root):
        dirnames[:] = [d for d in dirnames if d.lower() not in IGNORE_DIRS]
        if any(f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)) for f in filenames):
            yield dirpath

def resolve_targets(idents_csv: str, garments_root: str):
    """
    Accepts:
      • Plain SKU names, relative paths (Category/Subcategory/SKU), or absolute dirs
      • Glob patterns (e.g., 'Hoodies/*' or 'SKSLK_12*')
      • Directories that are NOT SKU leaves → expand to all descendant SKU leaves
    """
    idents = [s.strip() for s in idents_csv.replace("\n", ",").split(",") if s.strip()]
    if not idents: return [], []

    all_sku_dirs = list(iter_sku_folders(garments_root))
    rel_map = {p: os.path.relpath(p, garments_root) for p in all_sku_dirs}
    base_map = {p: os.path.basename(p) for p in all_sku_dirs}

    seen, out, unmatched = set(), [], []
    def add_path(p):
        ap = os.path.abspath(p)
        if os.path.isdir(ap):
            if _is_sku_folder(ap):
                if ap not in seen:
                    seen.add(ap); out.append(ap)
            else:
                # Expand directory to all descendant SKU leaves
                for leaf in iter_sku_folders(ap):
                    a = os.path.abspath(leaf)
                    if a not in seen:
                        seen.add(a); out.append(a)

    for ident in idents:
        before = len(out)
        # Absolute directory or SKU path
        if os.path.isabs(ident) and os.path.isdir(ident):
            add_path(ident)

        # Relative under garments root (dir or SKU)
        rel_candidate = os.path.join(garments_root, ident)
        if os.path.exists(rel_candidate):
            add_path(rel_candidate)

        # Glob/pattern over known SKU leaves (by basename or relative path)
        for p in all_sku_dirs:
            if fnmatch.fnmatch(base_map[p], ident) or fnmatch.fnmatch(rel_map[p], ident):
                add_path(p)

        if len(out) == before:
            unmatched.append(ident)

    out.sort()
    return out, unmatched

# Base/mask location resolution
def resolve_base_mask_dir(sku_folder: str,
                          garments_root: str = GARMENTS_ROOT,
                          base_root: str = BASE_PHOTOS_ROOT):
    """
    Map .../GARMENTS_ROOT/Category/Subcategory/SKU → .../BASE_ROOT/Category/Subcategory
    With robust fallbacks.
    """
    abs_sku = os.path.abspath(sku_folder)
    abs_gar = os.path.abspath(garments_root)
    try:
        rel = os.path.relpath(abs_sku, abs_gar)
    except Exception:
        rel = None

    if rel and not rel.startswith(".."):
        rel_parent = os.path.dirname(rel)
        cand = os.path.join(base_root, rel_parent)
        if os.path.isdir(cand): return cand

    subcat = os.path.basename(os.path.dirname(abs_sku))
    cat    = os.path.basename(os.path.dirname(os.path.dirname(abs_sku)))
    cand2  = os.path.join(base_root, cat, subcat)
    if os.path.isdir(cand2): return cand2

    cand3  = os.path.join(base_root, subcat)
    if os.path.isdir(cand3): return cand3
    return None

def _valid_ext(fname): return fname.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS))

def _file_prefix_or_none(filename: str):
    low = filename.lower()
    for base in ALLOWED_BASES:
        if low.startswith(base): return base
    return None

def _find_image_with_stem_and_suffix(directory, stem, suffix=""):
    if not directory or not os.path.isdir(directory):
        return None
    stem = stem.lower()
    for file in os.listdir(directory):
        fname, fext = os.path.splitext(file)
        if fext.lower() in (".png",".jpg",".jpeg") and fname.lower() == f"{stem}{suffix}":
            return os.path.join(directory, file)
    return None

# --- Existence check in Google Drive by Colab-style path ---
def drive_file_exists_any_ext_at_colab_path(target_colab_path: str,
                                            exts=(".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")) -> bool:
    """
    Given a Colab-style *file* path (incl. a filename with any extension),
    checks if a file with the SAME stem exists in the same folder with any of the allowed extensions.
    """
    try:
        parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(target_colab_path)
        stem, _ = os.path.splitext(desired_name)
        files = _list_children(parent_id, q_extra="")  # list once; filter locally
        allowed = {e.lower() for e in exts}
        for f in files:
            fname = f.get("name", "")
            s, e = os.path.splitext(fname)
            if s == stem and e.lower() in allowed:
                return True
        return False
    except Exception as e:
        print(f"⚠️ Ext-agnostic existence check failed for {target_colab_path}: {e}")
        return False


# === NEW: mask finding with AGNOSTIC priority ===
def find_mask_path(base_subcat_dir: str, stem_no_cut: str):
    """
    Priority:
      1) {stem}_mask_agnostic.(png|jpg|jpeg)
      2) {stem}_mask.(png|jpg|jpeg)
    """
    if not base_subcat_dir or not os.path.isdir(base_subcat_dir):
        return None

    candidates = []
    if PREFER_AGNOSTIC_MASKS:
      for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
          candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask_agnostic{ext}"))
    for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
        candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask{ext}"))

    for p in candidates:
        if os.path.exists(p):
            return p
    return None

# ──────────────────────────────────────────
# --- Aspect-ratio bbox (replaces square bbox usage) ---


def find_aspect_bbox(
    mask: Image.Image,
    aspect: tuple[int,int] = (4,5),   # width:height, e.g. (1280,1600)
    padding: int = 40,
    upper_padding: int | None = None,
    horiz_padding: int = 0,
    min_margin: int | None = None,
):
    """
    Return a rectangular bbox [x0, y0, x1, y1] that fully contains the mask + padding
    and matches the requested aspect ratio by expanding outward only.
    """
    if min_margin is None:
        try:
            min_margin = int(MASK_EXPAND_PX + 3 * MASK_FEATHER_PX + 5)
        except Exception:
            min_margin = 40

    m = np.array(mask.convert("L"))
    h, w = m.shape
    ys, xs = np.where(m > 128)
    if xs.size == 0:
        raise ValueError("Mask has no white pixels!")

    x_min, x_max = int(xs.min()), int(xs.max())
    y_min, y_max = int(ys.min()), int(ys.max())

    if upper_padding is None:
        upper_padding = padding

    # Initial padded bbox
    x0 = max(0, x_min - horiz_padding - min_margin)
    x1 = min(w, x_max + horiz_padding + min_margin)
    y0 = max(0, y_min - upper_padding - min_margin)
    y1 = min(h, y_max + padding + min_margin)

    bw, bh = (x1 - x0), (y1 - y0)
    # Desired aspect as float
    aw, ah = aspect
    target_ar = float(aw) / float(max(1, ah))

    # First pass: try to match aspect by expanding one dimension only
    def expand_to_aspect(x0, y0, x1, y1):
        bw = x1 - x0; bh = y1 - y0
        cur_ar = bw / float(max(1, bh))
        if cur_ar < target_ar:
            # too tall → need wider
            need_w = int(np.ceil(target_ar * bh))
            grow = max(0, need_w - bw)
            left_grow  = min(x0, grow // 2)
            right_grow = min(w - x1, grow - left_grow)
            x0 -= left_grow; x1 += right_grow
        elif cur_ar > target_ar:
            # too wide → need taller
            need_h = int(np.ceil(bw / target_ar))
            grow = max(0, need_h - bh)
            top_grow    = min(y0, grow // 2)
            bottom_grow = min(h - y1, grow - top_grow)
            y0 -= top_grow; y1 += bottom_grow
        return max(0,x0), max(0,y0), min(w,x1), min(h,y1)

    x0, y0, x1, y1 = expand_to_aspect(x0, y0, x1, y1)

    # Second pass: if a border capped us, re-try by expanding the other dimension
    bw, bh = (x1 - x0), (y1 - y0)
    cur_ar = bw / float(max(1, bh))
    if abs(cur_ar - target_ar) > 1e-3:
        if cur_ar < target_ar:
            # could not widen enough → try to grow height
            need_h = int(np.ceil(bw / target_ar))
            grow = max(0, need_h - bh)
            top_grow    = min(y0, grow // 2)
            bottom_grow = min(h - y1, grow - top_grow)
            y0 -= top_grow; y1 += bottom_grow
        else:
            # could not heighten enough → try to grow width
            need_w = int(np.ceil(target_ar * bh))
            grow = max(0, need_w - bw)
            left_grow  = min(x0, grow // 2)
            right_grow = min(w - x1, grow - left_grow)
            x0 -= left_grow; x1 += right_grow

    # Final clamp
    x0, y0 = max(0, int(x0)), max(0, int(y0))
    x1, y1 = min(w, int(x1)), min(h, int(y1))
    return [x0, y0, x1, y1]


# Alpha/white utilities for garment panel
WHITE_RGB = (255,255,255)
def flatten_alpha_to_white(img: Image.Image) -> Image.Image:
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        bg = Image.new("RGB", img.size, WHITE_RGB)
        bg.paste(img, mask=img.split()[-1])
        return bg
    return img.convert("RGB")

def _tight_bbox_nonwhite_or_opaque(img: Image.Image):
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        arr = np.asarray(img.convert("RGBA"))
        alpha = arr[...,3]
        fg = alpha > 0
    else:
        arr = np.asarray(img.convert("RGB"))
        fg = ~((arr[...,0]==255)&(arr[...,1]==255)&(arr[...,2]==255))
    if not np.any(fg): return None
    ys, xs = np.where(fg)
    x0, x1 = int(xs.min()), int(xs.max())+1
    y0, y1 = int(ys.min()), int(ys.max())+1
    return (x0,y0,x1,y1)

def crop_garment_keep_aspect(img: Image.Image) -> Image.Image:
    bbox = _tight_bbox_nonwhite_or_opaque(img)
    base = flatten_alpha_to_white(img)
    if bbox is None: return base
    full_bbox = (0,0,base.width,base.height)
    if bbox == full_bbox: return base
    return base.crop(bbox)

def to_centered_square(gar: Image.Image, fill=WHITE_RGB) -> Image.Image:
    w,h = gar.size; side = max(w,h)
    sq = Image.new("RGB", (side, side), fill)
    ox, oy = (side-w)//2, (side-h)//2
    sq.paste(gar, (ox,oy)); return sq


In [None]:
# ==== GPT scorer with few-shot pre-messages ====
import io, json, base64, re
from typing import Tuple, Dict, Any, Optional, List
from PIL import Image

from openai import OpenAI

from google.colab import userdata
openai_api_key = userdata.get('openai_api_key')


_client: Optional[OpenAI] = None
def _client_once() -> OpenAI:
    global _client
    if _client is None:
        _client = OpenAI(api_key=openai_api_key)
    return _client

def _pil_to_data_url(pil: Image.Image) -> str:
    buf = io.BytesIO()
    pil.convert("RGB").save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"data:image/png;base64,{b64}"


from typing import List

_DETAILS_SCHEMA = {
    "type": "object",
    "additionalProperties": False,
    "required": ["details"],
    "properties": {
        "details": {
            "type": "array",
            "items": {
                "type": "object",
                "additionalProperties": False,
                "required": ["type"],
                "properties": {
                    "type": {"type": "string",
                             "enum": ["crest","logo","patch"]},
                    "color": {"type": "string"}
                }
            }
        }
    }
}

def _normalize_detail_type(t: str) -> str:
    t = (t or "").strip().lower()
    mapping = {
        "waistband lettering": "waist text",
        "sleeve lettering": "sleeve text",
        "sleeve_text": "sleeve text",
        "waist_text": "waist text",
    }
    return mapping.get(t, t)

def _postprocess_details(payload: Dict[str, Any]) -> Dict[str, Any]:
    details = payload.get("details", [])
    fixed: List[Dict[str, Any]] = []
    for d in details:
        typ = _normalize_detail_type(d.get("type"))
        col = d.get("color")
        if typ == "sleeve text":
            fixed.append({"type": "sleeve text"})
        elif typ in ["crest","logo","patch"]:
            ent = {"type": typ}
            if col and isinstance(col, str) and col.strip():
                ent["color"] = col.strip()
            fixed.append(ent)
    return {"details": fixed}

def concat_side_by_side(left: Image.Image, right: Image.Image, pad: int = 16, bg=(255,255,255)) -> Image.Image:
    lh = max(left.height, right.height)
    def _rs(im):
        scale = lh / max(1, im.height)
        return im.resize((max(1,int(im.width*scale)), lh), Image.LANCZOS)
    L = _rs(left.convert("RGB")); R = _rs(right.convert("RGB"))
    canvas = Image.new("RGB", (L.width + pad + R.width, lh), bg)
    canvas.paste(L, (0, 0))
    canvas.paste(R, (L.width + pad, 0))
    return canvas




def gpt_score_tryon(garment_img: Image.Image, generated_img: Image.Image):
    """
    Sends your developer prompt + few-shot exemplars, then asks the model
    to write the review and put the numeric score in curly braces at the end.
    Returns (score:int, response_id:str).
    """
    pair = concat_side_by_side(garment_img, generated_img)
    data_url = _pil_to_data_url(pair)
    client = _client_once()

    messages = []
    # Your developer/system-style instruction for how to evaluate
    messages.append({
        "role": "developer",
        "content": [{"type": "input_text", "text": EVAL_DEVELOPER_PROMPT.strip()}]
    })
    # Your long few-shot examples exactly as provided earlier
    messages.extend(EVAL_FEWSHOT_MESSAGES)
    # The actual comparison request for THIS pair
    messages.append({
        "role": "user",
        "content": [
            {"type": "input_text",
             "text": (
                 "Compare:"
             )},
            {"type": "input_image", "image_url": data_url}
        ]
    })

    resp = client.responses.create(
        model=GPT_EVAL_MODEL,
        input=messages,
        temperature=0
    )

    text = getattr(resp, "output_text", "") or ""
    print(f"GPT full eval: {text}")
    # Robustly extract the last {N} or 'Rating: {N}'
    import re
    score = None
    # 1) last {...} containing 1-10
    brace_hits = re.findall(r"\{[^{}]*?(10|[1-9])[^{}]*?\}", text)
    if brace_hits:
        score = int(brace_hits[-1])
    else:
        # 2) Rating: {N} or Rating: N (fallback)
        m = re.search(r"(?i)rating\s*[:\-]?\s*\{?\s*(10|[1-9])\s*\}?", text)
        score = int(m.group(1)) if m else 1

    return score, resp.id


def gpt_detect_details(garment_img: Image.Image,
                       generated_img: Image.Image,
                       previous_response_id: str | None):
    """
    Calls your GPT_DETAIL_ANALYZE_PROMPT verbatim. Expects pure JSON per your spec.
    Returns the parsed dict (or {"details": []} on failure).
    """
    pair = concat_side_by_side(garment_img, generated_img)
    data_url = _pil_to_data_url(pair)
    client = _client_once()

    resp = client.responses.create(
        model=GPT_EVAL_MODEL,
        previous_response_id=previous_response_id,  # keep same thread
        input=[{
            "role": "user",
            "content": [
                {"type": "input_text", "text": GPT_DETAIL_ANALYZE_PROMPT},
                {"type": "input_image", "image_url": data_url}
            ]
        }],
        temperature=0
    )

    raw = getattr(resp, "output_text", "") or ""
    import json, re
    try:
        return json.loads(raw)
    except Exception:
        # If the model wrapped JSON in prose or code fences, try to extract the first JSON object
        m = re.search(r"\{[\s\S]*\}", raw)
        if m:
            try:
                return json.loads(m.group(0))
            except Exception:
                pass
        return {"details": []}



In [None]:
# --- Visualisation helpers (restored) ---

import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageOps

def open_upright(path) -> Image.Image:
    # EXIF-aware loader (same as before)
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im)

def show_gallery(img_list, titles=None, cols=3, w=4):
    """
    Display PIL images in a flexible grid (identical behaviour to your original).
    Only renders if SHOW_VISUALS is True.
    """
    if not globals().get("SHOW_VISUALS", False):
        return

    n = len(img_list)
    rows = math.ceil(n / cols)
    plt.figure(figsize=(cols * w, rows * w))

    for i, img in enumerate(img_list):
        plt.subplot(rows, cols, i + 1)
        # Accept PIL, torch tensors or numpy arrays (4-D batch ⇒ pick first)
        if isinstance(img, np.ndarray) and img.ndim == 4:
            img = img[0]  # (B,H,W,C) → (H,W,C)
        # Torch tensors are printed via duck-typing check to avoid hard import
        if "Tensor" in str(type(img)):
            img = img.detach().cpu().permute(1, 2, 0).numpy()
        plt.imshow(img)
        if titles and i < len(titles):
            plt.title(titles[i])
        plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
# --- Paste-back (mask-aware) + pair/mask builders ---

import cv2
from torchvision import transforms
_to_tensor = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
_to_tensor_mask = transforms.Compose([transforms.ToTensor()])

# ──────────────────────────────────────────
def paste_crop_back(
    full_img: Image.Image,
    edited_crop: Image.Image,
    crop_box,               # (x0,y0,x1,y1) in full_img coords
    crop_mask: np.ndarray,  # H×W uint8/bool, garment=white within crop_box
    expand_px: int = 20,    # outward dilation before feather (mask-aware)
    feather_px: int = 10,   # Gaussian σ for feathering (outside only)
    *,
    bin_thresh: int = 127,
    edge_kill_px: int | None = None,  # will be capped by actual margin
    retry_expand_px: int = 30,        # how much to enlarge bbox if needed
):
    x0, y0, x1, y1 = map(int, crop_box)
    tgt_w, tgt_h   = (x1 - x0), (y1 - y0)

    # Resize edit to crop size
    edit_rs = edited_crop.resize((tgt_w, tgt_h), Image.Resampling.LANCZOS)

    # --- 1) Binary mask in crop coordinates
    mask_np = crop_mask
    if isinstance(mask_np, Image.Image):
        mask_np = np.array(mask_np.convert("L"))
    if mask_np.ndim == 3:
        mask_np = mask_np[..., 0]
    mask_np = cv2.resize(mask_np, (tgt_w, tgt_h), interpolation=cv2.INTER_NEAREST)
    mask_bin = (mask_np > bin_thresh).astype(np.uint8)

    # --- 2) Build outside-only feather band
    if expand_px > 0:
        ksize = max(1, expand_px * 2 + 1)
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ksize, ksize))
        dil = cv2.dilate(mask_bin, kernel, iterations=1)
    else:
        dil = mask_bin.copy()

    outside = (dil - mask_bin).clip(0, 1).astype(np.float32) * 255.0
    if feather_px > 0:
        outside = cv2.GaussianBlur(outside, (0, 0), sigmaX=feather_px, sigmaY=feather_px)

    # --- 3) Edge-kill: CAP by actual margin so we only taper near rectangle borders
    # distance from crop border to full-image borders on each side
    margin_left   = x0
    margin_right  = full_img.width  - x1
    margin_top    = y0
    margin_bottom = full_img.height - y1
    max_safe_edgekill = max(2, min(margin_left, margin_right, margin_top, margin_bottom))

    if edge_kill_px is None:
        edge_kill_px = int(expand_px + 3 * feather_px)
    edge_kill_px = min(int(edge_kill_px), int(max_safe_edgekill))

    # apply taper only within edge_kill band near the crop edges (no global attenuation)
    yy, xx = np.mgrid[0:tgt_h, 0:tgt_w]
    dist_edge = np.minimum.reduce([xx, tgt_w - 1 - xx, yy, tgt_h - 1 - yy]).astype(np.float32)

    edge_factor = np.ones_like(dist_edge, np.float32)
    band = dist_edge < float(edge_kill_px)
    edge_factor[band] = dist_edge[band] / float(max(1.0, edge_kill_px))
    outside *= edge_factor

    # --- 4) Final alpha: solid interior + tapered outside (no inward feathering)
    alpha = np.zeros((tgt_h, tgt_w), np.float32)
    alpha[mask_bin > 0] = 255.0
    alpha += outside
    alpha = np.clip(alpha, 0, 255).astype(np.uint8)

    # --- 5) Leakage check:
    # Only retry if alpha > 0 at a crop edge that is NOT coincident with the full image edge.
    leaks_top    = alpha[0, :].max() > 0
    leaks_bottom = alpha[-1, :].max() > 0
    leaks_left   = alpha[:, 0].max() > 0
    leaks_right  = alpha[:, -1].max() > 0

    needs_retry = False
    if leaks_top    and y0 > 0:                      needs_retry = True
    if leaks_bottom and y1 < full_img.height:        needs_retry = True
    if leaks_left   and x0 > 0:                      needs_retry = True
    if leaks_right  and x1 < full_img.width:         needs_retry = True

    if needs_retry:
        print("⚠️  alpha touches crop border (inside canvas) → retrying with expanded bbox...")
        x0n = max(0, x0 - retry_expand_px)
        y0n = max(0, y0 - retry_expand_px)
        x1n = min(full_img.width,  x1 + retry_expand_px)
        y1n = min(full_img.height, y1 + retry_expand_px)
        new_box = [x0n, y0n, x1n, y1n]

        # Build a full-canvas mask only to crop-align it to the new box:
        full_mask = np.zeros((full_img.height, full_img.width), np.uint8)
        full_mask[y0:y1, x0:x1] = (mask_bin * 255).astype(np.uint8)
        region_mask_crop = full_mask[y0n:y1n, x0n:x1n]  # <<< keep mask aligned to new crop

        # Recurse with slightly larger edge_kill to be conservative
        return paste_crop_back(
            full_img   = full_img,
            edited_crop= edited_crop,
            crop_box   = new_box,
            crop_mask  = region_mask_crop,
            expand_px  = expand_px,
            feather_px = feather_px,
            bin_thresh = bin_thresh,
            edge_kill_px = edge_kill_px + 10,
            retry_expand_px = retry_expand_px
        )

    # --- 6) Composite
    mask_img = Image.fromarray(alpha, mode="L")
    region   = full_img.crop((x0, y0, x1, y1))
    comp     = Image.composite(edit_rs, region, mask_img)
    full_img.paste(comp, (x0, y0))
    return full_img


from diffusers.utils import load_image
# --- Pair & mask builder with letterboxing for the garment panel ---

from diffusers.utils import load_image
from PIL import Image

def _fit_to_canvas(im: Image.Image, size: tuple[int,int], *, fill=(255,255,255), resample=Image.LANCZOS) -> Image.Image:
    W,H = size
    if im.width == 0 or im.height == 0:
        return Image.new("RGB", (W,H), fill)
    scale = min(W / im.width, H / im.height)
    nw, nh = max(1, int(round(im.width * scale))), max(1, int(round(im.height * scale)))
    rs = im.resize((nw, nh), resample)
    canvas = Image.new("RGB", (W, H), fill)
    ox, oy = (W - nw) // 2, (H - nh) // 2
    canvas.paste(rs, (ox, oy))
    return canvas

def make_pair_and_mask(steve_image_path, mask_path, garment_path, size=None):
    # default to global WIDTH, HEIGHT
    W = WIDTH if size is None else size[0]
    H = HEIGHT if size is None else size[1]

    # Right panel (model/base) & its mask are expected to be pre-sized already in main loop,
    # but keep safe resizing here:
    steve = load_image(steve_image_path).convert("RGB").resize((W,H), Image.BICUBIC)
    msk   = load_image(mask_path).convert("RGB").resize((W,H), Image.NEAREST)

    # Left panel (garment): crop to content, then LETTERBOX to W×H (no stretch)
    gar_raw     = load_image(garment_path)
    gar_cropped = crop_garment_keep_aspect(gar_raw)
    gar_panel   = _fit_to_canvas(gar_cropped.convert("RGB"), (W,H), fill=WHITE_RGB, resample=Image.LANCZOS)

    steve_t = _to_tensor(steve)
    gar_t   = _to_tensor(gar_panel)
    msk_t   = _to_tensor_mask(msk)[:1]

    inpaint_image = torch.cat([gar_t, steve_t], dim=2)  # (3,H,2W)
    zeros = torch.zeros_like(msk_t)
    extended_mask = torch.cat([zeros, msk_t], dim=2)    # (1,H,2W)
    return inpaint_image, extended_mask, H, W



In [None]:
# --- LoRA pipeline + schedule ---

import torch
from diffusers import FluxFillPipeline
from diffusers.models import FluxTransformer2DModel
from peft import PeftModel
from peft.tuners.lora import LoraLayer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if DEVICE == "cuda" else torch.float32

def list_lora_adapters(model):
    names=set()
    for m in model.modules():
        if isinstance(m, LoraLayer):
            names.update(m.lora_A.keys())
    return sorted(names)

def set_lora_global_scale(model, scale: float, adapter: str | None = None):
    for m in model.modules():
        if isinstance(m, LoraLayer):
            adapters = [adapter] if adapter else list(m.lora_A.keys())
            for a in adapters:
                m.scaling[a] = (m.lora_alpha[a] / m.r[a]) * float(scale)

def lora_weight_schedule_params(step_idx, total_steps, w_start, w_mid, w_end):
    s_spike = max(0, min(int(SPIKE_AT_STEP), max(0, total_steps-1)))
    s_tail  = max(s_spike+1, min(int(TAIL_START_AT_STEP), max(0, total_steps-1)))
    if step_idx < s_spike: return float(w_start)
    if step_idx < s_tail:
        denom = max(1,(s_tail - s_spike))
        f = (step_idx - s_spike)/denom
        return float(w_start + (w_mid - w_start)*f)
    denom_tail = max(1, (total_steps-1) - s_tail)
    f = (step_idx - s_tail)/denom_tail
    f = max(0.0, min(1.0, f))
    return float(w_mid + (w_end - w_mid)*f)

class LoRAStepHook:
    def __init__(self, model, total_steps, w_start, w_mid, w_end):
        self.model = model; self.total_steps = int(total_steps)
        self.step_idx = -1; self.last_ts_val=None; self._handle=None
        self.w_start, self.w_mid, self.w_end = float(w_start), float(w_mid), float(w_end)
    def _extract_timestep_from(self, args, kwargs=None):
        kwargs = kwargs or {}
        if "timestep" in kwargs: t = kwargs["timestep"]
        elif len(args)>=2: t = args[1]
        else: return None
        if torch.is_tensor(t):
            t = t.flatten()[0].detach(); return float(t.item())
        try: return float(t)
        except: return None
    def _maybe_update_scale(self, ts_val):
        if ts_val is None: return
        if (self.last_ts_val is None) or (ts_val != self.last_ts_val):
            self.last_ts_val = ts_val
            self.step_idx += 1
            idx = min(self.step_idx, self.total_steps-1)
            w = lora_weight_schedule_params(idx, self.total_steps, self.w_start, self.w_mid, self.w_end)
            set_lora_global_scale(self.model, w)
    def _pre_hook_kwargs(self, module, args, kwargs): self._maybe_update_scale(self._extract_timestep_from(args, kwargs))
    def _pre_hook(self, module, args): self._maybe_update_scale(self._extract_timestep_from(args, None))
    def attach(self):
        try: self._handle = self.model.register_forward_pre_hook(self._pre_hook_kwargs, with_kwargs=True)
        except TypeError: self._handle = self.model.register_forward_pre_hook(self._pre_hook)
    def detach(self):
        if self._handle is not None: self._handle.remove(); self._handle=None

def _build_lora_pipe():
    pipe = FluxFillPipeline.from_pretrained(FILL_MODEL_ID, torch_dtype=DTYPE, use_safetensors=True).to(DEVICE)
    catvton_transformer = FluxTransformer2DModel.from_pretrained(CATVTON_XFM, torch_dtype=DTYPE, use_safetensors=True)
    pipe.transformer = catvton_transformer
    if LORA_PATH and os.path.isdir(LORA_PATH):
        pipe.transformer = PeftModel.from_pretrained(pipe.transformer, LORA_PATH)
        print(f"Loaded LoRA from: {LORA_PATH}")
    else:
        print("⚠️ LORA_PATH not found — running base CatVTON transformer.")
    pipe.transformer.to(DEVICE, dtype=DTYPE).eval()
    try:
        pipe.vae.enable_slicing(); pipe.vae.enable_tiling()
    except Exception: pass
    torch.cuda.empty_cache()
    return pipe

def run_with_lora_schedule(pipe, steps, schedule_triplet, **pipe_kwargs):
    try: pipe.scheduler.set_timesteps(steps, device=pipe._execution_device)
    except Exception: pass
    w_start, w_mid, w_end = schedule_triplet
    hook = LoRAStepHook(pipe.transformer, steps, w_start, w_mid, w_end); hook.attach()
    try:
        result = pipe(num_inference_steps=steps, **pipe_kwargs)
        imgs = result.images
    finally:
        hook.detach(); set_lora_global_scale(pipe.transformer, 1.0)
    return imgs

def run_inference_lora(image_path, mask_path, garment_path, size=(WIDTH,HEIGHT), steps=STEPS, guidance_scale=GUIDANCE, seed=777, prompt=CATVTON_PROMPT, pipe=None):
    if pipe is None: pipe = _build_lora_pipe()
    inpaint_image, extended_mask, H, W = make_pair_and_mask(image_path, mask_path, garment_path, size=size)
    generator = torch.Generator(device=DEVICE).manual_seed(seed)
    with torch.autocast(device_type=DEVICE, dtype=DTYPE, enabled=(DEVICE=="cuda")):
        imgs = run_with_lora_schedule(
            pipe, steps=steps,
            height=H, width=W*2,
            image=inpaint_image, mask_image=extended_mask,
            generator=generator, max_sequence_length=512,
            guidance_scale=guidance_scale, prompt=prompt
        )
    out = imgs[0]
    garment_result = out.crop((0,0,W,H))
    tryon_result   = out.crop((W,0,W*2,H))
    return garment_result, tryon_result

def generate_tryon(base_crop: Image.Image, mask_crop: Image.Image, garment_path: str, seed: int, pipe=None):
    model_size = (WIDTH, HEIGHT)
    base_for_model = base_crop.resize(model_size, Image.LANCZOS)
    mask_for_model = mask_crop.resize(model_size, Image.NEAREST)
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_img, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_msk:
        tmp_img, tmp_msk = f_img.name, f_msk.name
        base_for_model.save(tmp_img); mask_for_model.save(tmp_msk)
    _, tryon = run_inference_lora(tmp_img, tmp_msk, garment_path, size=model_size, steps=STEPS, guidance_scale=GUIDANCE, seed=seed, prompt=CATVTON_PROMPT, pipe=pipe)
    try: os.remove(tmp_img); os.remove(tmp_msk)
    except: pass
    return tryon.resize(base_crop.size, Image.LANCZOS)


In [None]:
# Colab cell — output routing + metadata helpers
import os, json
from PIL import PngImagePlugin

def ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

ensure_dir(OUTPUT_DIR)

def build_output_filename(sku_name: str, angle_code: str, ext=".png") -> str:
    # Examples: SS-12345-fr_rght or SS-12345-bc_lft
    angle_clean = _norm_angle(angle_code)
    return f"{sku_name}-{angle_clean}{ext}"


import json, piexif
from PIL import Image

def save_png_with_metadata(img, out_path, details_payload=None, quality=95):
    if details_payload:
        # Encode JSON as UTF-8 with an ASCII prefix per EXIF spec for UserComment
        payload = json.dumps(details_payload, ensure_ascii=False).encode("utf-8")
        user_comment = b"ASCII\x00\x00\x00" + payload  # indicates undefined/UTF-8
        exif_dict = {"0th": {}, "Exif": {piexif.ExifIFD.UserComment: user_comment}, "1st": {}, "GPS": {}, "Interop": {}}
        exif_bytes = piexif.dump(exif_dict)
        img.save(out_path, format="PNG", exif=exif_bytes)
    else:
        img.save(out_path, format="PNG")

import json, piexif
from PIL import Image, ImageOps


In [None]:
# --- HYPIR (optional overlay) ---

import sys
sys.path.insert(0, "/content/HYPIR")
try:
    from HYPIR.enhancer.sd2 import SD2Enhancer
except Exception:
    SD2Enhancer = None

_HYPIR = {"model": None}
def _init_hypir_if_needed():
    if not ENABLE_HYPIR_ENHANCE: return None
    if SD2Enhancer is None: return None
    if _HYPIR["model"] is not None: return _HYPIR["model"]
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    m = SD2Enhancer(
        base_model_path=HYPIR_BASE_MODEL,
        weight_path=HYPIR_WEIGHT_PATH,
        lora_modules=[
            "to_k","to_q","to_v","to_out.0",
            "conv","conv1","conv2","conv_shortcut","conv_out",
            "proj_in","proj_out","ff.net.2","ff.net.0.proj"
        ],
        lora_rank=256, model_t=200, coeff_t=200, device=dev
    )
    m.init_models()
    _HYPIR["model"] = m
    return m

import torchvision.transforms as T
from accelerate.utils import set_seed
_to_tensor_vis = T.ToTensor()

def hypir_enhance_pil(img_pil, prompt=None, upscale=None, seed=-1):
    model = _init_hypir_if_needed()
    if model is None: return img_pil
    if seed == -1: seed = random.randint(0, 2**32-1)
    set_seed(seed)
    prompt  = HYPIR_PROMPT if prompt  is None else prompt
    upscale = HYPIR_UPSCALE if upscale is None else upscale
    tens = _to_tensor_vis(img_pil.convert("RGB")).unsqueeze(0)
    with torch.no_grad():
        out_list = model.enhance(lq=tens, prompt=prompt, upscale=upscale, return_type="pil")
    return out_list[0].convert("RGB")


In [None]:
# --- Google APIs: gspread + Drive upload + Operations sync ---

import google.auth
SCOPES = ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/spreadsheets"]
creds, _ = google.auth.default(scopes=SCOPES)

import gspread
gs = gspread.authorize(creds)

from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
drive_svc = build("drive", "v3", credentials=creds)

FOLDER_MIME   = "application/vnd.google-apps.folder"
SHORTCUT_MIME = "application/vnd.google-apps.shortcut"
PATH_PREFIX   = "/content/drive/MyDrive/"

def _escape_name(name: str) -> str: return name.replace("'", r"\'")

def _maybe_follow_shortcut(file_obj):
    if file_obj.get("mimeType") == SHORTCUT_MIME:
        sd = file_obj.get("shortcutDetails", {}) or {}
        return sd.get("targetId"), sd.get("targetMimeType")
    return file_obj.get("id"), file_obj.get("mimeType")

def _list_children(parent_id: str, q_extra: str, page_size: int = 1000):
    q = f"'{parent_id}' in parents and trashed = false"
    if q_extra: q += f" and ({q_extra})"
    resp = drive_svc.files().list(
        q=q, spaces="drive", pageSize=page_size,
        fields="files(id,name,mimeType,shortcutDetails)",
        includeItemsFromAllDrives=True, supportsAllDrives=True,
    ).execute()
    return resp.get("files", [])

def _find_folder_id(parent_id: str, name: str):
    files = _list_children(parent_id, q_extra=f"name = '{_escape_name(name)}' and (mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')")
    for f in files:
        if f["mimeType"] == FOLDER_MIME:
            return f["id"]
    for f in files:
        if f["mimeType"] == SHORTCUT_MIME:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME: return tid
    files = _list_children(parent_id, q_extra=f"(mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')")
    needle = name.strip().casefold()
    for f in files:
        if (f.get("name","").strip().casefold()) == needle:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME: return tid
    return None

def _resolve_parent_id_and_filename_from_colab_path(colab_path: str):
    if not colab_path.startswith(PATH_PREFIX):
        raise ValueError(f"This helper supports only '{PATH_PREFIX}...'. Got: {colab_path}")
    parts = colab_path[len(PATH_PREFIX):].strip("/").split("/")
    if not parts: raise ValueError("Path must include a file name.")
    parent_id = "root"
    for part in parts[:-1]:
        next_id = _find_folder_id(parent_id, part)
        if not next_id: raise FileNotFoundError(f"Folder not found in path: '{part}'")
        parent_id = next_id
    desired_name = parts[-1]
    return parent_id, desired_name

def upload_to_drive_folder(local_path: str, parent_folder_id: str, desired_name: str | None = None):
    media = MediaFileUpload(local_path, resumable=True)
    body = {"name": desired_name or os.path.basename(local_path), "parents": [parent_folder_id]}
    file = drive_svc.files().create(
        body=body, media_body=media,
        fields="id, webViewLink, name, parents",
        supportsAllDrives=True,
    ).execute()
    # anyone with link reader
    drive_svc.permissions().create(fileId=file["id"], body={"type":"anyone","role":"reader"},
                                   fields="id", supportsAllDrives=True).execute()
    return file

def update_operations_status(spreadsheet_id: str, sku_name: str, angle_code: str,
                             ops_sheet_name: str = OPS_SHEET_NAME,
                             status_value: str = "Girls need to check",
                             update_all: bool = False):
    sh = gs.open_by_key(spreadsheet_id)
    ws_ops = sh.worksheet(ops_sheet_name)
    col_c = ws_ops.col_values(3)
    target_sku = _norm_sku(sku_name)
    target_rows = [idx for idx, val in enumerate(col_c, start=1) if idx>1 and _norm_sku(val)==target_sku]
    updated_rows = []
    for r in target_rows:
        angle_here = (ws_ops.cell(r,5).value or "").strip()
        if angle_here == (angle_code or "").strip():
            ws_ops.update_cell(r,9,status_value)
            updated_rows.append(r)
            if not update_all: break
    return {"updated": len(updated_rows), "rows": updated_rows}

def upload_file_and_append_to_sheet(local_path: str, target_colab_path: str,
                                    sku_name: str, angle: str,
                                    spreadsheet_id: str, worksheet_name: str):
    parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(target_colab_path)
    uploaded = upload_to_drive_folder(local_path, parent_id, desired_name)
    file_id  = uploaded["id"]
    file_url = uploaded.get("webViewLink") or f"https://drive.google.com/file/d/{file_id}/view?usp=sharing"
    folder_url = f"https://drive.google.com/drive/folders/{parent_id}"

    ts  = datetime.now(pytz.timezone(TIMEZONE)).strftime("%m-%d %H:%M:%S")
    uid = str(uuid.uuid4())

    sh = gs.open_by_key(spreadsheet_id); ws = sh.worksheet(worksheet_name)
    sku_cell = f'=HYPERLINK("{folder_url}"; "{sku_name}")'
    ws.append_row([sku_cell, angle, ts, file_url, uid, "Girls need to check", OPERATOR], value_input_option="USER_ENTERED")

    try: update_operations_status(spreadsheet_id, sku_name, angle, OPS_SHEET_NAME, "Girls need to check", update_all=False)
    except Exception as e: print(f"⚠️ Ops update failed: {e}")

    return {"file_url": file_url}


#build pipe

In [None]:
local_pipe = _build_lora_pipe()

model_index.json:   0%|          | 0.00/536 [00:00<?, ?B/s]

Fetching 23 files:   0%|          | 0/23 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/273 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/782 [00:00<?, ?B/s]

text_encoder_2/model-00002-of-00002.safe(…):   0%|          | 0.00/4.53G [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/19.9k [00:00<?, ?B/s]

text_encoder_2/model-00001-of-00002.safe(…):   0%|          | 0.00/4.99G [00:00<?, ?B/s]

text_encoder/model.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/588 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/705 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

tokenizer_2/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.8k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/378 [00:00<?, ?B/s]

transformer/diffusion_pytorch_model-0000(…):   0%|          | 0.00/9.98G [00:00<?, ?B/s]

transformer/diffusion_pytorch_model-0000(…):   0%|          | 0.00/9.95G [00:00<?, ?B/s]

transformer/diffusion_pytorch_model-0000(…):   0%|          | 0.00/3.87G [00:00<?, ?B/s]

(…)ion_pytorch_model.safetensors.index.json:   0%|          | 0.00/121k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/820 [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/442 [00:00<?, ?B/s]

(…)ion_pytorch_model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

diffusion_pytorch_model-00002-of-00003.s(…):   0%|          | 0.00/9.99G [00:00<?, ?B/s]

diffusion_pytorch_model-00001-of-00003.s(…):   0%|          | 0.00/9.99G [00:00<?, ?B/s]

diffusion_pytorch_model-00003-of-00003.s(…):   0%|          | 0.00/3.83G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded LoRA from: /content/drive/MyDrive/Dazzl/SikSilk/their_dataset_LORA/4x5_1280_their_ds_LORA_13_w_jitter/4x5_1280_their_ds_LORA_13_w_jitter_best


# BATCH HELPERS

In [None]:
# --- Batch processor  ---

# Colab cell — replace the whole function to insert GPT eval + retry + routing


def process_one_garment_folder(folder_path: str, pipe=None, allowed_angles=None):
    # STRICT targets (what to PRODUCE), e.g. {'fr_cl'}
    allowed_outputs = {_norm_angle(a) for a in (allowed_angles or [])}
    # Flexible sources (what garment filenames may START WITH), e.g. ['fr_cl','fr','fr_']
    allowed_sources = expand_as_list(allowed_angles) if allowed_angles else None

    base_subcat_dir = resolve_base_mask_dir(folder_path)
    if not base_subcat_dir:
        print(f"⚠️ Cannot resolve base/mask dir for SKU: {folder_path}")
    files_sorted = sorted(os.listdir(folder_path))

    worklist = []
    for file in files_sorted:
        low = file.lower()

        # 1) Garment filename gates
        if allowed_sources and not any(low.startswith(src) for src in allowed_sources):  # aliasing allowed ONLY here
            continue
        if SKIP_FILENAME_TOKENS and any(tok in low for tok in SKIP_FILENAME_TOKENS):
            continue
        if REQUIRE_CUT_IN_FILENAME and ("cut" not in low):
            continue
        if not _valid_ext(file):
            continue

        # 2) Determine matched SOURCE and map → STRICT TARGET
        matched_source = next((src for src in allowed_sources if low.startswith(src)), None) if allowed_sources else None
        if not matched_source:
            continue
        target_angle = pick_target_angle(matched_source, allowed_outputs) if allowed_outputs else _norm_angle(matched_source)
        if allowed_outputs and not target_angle:
            continue  # garment belongs to no requested target

        # 3) STRICT base/mask resolution: **NO aliasing, do not try stem_nocut**
        base_img_path = _find_image_with_stem_and_suffix(base_subcat_dir, target_angle)
        if not base_img_path:
            print(f"⚠️ Missing BASE for target '{target_angle}' → skipping {file}")
            continue

        mask_path = find_mask_path(base_subcat_dir, target_angle)
        if not mask_path:
            print(f"⚠️ Missing MASK for target '{target_angle}' → skipping {file}")
            continue

        # 4) Queue job; bind the OUTPUT ANGLE = target_angle (never the source)
        worklist.append((file, target_angle, base_img_path, mask_path))

    sku_name = os.path.basename(folder_path)
    if allowed_sources:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate (outputs={sorted(list(allowed_outputs))}, sources={sorted(list(set(allowed_sources)))})")
    else:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate")

    if not worklist:
        return

    # === Run ===
    for idx, (file, target_angle, base_img_path, mask_path) in enumerate(worklist, start=1):
        print(f"   {idx:>3}/{len(worklist):<3}  {file}  | USING STRICT base/mask='{target_angle}'")
        garment_path = os.path.join(folder_path, file)

        # Skip if output already exists
        sku_name   = os.path.basename(folder_path)
        angle_code = _norm_angle(target_angle)
        out_stem   = f"{sku_name}_{angle_code}"
        dest_check = os.path.join(OUTPUT_DIR, out_stem + ".png")

        if drive_file_exists_any_ext_at_colab_path(dest_check):
            print(f"      ⏭️  Skip: {out_stem}.(png/jpg/jpeg) already exists in {OUTPUT_DIR}")
            continue
        elif drive_file_exists_any_ext_at_colab_path(dest_check):
            print(f"      ⏭️  Skip: {out_stem}.(png/jpg/jpeg) already exists in {OUTPUT_DIR}")
            continue

        try:
            garment_img = flatten_alpha_to_white(open_upright(garment_path))
            base_full   = Image.open(base_img_path).convert("RGB")
            mask_full   = Image.open(mask_path).convert("L")

            show_gallery(
                [garment_img, base_full, mask_full.convert("RGB")],
                ["Source garment (white BG)", f"Base photo [{angle_code}]", f"Mask [{os.path.basename(mask_path)}]"]
            )

            # Crop & prepare
            bbox = find_aspect_bbox(mask_full, aspect=TARGET_ASPECT, padding=CROP_PADDING, upper_padding=UPPER_PADDING, horiz_padding=CROP_PADDING, min_margin=10)
            base_crop = base_full.crop(tuple(bbox))
            mask_crop = mask_full.crop(tuple(bbox))
            show_gallery([base_crop, mask_crop.convert("RGB"), garment_img], ["Cropped base", "Cropped mask", "Garment (white BG)"])

            # === Try generation with schedule candidates until score >= threshold ===
            best_final = None
            best_score = -1
            last_resp_id = None

            for attempt, sched in enumerate(LORA_SCHEDULE_CANDIDATES, start=1):
                seed = random.randint(1, 2**32 - 1)
                # 1) Generate cropped try-on with chosen schedule
                model_size = (WIDTH, HEIGHT)
                base_for_model = base_crop.resize(model_size, Image.LANCZOS)
                mask_for_model = mask_crop.resize(model_size, Image.NEAREST)
                with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_img, \
                     tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f_msk:
                    tmp_img, tmp_msk = f_img.name, f_msk.name
                    base_for_model.save(tmp_img); mask_for_model.save(tmp_msk)

                try:
                    generator = torch.Generator(device=DEVICE).manual_seed(seed)
                    with torch.autocast(device_type=DEVICE, dtype=DTYPE, enabled=(DEVICE=="cuda")):
                        imgs = run_with_lora_schedule(
                            local_pipe, steps=STEPS, schedule_triplet=sched,
                            height=model_size[1], width=model_size[0]*2,
                            image=make_pair_and_mask(tmp_img, tmp_msk, garment_path, size=model_size)[0],
                            mask_image=make_pair_and_mask(tmp_img, tmp_msk, garment_path, size=model_size)[1],
                            generator=generator, max_sequence_length=512,
                            guidance_scale=GUIDANCE, prompt=CATVTON_PROMPT
                        )
                    out = imgs[0]
                    W,H = model_size
                    tryon_sq = out.crop((W,0,W*2,H))  # right half is the try-on
                finally:
                    try: os.remove(tmp_img); os.remove(tmp_msk)
                    except: pass

                # Optional HYPIR blend
                if ENABLE_HYPIR_ENHANCE and float(HYPIR_OVERLAY_OPACITY) > 0.0:
                    hyp = hypir_enhance_pil(tryon_sq, prompt=HYPIR_PROMPT, upscale=HYPIR_UPSCALE)
                    tryon_sq = Image.blend(tryon_sq.convert("RGB"), hyp, float(HYPIR_OVERLAY_OPACITY))

                # Paste back to full frame
                final_img = paste_crop_back(
                    full_img   = base_full.copy(),
                    edited_crop= tryon_sq,
                    crop_box   = bbox,
                    crop_mask  = np.array(mask_crop),
                    expand_px  = MASK_EXPAND_PX,
                    feather_px = MASK_FEATHER_PX
                )
                show_gallery([final_img], [f"Attempt {attempt} with schedule={sched}"])

                # 2) GPT scoring on [garment | final]
                if GPT_EVAL_ENABLED:
                    try:
                        score, resp_id = gpt_score_tryon(garment_img, final_img)
                        last_resp_id = resp_id
                        print(f"      🔎 GPT score = {score} (threshold {GPT_PASS_THRESHOLD})")
                    except Exception as e:
                        print(f"      ⚠️ GPT score failed: {e}")
                        score = 10  # fail-open to avoid endless loops

                    if score > best_score:
                        best_score = score
                        best_final = final_img.copy()

                    if score >= GPT_PASS_THRESHOLD:
                        final_img = best_final
                        break
                else:
                    best_final = final_img
                    best_score = 10
                    break  # no eval, accept first

            if best_final is None:
                best_final = final_img

            # 3) If passing, optionally ask for details on *same* conversation
            details = {"details": []}
            if GPT_EVAL_ENABLED and best_score >= GPT_PASS_THRESHOLD:
                try:
                    details = gpt_detect_details(garment_img, best_final, previous_response_id=last_resp_id)
                except Exception as e:
                    print(f"      ⚠️ GPT details failed: {e}")
                    details = {"details": []}

            # 4) Route & save with required filename pattern: {SKU}_{angle}.jpg
            out_name = build_output_filename(sku_name, angle_code, ext=".png")
            tmp_path = os.path.join("/tmp", out_name)
            has_details = bool(details.get("details"))


            # Save to the consolidated output folder
            save_png_with_metadata(best_final, tmp_path, details_payload=details if has_details else None)
            target_path_for_drive = os.path.join(OUTPUT_DIR, out_name)

            # Drive upload + sheet append (kept from your original flow)
            info = upload_file_and_append_to_sheet(
                local_path       = tmp_path,
                target_colab_path= target_path_for_drive,
                sku_name         = sku_name,
                angle            = angle_code,
                spreadsheet_id   = SPREADSHEET_ID,
                worksheet_name   = GEN_LOG_SHEET,
            )
            print(f"      ✅ Uploaded → {info['file_url']}  (details: {has_details and details})")

        except Exception as e:
            print(f"      ❌ Error: {e}")


In [None]:
# --- Sheet-driven angle selection + runners ---

# Sheet columns (0-based)
COL_C_SKU, COL_E_ANGLE, COL_H_NOTE, COL_I_STATE, COL_J_FLAG = 2,4,7,8,9

def angle_row_satisfies_conditions(row, sku_rows, *, regen_token, enforce_j_bans, banned_list, enforce_j_req, required_list, required_mode):
    def _h_has_token(r, token): return (token or "").casefold() in ((r[COL_H_NOTE] if len(r)>COL_H_NOTE else "") or "").casefold()
    def _i_empty(r):             return ((r[COL_I_STATE] if len(r)>COL_I_STATE else "") or "").strip() == ""
    def _sku_banned(rows, banned):
        if not banned: return False
        for rr in rows:
            j = ((rr[COL_J_FLAG] if len(rr)>COL_J_FLAG else "") or "").casefold()
            if any(b in j for b in banned): return True
        return False
    def _sku_required(rows, reqs, mode):
        if not reqs: return True
        j_concat = " ".join([((r[COL_J_FLAG] if len(r)>COL_J_FLAG else "") or "") for r in rows]).casefold()
        return all(r in j_concat for r in reqs) if mode=="ALL" else any(r in j_concat for r in reqs)

    if not _h_has_token(row, regen_token): return False
    if not _i_empty(row): return False
    if enforce_j_bans and _sku_banned(sku_rows, banned_list): return False
    if enforce_j_req  and not _sku_required(sku_rows, required_list, required_mode): return False
    angle = (row[COL_E_ANGLE] if len(row)>COL_E_ANGLE else "").strip()
    return angle != ""

def fetch_sku_angle_pairs_from_ops(spreadsheet_id: str, ops_sheet_name: str,
                                   *, regen_token, enforce_j_bans, banned_list, enforce_j_requires, required_list, required_mode):
    sh = gs.open_by_key(spreadsheet_id)
    ws = sh.worksheet(ops_sheet_name)
    all_vals = ws.get_all_values() or []
    rows = all_vals[1:] if len(all_vals)>1 else []

    by_sku = {}
    for r in rows:
        sku_raw = (r[COL_C_SKU] if len(r)>COL_C_SKU else "").strip()
        if not sku_raw: continue
        key = _norm_sku(sku_raw)
        by_sku.setdefault(key, {"sku_display": sku_raw, "rows": []})["rows"].append(r)

    pairs = []
    for key, bundle in by_sku.items():
        sku_rows = bundle["rows"]; display = bundle["sku_display"]
        for r in sku_rows:
            angle = (r[COL_E_ANGLE] if len(r)>COL_E_ANGLE else "").strip()
            if not angle: continue
            if angle_row_satisfies_conditions(
                r, sku_rows,
                regen_token=regen_token,
                enforce_j_bans=enforce_j_bans,
                banned_list=banned_list,
                enforce_j_req=enforce_j_requires,
                required_list=required_list,
                required_mode=(required_mode or "ANY").upper(),
            ):
                pairs.append({"sku": display, "angle": angle})

    seen=set(); uniq=[]
    for p in pairs:
        k = (_norm_sku(p["sku"]), _norm_angle(p["angle"]))
        if k not in seen:
            seen.add(k); uniq.append(p)
    print(f"Found {len(uniq)} eligible (SKU, angle) pair(s).")
    return uniq

def build_sku_folder_index(garments_root: str):
    return { _norm_sku(os.path.basename(p)) : p for p in iter_sku_folders(garments_root) }

# --- Entrypoints ---
def run_sheet():
    eligible = fetch_sku_angle_pairs_from_ops(
        SPREADSHEET_ID, OPS_SHEET_NAME,
        regen_token=ANGLE_NEEDS_REGENERATE_TOKEN,
        enforce_j_bans=ENFORCE_BAN_SUBSTRINGS,
        banned_list=BANNED_SUBSTRINGS,
        enforce_j_requires=ENFORCE_REQUIRE_SUBSTRINGS,
        required_list=REQUIRED_SUBSTRINGS,
        required_mode=REQUIRED_SUBSTRINGS_MODE,
    )
    per_sku_angles = {}
    display_name = {}
    for item in eligible:
        k = _norm_sku(item["sku"])
        per_sku_angles.setdefault(k, set()).add(_norm_angle(item["angle"]))
        display_name[k] = item["sku"]

    sku_index = build_sku_folder_index(GARMENTS_ROOT)
    total = len(per_sku_angles)
    print(f"➡️  Will process {total} SKU(s) from sheet.")
    if total == 0: return
    shared_pipe = local_pipe
    for i, (sku_key, angles) in enumerate(per_sku_angles.items(), start=1):
        folder = sku_index.get(sku_key)
        disp   = display_name.get(sku_key, sku_key)
        if not folder:
            print(f"⚠️  Missing folder for SKU '{disp}' under GARMENTS_ROOT.")
            continue
        print(f"\nSKU {i}/{total} ▶️  {disp}  angles={sorted(list(angles))}")
        try:
            process_one_garment_folder(folder, pipe=shared_pipe, allowed_angles=angles)
            print(f"✅ Finished: {disp}")
        except Exception as e:
            print(f"❌ Error in {disp}: {e}")
    print("\n🏁 Sheet run complete.")

def run_dir():
    if not TARGET_DIR:
        print("⚠️ TARGET_DIR is empty.")
        return
    targets, unmatched = resolve_targets(TARGET_DIR, GARMENTS_ROOT)
    if not targets:
        print(f"⚠️ No SKU leaves found under: {TARGET_DIR}")
        if unmatched: print("Unmatched:", unmatched)
        return
    print(f"➡️  Will process {len(targets)} SKU(s) from directory.")
    shared_pipe = local_pipe
    for i, p in enumerate(targets, start=1):
        name = os.path.basename(p)
        print(f"\nSKU {i}/{len(targets)} ▶️  {name}")
        try:
            process_one_garment_folder(p, pipe=shared_pipe)
            print(f"✅ Finished: {name}")
        except Exception as e:
            print(f"❌ Error in {name}: {e}")
    print("\n🏁 Directory run complete.")

def run_list():
    targets, unmatched = resolve_targets(SKU_CSV, GARMENTS_ROOT)
    if not targets:
        print("⚠️ No matching SKU folders found.")
        if unmatched: print("Unmatched:", ", ".join(unmatched))
        return
    print(f"➡️  Will process {len(targets)} SKU(s).")
    shared_pipe = local_pipe
    for i, p in enumerate(targets, start=1):
        name = os.path.basename(p)
        print(f"\nSKU {i}/{len(targets)} ▶️  {name}")
        try:
            process_one_garment_folder(p, pipe=shared_pipe, allowed_angles=ALLOWED_BASES)
            print(f"✅ Finished: {name}")
        except Exception as e:
            print(f"❌ Error in {name}: {e}")
    if unmatched:
        print("\nℹ️  Unmatched identifiers:")
        for u in unmatched: print("   -", u)
    print("\n🏁 List run complete.")




# DISPATCH

In [None]:
# Dispatch
if RUN_MODE == "sheet":
    run_sheet()
elif RUN_MODE == "dir":
    run_dir()
else:
    run_list()

#UNASSIGN

In [None]:
from google.colab import runtime
runtime.unassign()