In [None]:
# !apt update
# !apt-get install git-lfs
# !git lfs install
# !mkdir ../vton_origin
# !git clone https://huggingface.co/spaces/yisol/IDM-VTON ../vton_origin
# !rm -rf ../vton_origin/.git/
# !rm -rf ../vton_origin/example/
# !apt-get install rsync
# !rsync -av --ignore-existing vton_origin/ VTO_demo/
# !pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
# !pip install -r requirements.txt
# !pip install ipympl

In [None]:
!nvidia-smi

In [None]:
%matplotlib inline
%matplotlib notebook
import sys
import cv2
import json
import matplotlib.pyplot as plt
import gradio as gr
from PIL import Image
from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline
from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref
from src.unet_hacked_tryon import UNet2DConditionModel
from transformers import (
    CLIPImageProcessor,
    CLIPVisionModelWithProjection,
    CLIPTextModel,
    CLIPTextModelWithProjection,
)
from diffusers import DDPMScheduler,AutoencoderKL
from typing import List
from scipy.ndimage import binary_erosion

import torch
import os
from transformers import AutoTokenizer
import numpy as np
from utils_mask import get_mask_location
from torchvision import transforms
import apply_net
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose
from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
from torchvision.transforms.functional import to_pil_image

Im = Image.fromarray

In [None]:
def pi(img_like):
    # Create a figure
    fig, ax = plt.subplots()
    
    ax.imshow(img_like)
    ax.legend()
    plt.savefig('tmp_plot.png')
        
    img = Image.open('tmp_plot.png')

    plt.close(fig)

    return img

In [None]:
from PIL import Image, ImageChops, ImageFilter

import torch
import numpy as np

from utils_mask import get_mask_location
from torchvision import transforms
import apply_net

from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation
from torchvision.transforms.functional import to_pil_image
from segmentation_processor import request_segmentation_results, extract_submask

def correct_masking(preserve_mask, org_image, mask, mask_gray):
    preserve_mask = Image.fromarray(preserve_mask).convert('L')
    mask2_inverted = ImageChops.invert(preserve_mask)
    corrected_mask = ImageChops.multiply(mask, mask2_inverted)
    corrected_mask_gray = Image.composite(org_image, mask_gray, preserve_mask)
    return corrected_mask, corrected_mask_gray


def erose_mask(mask, kernal_size=11):
    eroded_mask = binary_erosion(
        mask,
        structure = np.ones((kernal_size, kernal_size), dtype=np.uint8),
        iterations=1
    )
    return eroded_mask.astype(mask.dtype)


def blend_image_with_color(image, mask, color):
    # Convert the mask and image to NumPy arrays
    mask_array = np.array(mask)  # Normalize mask to [0, 1]
    image_array = np.array(image)

    # Create a color array (same size as the image, filled with the target color)
    color_array = np.full_like(image_array, color)

    mask_ex = mask_array[..., None].astype(float) * 0.5

    # Blend the image and the color based on the mask
    blended_array = image_array * (1 - mask_ex) + color_array * mask_ex

    # Convert the result back to a PIL Image
    blended_image = Image.fromarray(np.uint8(blended_array))

    return blended_image

In [None]:
import yaml
from pipeline_loader import PipelineLoader

with open('pipeline_config.yaml', 'r') as file:
    pipeline_config = yaml.safe_load(file)


pipeline_loader = PipelineLoader(
    base_path=pipeline_config['base_path'],
    device=pipeline_config['device']
)

In [None]:
def erode_mask(org_mask, erosion_size=3):
    '''
        - mask (np.ndarray): Binary mask with 0s and 255s.
    '''
    if isinstance(org_mask, Image.Image):
        org_mask = org_mask.convert('L')

    mask = np.array(org_mask)
    mask = np.where(mask > 0, 255, 0).astype(np.uint8)
    erosion_kernel = np.ones((erosion_size, erosion_size), np.uint8)
    eroded_mask = cv2.erode(mask, erosion_kernel, iterations=1)
    _, final_mask = cv2.threshold(eroded_mask, 127, 255, cv2.THRESH_BINARY)
    
    if isinstance(org_mask, Image.Image):
        final_mask = Image.fromarray(final_mask).convert('L')
    return final_mask


def fill_holes(org_mask, min_size):
    """
    Removes small isolated clusters from a binary mask.
    
    :param binary_mask: Binary mask (numpy array)
    :param min_size: Minimum size of clusters to keep
    :return: Cleaned binary mask with small clusters removed
    """

    if isinstance(org_mask, Image.Image):
        org_mask = org_mask.convert('L')

    binary_mask = np.array(org_mask) > 0
    binary_mask = binary_mask.astype(np.uint8)

    # Find all connected components (clusters)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
    
    # Initialize an output mask to store the result
    cleaned_mask = np.zeros_like(binary_mask, dtype=np.uint8)
    
    # Iterate through all the components and filter based on size
    for i in range(1, num_labels):  # Start from 1 to skip the background
        component_size = stats[i, cv2.CC_STAT_AREA]
        if component_size >= min_size:
            # Keep the component if it's larger than the minimum size
            cleaned_mask[labels == i] = 255

    cleaned_mask = cleaned_mask.astype(bool)

    if isinstance(org_mask, Image.Image):
        cleaned_mask = Image.fromarray(cleaned_mask).convert('L')

    return cleaned_mask


def remove_small_clusters_np(org_mask, min_size):
    cleaned_mask = fill_holes(org_mask, min_size)
    cleaned_mask_inv = ~cleaned_mask
    cleaned_mask_inv_v2 = fill_holes(cleaned_mask_inv, min_size)
    cleaned_mask = ~cleaned_mask_inv_v2
    return cleaned_mask

In [None]:
class TryOnProcessor:
    def __init__(self, pipeline_config, pipeline_loader):
        self.pipeline_config = pipeline_config
        self.segmentaion_config = self.pipeline_config['segmentaion']
        self.device = pipeline_config['device']
        self.pipe = pipeline_loader.get_pipeline()
        self.openpose_model = pipeline_loader.get_openpose_model()
        self.parsing_model = pipeline_loader.get_parsing_model()
        self.tensor_transform = pipeline_loader.get_tensor_transform()

    def to(self, device):
        self.pipe.to(device)

    def preprocess_submasks(self, init_image):
        init_segmentation_map, init_classes_mapping = request_segmentation_results(
            url=self.segmentaion_config['service_url'], 
            image=init_image
        )

        pre_preservation_classes = extract_submask(
            segmentation_map=init_segmentation_map,
            submask_classes=self.segmentaion_config['pre_preservation_classes'],
            classes_mapping=init_classes_mapping
        )

        pre_preservation_classes = erose_mask(pre_preservation_classes)

        return pre_preservation_classes, init_segmentation_map, init_classes_mapping
        

    def postprocess_submasks(
        self, 
        init_image, 
        init_segmentation_map,
        init_classes_mapping, 
        result_image,
        erosion_size=3,
    ):
        segmentaion_config = self.segmentaion_config

        soft_preservation_submask = extract_submask(
            segmentation_map=init_segmentation_map,
            submask_classes=segmentaion_config['soft_preservation_classes'],
            classes_mapping=init_classes_mapping
        )


        result_segmentation_map, result_classes_mapping = request_segmentation_results(
            url=segmentaion_config['service_url'], 
            image=result_image
        )

        clothing_submask = extract_submask(
            segmentation_map=result_segmentation_map,
            submask_classes=segmentaion_config['clothing_classes'],
            classes_mapping=result_classes_mapping
        )

        Image.fromarray(clothing_submask).convert("L").save('clothing_submask.png')

        soft_mask = np.logical_and(
            soft_preservation_submask, 
            np.logical_not(clothing_submask)
        )
        soft_mask = remove_small_clusters_np(soft_mask, min_size=1000)
        soft_mask_pil = Image.fromarray(soft_mask).convert("L")
        soft_mask_pil = erode_mask(soft_mask_pil,erosion_size=erosion_size)

        composed_image = Image.composite(init_image, result_image, soft_mask_pil)
        return composed_image


    def preprocess_images(self, human_canva, garm_img):
        garm_img = garm_img.convert("RGB").resize((768, 1024))
        human_img_orig = human_canva["background"].convert("RGB")
        human_img = human_img_orig.resize((768, 1024))
        return garm_img, human_img, human_img_orig

    def generate_keypoints_and_parse_model(self, human_img):
        resized_human_img = human_img.resize((384, 512))
        keypoints = self.openpose_model(resized_human_img)
        model_parse, _ = self.parsing_model(resized_human_img)
        return keypoints, model_parse

    def generate_mask_and_mask_gray(self, model_parse, keypoints, human_img):
        mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints)
        mask = mask.resize((768, 1024))
        mask_gray = (1 - transforms.ToTensor()(mask)) * self.tensor_transform(human_img)
        mask_gray = to_pil_image((mask_gray + 1.0) / 2.0)
        return mask, mask_gray

    def prepare_human_image_for_pose_estimation(self, human_img):
        human_img_arg = _apply_exif_orientation(human_img.resize((384, 512)))
        human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR")
        return human_img_arg

    def generate_pose_image(self, human_img_arg):
        argument_parser = apply_net.create_argument_parser()
        args = argument_parser.parse_args(
            (
                'show',
                './configs/densepose_rcnn_R_50_FPN_s1x.yaml',
                './ckpt/densepose/model_final_162be9.pkl',
                'dp_segm', '-v', '--opts', 'MODEL.DEVICE', self.device
            )
        )
        pose_img = args.func(args, human_img_arg)
        pose_img = pose_img[:, :, ::-1]
        pose_img = Image.fromarray(pose_img).resize((768, 1024))
        return pose_img

    def encode_prompts(self, garment_des):
        prompt = "model is wearing " + garment_des
        negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
        with torch.inference_mode():
            (
                prompt_embeds,
                negative_prompt_embeds,
                pooled_prompt_embeds,
                negative_pooled_prompt_embeds,
            ) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=True,
                negative_prompt=negative_prompt,
            )

        prompt = "a photo of " + garment_des
        if not isinstance(prompt, list):
            prompt = [prompt]
        if not isinstance(negative_prompt, list):
            negative_prompt = [negative_prompt]
        with torch.inference_mode():
            (prompt_embeds_c, _, _, _) = self.pipe.encode_prompt(
                prompt,
                num_images_per_prompt=1,
                do_classifier_free_guidance=False,
                negative_prompt=negative_prompt,
            )
        return (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
            prompt_embeds_c,
        )

    def prepare_images_for_model(self, pose_img, garm_img):
        pose_img_tensor = self.tensor_transform(pose_img).unsqueeze(0).to(self.device, torch.float16)
        garm_tensor = self.tensor_transform(garm_img).unsqueeze(0).to(self.device, torch.float16)
        return pose_img_tensor, garm_tensor

    def generate_images_with_model(
        self,
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
        denoise_steps,
        generator,
        pose_img_tensor,
        prompt_embeds_c,
        garm_tensor,
        mask,
        human_img,
        garm_img,
    ):
        images = self.pipe(
            prompt_embeds=prompt_embeds.to(self.device, torch.float16),
            negative_prompt_embeds=negative_prompt_embeds.to(self.device, torch.float16),
            pooled_prompt_embeds=pooled_prompt_embeds.to(self.device, torch.float16),
            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(self.device, torch.float16),
            num_inference_steps=denoise_steps,
            generator=generator,
            strength=1.0,
            pose_img=pose_img_tensor,
            text_embeds_cloth=prompt_embeds_c.to(self.device, torch.float16),
            cloth=garm_tensor,
            mask_image=mask,
            image=human_img,
            height=1024,
            width=768,
            ip_adapter_image=garm_img.resize((768, 1024)),
            guidance_scale=2.0,
        )[0]
        return images

In [None]:
self = TryOnProcessor(
    pipeline_config, pipeline_loader
)

In [None]:
human_canva = {
    'background': Image.open(
        '/workspace/VTO_demo/example/human/carlos-costa-beard-1.jpg'
)}
garm_img = Image.open('/workspace/VTO_demo/example/cloth/kimono.jpg')
garment_des = 'Short Sleeve Open Front Kimono in Yellow with Tropical Floral Print'
denoise_steps = 30
seed = 997

In [None]:
garm_img, human_img, human_img_orig = self.preprocess_images(human_canva, garm_img)
org_size = human_img_orig.size

self.to('cpu')
torch.cuda.empty_cache()
(
    pre_preservation_classes,
    init_segmentation_map,
    init_classes_mapping
) = self.preprocess_submasks(init_image=human_img)
self.to('cuda')

# Generate keypoints and parse model
keypoints, model_parse = self.generate_keypoints_and_parse_model(human_img)

# Generate mask and mask_gray
mask, mask_gray = self.generate_mask_and_mask_gray(model_parse, keypoints, human_img)

mask, mask_gray = correct_masking(
    preserve_mask=pre_preservation_classes, 
    org_image=human_img,
    mask=mask,
    mask_gray=mask_gray
)

# Prepare human image for pose estimation
human_img_arg = self.prepare_human_image_for_pose_estimation(human_img)

# Generate pose image
pose_img = self.generate_pose_image(human_img_arg)

with torch.no_grad():
    with torch.cuda.amp.autocast():
        # Encode prompts
        (
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
            prompt_embeds_c,
        ) = self.encode_prompts(garment_des)

        # Prepare images for the model
        pose_img_tensor, garm_tensor = self.prepare_images_for_model(pose_img, garm_img)

        generator = (
            torch.Generator(self.device).manual_seed(seed) if seed is not None else None
        )

        # Generate images with the model
        images = self.generate_images_with_model(
            prompt_embeds,
            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
            denoise_steps,
            generator,
            pose_img_tensor,
            prompt_embeds_c,
            garm_tensor,
            mask,
            human_img,
            garm_img,
        )

result_image = images[0]

self.to('cpu')
torch.cuda.empty_cache()
compose_result = self.postprocess_submasks(
    init_image=human_img,
    init_segmentation_map=init_segmentation_map,
    init_classes_mapping=init_classes_mapping,
    result_image=result_image,
)
compose_result_res = compose_result.resize(org_size)
self.to('cuda')

In [None]:
compose_result = compose_result.resize(human_img.size)

In [None]:
segmentaion_config = self.segmentaion_config

soft_preservation_submask = extract_submask(
    segmentation_map=init_segmentation_map,
    submask_classes=segmentaion_config['soft_preservation_classes'],
    classes_mapping=init_classes_mapping
)

result_segmentation_map, result_classes_mapping = request_segmentation_results(
    url=segmentaion_config['service_url'], 
    image=result_image
)

clothing_submask = extract_submask(
    segmentation_map=result_segmentation_map,
    submask_classes=segmentaion_config['clothing_classes'],
    classes_mapping=result_classes_mapping
)

In [None]:
blend_image_with_color(compose_result, soft_preservation_submask, (0,255,0))

In [None]:
def get_all_submasks(segmentation_map, classes_mapping):
    res_submasks = {}
    for name, idx in result_classes_mapping.items():
        submask = extract_submask(
            segmentation_map=segmentation_map,
            submask_classes=[name],
            classes_mapping=classes_mapping
        )
        submask = remove_small_clusters_np(submask,min_size=1000)
        res_submasks[name] = submask
    return res_submasks

In [None]:
init_submasks = get_all_submasks(
    segmentation_map=init_segmentation_map,
    classes_mapping=init_classes_mapping
)

In [None]:
res_submasks = get_all_submasks(
    segmentation_map=result_segmentation_map,
    classes_mapping=result_classes_mapping
)

In [None]:
def get_mask_center(mask):
    binary_mask = np.where(mask > 0, 1, 0).astype(np.uint8)
    y_coords, x_coords = np.nonzero(binary_mask)
    if len(x_coords) > 0 and len(y_coords) > 0:
        x_center = np.mean(x_coords)
        y_center = np.mean(y_coords)
        return (int(x_center), int(y_center))
    else:
        return None  # Return None if the mask is empty

In [None]:
def bresenham_line(x1, y1, x2, y2, sizes):
    """
    Standard Bresenham's algorithm for generating points between (x1, y1) and (x2, y2).
    """
    points = []
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    sx = 1 if x1 < x2 else -1
    sy = 1 if y1 < y2 else -1
    err = dx - dy

    while True:
        points.append((x1, y1))
        if x1 == 0 or x1 == sizes[0]-1 or y1 == 0 or y1 == sizes[1] - 1:
            break
        e2 = 2 * err
        if e2 > -dy:
            err -= dy
            x1 += sx
        if e2 < dx:
            err += dx
            y1 += sy

    return points

In [None]:
def perpendicular_line(x, y, slope, length):
    """
    Build a perpendicular line to a given slope, centered at (x, y).
    
    Args:
    - x, y: Point where the perpendicular line originates.
    - slope: Slope of the original line.
    - length: Length of the perpendicular line to generate.
    
    Returns:
    - List of points [(x, y), ...] on the perpendicular line.
    """
    if slope == 0:  # Handle the case where the line is horizontal
        perp_slope = np.inf
    else:
        perp_slope = -1 / slope
    
    perp_points = []
    
    for i in range(-length//2, length//2 + 1):
        if perp_slope == np.inf:  # Perpendicular to horizontal line
            perp_points.append((x, y + i))
        else:
            new_x = x + i
            new_y = int(y + perp_slope * i)
            perp_points.append((new_x, new_y))
    
    return perp_points

def find_intersections(mask2, points):
    """
    Find the intersection points between the second mask and the points on the perpendicular line.
    
    Args:
    - mask2 (np.ndarray): Second mask where we are looking for intersections.
    - points (list): List of points on the perpendicular line.
    
    Returns:
    - List of points where the line intersects the second mask.
    """
    intersections = []
    for (x, y) in points:
        if 0 <= x < mask2.shape[1] and 0 <= y < mask2.shape[0]:  # Check if within bounds
            if mask2[y, x] > 0:  # Intersection with the second mask
                intersections.append((x, y))
    return intersections

def process_masks(center1, center2, mask1, mask2, perp_length=30):
    """
    Draw a line between the centers of two masks and for each pixel in the first mask along that line,
    find perpendicular lines and their intersections with the second mask.

    Args:
    - mask1 (np.ndarray): First binary mask.
    - mask2 (np.ndarray): Second binary mask.
    - center1 (tuple): Center of the first mask (x1, y1).
    - center2 (tuple): Center of the second mask (x2, y2).
    - perp_length (int): Length of the perpendicular lines to generate.

    Returns:
    - intersections: List of intersection points with mask2.
    """
    intersections = []
    
    x1, y1 = center1
    x2, y2 = center2

    sizes = mask1.shape[::-1]
    # Step 1: Get the points along the line between the two centers
    line_points1 = bresenham_line(x1, y1, x2, y2, sizes=sizes)
    line_points2 = bresenham_line(x2, y2, x1, y1, sizes=sizes)
    line_points = list(set(line_points1 + line_points2))
    
    # Step 2: Get the slope of the line between the centers
    slope = (y2 - y1) / (x2 - x1) if x1 != x2 else 0  # Handle vertical line case
    
    # Step 3: For each point on the line, generate a perpendicular line
    for (x, y) in line_points:
        if mask1[y, x] > 0:  # Only process if the pixel is part of the first mask
            perp_points = perpendicular_line(x, y, slope, perp_length)
            
            # Step 4: Find intersections with the second mask
            intersect_points = find_intersections(mask2, perp_points)
            intersections.extend(intersect_points)
    
    return intersections

In [None]:
mask1 = res_submasks['Left_Lower_Arm']
mask2 = init_submasks['Left_Lower_Arm']

In [None]:
center1 = get_mask_center(res_submasks['Left_Lower_Arm'])
center2 = get_mask_center(res_submasks['Left_Hand'])

In [None]:
intersections = process_masks(
    center1 = center1,
    center2 = center2,
    mask1 = res_submasks['Left_Lower_Arm'],
    mask2 = init_submasks['Left_Lower_Arm']
)

In [None]:
intersection_mask = np.zeros_like(res_submasks['Left_Lower_Arm'])
intersections_np = np.array(intersections)[:,::-1]

In [None]:
intersection_mask[intersections_np[:,0], intersections_np[:,1]] = True

In [None]:
dilated_mask = cv2.dilate(
    intersection_mask.astype(np.uint8), 
    kernel=np.ones((3, 3), np.uint8), iterations=1)
dilated_mask = dilated_mask.astype(bool)

In [None]:
pi(init_submasks['Left_Lower_Arm']*2 + res_submasks['Left_Lower_Arm'])

In [None]:
Im(dilated_mask)

In [None]:
Im(intersection_mask)