# Virtual Try-On

Данный ноутбук содержит код для установки и совместного запуска моделей [FrankMocap](https://github.com/facebookresearch/frankmocap) и [MultiGarmentNetwork](https://github.com/bharat-b7/MultiGarmentNetwork), позволяющий реализовать виртуальную примерку 3D модели одежды на видео. 
- FrankMocap позволяет производить захват движения человека по видео и оценивать 3D позу человека. 
- MultiGarmentNetwork используется для покадрового наложения меша одежды на SMPL модель человека, полученную от FrankMocap. 
- В качестве источника 3D моделей одежды используется Digital Wardrode из репозитория MultiGarmentNetwork.

# Настройка FrankMocap

## Базовая установка FrankMocap

Мной были внесены изменения в части кода в репозитории FrankMocap (более подробное описание доступно ниже), поэтому клонируем свою копию этого репозитория. Изменения доступны в ветке `virtual-try-on`.

In [None]:
! git clone https://github.com/LukashevichIlya/frankmocap.git

In [None]:
%cd frankmocap/

In [None]:
! git checkout virtual-try-on

In [None]:
! pip install -r docs/requirements.txt

## Установка вспомогательных библиотек

In [None]:
! apt-get install ffmpeg xvfb

In [None]:
! pip install torchgeometry

## Установка PyTorch3D для рендеринга

In [None]:
# пришлось установить данные версии для совместимости
! pip install 'torch==1.6.0+cu101' -f https://download.pytorch.org/whl/torch_stable.html
! pip install 'torchvision==0.7.0+cu101' -f https://download.pytorch.org/whl/torch_stable.html
! pip install 'pytorch3d==0.2.5'

## Установка [2D keypoint detector](https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch)

In [None]:
! sh scripts/install_pose2d.sh

## Загрузка предобученной модели и других дополнительных данных

In [None]:
! sh scripts/download_data_body_module.sh

## Загрузка примеров данных

In [None]:
! sh scripts/download_sample_data.sh

## Загрузка SMPL модели

Для загрузки SMPL модели необходимо зарегистрироваться на [сайте](https://smpl.is.tue.mpg.de/). В моем случае модели доступны на моем Google Drive.

In [11]:
! mkdir -p extra_data/smpl

In [None]:
%cd ..

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [14]:
! cp /content/drive/MyDrive/SMPL-models/1.0.0/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl /content/frankmocap/extra_data/smpl/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl

# Настройка MultiGarmentNetwork

## Клонирование репозитория

In [None]:
! git clone https://github.com/bharat-b7/MultiGarmentNetwork.git

## Загрузка Digital Wardrobe

In [None]:
! wget https://datasets.d2.mpi-inf.mpg.de/MultiGarmentNetwork/Multi-Garmentdataset.zip

In [None]:
! wget https://datasets.d2.mpi-inf.mpg.de/MultiGarmentNetwork/Multi-Garmentdataset_02.zip

In [18]:
! unzip -qn Multi-Garmentdataset.zip

In [19]:
! unzip -qn Multi-Garmentdataset_02.zip

In [20]:
! rm Multi-Garmentdataset.zip Multi-Garmentdataset_02.zip

## Загрузка SMPL модели

In [21]:
! cp /content/drive/MyDrive/SMPL-models/1.0.0/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl /content/MultiGarmentNetwork/assets/neutral_smpl.pkl

## Установка [DIRT](https://github.com/pmh47/dirt)

In [None]:
! git clone https://github.com/pmh47/dirt.git

In [23]:
! sed -i 's|set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=sm_30 --expt-relaxed-constexpr -DNDEBUG")|set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -arch=sm_60 --expt-relaxed-constexpr -DNDEBUG")|' dirt/csrc/CMakeLists.txt

In [None]:
! cd dirt && pip install .

## Установка [Mesh Package](https://github.com/MPI-IS/mesh)

In [None]:
! git clone https://github.com/MPI-IS/mesh.git

In [None]:
! apt-get install libboost-dev xvfb libosmesa6-dev

In [None]:
! cd mesh && make all

В ячейке ниже после ее выполнения появляется предупреждение о необходимости перезапуска среды выполнения. Это стоит сделать и начать выполнение уже со следующей за ней ячейки.

In [None]:
! pip install numpy -I

In [None]:
! pip install -q opendr

# Запуск модели

## Изменение кода MultiGarmentNetwork

In [None]:
%cd /content/MultiGarmentNetwork

In [3]:
! find ./ -type f -name "*.py" -exec sed -i 's|import cPickle as pkl|import _pickle as pkl|' {} \;
! find ./ -type f -name "*.py" -exec sed -i 's|import cPickle as pickle|import _pickle as pickle|' {} \;
! find ./ -type f -name "*.py" -exec sed -i -r s/pkl\.load[\(]open[\(]\(.+\)[\)][\)]/pkl.load\(open\(\\1,\ \'rb\'\)\ ,\ encoding=\'latin1\'\)/g  {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|/BS/bharat/work/MGN_release/||" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|/BS/bharat/work/MGN_final_release/||" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|from posemapper|from .posemapper|" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|from serialization|from .serialization|" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|from verts|from .verts|" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|^import lbs$|from . import lbs|" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|/BS/RVH/work/data/smpl_models/neutral/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl|/content/MultiGarmentNetwork/assets/neutral_smpl.pkl|" {} \;
! find ./ -type f -name "*.py" -exec sed -i "s|assets/smpl_vt_ft.pkl|/content/MultiGarmentNetwork/assets/smpl_vt_ft.pkl|" {} \;

Импортируем в явном виде код для функции `dress` из файла `dress_SMPL.py`. Также в функциях `dress` и `pose_garment` был добавлен аргумент `model_data` для ускорения работы модели. `model_data` – результат работы функции `get_hres_smpl_model_data()`, которая загружает SMPL модель с диска. Ее вызов в других функциях значительно замедляет работу всей модели (примерно в 3 раза, ускорение с 7.5 секунд до 2.5 секунд на обработку одного кадра видео с одним элементом Digital Wardrobe).

In [4]:
'''
Code to dress SMPL with registered garments.
Set the "path" variable in this code to the downloaded Multi-Garment Dataset

If you use this code please cite:
"Multi-Garment Net: Learning to Dress 3D People from Images", ICCV 2019

Code author: Bharat
Shout out to Chaitanya for intersection removal code
'''

from psbody.mesh import Mesh, MeshViewers
import numpy as np
import _pickle as pkl
from utils.smpl_paths import SmplPaths
from lib.ch_smpl import Smpl
from utils.interpenetration_ind import remove_interpenetration_fast
from os.path import join, split
from glob import glob

def load_smpl_from_file(file):
    dat = pkl.load(open(file, 'rb') , encoding='latin1')
    dp = SmplPaths(gender=dat['gender'])
    smpl_h = Smpl(dp.get_hres_smpl_model_data())

    smpl_h.pose[:] = dat['pose']
    smpl_h.betas[:] = dat['betas']
    smpl_h.trans[:] = dat['trans']

    return smpl_h

def pose_garment(garment, vert_indices, smpl_params, model_data):
    '''
    :param smpl_params: dict with pose, betas, v_template, trans, gender
    '''
    # dp = SmplPaths(gender=smpl_params['gender'])
    smpl = Smpl(model_data)
    smpl.pose[:] = 0
    smpl.betas[:] = smpl_params['betas']
    # smpl.v_template[:] = smpl_params['v_template']

    offsets = np.zeros_like(smpl.r)
    offsets[vert_indices] = garment.v - smpl.r[vert_indices]
    smpl.v_personal[:] = offsets
    smpl.pose[:] = smpl_params['pose']
    smpl.trans[:] = smpl_params['trans']

    mesh = Mesh(smpl.r, smpl.f).keep_vertices(vert_indices)
    return mesh

def retarget(garment_mesh, src, tgt):
    '''
    For each vertex finds the closest point and
    :return:
    '''
    from psbody.mesh import Mesh
    verts, _ = src.closest_vertices(garment_mesh.v)
    verts = np.array(verts)
    tgt_garment = garment_mesh.v - src.v[verts] + tgt.v[verts]
    return Mesh(tgt_garment, garment_mesh.f)

def dress(smpl_tgt, model_data, body_src, garment, vert_inds, garment_tex = None):
    '''
    :param smpl: SMPL in the output pose
    :param garment: garment mesh in t-pose
    :param body_src: garment body in t-pose
    :param garment_tex: texture file
    :param vert_inds: vertex association b/w smpl and garment
    :return:
    To use texture files, garments must have vt, ft
    '''
    tgt_params = {'pose': np.array(smpl_tgt.pose.r), 'trans': np.array(smpl_tgt.trans.r), 'betas': np.array(smpl_tgt.betas.r), 'gender': 'neutral'}
    smpl_tgt.pose[:] = 0
    body_tgt = Mesh(smpl_tgt.r, smpl_tgt.f)

    ## Re-target
    ret = retarget(garment, body_src, body_tgt)

    ## Re-pose
    ret_posed = pose_garment(ret, vert_inds, tgt_params, model_data)
    body_tgt_posed = pose_garment(body_tgt, range(len(body_tgt.v)), tgt_params, model_data)

    ## Remove intersections
    ret_posed_interp = remove_interpenetration_fast(ret_posed, body_tgt_posed)
    ret_posed_interp.vt = garment.vt
    ret_posed_interp.ft = garment.ft
    ret_posed_interp.set_texture_image(garment_tex)

    return ret_posed_interp

## Изменение кода FrankMocap

In [None]:
%cd /content/frankmocap

Для добавления текстур нужно было изменить код файлов `screen_free_visualizer.py` и `p3d_renderer.py`. Именно это сделано в моем fork-е репозитория FrankMocap в ветке `virtual-try-on`.

Также нужно изменить `demo_bodymocap.py`, а именно функцию `run_body_mocap`, чтобы можно было работать с одеждой из MGN. Именно это сделано в следующей ячейке.

In [6]:
# Copyright (c) Facebook, Inc. and its affiliates.

import os
import sys
import os.path as osp
import torch
from torchvision.transforms import Normalize
import numpy as np
import cv2
import argparse
import json
import pickle
from datetime import datetime

from demo.demo_options import DemoOptions
from bodymocap.body_mocap_api import BodyMocap
from bodymocap.body_bbox_detector import BodyPoseEstimator
import mocap_utils.demo_utils as demo_utils
import mocap_utils.general_utils as gnu
from mocap_utils.timer import Timer
from mocap_utils.coordconv import convert_smpl_to_bbox, convert_bbox_to_oriIm
from renderer.p3d_renderer import Pytorch3dRenderer

import renderer.image_utils as imu
from renderer.viewer2D import ImShow


def dress_garment(img_original, smpl, model_data, pred_output, 
                  garment_org_body_unposed, garment_unposed, 
                  vert_inds, garment_tex):
    smpl.pose[:] = pred_output['pred_body_pose'].reshape(72)
    smpl.betas[:] = pred_output['pred_betas'].reshape(10)
    smpl.trans[:] = 0
    garment_unposed.set_texture_image(garment_tex)

    new_garment = dress(smpl, model_data, garment_org_body_unposed, garment_unposed, vert_inds, garment_tex)
    pred_vertices = new_garment.v
    camScale = pred_output['pred_camera'][0]
    camTrans = pred_output['pred_camera'][1:]
    bboxTopLeft = pred_output['bbox_top_left']
    boxScale_o2n = pred_output['bbox_scale_ratio']
    pred_vertices_bbox = convert_smpl_to_bbox(pred_vertices, camScale, camTrans)
    pred_vertices_img = convert_bbox_to_oriIm(pred_vertices_bbox, boxScale_o2n, bboxTopLeft, 
                                              img_original.shape[1], img_original.shape[0])
    
    pred_mesh_dict = dict(vertices=pred_vertices_img,
                          faces=new_garment.f.astype(np.int32),
                          vertices_texture=new_garment.vt,
                          faces_texture=new_garment.ft,
                          texture=new_garment.texture_image[:, :, [2, 1, 0]])
    return pred_mesh_dict

    

def run_body_mocap(args, body_bbox_detector, body_mocap, visualizer, smpl, model_data, 
                   garment_org_body_unposed_1, garment_unposed_1, vert_inds_1, garment_tex_1,
                   garment_org_body_unposed_2, garment_unposed_2, vert_inds_2, garment_tex_2):
    #Setup input data to handle different types of inputs
    input_type, input_data = demo_utils.setup_input(args)

    cur_frame = args.start_frame
    video_frame = 0
    timer = Timer()
    while True:
        timer.tic()
        # load data
        load_bbox = False

        if input_type =='image_dir':
            if cur_frame < len(input_data):
                image_path = input_data[cur_frame]
                img_original_bgr  = cv2.imread(image_path)
            else:
                img_original_bgr = None

        elif input_type == 'bbox_dir':
            if cur_frame < len(input_data):
                print("Use pre-computed bounding boxes")
                image_path = input_data[cur_frame]['image_path']
                hand_bbox_list = input_data[cur_frame]['hand_bbox_list']
                body_bbox_list = input_data[cur_frame]['body_bbox_list']
                img_original_bgr  = cv2.imread(image_path)
                load_bbox = True
            else:
                img_original_bgr = None

        elif input_type == 'video':      
            _, img_original_bgr = input_data.read()
            if video_frame < cur_frame:
                video_frame += 1
                continue
            # save the obtained video frames
            image_path = osp.join(args.out_dir, "frames", f"{cur_frame:05d}.jpg")
            if img_original_bgr is not None:
                video_frame += 1
                if args.save_frame:
                    gnu.make_subdir(image_path)
                    cv2.imwrite(image_path, img_original_bgr)

        elif input_type == 'webcam':    
            _, img_original_bgr = input_data.read()

            if video_frame < cur_frame:
                video_frame += 1
                continue
            # save the obtained video frames
            image_path = osp.join(args.out_dir, "frames", f"scene_{cur_frame:05d}.jpg")
            if img_original_bgr is not None:
                video_frame += 1
                if args.save_frame:
                    gnu.make_subdir(image_path)
                    cv2.imwrite(image_path, img_original_bgr)
        else:
            assert False, "Unknown input_type"

        cur_frame +=1
        if img_original_bgr is None or cur_frame > args.end_frame:
            break   
        print("--------------------------------------")

        if load_bbox:
            body_pose_list = None
        else:
            body_pose_list, body_bbox_list = body_bbox_detector.detect_body_pose(
                img_original_bgr)
        hand_bbox_list = [None, ] * len(body_bbox_list)

        # save the obtained body & hand bbox to json file
        if args.save_bbox_output: 
            demo_utils.save_info_to_json(args, image_path, body_bbox_list, hand_bbox_list)

        if len(body_bbox_list) < 1: 
            print(f"No body deteced: {image_path}")
            continue

        #Sort the bbox using bbox size 
        # (to make the order as consistent as possible without tracking)
        bbox_size =  [ (x[2] * x[3]) for x in body_bbox_list]
        idx_big2small = np.argsort(bbox_size)[::-1]
        body_bbox_list = [ body_bbox_list[i] for i in idx_big2small ]
        if args.single_person and len(body_bbox_list)>0:
            body_bbox_list = [body_bbox_list[0], ]       

        # Body Pose Regression
        pred_output_list = body_mocap.regress(img_original_bgr, body_bbox_list)
        assert len(body_bbox_list) == len(pred_output_list)

        # extract mesh for rendering (vertices in image space and faces) from pred_output_list
        pred_mesh_list = demo_utils.extract_mesh_from_output(pred_output_list)

        # Generate random SMPL body (Feel free to set up ur own smpl) as target subject
        pred_output = pred_output_list[0]

        pred_mesh_dict_1 = dress_garment(img_original_bgr, smpl, model_data, pred_output, 
                                         garment_org_body_unposed_1, garment_unposed_1, 
                                         vert_inds_1, garment_tex_1)
        
        pred_mesh_dict_2 = dress_garment(img_original_bgr, smpl, model_data, pred_output, 
                                         garment_org_body_unposed_2, garment_unposed_2, 
                                         vert_inds_2, garment_tex_2)
        
        pred_mesh_list = [pred_mesh_dict_1, pred_mesh_dict_2]

        # visualization
        res_img = visualizer.visualize(
            img_original_bgr,
            pred_mesh_list=pred_mesh_list)
        
        # show result in the screen
        if not args.no_display:
            res_img = res_img.astype(np.uint8)
            ImShow(res_img)

        # save result image
        if args.out_dir is not None:
            demo_utils.save_res_img(args.out_dir, image_path, res_img)

        # save predictions to pkl
        if args.save_pred_pkl:
            demo_type = 'body'
            demo_utils.save_pred_to_pkl(
                args, demo_type, image_path, body_bbox_list, hand_bbox_list, pred_output_list)

        timer.toc(bPrint=True,title="Time")
        print(f"Processed : {image_path}")

    #save images as a video
    if not args.no_video_out and input_type in ['video', 'webcam']:
        demo_utils.gen_video_out(args.out_dir, args.seq_name)

    if input_type =='webcam' and input_data is not None:
        input_data.release()
    cv2.destroyAllWindows()

## Запуск на отдельном видео

Загрузим несколько примеров видео в папку для обработки моделью.

In [7]:
! cp -a /content/drive/MyDrive/Videos/.  /content/frankmocap/sample_data

Инициализируем модели FrankMocap и MGN, а также выберем одежду для верхней и нижней части тела.

In [None]:
args = DemoOptions().parser.parse_args('--input_path ./sample_data/1022653957-preview.mp4 --out_dir ./mocap_output --renderer_type pytorch3d --no_display'.split())
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
assert torch.cuda.is_available(), "Current version only supports GPU"

# Set bbox detector
body_bbox_detector = BodyPoseEstimator()

# Set mocap regressor
use_smplx = args.use_smplx
checkpoint_path = args.checkpoint_body_smplx if use_smplx else args.checkpoint_body_smpl
print("use_smplx", use_smplx)
body_mocap = BodyMocap(checkpoint_path, args.smpl_dir, device, use_smplx)

# Set Visualizer
if args.renderer_type in ['pytorch3d', 'opendr']:
    from renderer.screen_free_visualizer import Visualizer
else:
    from renderer.visualizer import Visualizer
visualizer = Visualizer(args.renderer_type)

# MGN part
path = '/content/Multi-Garment_dataset/'
all_scans = glob(path + '*')
garment_classes = ['Pants', 'ShortPants', 'ShirtNoCoat', 'TShirtNoCoat', 'LongCoat']
gar_dict = {}
for gar in garment_classes:
    gar_dict[gar] = glob(join(path, '*', gar + '.obj'))

dp = SmplPaths()
vt, ft = dp.get_vt_ft_hres()
model_data = dp.get_hres_smpl_model_data()
smpl = Smpl(model_data)

## This file contains correspondances between garment vertices and smpl body
fts_file = '/content/MultiGarmentNetwork/assets/garment_fts.pkl'
vert_indices, fts = pkl.load(open(fts_file, 'rb') , encoding='latin1')
fts['naked'] = ft

## Choose T-shirt and set up garment
garment_type_tshirt = 'LongCoat'
index = np.random.randint(0, len(gar_dict[garment_type_tshirt]))   ## Randomly pick from the digital wardrobe
path_tshirt = split(gar_dict[garment_type_tshirt][index])[0]

garment_tshirt_org_body_unposed = load_smpl_from_file(join(path_tshirt, 'registration.pkl'))
garment_tshirt_org_body_unposed.pose[:] = 0
garment_tshirt_org_body_unposed.trans[:] = 0
garment_tshirt_org_body_unposed = Mesh(garment_tshirt_org_body_unposed.v, garment_tshirt_org_body_unposed.f)

garment_tshirt_unposed = Mesh(filename=join(path_tshirt, garment_type_tshirt + '.obj'))
garment_tshirt_tex = join(path_tshirt, 'multi_tex.jpg')

vert_inds_tshirt = vert_indices[garment_type_tshirt]

## Choose Pants and set up garment
garment_type_pants = 'Pants'
index = np.random.randint(0, len(gar_dict[garment_type_pants]))   ## Randomly pick from the digital wardrobe
path_pants = split(gar_dict[garment_type_pants][index])[0]

garment_pants_org_body_unposed = load_smpl_from_file(join(path_pants, 'registration.pkl'))
garment_pants_org_body_unposed.pose[:] = 0
garment_pants_org_body_unposed.trans[:] = 0
garment_pants_org_body_unposed = Mesh(garment_pants_org_body_unposed.v, garment_pants_org_body_unposed.f)

garment_pants_unposed = Mesh(filename=join(path_pants, garment_type_pants + '.obj'))
garment_pants_tex = join(path_pants, 'multi_tex.jpg')

vert_inds_pants = vert_indices[garment_type_pants]

In [31]:
! rm -r /content/frankmocap/mocap_output

Запустим модель на выбранном видео.

In [None]:
run_body_mocap(args, body_bbox_detector, body_mocap,
               visualizer, smpl, model_data,
               garment_tshirt_org_body_unposed, 
               garment_tshirt_unposed, vert_inds_tshirt, garment_tshirt_tex,
               garment_pants_org_body_unposed,
               garment_pants_unposed, vert_inds_pants, garment_pants_tex)

## Итоговое видео

Добавим получившееся итоговое видео.

In [33]:
# Источник: https://stackoverflow.com/a/65273831

from IPython.display import HTML
from base64 import b64encode
import os

# Input video path
save_path = "mocap_output/1022653957-preview.mp4"

# Compressed video path
compressed_path = "mocap_output/dancing_man_compressed.mp4"

os.system(f"ffmpeg -i {save_path} -vcodec libx264 {compressed_path}")

# Show video
mp4 = open(compressed_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=800 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)