# Install Libraries

In [None]:
!pip install tensorflow==2.18.1 mediapipe numpy matplotlib opencv-python
!pip install neurite@git+https://github.com/adalca/neurite.git@40c6d0e277b12dc9dddb6e76f2dbdd373b7d22b1
!pip install voxelmorph@git+https://github.com/voxelmorph/voxelmorph.git@923a37d51b0c8d93eb576156c07ecb25c2a4e730

# Import Libraries

In [None]:
import tensorflow as tf
import voxelmorph as vxm
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Import Images

In [None]:
# REPLACE WITH YOUR IMAGE PATH
image_path = 'IMAGE_PATH'
# REPLACE WITH YOUR CLOTH IMAGE PATH
cloth_image_path = 'CLOTH_IMAGE_PATH'

image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
cloth_image = cv2.cvtColor(cv2.imread(cloth_image_path), cv2.COLOR_BGR2RGB)

# Utility Functions for Inference

In [None]:
!wget -O pose_landmarker.task -q https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_heavy/float16/1/pose_landmarker_heavy.task

In [None]:
from mediapipe import solutions
from mediapipe.framework.formats import landmark_pb2
import numpy as np


def draw_landmarks_on_image(image_size, detection_result, image=None):
    pose_landmarks_list = detection_result.pose_landmarks
    if image is None :
        canvas = np.zeros(image_size)
    else :
        canvas = np.copy(image)

    # Loop through the detected poses to visualize.
    for idx in range(len(pose_landmarks_list)):
        pose_landmarks = pose_landmarks_list[idx]

        # Draw the pose landmarks.
        pose_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
        pose_landmarks_proto.landmark.extend([
          landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in pose_landmarks
        ])
        solutions.drawing_utils.draw_landmarks(
          canvas,
          pose_landmarks_proto,
          solutions.pose.POSE_CONNECTIONS,
          solutions.drawing_styles.get_default_pose_landmarks_style())
    return canvas

In [None]:
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

base_options = python.BaseOptions(model_asset_path='pose_landmarker.task')
options = vision.PoseLandmarkerOptions(
    base_options=base_options
)
detector = vision.PoseLandmarker.create_from_options(options)

In [None]:
def get_mediapipe_skeleton(image) :
    input_int = (image * 255).astype(np.uint8)
    mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=input_int)
    detection_result = detector.detect(mp_image)
    landmark = draw_landmarks_on_image(input_int.shape, detection_result)
    landmark = np.expand_dims(np.mean(landmark, axis=-1), axis=-1)
    return np.array(landmark, dtype=np.float32) / 255

In [None]:
def get_person_representations(image, size='small') :
    if size not in ['small', 'large'] :
        raise ValueError('Argument size must be either "small" or "large"')
    else :
        # standardize input shape
        image_small = cv2.resize(image, (192, 256))
        image_small = np.expand_dims(image_small, axis=0) / 255

        image = cv2.resize(image, (384, 512))
        image = np.expand_dims(image, axis=0) / 255
    
        # get pose skeleton
        pose_skeleton = get_mediapipe_skeleton(image_small[0])
        pose_skeleton = np.expand_dims(pose_skeleton, axis=0)
    
        # predict agnostic segmentation
        agnostic_segmentation = agnostic_seg_model((image_small, pose_skeleton))
        agnostic_segmentation = np.argmax(agnostic_segmentation, axis=-1)
        agnostic_segmentation = np.expand_dims(np.where(agnostic_segmentation == 0, 0, 1), axis=-1).astype(np.float32)
    
        if size != 'small' :
            agnostic_segmentation = np.expand_dims(cv2.resize(agnostic_segmentation[0], (384, 512), interpolation=cv2.INTER_NEAREST), axis=-1)
            agnostic_segmentation = np.expand_dims(agnostic_segmentation, axis=0)
            agnostic_representation = image * agnostic_segmentation
            pose_skeleton = np.expand_dims(get_mediapipe_skeleton(image[0]), axis=0)
        else :
            agnostic_representation = image_small * agnostic_segmentation
        return agnostic_representation, pose_skeleton

In [None]:
def get_cloth_segmentation(cloth_image, size='small') :
    if size not in ['small', 'large'] :
        raise ValueError('Argument size must be either "small" or "large"')
    else :
        cloth_image = cv2.resize(cloth_image, (192, 256))
        cloth_image = np.expand_dims(cloth_image, axis=0) / 255
        cloth_seg = clothes_seg_model(cloth_image)
        
        cloth_seg = np.where(cloth_seg > 0.5, 1, 0).astype(np.float32)
    
        if size != 'small' :
            cloth_seg = cv2.resize(cloth_seg[0], (384, 512), interpolation=cv2.INTER_NEAREST)
            cloth_seg = np.expand_dims(np.expand_dims(cloth_seg, axis=-1), axis=0)
    
        return cloth_seg

In [None]:
def generate_tryon(image, cloth_image, size='small') :
    person_agnostic, pose_skeleton = get_person_representations(image, size=size)
    cloth_segmentation = get_cloth_segmentation(cloth_image, size=size)

    if size == 'small' :
        cloth_image = cv2.resize(cloth_image, (192, 256))
    else :
        cloth_image = cv2.resize(cloth_image, (384, 512), interpolation=cv2.INTER_LINEAR)
    cloth_image = np.expand_dims(cloth_image, axis=0) / 255

    if size == 'small' :
        deformation_fields = warp_unet_small((person_agnostic, pose_skeleton, cloth_image, cloth_segmentation))
        warped_cloth = vxm.layers.SpatialTransformer()([cloth_image, deformation_fields]).numpy()
        tryon = tryon_generator_small((person_agnostic, pose_skeleton, warped_cloth))
    else :
        deformation_fields = warp_unet_large((person_agnostic, pose_skeleton, cloth_image, cloth_segmentation))
        warped_cloth = vxm.layers.SpatialTransformer()([cloth_image, deformation_fields]).numpy()
        tryon = tryon_generator_large((person_agnostic, pose_skeleton, warped_cloth))
    return tryon[0]

# Import Model

In [None]:
agnostic_seg_model = tf.keras.models.load_model('/kaggle/input/viton-agnostic-segmentation/tensorflow2/v1/1/agnostic_segmentation.keras')
clothes_seg_model = tf.keras.models.load_model('/kaggle/input/viton-clothes-segmentation/tensorflow2/v1/1/clothes_segmentation.keras')
warp_unet_small = tf.keras.models.load_model('/kaggle/input/viton-small/tensorflow2/small/1/warp_unet.keras')
tryon_generator_small = tf.keras.models.load_model('/kaggle/input/viton-small/tensorflow2/small/1/tryon_generator.keras')
warp_unet_large = tf.keras.models.load_model('/kaggle/input/viton-large/tensorflow2/large/1/warp_unet_large.keras')
tryon_generator_large = tf.keras.models.load_model('/kaggle/input/viton-large/tensorflow2/large/1/tryon_generator_large.keras')

# Inference

In [None]:
tryon = generate_tryon(image, cloth_image)
tryon_large = generate_tryon(image, cloth_image, size='large')

In [None]:
plt.figure(figsize=(18, 8))

plt.subplot(121)
plt.imshow(tryon)
plt.axis(False)

plt.subplot(122)
plt.imshow(tryon_large)
plt.axis(False)

plt.tight_layout()
plt.show()