In [3]:
from models.decalib.deca import DECA
from models.decalib.utils.config import cfg as deca_cfg
from models.decalib.datasets.detectors import FAN

import cv2 as cv
import os
import numpy as np
from skimage.transform import estimate_transform, warp
import torch

# Load FLAME model and DECA model
deca_cfg['model']['flame_model_path'] = './pretrained/generic_model.pkl'
deca_cfg['pretrained_modelpath'] = './pretrained/deca_model.tar'
deca_cfg['model']['flame_lmk_embedding_path'] = './pretrained/landmark_embedding.npy'
deca_cfg['model']['use_tex'] = False

deca = DECA(config=deca_cfg)
face_detector = FAN()

def bbox2point(left, right, top, bottom, type='bbox'):
    ''' bbox from detector and landmarks are different
    '''
    if type=='kpt68':
        old_size = (right - left + bottom - top)/2*1.1
        center = torch.tensor([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])
    elif type=='bbox':
        old_size = (right - left + bottom - top)/2
        center = torch.tensor([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0  + old_size*0.12])
    else:
        raise NotImplementedError
    return old_size, center

image = cv.imread('01464.png', cv.IMREAD_UNCHANGED)
        
if len(image.shape) == 3:
    if image.shape[2] == 4:
        alpha_channel = image[...,3]
        bgr_channels = image[...,:3]
        rgb_channels = cv.cvtColor(bgr_channels, cv.COLOR_BGR2RGB)
        
        # White Background Image
        background_image = np.zeros_like(rgb_channels, dtype=np.uint8)
        
        # Alpha factor
        alpha_factor = alpha_channel[:,:,np.newaxis].astype(np.float32) / 255.
        alpha_factor = np.concatenate((alpha_factor,alpha_factor,alpha_factor), axis=2)

        # Transparent Image Rendered on White Background
        base = rgb_channels * alpha_factor
        background = background_image * (1 - alpha_factor)
        image = base + background
    else:
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

bbox, bbox_type = face_detector.run(image)

left = bbox[0]; right=bbox[2]
top = bbox[1]; bottom=bbox[3]

old_size, center = bbox2point(left, right, top, bottom, type=bbox_type)
size = int(old_size*1.25)

src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])

DST_PTS = np.array([[0, 0], [0, 223], [223, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)

image = image / 255.

dst_image = warp(image, tform.inverse, output_shape=(224, 224))
dst_image = dst_image.transpose(2,0,1)
with torch.no_grad():
    codedict = deca.encode(torch.tensor(dst_image).float().cuda()[None])

    tform = torch.tensor(tform.params).float()[None, ...]
    original_image = torch.tensor(image[None]).float().cuda()
    tform = torch.inverse(tform).transpose(1,2).cuda()
    opdict, visdict = deca.decode(codedict, render_orig=True, original_image=original_image.permute(0,3,1,2), tform=tform)





In [None]:
cv.imwrite('vis.png', deca.visualize(visdict))

In [4]:
deca.save_obj('01464.obj', opdict)

In [None]:
import torch

import cv2 as cv
import numpy as np

from pytorch3d.io import load_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh.textures import TexturesUV

from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings

import torchvision.transforms.functional as tvf

import utils.io as io

base_path = 'dataset/test_data/'

# load mesh
verts, faces, aux = load_obj(base_path + 'OBJs/Head.obj', device='cuda')

In [None]:
def load_sdr(image_name):
    
    image = cv.imread(image_name, cv.IMREAD_UNCHANGED)
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    
    if image.dtype == 'uint8':
        image = image / 255.
    elif image.dtype == 'uint16':
        image = image / 65535.
    
    image = cv.resize(image, (4096, 4096), interpolation=cv.INTER_NEAREST)
    
    return torch.from_numpy(image).float()

# load textures
basecolor = load_sdr(base_path + f'/Textures/Albedo.png').cuda()
roughness = load_sdr(base_path + f'/Textures/Roughness_CR.png').cuda()
specular = load_sdr(base_path + f'/Textures/Specular_CC.png').cuda()
#normal = load_sdr(base_path + f'/Textures/Normal.png').cuda()
#normal = torch.nn.functional.normalize(normal, dim=-1)

verts_uvs = aux.verts_uvs.cuda()
faces_uvs = faces.textures_idx.cuda()
image = torch.cat([basecolor, roughness[...,0][..., None], specular[...,0][..., None]], dim=-1)[None]
texture = TexturesUV(verts_uvs=[verts_uvs], faces_uvs=[faces_uvs], maps=image)

mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=texture).to('cuda')
camera = PerspectiveCameras(in_ndc=False, image_size=[(512, 512)], device='cuda')
rasterizer_settings = RasterizationSettings(image_size=512)
rasterizer = MeshRasterizer(raster_settings=rasterizer_settings, cameras=camera)
rasterizer = rasterizer.to('cuda')

In [None]:
from utils.transforms import B2P

# RT = torch.tensor([[1.0, 0.0, 0.0, 0.0],
#                    [0.0, -0.052335940301418304, -0.9986295104026794, -0.699999988079071],
#                    [0.0, 0.9986295104026794, -0.052335940301418304, 0.800000011920929],
#                    [0.0, 0.0, 0.0, 1.0]]).cuda()

RT = torch.tensor([[0.0, 0.0, 0.0, 0.0],
                   [0.0, -0.052335940301418304, -0.9986295104026794, -0.699999988079071],
                   [0.0, 0.9986295104026794, -0.052335940301418304, 0.800000011920929],
                   [0.0, 0.0, 0.0, 1.0]]).cuda()

R, T, RT_4x4 = B2P(RT)

#R, T = look_at_view_transform(dist=1.5, elev=-2, azim=3.1415, device='cuda') 

K = torch.tensor([[512., 0., 256., 0.],
                  [0., 512., 256., 0.],
                  [0., 0., 0., 1.],
                  [0., 0., 1., 0.]]).cuda()

with torch.no_grad():
    
    rasterizer.cameras.R = R[None]
    rasterizer.cameras.T = T[None]
    rasterizer.cameras.K = K[None]
    
    fragments = rasterizer(mesh)

    textures = mesh.textures.sample_textures(fragments).squeeze(3)

tvf.to_pil_image(textures[0][...,:3].permute(2,0,1)).save('test.png')