## 1. renamed
## 2. resize_and_crop
## 3. segment person or cloth with sam2-small
## 4. make masked imgs

## run this code with sam2 env

In [1]:
# pip install accelerate
# !pip install diffusers
# !pip install onnxruntime scipy jsonschema

In [2]:
import sys, os
current_dir = os.getcwd()
project_root_dir = os.path.abspath(os.path.join(current_dir, ".."))
catvton_dir = os.path.join(project_root_dir, "CatVTON")

if catvton_dir not in sys.path:
    sys.path.insert(0, catvton_dir)

In [3]:
import os
import shutil

# Define source directories
renamed_person_dir = os.path.join(project_root_dir, "data", "renamed", "renamed_person_images")
renamed_cloth_dir  = os.path.join(project_root_dir, "data", "renamed", "renamed_cloth_images")

# Define destination directories for viton-hd style dataset
dest_person_dir = os.path.join(project_root_dir, "data","dataset", "image")
dest_cloth_upper = os.path.join(project_root_dir, "data","dataset", "cloth", "upper_img_sam2")
dest_cloth_lower = os.path.join(project_root_dir, "data","dataset", "cloth", "lower_img_sam2")
dest_mask_upper = os.path.join(project_root_dir, "data","dataset", "cloth", "upper_mask_sam2")
dest_mask_lower = os.path.join(project_root_dir, "data","dataset", "cloth", "lower_mask_sam2")

os.makedirs(dest_person_dir, exist_ok=True)
os.makedirs(dest_cloth_upper, exist_ok=True)
os.makedirs(dest_cloth_lower, exist_ok=True)
os.makedirs(dest_mask_upper, exist_ok=True)
os.makedirs(dest_mask_lower, exist_ok=True)

# 페어 만들기
person_cloth_pairs = {
    "person_image_path" : "",
    "upper_cloth_path" : "",
    "lower_cloth_path" : ""
}
person_cloth_pairs_list = []
#/Users/coldbrew/Documents/VTON-project/data/renamed/renamed_person_images/Jonghyeon_manA_mana_30.jpg
for person_file_name in sorted(os.listdir(renamed_person_dir)):
    person_cloth_pairs["person_image_path"] = os.path.join(renamed_person_dir, person_file_name)
    person_cloth_pairs["upper_cloth_path"] = os.path.join(renamed_cloth_dir, f"{person_file_name.split('_')[1]}00.jpg").lower()
    person_cloth_pairs["lower_cloth_path"] = os.path.join(renamed_cloth_dir, f"{person_file_name.split('_')[2]}01.jpg").lower()
    person_cloth_pairs_list.append(person_cloth_pairs.copy())
    

person_cloth_pairs_list[:5]

[{'person_image_path': 'c:\\Users\\coldbrew\\VTON-project\\data\\renamed\\renamed_person_images\\Jonghyeon_manA_mana_120.jpg',
  'upper_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renamed_cloth_images\\mana00.jpg',
  'lower_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renamed_cloth_images\\mana01.jpg'},
 {'person_image_path': 'c:\\Users\\coldbrew\\VTON-project\\data\\renamed\\renamed_person_images\\Jonghyeon_manA_mana_150.jpg',
  'upper_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renamed_cloth_images\\mana00.jpg',
  'lower_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renamed_cloth_images\\mana01.jpg'},
 {'person_image_path': 'c:\\Users\\coldbrew\\VTON-project\\data\\renamed\\renamed_person_images\\Jonghyeon_manA_mana_30.jpg',
  'upper_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renamed_cloth_images\\mana00.jpg',
  'lower_cloth_path': 'c:\\users\\coldbrew\\vton-project\\data\\renamed\\renam

### sam2 정의

In [4]:
# 여러 마스크에서 의상 마스크 찾는 함수 만들기
# 하나만 찾아야하고 여러개면 오류 발생
def find_cloth_mask(cloth_image, masks):
    img_h, img_w = cloth_image.shape[:2]
    x_min, x_max = int(img_w * 0.4), int(img_w * 0.6)  # 이미지 가로 40~60%
    y_min, y_max = int(img_h * 0.2), int(img_h * 0.5)  # 이미지 세로 20~50%

    # 박스 안에 가장 많이 포함되는 마스크 찾기
    best_mask = None
    max_overlap = 0

    for mask in masks:
        x, y, w, h = mask['bbox']
        
        # 마스크의 좌표 (bbox)
        mask_x_min, mask_x_max = x, x + w
        mask_y_min, mask_y_max = y, y + h
        
        # 중앙 박스와의 겹치는 부분 계산
        overlap_x_min = max(x_min, mask_x_min)
        overlap_x_max = min(x_max, mask_x_max)
        overlap_y_min = max(y_min, mask_y_min)
        overlap_y_max = min(y_max, mask_y_max)
        
        # 겹치는 영역의 크기 계산
        overlap_width = max(0, overlap_x_max - overlap_x_min)
        overlap_height = max(0, overlap_y_max - overlap_y_min)
        overlap_area = overlap_width * overlap_height
        
        # 가장 많이 겹치는 마스크 선택
        if overlap_area > max_overlap:
            best_mask = mask
            max_overlap = overlap_area

    # 마스크가 없으면 종료
    if best_mask is None:
        print("⚠️ 중앙 박스 안에 포함되는 마스크(상의/바지)를 찾을 수 없습니다.")
    else:
        return best_mask['segmentation']

In [5]:
import sys, os
project_root_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sam_root = os.path.join(project_root_dir,"sam2")
sys.path.insert(0, sam_root)

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

model_cfg = os.path.join(sam_root,"checkpoints\sam2.1_hiera_s.yaml")
sam2_checkpoint = os.path.join(sam_root,"checkpoints\sam2.1_hiera_small.pt")

sam2 = build_sam2(model_cfg, sam2_checkpoint, device="cuda", apply_postprocessing=True)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

In [7]:
from utils import resize_and_crop, resize_and_padding
from tqdm import tqdm
import rembg
import PIL
import numpy as np

for idx, pcp_dict in tqdm(enumerate(person_cloth_pairs_list), total=len(person_cloth_pairs_list)):
    new_name = f"{idx:05d}.jpg"
    # 파일 열기
    person_img = PIL.Image.open(pcp_dict["person_image_path"]).convert("RGB")
    upper_img = PIL.Image.open(pcp_dict["upper_cloth_path"]).convert("RGB")
    lower_img = PIL.Image.open(pcp_dict["lower_cloth_path"]).convert("RGB")

    # 사람 이미지는 resize_and_crop 함수로, 옷 이미지는 resize_and_padding 함수로 크기 조정
    person_img = resize_and_crop(person_img, (1024, 768))
    upper_img = resize_and_crop(upper_img, (1024, 768))
    lower_img = resize_and_crop(lower_img, (1024, 768))

    # 파일 회전
    person_img_rotated = person_img.rotate(0, expand=True)
    upper_img_rotated = upper_img.rotate(270, expand=True)
    lower_img_rotated = lower_img.rotate(270, expand=True)
    
    # to numpy
    # person_img_rotated = np.array(person_img_rotated)
    upper_img_rotated = np.array(upper_img_rotated)
    lower_img_rotated = np.array(lower_img_rotated)

    # 마스크 생성
    upper_masks = mask_generator.generate(upper_img_rotated)
    lower_masks = mask_generator.generate(lower_img_rotated)
    segmented_upper_mask = find_cloth_mask(upper_img_rotated, upper_masks)
    segmented_lower_mask = find_cloth_mask(lower_img_rotated, lower_masks)
    
    # numpy 배열로 변환 (segmentation 연산을 위해)
    upper_img_np = np.array(upper_img_rotated)
    lower_img_np = np.array(lower_img_rotated)

    # --- 3. 출력 이미지 생성 ---
    # (a) 옷 영역은 원본, 나머지 영역은 흰색 처리 (cloth extracted image)
    # segmented_upper_mask[..., None]를 사용해 채널 차원을 맞춰줍니다.
    upper_cloth_extracted_np = np.where(segmented_upper_mask[..., None], upper_img_np, 255)
    lower_cloth_extracted_np = np.where(segmented_lower_mask[..., None], lower_img_np, 255)
    upper_cloth_extracted_img = PIL.Image.fromarray(upper_cloth_extracted_np.astype(np.uint8))
    lower_cloth_extracted_img = PIL.Image.fromarray(lower_cloth_extracted_np.astype(np.uint8))
    
    # (b) 바이너리 마스크 이미지: 옷 영역은 흰색(255), 배경은 검정(0)
    upper_binary_mask_np = (segmented_upper_mask.astype(np.uint8)) * 255
    lower_binary_mask_np = (segmented_lower_mask.astype(np.uint8)) * 255
    upper_binary_mask_img = PIL.Image.fromarray(upper_binary_mask_np, mode="L")
    lower_binary_mask_img = PIL.Image.fromarray(lower_binary_mask_np, mode="L")

    # --- 4. 저장 ---
    upper_cloth_extracted_img.save(os.path.join(dest_cloth_upper, new_name))
    lower_cloth_extracted_img.save(os.path.join(dest_cloth_lower, new_name))
    upper_binary_mask_img.save(os.path.join(dest_mask_upper, new_name))
    lower_binary_mask_img.save(os.path.join(dest_mask_lower, new_name))

    # # 흰 배경 생성
    # person_white_bg = PIL.Image.new("RGB", (768, 1024), (255, 255, 255))
    # upper_white_bg = PIL.Image.new("RGB", (768, 1024), (255, 255, 255))
    # lower_white_bg = PIL.Image.new("RGB", (768, 1024), (255, 255, 255))

    # person_white_bg.paste(person_img, (0, 0), person_img)
    # upper_white_bg.paste(segmented_upper_mask, (0, 0), segmented_upper_mask)
    # lower_white_bg.paste(segmented_lower_mask, (0, 0), segmented_lower_mask)

    # # 파일 저장
    # person_white_bg.save(os.path.join(dest_person_dir, new_name))
    # upper_white_bg.save(os.path.join(dest_cloth_upper, new_name))
    # lower_white_bg.save(os.path.join(dest_cloth_lower, new_name))

 15%|█▌        | 30/195 [02:16<12:30,  4.55s/it]


KeyboardInterrupt: 