In [None]:
!git clone https://github.com/Zheng-Chong/CatVTON
%cd CatVTON


In [None]:
!pip install -q \
 "torch>=2.1.0,<2.5.0" \
 "torchvision>=0.16.0,<0.20.0" \
 accelerate>=0.30.0 \
 diffusers>=0.27.0 \
 matplotlib>=3.8.0 \
 numpy>=1.25.0 \
 opencv_python>=4.9.0 \
 pillow>=10.0.0 \
 PyYAML>=6.0 \
 scipy>=1.11.0 \
 scikit-image>=0.22.0 \
 tqdm>=4.66.0 \
 transformers>=4.40.0 \
 fvcore>=0.1.5 \
 cloudpickle>=3.0.0 \
 omegaconf>=2.3.0 \
 pycocotools>=2.0.7 \
 av>=12.0.0 \
 gradio>=4.25.0 \
 peft>=0.11.0 \
 huggingface_hub>=0.20.0

In [None]:
!pip install pyngrok
from pyngrok import ngrok

# Check the output of the previous cell to confirm the Gradio port
gradio_port = 7860  # Or the port your Gradio app is running on

# Optional: Set your ngrok authtoken for more stable tunnels
ngrok.set_auth_token("YOUR_NGROQ_API_KEY")

try:
    public_url = ngrok.connect(gradio_port)
    print(f"Gradio interface is available at: {public_url}")
except Exception as e:
    print(f"Error connecting to ngrok: {e}")

In [None]:
import os
from huggingface_hub import login

# --- Authentication ---
hf_token = "YOUR_HF_API_KEY" # Replace with your actual token

# Login using the token
try:
    login(token=hf_token)
    print("Hugging Face login successful!")
except Exception as e:
    print(f"Hugging Face login failed: {e}")
    # Handle the error appropriately, maybe exit or raise
    exit()

In [None]:
# --- Imports ---
import os
import gradio as gr
from datetime import datetime
import numpy as np
import torch
from diffusers.image_processor import VaeImageProcessor
from huggingface_hub import snapshot_download
from PIL import Image
import traceback # For printing full errors

# --- Check and Import Local Files ---
# This script assumes 'model' and 'utils' directories are present
# in the same location as this script in your Kaggle environment.
try:
    from model.cloth_masker import AutoMasker, vis_mask
    from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
    from utils import resize_and_crop, resize_and_padding
    print("Successfully imported local modules: cloth_masker, pipeline_flux_tryon, utils")
except ModuleNotFoundError as e:
    print(f"ERROR: Failed to import local modules: {e}")
    print("Please ensure the 'model' and 'utils' directories containing the necessary .py files")
    print("are uploaded to the same directory as this script in your Kaggle environment.")
    # Optional: Exit if local files are crucial and missing
    # exit()
except Exception as e:
    print(f"An unexpected error occurred during local module import: {e}")
    # exit()

# --- Configuration (Replaces argparse) ---
class AppConfig:
    # Correct Hugging Face repo ID (verify if FLUX.1-Fill-dev exists under this org)
    base_model_path = "black-forest-labs/FLUX.1-Fill-dev"
    # CatVTON repo for LoRA and auxiliary models
    resume_path = "zhengchong/CatVTON"
    output_dir = "resource/demo/output"
    # Use FP16 for P100/T4 compatibility and reduced memory
    mixed_precision = "fp16"
    allow_tf32 = False # TF32 not relevant/supported on P100
    width = 768 # Target resolution width
    height = 1024 # Target resolution height
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if mixed_precision == "fp16" else (torch.bfloat16 if mixed_precision == "bf16" else torch.float32)

args = AppConfig() # Create an object to hold config values

print("--- Configuration ---")
print(f"Base Model: {args.base_model_path}")
print(f"Resume Path: {args.resume_path}")
print(f"Target Device: {args.device}")
print(f"Target dtype: {args.torch_dtype}")
print(f"Target Resolution: {args.width}x{args.height}")
print("--------------------")

# --- Global Variables for Models (Load Once) ---
pipeline_flux = None
automasker = None
mask_processor = None

# --- Utility Functions ---
def image_grid(imgs, rows, cols):
    if not imgs: return None # Handle empty list
    if len(imgs) != rows * cols:
         print(f"Warning: image_grid expected {rows*cols} images, got {len(imgs)}")
         # Handle mismatch, e.g., create a blank grid or adjust rows/cols
         # For now, just return None or the first image
         return imgs[0] if imgs else Image.new("RGB", (100,100))

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        # Ensure img is a PIL Image object
        if not isinstance(img, Image.Image):
             print(f"Warning: Item {i} is not a PIL image, skipping.")
             continue # Or create a placeholder
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

# --- Model Loading Function (Called Once) ---
def load_all_models_and_tools():
    global pipeline_flux, automasker, mask_processor
    print("--- Starting Model Loading ---")
    try:
        # Download CatVTON specific files (LoRA, DensePose, SCHP)
        print(f"Downloading CatVTON resume files from: {args.resume_path}")
        repo_path = snapshot_download(repo_id=args.resume_path)
        print(f"CatVTON files downloaded to: {repo_path}")

        # Load the base FLUX pipeline
        # Using low_cpu_mem_usage to reduce RAM spike during initial loading
        print(f"Loading FLUX pipeline from: {args.base_model_path}")
        pipeline_flux = FluxTryOnPipeline.from_pretrained(
            args.base_model_path,
            torch_dtype=args.torch_dtype,
            low_cpu_mem_usage=True # Reduce CPU RAM usage
        )
        print("FLUX pipeline loaded.")

        # Load the CatVTON LoRA weights specific for FLUX
        print("Loading CatVTON LoRA weights for FLUX...")
        lora_path = os.path.join(repo_path, "flux-lora")
        lora_weight_file = os.path.join(lora_path, 'pytorch_lora_weights.safetensors')
        if os.path.exists(lora_path) and os.path.exists(lora_weight_file):
             pipeline_flux.load_lora_weights(
                 lora_path,
                 weight_name='pytorch_lora_weights.safetensors'
             )
             print("LoRA weights loaded.")
        else:
             print(f"Warning: LoRA path or weight file not found at {lora_path}. Skipping LoRA.")


        # Move pipeline to GPU *before* enabling offload
        print(f"Moving pipeline to {args.device} with dtype {args.torch_dtype}...")
        pipeline_flux.to(args.device) # No dtype here, already loaded with it
        print("Pipeline moved to device.")

        # Enable CPU offloading *after* moving to device and loading LoRA
        print("Enabling CPU model offload...")
        pipeline_flux.enable_model_cpu_offload()
        print("CPU offload enabled.")

        # Initialize AutoMasker
        print("Initializing AutoMasker...")
        mask_processor = VaeImageProcessor(
            vae_scale_factor=pipeline_flux.vae_scale_factor if hasattr(pipeline_flux, 'vae_scale_factor') else 8, # Use pipeline's scale factor if available
            do_normalize=False,
            do_binarize=True,
            do_convert_grayscale=True
        )
        densepose_path = os.path.join(repo_path, "DensePose")
        schp_path = os.path.join(repo_path, "SCHP")
        if not os.path.exists(densepose_path): print(f"Warning: DensePose path not found at {densepose_path}")
        if not os.path.exists(schp_path): print(f"Warning: SCHP path not found at {schp_path}")

        automasker = AutoMasker(
            densepose_ckpt=densepose_path,
            schp_ckpt=schp_path,
            device=args.device # Load masker models to the GPU
        )
        print("AutoMasker initialized.")
        print("--- Model Loading Complete ---")

    except ImportError as e:
         print(f"ERROR during loading: Failed to import a required module: {e}")
         print("This likely means 'model' or 'utils' directories are missing or incorrect.")
         raise # Re-raise the error to stop execution
    except Exception as e:
        print(f"ERROR during loading: An unexpected error occurred.")
        traceback.print_exc() # Print the full error details
        raise # Re-raise the error

# --- Gradio Submit Function ---
def submit_function_flux(
    person_image_input, # Renamed to avoid conflict with variable inside
    cloth_image_input, # Renamed
    cloth_type,
    num_inference_steps,
    guidance_scale,
    seed,
    show_type
):
    # Ensure models are loaded
    if pipeline_flux is None or automasker is None or mask_processor is None:
        return "ERROR: Models not loaded. Please check loading logs.", None # Return error message

    print("--- Starting Inference ---")
    try:
        # Process image editor input
        # Check if input is a dictionary (from ImageEditor) or just a path (from Examples)
        if isinstance(person_image_input, dict) and "background" in person_image_input:
             person_image_path = person_image_input["background"]
             mask_path = person_image_input["layers"][0] if person_image_input["layers"] else None
        elif isinstance(person_image_input, str): # Path from examples
             person_image_path = person_image_input
             mask_path = None # No mask from examples
        else:
             raise ValueError("Invalid person image input format")

        if isinstance(cloth_image_input, str):
            cloth_image_path = cloth_image_input
        else:
             raise ValueError("Invalid cloth image input format")

        person_image = Image.open(person_image_path).convert("RGB")
        cloth_image = Image.open(cloth_image_path).convert("RGB")

        mask = None
        if mask_path is not None:
            print("Processing provided mask...")
            mask_pil = Image.open(mask_path).convert("L")
            # Check if mask is empty (all black or all white)
            mask_np = np.array(mask_pil)
            if len(np.unique(mask_np)) > 1 and np.any(mask_np > 128): # Check if not empty and has non-black pixels
                 mask_np[mask_np > 0] = 255 # Binarize
                 mask = Image.fromarray(mask_np)
                 print("Using user-drawn mask.")
            else:
                 print("User-drawn mask is empty or invalid, will generate automatically.")
                 mask = None # Fallback to auto-mask


        # Adjust image sizes
        print(f"Resizing images to {args.width}x{args.height}")
        person_image_resized = resize_and_crop(person_image, (args.width, args.height))
        cloth_image_resized = resize_and_padding(cloth_image, (args.width, args.height))

        # Process mask (Generate if not provided or invalid)
        if mask is not None:
            print("Resizing provided mask...")
            mask = resize_and_crop(mask, (args.width, args.height))
        else:
            print(f"Generating mask for type: {cloth_type}")
            mask_data = automasker(person_image_resized, cloth_type)
            if mask_data and 'mask' in mask_data:
                 mask = mask_data['mask']
                 print("Auto-mask generated.")
            else:
                 print("Warning: Auto-mask generation failed. Using blank mask.")
                 # Create a blank (black) mask as fallback
                 mask = Image.new("L", (args.width, args.height), 0)

        print("Blurring mask...")
        mask = mask_processor.blur(mask, blur_factor=9)

        # Set random seed
        generator = None
        if seed != -1:
            print(f"Using seed: {seed}")
            generator = torch.Generator(device=args.device).manual_seed(seed)
        else:
            print("Using random seed.")


        # Inference
        print(f"Running pipeline inference with {num_inference_steps} steps, CFG {guidance_scale}...")
        with torch.inference_mode(): # Ensure gradients are off
             with torch.autocast(args.device, dtype=args.torch_dtype, enabled=(args.mixed_precision != "no")): # Enable autocast
                result_image = pipeline_flux(
                    image=person_image_resized,
                    condition_image=cloth_image_resized,
                    mask_image=mask,
                    height=args.height,
                    width=args.width,
                    num_inference_steps=int(num_inference_steps),
                    guidance_scale=float(guidance_scale),
                    generator=generator
                ).images[0]
        print("Inference complete.")

        # Post-processing
        print("Creating visualization...")
        masked_person_vis = vis_mask(person_image_resized, mask) # Use resized person image

        # Return result based on show type
        if show_type == "result only":
            print("Displaying result only.")
            return result_image
        else:
            output_width, output_height = result_image.size # Should match args.width, args.height
            if show_type == "input & result":
                print("Displaying input & result.")
                condition_width = output_width // 2
                conditions = image_grid([person_image_resized, cloth_image_resized], 2, 1)
            else: # "input & mask & result"
                print("Displaying input & mask & result.")
                condition_width = output_width // 3
                conditions = image_grid([person_image_resized, masked_person_vis, cloth_image_resized], 1, 3) # Grid requires matching dims

            if conditions is None:
                 print("Warning: Condition grid failed, returning result only.")
                 return result_image

            conditions_resized = conditions.resize((condition_width, output_height), Image.LANCZOS) # Use LANCZOS for better quality
            new_result_image = Image.new("RGB", (output_width + condition_width + 5, output_height))
            new_result_image.paste(conditions_resized, (0, 0))
            new_result_image.paste(result_image, (condition_width + 5, 0))
            print("Combined visualization created.")
            return new_result_image

    except Exception as e:
        print(f"ERROR during inference:")
        traceback.print_exc()
        # Return error message to Gradio interface
        return f"Error during inference: {e}", None


# --- Gradio UI Definition ---
def app_gradio():
    global args # Allow access to config

    with gr.Blocks(title="CatVTON with FLUX") as demo:
        gr.Markdown("# CatVTON with FLUX") # Removed specific model name for flexibility
        with gr.Row():
            with gr.Column(scale=1, min_width=350):
                with gr.Row():
                    # Hidden image path holder for examples
                    image_path_flux = gr.Image(
                        type="filepath", interactive=False, visible=False
                    )
                    # Image Editor for user input and mask drawing
                    person_image_flux = gr.ImageEditor(
                         label="Person Image (Draw Mask Here if Needed)",
                         sources=["upload", "clipboard"], # Allow upload and paste
                         type="filepath", # Saves uploaded/edited image to a temp path
                         interactive=True,
                         # Tool order might matter for default selection
                         # tools=["select", "move", "sketch"], # Default sketch tool? Check Gradio docs
                         brush=gr.Brush(default_size=20, colors=["#FFFFFF"], color_mode="fixed") # White brush for mask
                    )

                with gr.Row():
                    with gr.Column(scale=1, min_width=230):
                        cloth_image_flux = gr.Image(
                            interactive=True, label="Garment/Condition Image",
                            sources=["upload", "clipboard"], type="filepath"
                        )
                    with gr.Column(scale=1, min_width=120):
                        gr.Markdown(
                            '<span style="color: #808080; font-size: small;">Provide Mask via:<br>1. `🖌️` Draw on Person image<br>2. Auto-generate below</span>'
                        )
                        cloth_type = gr.Radio(
                            label="Auto-Mask Cloth Type",
                            info="Used if no mask is drawn above",
                            choices=["upper", "lower", "overall"],
                            value="upper",
                        )

                submit_flux = gr.Button("Submit Try-On")
                gr.Markdown(
                    '<center><span style="color: #FF0000">Wait few minutes after clicking Submit!</span></center>'
                )

                with gr.Accordion("Advanced Options", open=False):
                    num_inference_steps_flux = gr.Slider(
                        label="Inference Steps", minimum=10, maximum=100, step=1, value=30 # Reduced default steps
                    )
                    guidance_scale_flux = gr.Slider(
                        label="Guidance Scale (CFG)", minimum=0.0, maximum=10.0, step=0.1, value=2.5 # Reduced default/max CFG
                    )
                    seed_flux = gr.Slider(
                        label="Seed (-1 for random)", minimum=-1, maximum=2147483647, step=1, value=42
                    )
                    show_type = gr.Radio(
                        label="Output Display",
                        choices=["result only", "input & result", "input & mask & result"],
                        value="input & mask & result",
                    )

            with gr.Column(scale=2, min_width=500):
                result_image_flux = gr.Image(interactive=False, label="Result", type="pil") # Output as PIL
                with gr.Row():
                    # Define example paths *after* args is defined
                    root_path = "resource/demo/example" # Make sure this path exists in Kaggle
                    person_example_path = os.path.join(root_path, "person")
                    condition_example_path = os.path.join(root_path, "condition")

                    with gr.Column():
                        # Check if example dirs exist before creating Examples
                        if os.path.exists(os.path.join(person_example_path, "men")):
                            gr.Examples(
                                examples=[os.path.join(person_example_path, "men", f) for f in os.listdir(os.path.join(person_example_path, "men")) if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
                                examples_per_page=4,
                                inputs=image_path_flux, # Link to hidden image path component
                                label="Person Examples (Men)",
                            )
                        if os.path.exists(os.path.join(person_example_path, "women")):
                             gr.Examples(
                                examples=[os.path.join(person_example_path, "women", f) for f in os.listdir(os.path.join(person_example_path, "women")) if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
                                examples_per_page=4,
                                inputs=image_path_flux, # Link to hidden image path component
                                label="Person Examples (Women)",
                             )
                        gr.Markdown('<span style="color: #808080; font-size: small;">*Example images may need to be uploaded to `resource/demo/example`</span>')

                    with gr.Column():
                        if os.path.exists(os.path.join(condition_example_path, "upper")):
                             gr.Examples(
                                examples=[os.path.join(condition_example_path, "upper", f) for f in os.listdir(os.path.join(condition_example_path, "upper")) if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
                                examples_per_page=4,
                                inputs=cloth_image_flux,
                                label="Condition Upper Examples",
                             )
                        if os.path.exists(os.path.join(condition_example_path, "overall")):
                             gr.Examples(
                                examples=[os.path.join(condition_example_path, "overall", f) for f in os.listdir(os.path.join(condition_example_path, "overall")) if f.lower().endswith(('.png', '.jpg', '.jpeg'))],
                                examples_per_page=4,
                                inputs=cloth_image_flux,
                                label="Condition Overall Examples",
                             )
                        # Add more condition examples if needed...


        # --- Event Handlers ---
        # When an example person image is clicked, update the hidden path holder
        # Then, use a .then() event to update the actual ImageEditor from the path holder
        # This prevents examples directly overwriting ImageEditor mask layers
        def update_editor_from_path(filepath):
            print(f"Loading example image into editor: {filepath}")
            if filepath and os.path.exists(filepath):
                 # Return dict format expected by ImageEditor: background only
                 return {"background": filepath, "layers": [], "composite": None}
            return None # Return None if path is invalid

        # Link example clicks to the hidden component
        # Find all gr.Examples components targeting image_path_flux if dynamically generated
        example_components_person = [comp for comp in demo.GetComponents() if isinstance(comp, gr.Examples) and comp.inputs == [image_path_flux]]
        for ex_comp in example_components_person:
             ex_comp.click(
                 fn=lambda x: x, # Pass the filepath through
                 inputs=ex_comp.inputs,
                 outputs=image_path_flux # Update the hidden path holder
             ).then(
                 fn=update_editor_from_path,
                 inputs=image_path_flux, # Read from the hidden path holder
                 outputs=person_image_flux # Update the ImageEditor
             )

        # Submit button click
        submit_flux.click(
            fn=submit_function_flux,
            inputs=[person_image_flux, cloth_image_flux, cloth_type, num_inference_steps_flux, guidance_scale_flux, seed_flux, show_type],
            outputs=result_image_flux, # Only output the final image/result grid
            api_name="catvton_flux_tryon" # Add API name if needed
        )

    return demo


# --- Main Execution Block ---
if __name__ == "__main__":
    # Load models only once when the script starts
    load_all_models_and_tools()

    # Create and launch the Gradio app
    gradio_app = app_gradio()
    gradio_app.queue().launch(share=False, show_error=True) # share=True if you want a public link from Kaggle