In [2]:
import torch as th
import sys, os
import tqdm
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"
sys.path.insert(0, '../')
from decalib.deca import DECA
from decalib.datasets import datasets 
from decalib.utils import util
from decalib.utils.config import cfg as deca_cfg
from decalib.utils.tensor_cropper import transform_points
from sample_scripts.sample_utils.params_utils import get_params_set
from sample_scripts.sample_utils.vis_utils import plot_image


def load_face(set_, path):
    deca_params = get_params_set(set_, params_key = ['shape', 'pose', 'exp', 'cam', 'light', 'tform', 'detail'], path=path)

    shape = []; pose = []; exp = []; cam = []; light = []; tform = []; detail = []
    for k, v in deca_params.items():
        
        shape.append(th.tensor(v['shape']))
        pose.append(th.tensor(v['pose']))
        exp.append(th.tensor(v['exp']))
        cam.append(th.tensor(v['cam']))
        light.append(th.tensor(v['light']))
        tform.append(th.tensor(v['tform']))
        detail.append(th.tensor(v['detail']))
    shape = th.stack(shape)
    pose = th.stack(pose)
    exp = th.stack(exp)
    cam = th.stack(cam)
    light = th.stack(light)
    tform = th.stack(tform)
    detail = th.stack(detail)
    concat_dict = {'shape':shape, 'pose':pose, 'exp':exp, 'cam':cam, 'light':light, 'tform':tform, 'detail':detail}
    return concat_dict, deca_params

device = 'cuda'
deca_cfg.model.use_tex = False
deca_cfg.rasterizer_type = 'standard'
deca_cfg.model.extract_tex = True
deca = DECA(config = deca_cfg, device=device)
deca_concat_dict_train, deca_id_dict_train = load_face('train', path='/data/mint/DPM_Dataset/ffhq_256_with_anno/params/')
deca_concat_dict_val, deca_id_dict_val = load_face('valid', path='/data/mint/DPM_Dataset/ffhq_256_with_anno/params/')

creating the FLAME Decoder
trained model found. load /home/mint/Relighting_preprocessing/DECA/data/deca_model.tar
Key=> shape : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-shape-anno.txt
Key=> pose : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-pose-anno.txt
Key=> exp : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-exp-anno.txt
Key=> cam : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-cam-anno.txt
Key=> light : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-light-anno.txt
Key=> tform : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-tform-anno.txt
Key=> detail : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//train/ffhq-train-detail-anno.txt
Key=> shape : Filename=>/data/mint/DPM_Dataset/ffhq_256_with_anno/params//valid/ffhq-valid-shape-anno.txt
Key=> pose : Filename=>/data/mint/DPM_Dataset/

In [5]:
img_name = '60065.jpg'
data_path = "/data/mint/ffhq_256_with_anno/ffhq_256/valid/"
testdata = datasets.TestData(f"{data_path}/{img_name}", iscrop=True, face_detector='fan', sample_step=10)

name = testdata[0]['imagename']
images = testdata[0]['image'].to(device)[None,...]
codedict = deca_id_dict_val[img_name].copy()
codedict['images'] = images
for p in ['shape', 'pose', 'exp', 'cam', 'light', 'tform', 'detail']:
    if p == 'light':
        codedict['light'] = codedict['light'].reshape(9, 3)
    if p == 'tform':
        codedict['tform'] = codedict['tform'].reshape(3, 3)
    codedict[p] = th.tensor(codedict[p]).float().to(device)[None, ...]
    print(p, codedict[p].shape)
    
print(testdata[0]['tform'])
print(codedict['tform'])
tform_inv = th.inverse(codedict['tform']).transpose(1,2).to(device)
print(tform_inv)
original_image = testdata[0]['original_image'][None, ...].to(device)
_, orig_visdict = deca.decode(codedict, original_image=original_image, render_orig=True, tform=tform_inv, use_template=False, use_detail=False, mean_cam=None)
plot_image(orig_visdict['shape_images'], [3])

shape torch.Size([1, 100])
pose torch.Size([1, 6])
exp torch.Size([1, 50])
cam torch.Size([1, 3])
light torch.Size([1, 9, 3])
tform torch.Size([1, 3, 3])
detail torch.Size([1, 128])
tensor([[ 1.1436e+00, -4.8790e-17, -3.8882e+01],
        [ 7.8174e-17,  1.1436e+00, -7.9479e+01],
        [ 0.0000e+00,  0.0000e+00,  1.0000e+00]])
tensor([[[ 1.1436e+00, -4.8790e-17, -3.8882e+01],
         [ 7.8174e-17,  1.1436e+00, -7.9479e+01],
         [ 0.0000e+00,  0.0000e+00,  1.0000e+00]]], device='cuda:0')
tensor([[[ 8.7444e-01, -5.9775e-17,  0.0000e+00],
         [ 3.7307e-17,  8.7444e-01,  0.0000e+00],
         [ 3.4000e+01,  6.9500e+01,  1.0000e+00]]], device='cuda:0')
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]


# Render with predicted shape

In [13]:
data_path = "/data/mint/ffhq_256_with_anno/ffhq_256/valid/"
testdata = datasets.TestData(f"{data_path}/", iscrop=True, face_detector='fan', sample_step=10)
savefolder = "/data/mint/ffhq_256_with_anno/for_verifying_shape_images/valid/"
os.makedirs(savefolder, exist_ok=True)

for i in tqdm.tqdm(range(len(testdata))):
    if i == 300:
        break
    name = testdata[i]['imagename']
    images = testdata[i]['image'].to(device)[None,...]
    # codedict = deca_id_dict_train[f"{name}.jpg"].copy()
    codedict = deca_id_dict_val[f"{name}.jpg"].copy()
    codedict['images'] = images
    for p in ['shape', 'pose', 'exp', 'cam', 'light', 'tform', 'detail']:
        if p == 'light':
            codedict['light'] = codedict['light'].reshape(9, 3)
        if p == 'tform':
            codedict['tform'] = codedict['tform'].reshape(3, 3)
        codedict[p] = th.tensor(codedict[p]).float().to(device)[None, ...]
    #     print(p, codedict[p].shape)
        
    tform_inv = th.inverse(codedict['tform']).transpose(1,2).to(device)
    original_image = testdata[0]['original_image'][None, ...].to(device)
    _, orig_visdict = deca.decode(codedict, original_image=original_image, render_orig=True, tform=tform_inv, use_template=False, mean_cam=None)    
    
    from torchvision.utils import save_image
    save_image(orig_visdict['shape_images'], fp=f"{savefolder}/{name}.png")

  3%|▎         | 300/10000 [00:45<24:29,  6.60it/s]


# Render with template shape

In [6]:
data_path = "/data/mint/DPM_Dataset/ffhq_256_with_anno/ffhq_256/valid/"
testdata = datasets.TestData(f"{data_path}/", iscrop=True, face_detector='fan', sample_step=10)
savefolder = "/data/mint/ffhq_256_with_anno/for_verifying_template_shape_images/valid/"

# Calculate the average cam/tform
mean_cam = th.mean(deca_concat_dict_train['cam'], dim=0, keepdims=True).float().to(device)
mean_tform = th.mean(deca_concat_dict_train['tform'], dim=0).float().to(device)
mean_tform = (mean_tform.reshape(3, 3))[None, ...]

os.makedirs(savefolder, exist_ok=True)
for i in tqdm.tqdm(range(len(testdata))):
    if i == 10:
        break
    name = testdata[i]['imagename']
    images = testdata[i]['image'].to(device)[None,...]
    codedict = deca_id_dict_val[f"{name}.jpg"].copy()
    codedict['images'] = images
    for p in ['shape', 'pose', 'exp', 'cam', 'light', 'tform', 'detail']:
        if p == 'light':
            codedict['light'] = codedict['light'].reshape(9, 3)
        if p == 'tform':
            codedict['tform'] = codedict['tform'].reshape(3, 3)
        codedict[p] = th.tensor(codedict[p]).float().to(device)[None, ...]
        
    tform_inv = th.inverse(mean_tform).transpose(1,2).to(device)
    original_image = testdata[0]['original_image'][None, ...].to(device)
    _, orig_visdict = deca.decode(codedict, original_image=original_image, render_orig=True, tform=tform_inv, use_template=True, mean_cam=mean_cam)    
    # import numpy as np
    # print(np.max(orig_visdict['shape_images'].cpu().numpy(), axis=(2, 3)))
    # print(np.min(orig_visdict['shape_images'].cpu().numpy(), axis=(2, 3)))
    
    from torchvision.utils import save_image
    save_image(orig_visdict['shape_images'], fp=f"{savefolder}/{name}.png")

  0%|          | 1/10000 [00:01<3:47:42,  1.37s/it]

[[0.8689835  0.87447834 0.87893146]]
[[0. 0. 0.]]


  0%|          | 2/10000 [00:02<2:52:13,  1.03s/it]

[[1.1004285 1.1045136 1.0924165]]
[[0. 0. 0.]]


  0%|          | 3/10000 [00:02<2:15:34,  1.23it/s]

[[1.2124145 1.2233125 1.2191988]]
[[0. 0. 0.]]


  0%|          | 4/10000 [00:03<1:53:56,  1.46it/s]

[[1.0952301 1.1152335 1.1188492]]
[[0. 0. 0.]]


  0%|          | 5/10000 [00:03<1:48:53,  1.53it/s]

[[0.9781778  0.9888724  0.99078053]]
[[0. 0. 0.]]


  0%|          | 6/10000 [00:04<1:39:53,  1.67it/s]

[[0.92282826 0.9259714  0.9242636 ]]
[[0. 0. 0.]]


  0%|          | 7/10000 [00:04<1:36:25,  1.73it/s]

[[0.8027229 0.7995497 0.7859293]]
[[0. 0. 0.]]


  0%|          | 8/10000 [00:05<1:37:33,  1.71it/s]

[[0.8889431  0.9120528  0.92131513]]
[[0. 0. 0.]]


  0%|          | 9/10000 [00:05<1:32:50,  1.79it/s]

[[0.8194454  0.82561123 0.833648  ]]
[[0. 0. 0.]]


  0%|          | 10/10000 [00:06<1:48:09,  1.54it/s]

[[1.0021014 1.0023755 1.0014464]]
[[0. 0. 0.]]



