In [1]:
import numpy as np
from pytorch_lightning import seed_everything
import torch as th
import os
import sys
sys.path.insert(0, '../../')
from sample_scripts.sample_utils.params_utils import get_params_set
from sample_scripts.sample_utils.inference_utils import to_tensor
from sample_scripts.sample_utils.vis_utils import plot_image
device = 'cuda'

In [2]:
def params_to_model(shape, exp, pose, cam, lights=None, i=0, uvdn=None):

    from model_3d.FLAME import FLAME
    from model_3d.FLAME.config import cfg as flame_cfg
    from model_3d.FLAME.utils.renderer import SRenderY
    import model_3d.FLAME.utils.util as util

    flame = FLAME.FLAME(flame_cfg.model).cuda()
    verts, landmarks2d, landmarks3d = flame(shape_params=shape, 
            expression_params=exp, 
            pose_params=pose)
    renderer = SRenderY(image_size=256, obj_filename=flame_cfg.model.topology_path, uv_size=flame_cfg.model.uv_size).to(device)

    ## projection
    landmarks2d = util.batch_orth_proj(landmarks2d, cam)[:,:,:2]; landmarks2d[:,:,1:] = -landmarks2d[:,:,1:]#; landmarks2d = landmarks2d*self.image_size/2 + self.image_size/2
    landmarks3d = util.batch_orth_proj(landmarks3d, cam); landmarks3d[:,:,1:] = -landmarks3d[:,:,1:] #; landmarks3d = landmarks3d*self.image_size/2 + self.image_size/2
    trans_verts = util.batch_orth_proj(verts, cam); trans_verts[:,:,1:] = -trans_verts[:,:,1:]

    ## rendering
    shape_images = renderer.render_shape(verts, trans_verts, lights=lights)

    # opdict = {'verts' : verts,}
    # os.makedirs('./rendered_obj', exist_ok=True)
    # save_obj(renderer=renderer, filename=(f'./rendered_obj/{i}.obj'), opdict=opdict)
    
    return {"shape_images":shape_images, "landmarks2d":landmarks2d, "landmarks3d":landmarks3d}


In [3]:
set_ = 'valid'
deca_params = get_params_set(set_)
deca_params_train = get_params_set('train')


Key=> shape : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-shape-anno.txt
Key=> pose : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-pose-anno.txt
Key=> exp : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-exp-anno.txt
Key=> cam : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-cam-anno.txt
Key=> light : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-light-anno.txt
Key=> faceemb : Filename=>/data/mint/ffhq_256_with_anno/params/valid/ffhq-valid-faceemb-anno.txt
Key=> shape : Filename=>/data/mint/ffhq_256_with_anno/params/train/ffhq-train-shape-anno.txt
Key=> pose : Filename=>/data/mint/ffhq_256_with_anno/params/train/ffhq-train-pose-anno.txt
Key=> exp : Filename=>/data/mint/ffhq_256_with_anno/params/train/ffhq-train-exp-anno.txt
Key=> cam : Filename=>/data/mint/ffhq_256_with_anno/params/train/ffhq-train-cam-anno.txt
Key=> light : Filename=>/data/mint/ffhq_256_with_anno/params/train/ffhq-train-light-an

In [4]:
img_name = '60000.jpg'
param = deca_params[img_name]
param = to_tensor(param, key=['shape', 'exp', 'pose', 'cam', 'light'], device=device)
out = params_to_model(shape=param['shape'][None, ...].to(th.float32), 
                exp=param['exp'][None, ...].to(th.float32), 
                pose=param['pose'][None, ...].to(th.float32), 
                cam=param['cam'][None, ...].to(th.float32),
                lights=param['light'][None, ...].reshape(-1, 9, 3).to(th.float32))

path = f'./render_face/{set_}'
os.makedirs(path, exist_ok=True)
plot_image(out['shape_images'], [3], fn=f"./{path}/{img_name.split('.')[0]}")

creating the FLAME Decoder




Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]


: 

In [39]:
shape = []; pose = []; exp = []; cam = []
for k, v in deca_params_train.items():
    shape.append(th.tensor(v['shape']))
    pose.append(th.tensor(v['pose']))
    exp.append(th.tensor(v['exp']))
    cam.append(th.tensor(v['cam']))
shape = th.stack(shape)
pose = th.stack(pose)
exp = th.stack(exp)
cam = th.stack(cam)

torch.Size([100])

In [45]:
seed_everything(47)
rand = np.random.randint(0, len(deca_params.keys()), size=10)
path = f'./render_face/{set_}'
for i in rand:
    img_name = list(deca_params.keys())[i]
    save_path = f'{path}/{img_name}'
    os.makedirs(save_path, exist_ok=True)
    param = deca_params[img_name]
    param = to_tensor(param, key=['shape', 'exp', 'pose', 'cam', 'light'], device=device)

    out = params_to_model(shape=param['shape'][None, ...].to(th.float32), 
                    exp=param['exp'][None, ...].to(th.float32), 
                    pose=param['pose'][None, ...].to(th.float32), 
                    cam=param['cam'][None, ...].to(th.float32),
                    lights=param['light'][None, ...].reshape(-1, 9, 3).to(th.float32)
    )
    os.makedirs(save_path + '/with_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/with_light/{img_name.split('.')[0]}_with_light")
    print(th.amax(out['shape_images'], dim=(2, 3)))

    out = params_to_model(shape=param['shape'][None, ...].to(th.float32), 
                    exp=param['exp'][None, ...].to(th.float32), 
                    pose=param['pose'][None, ...].to(th.float32), 
                    cam=param['cam'][None, ...].to(th.float32),
    )
    os.makedirs(save_path + '/without_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/without_light/{img_name.split('.')[0]}")
    print(th.amax(out['shape_images'], dim=(2, 3)))

    out = params_to_model(shape=param['shape'][None, ...].to(th.float32)*0, 
                    exp=param['exp'][None, ...].to(th.float32)*0, 
                    pose=param['pose'][None, ...].to(th.float32)*0,
                    cam=param['cam'][None, ...].to(th.float32),
                    lights=param['light'][None, ...].reshape(-1, 9, 3).to(th.float32)
    )
    os.makedirs(save_path + '/template_face_with_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/template_face_with_light/{img_name.split('.')[0]}")
    print(th.amax(out['shape_images'], dim=(2, 3)))

    out = params_to_model(shape=param['shape'][None, ...].to(th.float32)*0, 
                    exp=param['exp'][None, ...].to(th.float32)*0, 
                    pose=param['pose'][None, ...].to(th.float32)*0,
                    cam=param['cam'][None, ...].to(th.float32),
    )
    os.makedirs(save_path + '/template_face_without_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/template_face_without_light/{img_name.split('.')[0]}")
    print(th.amax(out['shape_images'], dim=(2, 3)))

    out = params_to_model(shape=th.mean(shape, dim=0)[None, ...].to(th.float32).to(device), 
                    exp=th.mean(exp, dim=0)[None, ...].to(th.float32).to(device), 
                    pose=th.mean(pose, dim=0)[None, ...].to(th.float32).to(device),
                    cam=th.mean(cam, dim=0)[None, ...].to(th.float32).to(device),
                    lights=param['light'][None, ...].reshape(-1, 9, 3).to(th.float32)
    )
    os.makedirs(save_path + '/mean_face_with_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/mean_face_with_light/{img_name.split('.')[0]}")
    print(th.amax(out['shape_images'], dim=(2, 3)))

    out = params_to_model(shape=th.mean(shape, dim=0)[None, ...].to(th.float32).to(device), 
                    exp=th.mean(exp, dim=0)[None, ...].to(th.float32).to(device), 
                    pose=th.mean(pose, dim=0)[None, ...].to(th.float32).to(device),
                    cam=th.mean(cam, dim=0)[None, ...].to(th.float32).to(device),
    )
    os.makedirs(save_path + '/mean_face_without_light/', exist_ok=True)
    plot_image(out['shape_images'], [3], fn=f"./{save_path}/mean_face_without_light/{img_name.split('.')[0]}")
    print(th.amax(out['shape_images'], dim=(2, 3)))
    assert False



Global seed set to 47


creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[1.0644, 1.0644, 1.0659]], device='cuda:0')
creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[0.7938, 0.7938, 0.7938]], device='cuda:0')
creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[1.0231, 1.0465, 1.0391]], device='cuda:0')
creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[0.7934, 0.7934, 0.7934]], device='cuda:0')
creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[1.0851, 1.0976, 1.0950]], device='cuda:0')
creating the FLAME Decoder
Image shape :  torch.Size([1, 3, 256, 256])
Channel length :  [3]
tensor([[0.7939, 0.7939, 0.7939]], device='cuda:0')


AssertionError: 