# SETUP (restart after this)

In [None]:
!pip install piexif

In [None]:
# -*- coding: utf-8 -*-
# --- Setup: Google auth + Drive + NanoBanana Pro deps ---

import sys, os, subprocess, textwrap, importlib

from google.colab import auth, drive
auth.authenticate_user()
drive.mount('/content/drive')

!pip -q install --upgrade pip
!pip -q install -U "google-genai>=1.40.0" pillow numpy opencv-python-headless matplotlib gspread google-auth google-auth-oauthlib google-api-python-client piexif tqdm torch torchvision torchaudio transformers timm

print("✅ Setup done for NanoBanana Pro.")


# Select angles

In [None]:
fr_lft = True #@param {type:"boolean"}
fr_rght = True #@param {type:"boolean"}
fr_cl = True #@param {type:"boolean"}
bc_lft = True #@param {type:"boolean"}
bc_rght = True #@param {type:"boolean"}
lft = True #@param {type:"boolean"}
rght = True #@param {type:"boolean"}
bc_ = True #@param {type:"boolean"}
fr_ = True #@param {type:"boolean"}
fr_cl_btm = False #@param {type:"boolean"}
fr_cl_tp = False #@param {type:"boolean"}


names = ["fr_lft","fr_rght","fr_cl","bc_lft","bc_rght","lft","rght","bc_","fr_","fr_cl_btm","fr_cl_tp"]
ALLOWED_BASES = [n for n in names if locals()[n]]

# CONFIG

In [None]:
# --- Unified CONFIG ---

# Selection mode: list only
RUN_MODE = "sku_list"     #@param ["sku_list"]

# For RUN_MODE == "sku_list"
SKU_CSV = "28920, 28747, 29018, 29095, 29094, 28746, 28745"  #@param {type:"string"}

# Paths
BASE_PHOTOS_ROOT  = "/content/drive/MyDrive/Dazzl/SikSilk/SKSLK_MODELS/"
GARMENTS_ROOT     = "/content/drive/MyDrive/Dazzl/SikSilk/AlexGens/SikSilk/"


# Filename/dir policy
VALID_EXTENSIONS  = (".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")
IGNORE_DIRS       = {"old", "__MACOSX", ".ds_store", "Ricardo", "toweling"}
SKIP_FILENAME_TOKENS_CSV   = "mask, generated, freelance, _sec, _backup"   # substrings to skip
SKIP_BASENAME_SUFFIXES_CSV = "_sec"                             # stem endings to skip
REQUIRE_CUT_IN_FILENAME    = False   #@param {type:"boolean"}
PREFER_AGNOSTIC_MASKS = True #@param {type:"boolean"}
secondary_garment = True #@param {type:"boolean"}
SECONDARY_GARMENT = secondary_garment

# Cropping / paste-back (square 1:1, generous garment margin)
CROP_PADDING      = 100        # px padding around garment when building crop
UPPER_PADDING     = 100        # extra padding above garment
HORIZ_PADDING     = 100        # horizontal padding
MASK_EXPAND_PX    = 200        # outward growth before feather (legacy, left for compatibility)
MASK_FEATHER_PX   = 50         # Gaussian sigma for feathering / paste-back taper
CROP_MIN_MARGIN   = 0        # minimum margin even if mask touches edge

TARGET_ASPECT = (1, 1)         # enforce square crops for 1:1 generations

# DINOv3 garment localisation
DINO_MODEL_ID         = "facebook/dinov3-vits16-pretrain-lvd1689m"
DINO_BOX_PERCENTILE   = 82     # percentile for cls-attn heatmap threshold (lower = larger crop)

# NanoBanana Pro (Google GenAI)
NANOBANANA_MODEL_ID = "gemini-3-pro-image-preview"
GEN_ASPECT_RATIO    = "1:1"
GEN_IMAGE_SIZE      = "4K"
TRYON_PROMPT = """You are an expert virtual try-on AI. You will be given a 'model image' and a 'garment image'. Your task is to create a new photorealistic image where the person from the 'model image' is wearing the clothing from the 'garment image'.

**Crucial Rules:**
1.  **Complete Garment Replacement:** You MUST completely REMOVE and REPLACE the clothing item worn by the person in the 'model image' with the new garment. No part of the original clothing (e.g., collars, sleeves, patterns) should be visible in the final image.
2.  **Preserve the Model:** The person's face, hair, tattoos (if any), body shape, and pose from the 'model image' MUST remain unchanged, pixel-for-pixel.
3.  **Preserve the Background:** The entire background from the 'model image' MUST be preserved perfectly, pixel-for-pixel. Do not change the background color.
4.  **Apply the Garment:** Realistically fit the new garment onto the person. It should adapt to their pose with natural folds, shadows, and lighting consistent with the original scene.
5.  **Output:** Return ONLY the final, edited image. Do not include any text.
6.  **Bespoke quality:** the garment should be ironed (if applicable) and sit perfectly well — this is a professional fashion product photoshoot."""

# Sheet-related (Ops removed) kept only for Gen Log appends
SPREADSHEET_ID = "1Kbq9__sEUQiuDPuza5Xy_hRyIn8pUvmfFj6vhPBrp8Y"
GEN_LOG_SHEET  = "Gen Log"

# Misc
SHOW_VISUALS = True
TIMEZONE     = "Europe/Lisbon"
OPERATOR     = "Ivan"
OUTPUT_DIR   = "/content/drive/MyDrive/Dazzl/SikSilk/SS_OUTPUT_FOLDER/v1-3/" #@param {type:"string"}


# === Garment/type taxonomy (kept) ===
ALLOWED_GARMENT_TYPES = [
    "hoodie","jeans","joggers","shorts","sweater","swimwear",
    "t-shirt","shirts","track top","trousers","twinset","polo","vests","shirts"
]
TOP_GARMENTS    = ["t-shirt","shirt","sweater","hoodie","track top","vest"]
BOTTOM_GARMENTS = ["shorts","jogger-trousers","trousers","jeans","swimwear"]
TWINSET_TYPES   = ["twinset"]

# === Details tokens ===
ALLOWED_DETAIL_TYPES = ["crest","logo","patch"]

# Angle sheet tokens (kept for compatibility, not used in v1.3)
ANGLE_NEEDS_REGENERATE_TOKEN = "Regenerate"
ENFORCE_BAN_SUBSTRINGS     = True
BANNED_SUBSTRINGS_CSV      = "wrong, pair, combo"
ENFORCE_REQUIRE_SUBSTRINGS = False
REQUIRED_SUBSTRINGS_CSV    = ""
REQUIRED_SUBSTRINGS_MODE   = "ANY"   # "ANY" | "ALL"

print("✅ Config ready for NanoBanana Pro v1.3 (DINOv3 garment localisation)")


# UTILS

In [None]:
# --- Core utilities: normalization, angles, walking, DINOv3 garment localisation ---

import os, re, fnmatch, math, uuid, pytz, random, gc, tempfile, traceback
from datetime import datetime
import numpy as np
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter, ImageDraw
import torch
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel

def normalize_sku_list(sku_csv: str) -> str:
    skus = []
    for raw in sku_csv.split(','):
        sku = raw.strip().upper()
        match = re.search(r'(\d+)', sku)
        if match:
            sku_number = match.group(1)
            skus.append(f"SS-{sku_number}")
    # Return as CSV string
    return ", ".join(skus)

SKU_CSV = normalize_sku_list(SKU_CSV)

# Parsers
def _parse_csv_list(s):  return [x.strip().casefold() for x in (s or "").split(",") if x.strip()]
BANNED_SUBSTRINGS       = _parse_csv_list(BANNED_SUBSTRINGS_CSV)
REQUIRED_SUBSTRINGS     = _parse_csv_list(REQUIRED_SUBSTRINGS_CSV)
SKIP_FILENAME_TOKENS    = set(_parse_csv_list(SKIP_FILENAME_TOKENS_CSV))
SKIP_BASENAME_SUFFIXES  = tuple(_parse_csv_list(SKIP_BASENAME_SUFFIXES_CSV))

# Normalizers
def _norm_sku(s):
    if s is None: return ""
    s = str(s).replace("\u00A0"," ")
    s = " ".join(s.split())
    return s.casefold()

def _norm_angle(s):
    s = (s or "").strip().lower()
    return s.strip("_ ").replace("-", "_")

# Angle aliases
ANGLE_ALIASES = {
    "fr_cl":   ["fr", "fr_"],
    "fr":      ["fr_cl"],
    "bc_lft":  ["bc", "bc_"],
    "bc_rght": ["bc", "bc_"],
}


# --- Helpers to keep outputs strict, sources flexible ---
def expand_as_list(angles):
    exp = list(expand_allowed_angles(angles))
    exp = [_norm_angle(a) for a in exp]
    exp.sort(key=len, reverse=True)  # prefer 'fr_cl' over 'fr'
    return exp

def pick_target_angle(source_angle: str, allowed_outputs: set) -> str | None:
    s = _norm_angle(source_angle)
    for target in allowed_outputs:
        fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
        if s in fam:
            return _norm_angle(target)
    return None


def expand_allowed_angles(angles):
    expanded = set()
    for a in (angles or []):
        a_norm = _norm_angle(a)
        expanded.add(a_norm)
        for alt in ANGLE_ALIASES.get(a_norm, []):
            expanded.add(_norm_angle(alt))
    return expanded

# Ignore set
IGNORE_DIRS = {d.lower() for d in IGNORE_DIRS}

# Walkers
def _is_sku_folder(path: str) -> bool:
    if os.path.basename(os.path.normpath(path)).lower() in IGNORE_DIRS:
        return False
    try:
        for f in os.listdir(path):
            if os.path.isfile(os.path.join(path, f)) and f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)):
                return True
    except Exception:
        return False
    return False

def iter_sku_folders(root: str):
    for dirpath, dirnames, filenames in os.walk(root):
        dirnames[:] = [d for d in dirnames if d.lower() not in IGNORE_DIRS]
        if any(f.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS)) for f in filenames):
            yield dirpath

def resolve_targets(idents_csv: str, garments_root: str):
    """
    Accepts:
      • Plain SKU names, relative paths (Category/Subcategory/SKU), or absolute dirs
      • Glob patterns (e.g., 'Hoodies/*' or 'SKSLK_12*')
      • Directories that are NOT SKU leaves → expand to all descendant SKU leaves
    """
    idents = [s.strip() for s in idents_csv.replace("\n", ",").split(",") if s.strip()]
    if not idents: return [], []

    all_sku_dirs = list(iter_sku_folders(garments_root))
    rel_map = {p: os.path.relpath(p, garments_root) for p in all_sku_dirs}
    base_map = {p: os.path.basename(p) for p in all_sku_dirs}

    seen, out, unmatched = set(), [], []
    def add_path(p):
        ap = os.path.abspath(p)
        if os.path.isdir(ap):
            if _is_sku_folder(ap):
                if ap not in seen:
                    seen.add(ap); out.append(ap)
            else:
                # Expand directory to all descendant SKU leaves
                for leaf in iter_sku_folders(ap):
                    a = os.path.abspath(leaf)
                    if a not in seen:
                        seen.add(a); out.append(a)

    for ident in idents:
        before = len(out)
        # Absolute directory or SKU path
        if os.path.isabs(ident) and os.path.isdir(ident):
            add_path(ident)

        # Relative under garments root (dir or SKU)
        rel_candidate = os.path.join(garments_root, ident)
        if os.path.exists(rel_candidate):
            add_path(rel_candidate)

        # Glob/pattern over known SKU leaves (by basename or relative path)
        for p in all_sku_dirs:
            if fnmatch.fnmatch(base_map[p], ident) or fnmatch.fnmatch(rel_map[p], ident):
                add_path(p)

        if len(out) == before:
            unmatched.append(ident)

    out.sort()
    return out, unmatched

# Base/mask location resolution
def resolve_base_mask_dir(sku_folder: str,
                          garments_root: str = GARMENTS_ROOT,
                          base_root: str = BASE_PHOTOS_ROOT):
    """
    Map .../GARMENTS_ROOT/Category/Subcategory/SKU → .../BASE_ROOT/Category/Subcategory
    With robust fallbacks.
    """
    abs_sku = os.path.abspath(sku_folder)
    abs_gar = os.path.abspath(garments_root)
    try:
        rel = os.path.relpath(abs_sku, abs_gar)
    except Exception:
        rel = None

    if rel and not rel.startswith(".."):
        rel_parent = os.path.dirname(rel)
        cand = os.path.join(base_root, rel_parent)
        if os.path.isdir(cand): return cand

    subcat = os.path.basename(os.path.dirname(abs_sku))
    cat    = os.path.basename(os.path.dirname(os.path.dirname(abs_sku)))
    cand2  = os.path.join(base_root, cat, subcat)
    if os.path.isdir(cand2): return cand2

    cand3  = os.path.join(base_root, subcat)
    if os.path.isdir(cand3): return cand3
    return None

def _valid_ext(fname): return fname.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS))

def _file_prefix_or_none(filename: str):
    low = filename.lower()
    for base in ALLOWED_BASES:
        if low.startswith(base): return base
    return None

def _find_image_with_stem_and_suffix(directory, stem, suffix=""):
    if not directory or not os.path.isdir(directory):
        return None
    stem = stem.lower()
    for file in os.listdir(directory):
        fname, fext = os.path.splitext(file)
        if fext.lower() in (".png",".jpg",".jpeg") and fname.lower() == f"{stem}{suffix}":
            return os.path.join(directory, file)
    return None

# --- Existence check in Google Drive by Colab-style path ---
def drive_file_exists_any_ext_at_colab_path(target_colab_path: str,
                                            exts=(".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG")) -> bool:
    """
    Given a Colab-style *file* path (incl. a filename with any extension),
    checks if a file with the SAME stem exists in the same folder with any of the allowed extensions.
    """
    try:
        parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(target_colab_path)
        stem, _ = os.path.splitext(desired_name)
        files = _list_children(parent_id, q_extra="")  # list once; filter locally
        allowed = {e.lower() for e in exts}
        for f in files:
            fname = f.get("name", "")
            s, e = os.path.splitext(fname)
            if s == stem and e.lower() in allowed:
                return True
        return False
    except Exception as e:
        print(f"⚠️ Ext-agnostic existence check failed for {target_colab_path}: {e}")
        return False


# === NEW: mask finding with AGNOSTIC priority ===
def find_mask_path(base_subcat_dir: str, stem_no_cut: str):
    """
    Priority:
      1) {stem}_mask_agnostic.(png|jpg|jpeg)
      2) {stem}_mask.(png|jpg|jpeg)
    """
    if not base_subcat_dir or not os.path.isdir(base_subcat_dir):
        return None

    candidates = []
    if PREFER_AGNOSTIC_MASKS:
      for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
          candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask_agnostic{ext}"))
    for ext in (".png",".jpg",".jpeg",".PNG",".JPG",".JPEG"):
        candidates.append(os.path.join(base_subcat_dir, f"{stem_no_cut}_mask{ext}"))

    for p in candidates:
        if os.path.exists(p):
            return p
    return None


def find_secondary_garment_path(folder_path: str, main_filename: str):
    """
    Locate the secondary garment paired with a primary garment file.
    Example: main 'bc_lft_cut.png' -> looks for 'bc_lft_sec_cut.(png|jpg)'.
    Falls back to a non-cut variant when REQUIRE_CUT_IN_FILENAME is False.
    """
    stem, _ = os.path.splitext(main_filename)
    has_cut = stem.endswith("_cut")
    core = stem[:-4] if has_cut else stem

    candidates = [f"{core}_sec_cut"]
    if not REQUIRE_CUT_IN_FILENAME:
        candidates.append(f"{core}_sec")
    if not has_cut:
        candidates.append(f"{stem}_sec_cut")

    seen = set()
    for cand in candidates:
        if cand in seen:
            continue
        seen.add(cand)
        for ext in VALID_EXTENSIONS:
            path = os.path.join(folder_path, f"{cand}{ext}")
            if os.path.exists(path):
                return path
    return None

# ──────────────────────────────────────────
# --- Aspect-ratio bbox (legacy) ---


def find_aspect_bbox(
    mask: Image.Image,
    aspect: tuple[int,int] = (1,1),   # width:height, e.g. (1280,1600)
    padding: int = 40,
    upper_padding: int | None = None,
    horiz_padding: int = 0,
    min_margin: int | None = None,
):
    """
    Return a rectangular bbox [x0, y0, x1, y1] that fully contains the mask + padding
    and matches the requested aspect ratio by expanding outward only.
    """
    if min_margin is None:
        try:
            min_margin = int(MASK_EXPAND_PX + 3 * MASK_FEATHER_PX + 5)
        except Exception:
            min_margin = 40

    m = np.array(mask.convert("L"))
    h, w = m.shape
    ys, xs = np.where(m > 128)
    if xs.size == 0:
        raise ValueError("Mask has no white pixels!")

    x_min, x_max = int(xs.min()), int(xs.max())
    y_min, y_max = int(ys.min()), int(ys.max())

    if upper_padding is None:
        upper_padding = padding

    # Initial padded bbox
    x0 = max(0, x_min - horiz_padding - min_margin)
    x1 = min(w, x_max + horiz_padding + min_margin)
    y0 = max(0, y_min - upper_padding - min_margin)
    y1 = min(h, y_max + padding + min_margin)

    bw, bh = (x1 - x0), (y1 - y0)
    # Desired aspect as float
    aw, ah = aspect
    target_ar = float(aw) / float(max(1, ah))

    # First pass: try to match aspect by expanding one dimension only
    def expand_to_aspect(x0, y0, x1, y1):
        bw = x1 - x0; bh = y1 - y0
        cur_ar = bw / float(max(1, bh))
        if cur_ar < target_ar:
            # too tall → need wider
            need_w = int(np.ceil(target_ar * bh))
            grow = max(0, need_w - bw)
            left_grow  = min(x0, grow // 2)
            right_grow = min(w - x1, grow - left_grow)
            x0 -= left_grow; x1 += right_grow
        elif cur_ar > target_ar:
            # too wide → need taller
            need_h = int(np.ceil(bw / target_ar))
            grow = max(0, need_h - bh)
            top_grow    = min(y0, grow // 2)
            bottom_grow = min(h - y1, grow - top_grow)
            y0 -= top_grow; y1 += bottom_grow
        return max(0,x0), max(0,y0), min(w,x1), min(h,y1)

    x0, y0, x1, y1 = expand_to_aspect(x0, y0, x1, y1)

    # Second pass: if a border capped us, re-try by expanding the other dimension
    bw, bh = (x1 - x0), (y1 - y0)
    cur_ar = bw / float(max(1, bh))
    if abs(cur_ar - target_ar) > 1e-3:
        if cur_ar < target_ar:
            # could not widen enough → try to grow height
            need_h = int(np.ceil(bw / target_ar))
            grow = max(0, need_h - bh)
            top_grow    = min(y0, grow // 2)
            bottom_grow = min(h - y1, grow - top_grow)
            y0 -= top_grow; y1 += bottom_grow
        else:
            # could not heighten enough → try to grow width
            need_w = int(np.ceil(target_ar * bh))
            grow = max(0, need_w - bw)
            left_grow  = min(x0, grow // 2)
            right_grow = min(w - x1, grow - left_grow)
            x0 -= left_grow; x1 += right_grow

    # Final clamp
    x0, y0 = max(0, int(x0)), max(0, int(y0))
    x1, y1 = min(w, int(x1)), min(h, int(y1))
    return [x0, y0, x1, y1]


# Alpha/white utilities for garment panel
WHITE_RGB = (255,255,255)
def flatten_alpha_to_white(img: Image.Image) -> Image.Image:
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        bg = Image.new("RGB", img.size, WHITE_RGB)
        bg.paste(img, mask=img.split()[-1])
        return bg
    return img.convert("RGB")

def _tight_bbox_nonwhite_or_opaque(img: Image.Image):
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        arr = np.asarray(img.convert("RGBA"))
        alpha = arr[...,3]
        fg = alpha > 0
    else:
        arr = np.asarray(img.convert("RGB"))
        fg = ~((arr[...,0]==255)&(arr[...,1]==255)&(arr[...,2]==255))
    if not np.any(fg): return None
    ys, xs = np.where(fg)
    x0, x1 = int(xs.min()), int(xs.max())+1
    y0, y1 = int(ys.min()), int(ys.max())+1
    return (x0,y0,x1,y1)

def crop_garment_keep_aspect(img: Image.Image) -> Image.Image:
    bbox = _tight_bbox_nonwhite_or_opaque(img)
    base = flatten_alpha_to_white(img)
    if bbox is None: return base
    full_bbox = (0,0,base.width,base.height)
    if bbox == full_bbox: return base
    return base.crop(bbox)

def to_centered_square(gar: Image.Image, fill=WHITE_RGB) -> Image.Image:
    w,h = gar.size; side = max(w,h)
    sq = Image.new("RGB", (side, side), fill)
    ox, oy = (side-w)//2, (side-h)//2
    sq.paste(gar, (ox,oy)); return sq

# --- DINOv3 garment localisation
from dino import (
    detect_garment_region_with_dinov3,
    heatmap_to_pil,
    draw_bbox,
    infer_garment_type_from_path,
    garment_position_hint,
)


In [None]:
# --- Visualisation helpers (restored) ---

import math
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageOps

def open_upright(path) -> Image.Image:
    # EXIF-aware loader (same as before)
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im)

def show_gallery(img_list, titles=None, cols=3, w=4):
    """
    Display PIL images in a flexible grid (identical behaviour to your original).
    Only renders if SHOW_VISUALS is True.
    """
    if not globals().get("SHOW_VISUALS", False):
        return

    n = len(img_list)
    rows = math.ceil(n / cols)
    plt.figure(figsize=(cols * w, rows * w))

    for i, img in enumerate(img_list):
        plt.subplot(rows, cols, i + 1)
        # Accept PIL, torch tensors or numpy arrays (4-D batch ⇒ pick first)
        if isinstance(img, np.ndarray) and img.ndim == 4:
            img = img[0]  # (B,H,W,C) → (H,W,C)
        # Torch tensors are printed via duck-typing check to avoid hard import
        if "Tensor" in str(type(img)):
            img = img.detach().cpu().permute(1, 2, 0).numpy()
        plt.imshow(img)
        if titles and i < len(titles):
            plt.title(titles[i])
        plt.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
# --- Paste-back (square with padding + feather) ---

import numpy as np
from PIL import Image


def build_square_feather_mask(size, feather_px: int):
    w, h = size
    xx = np.tile(np.arange(w)[None, :], (h, 1))
    yy = np.tile(np.arange(h)[:, None], (1, w))
    dist_edge = np.minimum.reduce([xx, w - 1 - xx, yy, h - 1 - yy]).astype(np.float32)
    alpha = np.ones((h, w), np.float32) * 255.0
    if feather_px > 0:
        band = dist_edge < float(feather_px)
        alpha[band] = (dist_edge[band] / float(feather_px)) * 255.0
    return Image.fromarray(alpha.clip(0, 255).astype(np.uint8), mode="L")


def make_square_crop_with_padding(full_img: Image.Image, bbox, fill=WHITE_RGB):
    """Return (square_crop, square_box) where square_box can extend outside the image.
    Missing areas are padded with `fill` so the crop we feed to NanoBanana is always 1:1.
    """
    x0, y0, x1, y1 = map(int, bbox)
    bw, bh = x1 - x0, y1 - y0
    side = max(bw, bh, 1)

    extra_w = side - bw
    extra_h = side - bh
    sq_x0 = x0 - extra_w // 2
    sq_y0 = y0 - extra_h // 2
    sq_x1 = sq_x0 + side
    sq_y1 = sq_y0 + side

    # Clamp the part we can actually copy from the original image
    img_w, img_h = full_img.size
    src_x0 = max(0, sq_x0)
    src_y0 = max(0, sq_y0)
    src_x1 = min(img_w, sq_x1)
    src_y1 = min(img_h, sq_y1)

    crop = full_img.crop((src_x0, src_y0, src_x1, src_y1))
    canvas = Image.new("RGB", (side, side), fill)
    paste_x = src_x0 - sq_x0
    paste_y = src_y0 - sq_y0
    canvas.paste(crop, (paste_x, paste_y))

    square_box = [sq_x0, sq_y0, sq_x1, sq_y1]
    return canvas, square_box


def paste_crop_back_square(
    full_img: Image.Image,
    edited_crop: Image.Image,
    crop_box,
    feather_px: int = 30,
):
    """Paste a (possibly padded) square crop back into the original frame.
    crop_box may extend beyond the source image; we only blend the overlapping region.
    """
    x0, y0, x1, y1 = map(int, crop_box)
    side = x1 - x0
    if side <= 0:
        return full_img

    edit_sq = edited_crop.resize((side, side), Image.Resampling.LANCZOS)
    mask_sq = build_square_feather_mask((side, side), feather_px)

    # Overlap region between crop_box and the base image
    img_w, img_h = full_img.size
    ix0, iy0 = max(0, x0), max(0, y0)
    ix1, iy1 = min(img_w, x1), min(img_h, y1)
    if ix0 >= ix1 or iy0 >= iy1:
        return full_img

    # Corresponding region inside the square crop
    cx0 = ix0 - x0
    cy0 = iy0 - y0
    cx1 = cx0 + (ix1 - ix0)
    cy1 = cy0 + (iy1 - iy0)

    edit_region = edit_sq.crop((cx0, cy0, cx1, cy1))
    mask_region = mask_sq.crop((cx0, cy0, cx1, cy1))
    base_region = full_img.crop((ix0, iy0, ix1, iy1))

    comp = Image.composite(edit_region, base_region, mask_region)
    full_img.paste(comp, (ix0, iy0))
    return full_img


In [None]:
# --- NanoBanana Pro try-on (Google Cloud GenAI) ---

import os
from google import genai
from google.genai import types
from PIL import Image

def _load_gemini_api_key():
    key = os.getenv("GEMINI_API_KEY") or os.getenv("GEMINI_APIKEY")
    try:
        from google.colab import userdata
        key = key or userdata.get("GEMINI_API_KEY")
    except Exception:
        pass
    if not key:
        raise ValueError("Set GEMINI_API_KEY in environment or Colab userdata.")
    return key

genai_client = genai.Client(api_key=_load_gemini_api_key())

def _extract_first_image(resp):
    import io
    parts = []
    if hasattr(resp, "parts"):
        parts.extend(resp.parts)
    for cand in getattr(resp, "candidates", []):
        parts.extend(getattr(getattr(cand, "content", None), "parts", []) or [])

    for part in parts:
        if isinstance(part, Image.Image):
            return part
        as_image = getattr(part, "as_image", None)
        if callable(as_image):
            img = as_image()
            if isinstance(img, Image.Image):
                return img
        inline = getattr(part, "inline_data", None)
        if inline and getattr(inline, "data", None):
            return Image.open(io.BytesIO(inline.data)).convert("RGB")
    raise ValueError("No image returned from NanoBanana Pro response.")


def run_nanobanana_tryon(model_image: Image.Image, garment_image: Image.Image,
                         *, aspect_ratio: str = GEN_ASPECT_RATIO,
                         image_size: str = GEN_IMAGE_SIZE,
                         prompt: str | None = None):
    prompt_text = prompt or TRYON_PROMPT
    resp = genai_client.models.generate_content(
        model=NANOBANANA_MODEL_ID,
        contents=[
            prompt_text,
            "Model image:",
            model_image.convert("RGB"),
            "Garment image:",
            garment_image.convert("RGB"),
        ],
        config=types.GenerateContentConfig(
            response_modalities=["IMAGE"],
            image_config=types.ImageConfig(
                aspect_ratio=aspect_ratio,
                image_size=image_size,
            ),
        ),
    )
    img = _extract_first_image(resp)
    return img.convert("RGB")

print("✅ NanoBanana Pro client ready.")


In [None]:
# Colab cell — output routing + metadata helpers
import os, json
from PIL import PngImagePlugin

def ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

ensure_dir(OUTPUT_DIR)

def build_output_filename(sku_name: str, angle_code: str, ext=".png", suffix="") -> str:
    # Examples: SS-12345-fr_rght or SS-12345-bc_lft
    angle_clean = _norm_angle(angle_code)
    suffix = suffix or ""
    return f"{sku_name}-{angle_clean}{suffix}{ext}"


import json, piexif
from PIL import Image

def save_png_with_metadata(img, out_path, details_payload=None, quality=95):
    if details_payload:
        # Encode JSON as UTF-8 with an ASCII prefix per EXIF spec for UserComment
        payload = json.dumps(details_payload, ensure_ascii=False).encode("utf-8")
        user_comment = b"ASCII\x00\x00\x00" + payload  # indicates undefined/UTF-8
        exif_dict = {"0th": {}, "Exif": {piexif.ExifIFD.UserComment: user_comment}, "1st": {}, "GPS": {}, "Interop": {}}
        exif_bytes = piexif.dump(exif_dict)
        img.save(out_path, format="PNG", exif=exif_bytes)
    else:
        img.save(out_path, format="PNG")

import json, piexif
from PIL import Image, ImageOps


In [None]:
# --- Google APIs: gspread + Drive upload (Operations sync removed) ---

import google.auth
SCOPES = ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/spreadsheets"]
creds, _ = google.auth.default(scopes=SCOPES)

import gspread
gs = gspread.authorize(creds)

from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
drive_svc = build("drive", "v3", credentials=creds)

FOLDER_MIME   = "application/vnd.google-apps.folder"
SHORTCUT_MIME = "application/vnd.google-apps.shortcut"
PATH_PREFIX   = "/content/drive/MyDrive/"

def _escape_name(name: str) -> str: return name.replace("'", r"'")

def _maybe_follow_shortcut(file_obj):
    if file_obj.get("mimeType") == SHORTCUT_MIME:
        sd = file_obj.get("shortcutDetails", {}) or {}
        return sd.get("targetId"), sd.get("targetMimeType")
    return file_obj.get("id"), file_obj.get("mimeType")

def _list_children(parent_id: str, q_extra: str, page_size: int = 1000):
    q = f"'{parent_id}' in parents and trashed = false"
    if q_extra: q += f" and ({q_extra})"
    resp = drive_svc.files().list(
        q=q, spaces="drive", pageSize=page_size,
        fields="files(id,name,mimeType,shortcutDetails)",
        includeItemsFromAllDrives=True, supportsAllDrives=True,
    ).execute()
    return resp.get("files", [])

def _find_folder_id(parent_id: str, name: str):
    files = _list_children(
        parent_id,
        q_extra=(
            f"name = '{_escape_name(name)}' and "
            f"(mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')"
        ),
    )

    # Direct folder match
    for f in files:
        if f["mimeType"] == FOLDER_MIME:
            return f["id"]

    # Shortcut to folder
    for f in files:
        if f["mimeType"] == SHORTCUT_MIME:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME:
                return tid

    # Fallback: match by case-insensitive name
    files = _list_children(
        parent_id,
        q_extra=(
            f"(mimeType = '{FOLDER_MIME}' or mimeType = '{SHORTCUT_MIME}')"
        ),
    )
    needle = name.strip().casefold()

    for f in files:
        if f.get("name", "").strip().casefold() == needle:
            tid, tmime = _maybe_follow_shortcut(f)
            if tmime == FOLDER_MIME:
                return tid

    return None


def _resolve_parent_id_and_filename_from_colab_path(colab_path: str):
    if not colab_path.startswith(PATH_PREFIX):
        raise ValueError(
            f"This helper supports only '{PATH_PREFIX}...'. Got: {colab_path}"
        )

    parts = colab_path[len(PATH_PREFIX):].strip("/").split("/")
    if not parts:
        raise ValueError("Path must include a file name.")

    parent_id = "root"

    for part in parts[:-1]:
        next_id = _find_folder_id(parent_id, part)
        if not next_id:
            raise FileNotFoundError(f"Folder not found in path: '{part}'")
        parent_id = next_id

    desired_name = parts[-1]
    return parent_id, desired_name


def upload_to_drive_folder(
    local_path: str,
    parent_folder_id: str,
    desired_name: str | None = None
):
    media = MediaFileUpload(local_path, resumable=True)
    body = {
        "name": desired_name or os.path.basename(local_path),
        "parents": [parent_folder_id],
    }

    file = (
        drive_svc.files()
        .create(
            body=body,
            media_body=media,
            fields="id, webViewLink, name, parents",
            supportsAllDrives=True,
        )
        .execute()
    )

    drive_svc.permissions().create(
        fileId=file["id"],
        body={"type": "anyone", "role": "reader"},
        fields="id",
        supportsAllDrives=True,
    ).execute()

    return file


def upload_file_and_append_to_sheet(
    local_path: str,
    target_colab_path: str,
    sku_name: str,
    angle: str,
    spreadsheet_id: str | None,
    worksheet_name: str | None,
):
    parent_id, desired_name = _resolve_parent_id_and_filename_from_colab_path(
        target_colab_path
    )

    uploaded = upload_to_drive_folder(local_path, parent_id, desired_name)
    file_id = uploaded["id"]

    file_url = (
        uploaded.get("webViewLink")
        or f"https://drive.google.com/file/d/{file_id}/view?usp=sharing"
    )
    folder_url = f"https://drive.google.com/drive/folders/{parent_id}"

    if spreadsheet_id and worksheet_name:
        ts = datetime.now(pytz.timezone(TIMEZONE)).strftime("%m-%d %H:%M:%S")
        uid = str(uuid.uuid4())

        sh = gs.open_by_key(spreadsheet_id)
        ws = sh.worksheet(worksheet_name)

        sku_cell = f'=HYPERLINK("{folder_url}"; "{sku_name}")'

        ws.append_row(
            [
                sku_cell,
                angle,
                ts,
                file_url,
                uid,
                "Girls need to check",
                OPERATOR,
            ],
            value_input_option="USER_ENTERED",
        )

    return {"file_url": file_url}


# BATCH HELPERS

In [None]:
# --- Batch processor  ---


def process_one_garment_folder(folder_path: str, allowed_angles=None):
    allowed_outputs_set = {_norm_angle(a) for a in (allowed_angles or [])}
    allowed_outputs = sorted(allowed_outputs_set)
    allowed_sources = expand_as_list(allowed_angles) if allowed_angles else None

    base_subcat_dir = resolve_base_mask_dir(folder_path)
    if not base_subcat_dir:
        print(f"⚠️ Cannot resolve base dir for SKU: {folder_path}")
    files_sorted = sorted(os.listdir(folder_path))

    garment_type = infer_garment_type_from_path(folder_path) or "garment"

    worklist = []
    for file in files_sorted:
        low = file.lower()

        if allowed_sources and not any(low.startswith(src) for src in allowed_sources):
            continue
        if SKIP_FILENAME_TOKENS and any(tok in low for tok in SKIP_FILENAME_TOKENS):
            continue
        if REQUIRE_CUT_IN_FILENAME and ("cut" not in low):
            continue
        if not _valid_ext(file):
            continue

        matching_sources = [src for src in (allowed_sources or []) if low.startswith(src)] if allowed_sources else [_norm_angle(os.path.splitext(file)[0])]
        if allowed_sources and not matching_sources:
            continue

        target_candidates = set()
        for src in matching_sources:
            norm_src = _norm_angle(src)
            for target in (allowed_outputs or [norm_src]):
                fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
                fam.add(_norm_angle(target))
                if norm_src in fam:
                    target_candidates.add(_norm_angle(target))

        if allowed_outputs and not target_candidates:
            continue

        for target_angle in sorted(target_candidates):
            base_img_path = _find_image_with_stem_and_suffix(base_subcat_dir, target_angle)
            if not base_img_path:
                print(f"⚠️ Missing BASE for target '{target_angle}' → skipping {file}")
                continue

            worklist.append((file, target_angle, base_img_path))

    sku_name = os.path.basename(folder_path)
    if allowed_sources:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate (outputs={sorted(list(allowed_outputs_set))}, sources={sorted(list(set(allowed_sources)))})")
    else:
        print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate")
    print(f"   Garment type guess (from path): {garment_type}")

    if not worklist:
        return

    for idx, (file, target_angle, base_img_path) in enumerate(worklist, start=1):
        print(f"   {idx:>3}/{len(worklist):<3}  {file}  | base='{target_angle}' via DINOv3")
        garment_path = os.path.join(folder_path, file)

        sku_name = os.path.basename(folder_path)
        angle_code = _norm_angle(target_angle)
        main_suffix = "_onlymain" if SECONDARY_GARMENT else ""
        final_out_name = build_output_filename(sku_name, angle_code, ext=".png", suffix="_both") if SECONDARY_GARMENT else None
        main_out_name = build_output_filename(sku_name, angle_code, ext=".png", suffix=main_suffix)
        skip_name = final_out_name or main_out_name
        dest_check = os.path.join(OUTPUT_DIR, skip_name)

        if drive_file_exists_any_ext_at_colab_path(dest_check):
            stage_label = "secondary" if final_out_name else "main"
            print(f"      ⏭️  Skip: {skip_name} already exists in {OUTPUT_DIR} ({stage_label} target)")
            continue

        try:
            garment_img = flatten_alpha_to_white(open_upright(garment_path))
            base_full   = Image.open(base_img_path).convert("RGB")

            def perform_tryon_stage(stage_base_full, stage_garment_img, suffix, stage_label):
                bbox, heatmap = detect_garment_region_with_dinov3(
                    stage_base_full,
                    garment_hint=garment_type,
                    score_percentile=DINO_BOX_PERCENTILE,
                    padding=CROP_PADDING,
                )
                bbox_preview = draw_bbox(stage_base_full, bbox, color=(255, 99, 71), width=12)
                heat_viz = heatmap_to_pil(heatmap)

                square_crop, square_box = make_square_crop_with_padding(stage_base_full, bbox, fill=WHITE_RGB)

                show_gallery(
                    [bbox_preview, heat_viz, square_crop, stage_garment_img],
                    [f"DINOv3 region [{stage_label}]", "DINO heatmap", "Square crop (padded)", "Garment (white BG)"]
                )

                tryon_gen = run_nanobanana_tryon(
                    model_image=square_crop,
                    garment_image=stage_garment_img,
                    aspect_ratio="1:1",
                    image_size=GEN_IMAGE_SIZE,
                    prompt=TRYON_PROMPT,
                )

                # Ensure the generated image matches the square crop size (even if model made it larger)
                side = square_crop.size[0]
                tryon_sq = tryon_gen.resize((side, side), Image.Resampling.LANCZOS)

                final_img_local = paste_crop_back_square(
                    full_img   = stage_base_full.copy(),
                    edited_crop= tryon_sq,
                    crop_box   = square_box,
                    feather_px = MASK_FEATHER_PX
                )
                show_gallery([tryon_gen, final_img_local], [f"Generated try-on [{stage_label}]", f"Final paste-back [{stage_label}]"])

                out_name_local = build_output_filename(sku_name, angle_code, ext=".png", suffix=suffix)
                tmp_path_local = os.path.join("/tmp", out_name_local)

                save_png_with_metadata(final_img_local, tmp_path_local, details_payload=None)
                target_path_for_drive = os.path.join(OUTPUT_DIR, out_name_local)

                info = upload_file_and_append_to_sheet(
                    local_path       = tmp_path_local,
                    target_colab_path= target_path_for_drive,
                    sku_name         = sku_name,
                    angle            = angle_code,
                    spreadsheet_id   = SPREADSHEET_ID,
                    worksheet_name   = GEN_LOG_SHEET,
                )
                print(f"      ✅ Uploaded [{stage_label}] → {info['file_url']}")
                return final_img_local

            main_result = perform_tryon_stage(
                stage_base_full=base_full,
                stage_garment_img=garment_img,
                suffix=main_suffix,
                stage_label="main",
            )

            if SECONDARY_GARMENT:
                sec_garment_path = find_secondary_garment_path(folder_path, file)
                if not sec_garment_path:
                    print(f"      ⚠️ Secondary garment missing for '{target_angle}' → kept main-only output.")
                    continue

                sec_garment_img = flatten_alpha_to_white(open_upright(sec_garment_path))

                perform_tryon_stage(
                    stage_base_full=main_result,
                    stage_garment_img=sec_garment_img,
                    suffix="_both",
                    stage_label="secondary",
                )

        except Exception as e:
            print(f"      ❌ Error: {e}")


In [None]:
# --- Sheet-driven angle selection + runners ---

def build_sku_folder_index(garments_root: str):
    return { _norm_sku(os.path.basename(p)) : p for p in iter_sku_folders(garments_root) }

def run_list():
    targets, unmatched = resolve_targets(SKU_CSV, GARMENTS_ROOT)
    if not targets:
        print("⚠️ No matching SKU folders found.")
        if unmatched: print("Unmatched:", ", ".join(unmatched))
        return
    print(f"➡️  Will process {len(targets)} SKU(s).")
    for i, p in enumerate(targets, start=1):
        name = os.path.basename(p)
        print(f"\nSKU {i}/{len(targets)} ▶️  {name}")
        try:
            process_one_garment_folder(p, allowed_angles=ALLOWED_BASES)
            print(f"✅ Finished: {name}")
        except Exception as e:
            print(f"❌ Error in {name}: {e}")
    if unmatched:
        print("\nℹ️  Unmatched identifiers:")
        for u in unmatched: print("   -", u)
    print("\n🏁 List run complete.")


# DISPATCH

In [None]:
# Dispatch
run_list()


#UNASSIGN

In [None]:
from google.colab import runtime
runtime.unassign()