In [None]:
import sys
import os
import io
import torch
import numpy as np
from datetime import datetime
from PIL import Image
import argparse
from huggingface_hub import snapshot_download
from diffusers.image_processor import VaeImageProcessor

# Add CatVTON directory to sys.path
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, ".."))
catvton_dir = os.path.join(parent_dir, "CatVTON")
if catvton_dir not in sys.path:
    sys.path.insert(0, catvton_dir)

from model.cloth_masker import AutoMasker, vis_mask
from model.pipeline import CatVTONPipeline
from utils import init_weight_dtype, resize_and_crop, resize_and_padding

def image_grid(images, rows, cols):
    if not images:
        return None
    widths, heights = zip(*(i.size for i in images))
    max_width, max_height = max(widths), max(heights)
    grid_img = Image.new('RGB', (cols * max_width, rows * max_height))
    for idx, img in enumerate(images):
        row = idx // cols
        col = idx % cols
        grid_img.paste(img, (col * max_width, row * max_height))
    return grid_img

# 기본 설정값 (args 대체)
class Args:
    width = 768
    height = 1024
    output_dir = "./output"

args_obj = Args()
if not os.path.exists(args_obj.output_dir):
    os.makedirs(args_obj.output_dir)

# repo_path 설정: attn_ckpt와 AutoMasker에서 사용
repo_path = snapshot_download(repo_id="zhengchong/CatVTON")

# 파이프라인 인스턴스 생성
pipeline = CatVTONPipeline(
    base_ckpt="booksforcharlie/stable-diffusion-inpainting",
    attn_ckpt="zhengchong/CatVTON",        # 학습된 try-on 모델 체크포인트 경로 (attn_ckpt로 사용)
    attn_ckpt_version="mix",
    weight_dtype=init_weight_dtype("bf16"),
    use_tf32=True,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# AutoMasker 및 mask_processor 설정
mask_processor = VaeImageProcessor(
    vae_scale_factor=8, 
    do_normalize=False, 
    do_binarize=True, 
    do_convert_grayscale=True
)
automasker = AutoMasker(
    densepose_ckpt=os.path.join(repo_path, "DensePose"),
    schp_ckpt=os.path.join(repo_path, "SCHP"),
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

def run_tryon(person_image_path, cloth_image_path, cloth_type,
              num_inference_steps=50, guidance_scale=2.5,
              seed=-1, show_type="result only"):
    try:
        # 이미지 로드
        person_img = Image.open(person_image_path).convert("RGB")
        cloth_img = Image.open(cloth_image_path).convert("RGB")
    except Exception as e:
        print("유효하지 않은 이미지 파일입니다.", e)
        return

    try:
        # 자동 마스킹 수행
        mask = automasker(person_img, cloth_type)['mask']
        mask = mask_processor.blur(mask, blur_factor=9)
        # Seed 설정
        generator = (torch.Generator(device='cuda').manual_seed(seed)
                     if seed != -1 else None)
        # Inference 호출
        result_img = pipeline(
            image=person_img,
            condition_image=cloth_img,
            mask=mask,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator
        )[0]
    except Exception as e:
        print("이미지 생성 중 오류가 발생했습니다.", e)
        return

    # 결과 이미지 저장
    date_str = datetime.now().strftime("%Y%m%d%H%M%S")
    folder_path = os.path.join(args_obj.output_dir, date_str[:8])
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    result_save_path = os.path.join(folder_path, date_str[8:] + ".png")
    result_img.save(result_save_path)
    print(f"결과 이미지가 저장되었습니다: {result_save_path}")

    # 선택적으로 결과 이미지를 grid 형태로 생성하여 보여줌
    if show_type != "result only":
        masked_person = vis_mask(person_img, mask)
        if show_type == "input & result":
            condition_width = person_img.size[0] // 2
            conditions = image_grid([person_img, cloth_img], 2, 1)
        else:
            condition_width = person_img.size[0] // 3
            conditions = image_grid([person_img, masked_person, cloth_img], 3, 1)
        conditions = conditions.resize((condition_width, person_img.size[1]), Image.NEAREST)
        new_result_image = Image.new("RGB", (person_img.size[0] + condition_width + 5, person_img.size[1]))
        new_result_image.paste(conditions, (0, 0))
        new_result_image.paste(result_img, (condition_width + 5, 0))
        display(new_result_image)
    else:
        display(result_img)

def parse_arguments():
    parser = argparse.ArgumentParser(description="Run Try-On Inference")
    parser.add_argument("--person", type=str, required=True, help="Path to the person image")
    parser.add_argument("--cloth", type=str, required=True, help="Path to the cloth image")
    parser.add_argument("--cloth_type", type=str, required=True, help="Type of the cloth")
    parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
    parser.add_argument("--guidance_scale", type=float, default=2.5, help="Guidance scale")
    parser.add_argument("--seed", type=int, default=-1, help="Seed value (-1 for random)")
    parser.add_argument("--show_type", type=str, default="result only",
                        choices=["result only", "input & result", "full"],
                        help="Display type for output image")
    return parser.parse_args()

if __name__ == "__main__":
    # Notebook 환경에서는 argparse.parse_args() 작동에 주의-> sys.argv 재정의 필요할 수 있음.
    args = parse_arguments()
    run_tryon(
        person_image_path=args.person,
        cloth_image_path=args.cloth,
        cloth_type=args.cloth_type,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        seed=args.seed,
        show_type=args.show_type
    )