In [1]:
import os, json
import numpy as np
import torch as th
import pandas as pd
import tqdm, glob
from collections import defaultdict
import copy

def load_deca_params(deca_dir):
    deca_params = {}

    # face params 
    params_key = ['shape', 'pose', 'exp', 'cam', 'light', 'tform', 'albedo', 'detail']
    # params_key = ['shape', 'pose', 'exp', 'cam', 'light', 'faceemb', 'shadow']
    for k in tqdm.tqdm(params_key, desc="Loading deca params..."):
        params_path = glob.glob(f"{deca_dir}/*{k}-anno.txt")
        for path in params_path:
            deca_params[k] = read_params(path=path)
    
    deca_params = swap_key(deca_params)
    return deca_params

def read_params(path):
    params = pd.read_csv(path, header=None, sep=" ", index_col=False, lineterminator='\n')
    params.rename(columns={0:'img_name'}, inplace=True)
    params = params.set_index('img_name').T.to_dict('list')
    return params

def swap_key(params):
    params_s = defaultdict(dict)
    for params_name, v in params.items():
        for img_name, params_value in v.items():
            params_s[img_name][params_name] = np.array(params_value).astype(np.float64)

    return params_s

In [4]:

def swap_params(set_):
    params_path = f'/data/mint/DPM_Dataset/ffhq_256_with_anno/params/{set_}'
    params_replace_path = f'/data/mint/DPM_Dataset/Fixing_DECA_detector/params/{set_}'

    p = load_deca_params(params_path)
    p_rep = load_deca_params(params_replace_path)

    for k, v in p_rep.items():
        # print(k)
        if k in p_rep.keys():
            for p_name in p[k].keys():
                # print(np.mean(p[k][p_name] - p_rep[k][p_name]))
                p[k][p_name] = copy.deepcopy(p_rep[k][p_name])
                # print(np.mean(p[k][p_name] - p_rep[k][p_name]))
        
    # op = f'/data/mint/DPM_Dataset/ffhq_256_with_anno/params/{set_}_fix/'
    op = f'/data/mint/DPM_Dataset/ffhq_256_with_anno/params/{set_}/'
    os.makedirs(op, exist_ok=True)
    fo_shape = open(f"{op}/ffhq-{set_}-shape-anno.txt", "w")
    fo_exp = open(f"{op}/ffhq-{set_}-exp-anno.txt", "w")
    fo_pose = open(f"{op}/ffhq-{set_}-pose-anno.txt", "w")
    fo_light = open(f"{op}/ffhq-{set_}-light-anno.txt", "w")
    fo_cam = open(f"{op}/ffhq-{set_}-cam-anno.txt", "w")
    fo_detail = open(f"{op}/ffhq-{set_}-detail-anno.txt", "w")
    fo_tform = open(f"{op}/ffhq-{set_}-tform-anno.txt", "w")
    fo_albedo = open(f"{op}/ffhq-{set_}-albedo-anno.txt", "w")
                
    fo_dict = {'shape':fo_shape, 'exp':fo_exp, 'pose':fo_pose, 
            'light':fo_light, 'cam':fo_cam, 'detail':fo_detail,
            'tform':fo_tform, 'albedo':fo_albedo}


    for img_name in tqdm.tqdm(p.keys()):      # Per image
            for p_name, p_val in p[img_name].items():      # Per param
                    fo_dict[p_name].write(img_name + " ")
                    p_val = p_val.flatten()
                    # print(img_name, p_name, p_val)
                    # assert False
                    fo_dict[p_name].write(" ".join([str(x) for x in p_val]) + "\n")


In [5]:
swap_params('valid')
swap_params('train')

Loading deca params...: 100%|██████████| 8/8 [00:03<00:00,  2.12it/s]
Loading deca params...: 100%|██████████| 8/8 [00:00<00:00, 218.02it/s]
100%|██████████| 10000/10000 [00:01<00:00, 5641.96it/s]
Loading deca params...: 100%|██████████| 8/8 [00:30<00:00,  3.84s/it]
Loading deca params...: 100%|██████████| 8/8 [00:00<00:00, 195.72it/s]
100%|██████████| 60000/60000 [00:11<00:00, 5420.17it/s]
