# VTON v2 — Nanobanana Pro (Google Cloud API)

This notebook rebuilds the v1 workflow using Google's nanobanana pro image model. It keeps the SKU + angle list flow, Google Drive / gspread ops sync, source search, and the crop/pasteback logic with square 1:1 crops. Masks, LoRA, GPT scoring, and HYPIR are removed.


## Setup
Install the Google Cloud generative AI client and authenticate for Drive/Sheets. The GEMINI_API_KEY should be stored in Colab `userdata`.


In [None]:
# --- Setup: Google auth + Drive + deps ---
!pip install -q -U "google-genai>=1.40.0" gspread google-auth google-auth-oauthlib google-auth-httplib2 google-api-python-client

import os, sys, textwrap, json, pathlib, typing
from google.colab import auth, drive, userdata

auth.authenticate_user()
drive.mount('/content/drive')

GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY or ''


## Select angles
Toggle which base angles should be produced for each SKU.


In [None]:
fr_lft = True #@param {type:"boolean"}
fr_rght = True #@param {type:"boolean"}
fr_cl = True #@param {type:"boolean"}
bc_lft = True #@param {type:"boolean"}
bc_rght = True #@param {type:"boolean"}
lft = True #@param {type:"boolean"}
rght = True #@param {type:"boolean"}
bc_ = True #@param {type:"boolean"}
fr_ = True #@param {type:"boolean"}
fr_cl_btm = False #@param {type:"boolean"}
fr_cl_tp = False #@param {type:"boolean"}

names = ["fr_lft","fr_rght","fr_cl","bc_lft","bc_rght","lft","rght","bc_","fr_","fr_cl_btm","fr_cl_tp"]
ALLOWED_BASES = [n for n in names if locals()[n]]


## Config
Only SKU list mode is kept. Adjust roots for garments, base/model photos, and output destinations.


In [None]:
# --- Unified CONFIG ---
RUN_MODE = "sku_list"
SKU_CSV = "28748, 28920"  #@param {type:"string"}

GARMENTS_ROOT = "/content/drive/MyDrive/Dazzl/Garments"  #@param {type:"string"}
BASE_ROOT     = "/content/drive/MyDrive/Dazzl/BasePhotos"  #@param {type:"string"}
OUTPUT_DIR    = "/content/drive/MyDrive/Dazzl/vton_v2_outputs"  #@param {type:"string"}

# Google Sheets / Drive
GOOGLE_SHEET_ID = ""  #@param {type:"string"}
OPERATIONS_SHEET_NAME = "operations"
DRIVE_UPLOAD_PARENT_ID = ""  #@param {type:"string"}

# File filters + cropping
VALID_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp"]
IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints", "thumbnails"}
CROP_MARGIN = 200   # pixels away from garment when cropping
TARGET_ASPECT = "1:1"
IMAGE_SIZE = "4K"

# Prompt used for nanobanana pro try-on
TRYON_PROMPT = textwrap.dedent("""\
You are an expert virtual try-on AI. You will be given a 'model image' and a 'garment image'. Your task is to create a new photorealistic image where the person from the 'model image' is wearing the clothing from the 'garment image'.

Crucial Rules:
1.  Complete Garment Replacement: You MUST completely REMOVE and REPLACE the clothing item worn by the person in the 'model image' with the new garment. No part of the original clothing (e.g., collars, sleeves, patterns) should be visible in the final image.
2.  Preserve the Model: The person's face, hair, body shape, and pose from the 'model image' MUST remain unchanged.
3.  Preserve the Background: The entire background from the 'model image' MUST be preserved perfectly.
4.  Apply the Garment: Realistically fit the new garment onto the person. It should adapt to their pose with natural folds, shadows, and lighting consistent with the original scene.
5.  Output: Return ONLY the final, edited image. Do not include any text.""")


## Utilities
SKU normalization, angle helpers, source searching, and crop helpers.


In [None]:
import os, re, fnmatch, math, uuid, pytz, random, gc, tempfile, traceback
from datetime import datetime
import numpy as np
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

# Normalize SKU list to SS-##### format
def normalize_sku_list(sku_csv: str) -> str:
    skus = []
    for raw in sku_csv.split(','):
        sku = raw.strip().upper()
        match = re.search(r'(\d+)', sku)
        if match:
            skus.append(f"SS-{match.group(1)}")
    return ", ".join(skus)

SKU_CSV = normalize_sku_list(SKU_CSV)

# Parsing helpers
def _parse_csv_list(s):
    return [x.strip().casefold() for x in (s or "").split(',') if x.strip()]

def _norm_sku(s):
    if s is None: return ""
    s = str(s).replace(" "," ")
    s = " ".join(s.split())
    return s.casefold()

def _norm_angle(s):
    s = (s or "").strip().lower()
    return s.strip("_ ").replace("-", "_")

ANGLE_ALIASES = {
    "fr_cl": ["fr", "fr_"],
}

def expand_allowed_angles(angles):
    expanded = set()
    for a in (angles or []):
        a_norm = _norm_angle(a)
        expanded.add(a_norm)
        for alt in ANGLE_ALIASES.get(a_norm, []):
            expanded.add(_norm_angle(alt))
    return expanded

IGNORE_DIRS = {d.lower() for d in IGNORE_DIRS}

def _is_image_file(name: str) -> bool:
    return name.lower().endswith(tuple(e.lower() for e in VALID_EXTENSIONS))

def _is_sku_folder(path: str) -> bool:
    base = os.path.basename(os.path.normpath(path)).lower()
    if base in IGNORE_DIRS:
        return False
    try:
        for f in os.listdir(path):
            if os.path.isfile(os.path.join(path, f)) and _is_image_file(f):
                return True
    except Exception:
        return False
    return False

# Walkers
def iter_sku_folders(root: str):
    for dirpath, dirnames, filenames in os.walk(root):
        dirnames[:] = [d for d in dirnames if d.lower() not in IGNORE_DIRS]
        if any(_is_image_file(f) for f in filenames):
            yield dirpath

def resolve_targets(idents_csv: str, garments_root: str):
    idents = [s.strip() for s in idents_csv.replace(';', ',').split(',') if s.strip()]
    if not idents: return [], []

    all_sku_dirs = list(iter_sku_folders(garments_root))
    rel_map = {p: os.path.relpath(p, garments_root) for p in all_sku_dirs}
    base_map = {p: os.path.basename(p) for p in all_sku_dirs}

    # index every folder by basename so we can match SKU names even when
    # the immediate SKU folder does not contain images (only nested angle folders).
    name_index = {}
    for dirpath, dirnames, filenames in os.walk(garments_root):
        dirnames[:] = [d for d in dirnames if d.lower() not in IGNORE_DIRS]
        name_index.setdefault(os.path.basename(dirpath).lower(), []).append(dirpath)

    seen, out, unmatched = set(), [], []

    def add_path(p):
        ap = os.path.abspath(p)
        if os.path.isdir(ap):
            if _is_sku_folder(ap):
                if ap not in seen:
                    seen.add(ap); out.append(ap)
            else:
                for leaf in iter_sku_folders(ap):
                    a = os.path.abspath(leaf)
                    if a not in seen:
                        seen.add(a); out.append(a)
        else:
            unmatched.append(p)

    for ident in idents:
        ident_lower = ident.lower()
        if os.path.isabs(ident):
            add_path(ident)
            continue

        matched_any = False
        glob_paths = fnmatch.filter(rel_map.values(), ident) + fnmatch.filter(base_map.values(), ident)
        for rel in glob_paths:
            add_path(os.path.join(garments_root, rel)); matched_any = True

        for dirpath in name_index.get(ident_lower, []):
            add_path(dirpath); matched_any = True

        if not matched_any:
            unmatched.append(ident)

    return out, unmatched

# Image helpers
WHITE_RGB = (255,255,255)

def flatten_alpha_to_white(img: Image.Image) -> Image.Image:
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        bg = Image.new("RGB", img.size, WHITE_RGB)
        bg.paste(img, mask=img.split()[-1])
        return bg
    return img.convert("RGB")


def _tight_bbox_nonwhite_or_opaque(img: Image.Image):
    if img.mode in ("RGBA","LA") or ("transparency" in img.info):
        arr = np.asarray(img.convert("RGBA"))
        alpha = arr[...,3]
        fg = alpha > 0
    else:
        arr = np.asarray(img.convert("RGB"))
        fg = ~((arr[...,0]==255)&(arr[...,1]==255)&(arr[...,2]==255))
    if not np.any(fg): return None
    ys, xs = np.where(fg)
    x0, x1 = int(xs.min()), int(xs.max())+1
    y0, y1 = int(ys.min()), int(ys.max())+1
    return (x0,y0,x1,y1)


def crop_square_with_margin(img: Image.Image, margin: int = CROP_MARGIN):
    base = flatten_alpha_to_white(img)
    bbox = _tight_bbox_nonwhite_or_opaque(img)
    if bbox is None:
        w,h = base.size
        side = min(w,h)
        box = ((w-side)//2, (h-side)//2, (w+side)//2, (h+side)//2)
        return base.crop(box), box
    x0,y0,x1,y1 = bbox
    x0 = max(0, x0 - margin); y0 = max(0, y0 - margin)
    x1 = min(base.width, x1 + margin); y1 = min(base.height, y1 + margin)
    bw, bh = x1 - x0, y1 - y0
    side = max(bw, bh)
    cx, cy = x0 + bw//2, y0 + bh//2
    half = side//2
    x0 = max(0, cx - half); y0 = max(0, cy - half)
    x1 = min(base.width, x0 + side); y1 = min(base.height, y0 + side)
    x0 = max(0, x1 - side); y0 = max(0, y1 - side)
    box = (int(x0), int(y0), int(x1), int(y1))
    return base.crop(box), box


def to_centered_square(gar: Image.Image, fill=WHITE_RGB) -> Image.Image:
    w,h = gar.size; side = max(w,h)
    sq = Image.new("RGB", (side, side), fill)
    ox, oy = (side-w)//2, (side-h)//2
    sq.paste(gar, (ox,oy))
    return sq


def expand_as_list(angles):
    exp = list(expand_allowed_angles(angles))
    exp = [_norm_angle(a) for a in exp]
    exp.sort(key=len, reverse=True)
    return exp


def pick_target_angle(source_angle: str, allowed_outputs: set) -> str | None:
    s = _norm_angle(source_angle)
    for target in allowed_outputs:
        fam = {_norm_angle(x) for x in expand_allowed_angles([target])}
        if s in fam:
            return _norm_angle(target)
    return None


def open_upright(path) -> Image.Image:
    # EXIF-aware loader
    with Image.open(path) as im:
        return ImageOps.exif_transpose(im.convert("RGB"))


def show_gallery(img_list, titles=None, cols=3, w=4):
    """
    Display PIL images in a flexible grid (same style as v1).
    """
    titles = titles or [None]*len(img_list)
    rows = math.ceil(len(img_list) / cols)
    plt.figure(figsize=(cols*w, rows*w))
    for idx, (img, title) in enumerate(zip(img_list, titles)):
        plt.subplot(rows, cols, idx+1)
        plt.imshow(ImageOps.exif_transpose(img))
        if title:
            plt.title(title)
        plt.axis('off')
    plt.tight_layout()
    plt.show()


## Paste-back helper
Square crops with 200px margin are pasted back without masks.


In [None]:
from PIL import ImageChops

def paste_crop_back(full_img: Image.Image, edited_crop: Image.Image, crop_box):
    x0,y0,x1,y1 = crop_box
    target_w, target_h = x1 - x0, y1 - y0
    if edited_crop.size != (target_w, target_h):
        edited_crop = edited_crop.resize((target_w, target_h), resample=Image.Resampling.LANCZOS)
    out = full_img.copy()
    out.paste(edited_crop, (x0, y0))
    return out


## Google APIs
Authorize gspread and Drive upload + operations logging.


In [None]:
import google.auth
SCOPES = ["https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/spreadsheets"]
creds, _ = google.auth.default(scopes=SCOPES)

import gspread
gs = gspread.authorize(creds)

from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
drive_svc = build("drive", "v3", credentials=creds)

from pathlib import Path

def ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

ensure_dir(OUTPUT_DIR)

def build_output_filename(sku_name: str, angle_code: str, ext=".png") -> str:
    angle_clean = _norm_angle(angle_code)
    return f"{sku_name}-{angle_clean}{ext}"

def upload_to_drive(local_path: str, parent_id: str | None = None):
    if not parent_id:
        return None
    file_metadata = {"name": os.path.basename(local_path)}
    if parent_id:
        file_metadata["parents"] = [parent_id]
    media = MediaFileUpload(local_path, mimetype="image/png")
    created = drive_svc.files().create(body=file_metadata, media_body=media, fields="id,webViewLink").execute()
    return created

def log_operation(row):
    if not GOOGLE_SHEET_ID:
        return None
    sh = gs.open_by_key(GOOGLE_SHEET_ID)
    try:
        ws = sh.worksheet(OPERATIONS_SHEET_NAME)
    except Exception:
        ws = sh.add_worksheet(title=OPERATIONS_SHEET_NAME, rows=1000, cols=10)
    ws.append_row(row, value_input_option="USER_ENTERED")


## Nanobanana pro client and generation
Use the Google Cloud API with enforced 1:1 aspect ratio and 4K resolution.


In [None]:
from google import genai
from google.genai import types
from PIL import Image as PILImage

client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))

PRO_MODEL_ID = "gemini-3-pro-image-preview"

def run_nanobanana_tryon(model_img: PILImage.Image, garment_img: PILImage.Image, *, extra_prompt: str = "") -> PILImage.Image:
    contents = [TRYON_PROMPT]
    if extra_prompt:
        contents.append(extra_prompt)
    contents.extend([model_img, garment_img])
    response = client.models.generate_content(
        model=PRO_MODEL_ID,
        contents=contents,
        config=types.GenerateContentConfig(
            response_modalities=["IMAGE"],
            image_config=types.ImageConfig(
                aspect_ratio=TARGET_ASPECT,
                image_size=IMAGE_SIZE
            )
        )
    )
    for part in response.parts:
        as_image = part.as_image()
        if as_image:
            return as_image
    raise RuntimeError("No image returned from nanobanana pro")


## Batch helpers
Search for garment/base sources, crop garments to squares with 200px margin, generate 1:1 try-on results, paste back if needed, and sync to Drive/Sheets.


In [None]:
def list_images_for_angle(folder: str, angle_code: str):
    angle_norm = _norm_angle(angle_code)
    files = []
    for f in os.listdir(folder):
        if not _is_image_file(f):
            continue
        name = os.path.splitext(f)[0].lower()
        if angle_norm in name.split('_') or name.startswith(angle_norm):
            files.append(os.path.join(folder, f))
    return sorted(files)

def find_first_image(folder: str, angle_code: str):
    imgs = list_images_for_angle(folder, angle_code)
    return imgs[0] if imgs else None

def load_image(path: str):
    return open_upright(path)

def process_one_garment_folder(folder_path: str, allowed_angles=None):
    allowed_outputs = {_norm_angle(a) for a in (allowed_angles or [])}
    sku_name = os.path.basename(folder_path)
    base_folder = os.path.join(BASE_ROOT, sku_name)
    if not os.path.isdir(base_folder):
        print(f"⚠️  {sku_name}: no base folder for {base_folder}")
        return []

    angles_to_process = expand_as_list(allowed_outputs) if allowed_outputs else expand_as_list(ALLOWED_BASES)
    worklist = []
    for ang in angles_to_process:
        garment_path = find_first_image(folder_path, ang)
        base_path = find_first_image(base_folder, ang)
        if not garment_path or not base_path:
            print(f"⏭️  {sku_name} {ang}: garment or base missing (garment={bool(garment_path)}, base={bool(base_path)})")
            continue
        worklist.append((ang, garment_path, base_path))

    print(f"▶️  {sku_name}: {len(worklist)} image(s) to generate (angles={angles_to_process})")
    if not worklist:
        return []

    results = []
    for idx, (ang, garment_path, base_path) in enumerate(worklist, start=1):
        print(f"   {idx:>3}/{len(worklist):<3} {ang} → garment={os.path.basename(garment_path)} | base={os.path.basename(base_path)}")
        garment_img_raw = load_image(garment_path)
        garment_crop, crop_box = crop_square_with_margin(garment_img_raw, CROP_MARGIN)
        garment_sq = to_centered_square(garment_crop)
        model_img = load_image(base_path)

        show_gallery(
            [garment_img_raw, garment_crop, garment_sq],
            ["Source garment", f"Crop (200px margin) {crop_box}", "Square garment for 1:1"]
        )
        show_gallery(
            [model_img, garment_sq],
            [f"Base/model [{ang}]", "Garment input to nanobanana pro"],
            cols=2,
            w=5
        )

        try:
            print("      🚀 Sending to nanobanana pro…")
            generated = run_nanobanana_tryon(model_img, garment_sq)
        except Exception as ex:
            print(f"      ❌ {sku_name} {ang}: generation failed → {ex}")
            continue

        show_gallery([generated], ["Nanobanana pro output"], cols=1, w=6)

        fname = build_output_filename(sku_name, ang)
        out_path = os.path.join(OUTPUT_DIR, fname)
        ensure_dir(os.path.dirname(out_path))
        generated.save(out_path)
        print(f"      💾 saved → {out_path}")

        drive_info = upload_to_drive(out_path, DRIVE_UPLOAD_PARENT_ID)
        if drive_info:
            print(f"      ☁️ uploaded → {drive_info.get('webViewLink')}")
        log_operation([sku_name, ang, os.path.relpath(out_path, OUTPUT_DIR), datetime.utcnow().isoformat(), (drive_info or {}).get('webViewLink', "")])
        results.append(out_path)
    return results

def run_list():
    targets, unmatched = resolve_targets(SKU_CSV, GARMENTS_ROOT)
    if unmatched:
        print("Unmatched identifiers:", unmatched)
    print(f"Resolved {len(targets)} garment folders")
    for t in targets:
        process_one_garment_folder(t, allowed_angles=ALLOWED_BASES)


## Dispatch
Trigger batch processing for the normalized SKU list.


In [None]:
if RUN_MODE != "sku_list":
    raise ValueError("Only sku_list mode is supported in v2")
run_list()
