In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import trimesh
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from utils.multi_garment_dataset import Multi_Garment_Dataset
from utils.util import load_obj, load_pickle_file, write_pickle_file, get_f2vts
from models.networks.smpl import SMPL
from models.networks.render import SMPLRenderer

In [2]:
data_root = 'data/Multi-Garment_dataset'
image_size = 256
tex_size = 3
batch_size = 2
num_frame = 2

In [3]:
train_dataset = Multi_Garment_Dataset(data_root=data_root, num_frame=num_frame)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [4]:
smpl = SMPL(pkl_path='pretrains/smpl_model.pkl', isHres=True).cuda()
smpl_renderer = SMPLRenderer(smpl.faces_hres, image_size=image_size, tex_size=tex_size).cuda()

In [5]:
for i, data in enumerate(train_loader):
    for key in data:
        print(key, data[key].size())
    
    with torch.no_grad():
        shape = data['shape'].cuda()
        poses = data['poses'].cuda()
        cams = data['cams'].cuda()
        v_personal = data['v_personal'].cuda()

        pose = poses[:, 0, :].contiguous()
        cam = cams[:, 0, :].contiguous()

        verts = smpl(shape, pose)
        verts_personal = smpl(shape, pose, v_personal)
        
        verts = smpl_renderer.project_to_image(verts, cam, flip=True, withz=True)
        verts_personal = smpl_renderer.project_to_image(verts_personal, cam, flip=True, withz=True)

        uv_img = data['uv_image'].cuda()
        f2vts = data['f2vts'].cuda()
        tex_gt = smpl_renderer.extract_tex(uv_img, smpl_renderer.points_to_sampler(f2vts))


        
        img_masked, mask = smpl_renderer(verts, tex_gt)
        img_masked_personal, mask_personal = smpl_renderer(verts_personal, tex_gt)
        
        fim, wim = smpl_renderer.render_fim_wim(verts_personal)
    
    if i >= 0:
        break

shape torch.Size([2, 10])
poses torch.Size([2, 2, 72])
cams torch.Size([2, 2, 3])
v_personal torch.Size([2, 27554, 3])
uv_image torch.Size([2, 3, 2048, 2048])
f2vts torch.Size([2, 55104, 3, 2])


In [6]:
print(fim.size(), wim.size())

torch.Size([2, 256, 256]) torch.Size([2, 256, 256, 3])


In [7]:
print(fim.min(), fim.max())

tensor(-1, device='cuda:0', dtype=torch.int32) tensor(55103, device='cuda:0', dtype=torch.int32)


In [None]:
mesh = trimesh.Trimesh(vertices=verts[0].cpu(), faces=smpl.faces_hres.cpu(), process=False)
mesh.show()

In [None]:
mesh = trimesh.Trimesh(vertices=verts_personal[0].cpu(), faces=smpl.faces_hres.cpu(), process=False)
mesh.show()

In [None]:
img_masked_vis = (img_masked.cpu().numpy()[0] * 255).astype(np.uint8).transpose(1, 2, 0)
plt.imshow(img_masked_vis)

In [None]:
img_masked_personal_vis = (img_masked_personal.cpu().numpy()[0] * 255).astype(np.uint8).transpose(1, 2, 0)
plt.imshow(img_masked_personal_vis)

In [None]:
tex = smpl_renderer.extract_tex_from_image(img_masked_personal, verts_personal)
print(tex.size())

In [None]:
img_masked_personal_tex, mask_personal_tex = smpl_renderer(verts_personal, tex)

In [None]:
img_masked_personal_tex_vis = (img_masked_personal_tex.cpu().numpy()[0] * 255).astype(np.uint8).transpose(1, 2, 0)
plt.imshow(img_masked_personal_tex_vis)

In [None]:
pose_T = torch.zeros(pose.size()).float().cuda()
print(pose_T.shape)

In [None]:
verts_T_personal = smpl(shape, pose_T, v_personal)
print(verts_T_personal.size())

In [None]:
mesh = trimesh.Trimesh(vertices=verts_T_personal[0].cpu(), faces=smpl.faces_hres.cpu(), process=False)
mesh.show()

In [None]:
people_IDs_list = os.listdir(data_root)
print(len(people_IDs_list))

In [None]:
garment_classes = ['Pants', 'ShortPants', 'ShirtNoCoat', 'TShirtNoCoat', 'LongCoat']

In [None]:
def get_hres(v, f):
    """
    Get an upsampled version of the mesh.
    OUTPUT:
        - nv: new vertices
        - nf: faces of the upsampled
        - mapping: mapping from low res to high res
    """
    from opendr.topology import loop_subdivider
    (mapping, nf) = loop_subdivider(v, f)
    nv = mapping.dot(v.ravel()).reshape(-1, 3)
    return (nv, nf, mapping)

def get_vt_ft():
    vt, ft = load_pickle_file('pretrains/smpl_vt_ft.pkl')
    return vt, ft

def get_vt_ft_hres():
    vt, ft = get_vt_ft()
    vt, ft, _ = get_hres(np.hstack((vt, np.ones((vt.shape[0], 1)))), ft)
    return vt[:, :2], ft

In [None]:
vt, ft = get_vt_ft()
print(vt.shape)
print(ft.shape)

In [None]:
vt_hres, ft_hres = get_vt_ft_hres()
print(vt_hres.shape)
print(ft_hres.shape)

In [None]:
vert_indices, fts = load_pickle_file('pretrains/garment_fts.pkl')
fts['naked'] = ft_hres
for key in vert_indices:
    print(key, vert_indices[key].shape)
print('------------')
for key in fts:
    print(key, fts[key].shape)

In [None]:
def get_shape_pose_cam_v_personal(smpl, people_ID, device='cuda:0'):

    smpl_registration_pkl = load_pickle_file(os.path.join(data_root, people_ID, 'registration.pkl'))
        
    shape = torch.from_numpy(smpl_registration_pkl['betas']).float().to(device)
    pose = torch.from_numpy(smpl_registration_pkl['pose']).float().to(device)
    
    pose_T = torch.zeros(pose.shape).float().to(device)
    
    verts_T = smpl(shape[None], pose_T[None])[0]
    
    cam = torch.zeros(3).float().to(device)
    cam[0] = (1 - torch.rand(1) * 0.2) / verts_T[:, 0:1].abs().max()
    cam[1] = - (verts_T[:, 0].min() + verts_T[:, 0].max()) / 2
    cam[2] = - (verts_T[:, 1].min() + verts_T[:, 1].max()) / 2
    
    v_personal = torch.zeros(verts_T.shape).to(device)
    print("---"+people_ID+"---")
    for garment_type in garment_classes:
        garment_obj_path = os.path.join(data_root, people_ID, garment_type+'.obj')
        vert_inds = torch.from_numpy(vert_indices[garment_type])
        if os.path.isfile(garment_obj_path):
            garment_obj = load_obj(garment_obj_path)
            garment_v = torch.from_numpy(garment_obj['vertices']).float().to(device)
            v_personal[vert_inds] = garment_v - verts_T[vert_inds]
            print(garment_type)
    return shape, pose, cam, v_personal

In [None]:
from utils.util import write_pickle_file
for people_ID in people_IDs_list:
    shape, pose, cam, v_personal = get_shape_pose_cam_v_personal(smpl, people_ID)
    smpl_registered_pkl = {'betas': shape.cpu().numpy(), 'pose': pose.cpu().numpy(), 'camera': cam.cpu().numpy(), 'v_personal': v_personal.cpu().numpy()}
    write_pickle_file(os.path.join(data_root, people_ID, 'smpl_registered.pkl'), smpl_registered_pkl)

In [None]:
people_ID = people_IDs_list[0]

In [None]:
shape, pose, cam, v_personal = get_shape_pose_cam_v_personal(smpl, people_ID)
print(shape.size(), pose.size(), cam.size(), v_personal.size())

In [None]:
pose_T = torch.zeros(pose.shape).float().cuda()
print(pose_T.shape)

In [None]:
verts_T_personal = smpl(shape[None], pose_T[None], v_personal[None])
print(verts_T_personal.size())

In [None]:
mesh = trimesh.Trimesh(vertices=verts_T_personal[0].cpu(), faces=smpl.faces_hres.cpu(), process=False)
mesh.show()

In [None]:
verts_personal = smpl(shape[None], pose[None], v_personal[None])
print(verts_personal.size())

In [None]:
mesh = trimesh.Trimesh(vertices=verts_personal[0].cpu(), faces=smpl.faces_hres.cpu(), process=False)
mesh.show()

In [None]:
uv_img = Image.open(os.path.join(data_root, people_ID, 'registered_tex.jpg')).convert('RGB')
plt.imshow(uv_img)

In [None]:
uv_img = transforms.ToTensor()(uv_img).cuda()
print(uv_img.size())

In [None]:
f2vts = get_f2vts(os.path.join(data_root, people_ID, 'smpl_registered.obj'))
f2vts = torch.from_numpy(f2vts).float().cuda()
print(f2vts.size())

In [None]:
tex = smpl_renderer.extract_tex(uv_img[None], smpl_renderer.f2vts_to_sampler(f2vts[None]))
print(tex.size())

In [None]:
img_T = smpl_renderer.render(verts_T_personal, cam[None], tex)
print(img_T.size())

In [None]:
img_T_vis = (img_T.detach().cpu().numpy()[0] * 255).astype(np.uint8).transpose(1, 2, 0)
plt.imshow(img_T_vis)

In [None]:
img = smpl_renderer.render(verts_personal, cam[None], tex)
print(img.size())

In [None]:
img_vis = (img.detach().cpu().numpy()[0] * 255).astype(np.uint8).transpose(1, 2, 0)
plt.imshow(img_vis)