# Setup

In [None]:
%matplotlib inline
!pip install face-alignment

In [None]:
import torch, torchvision
import os
import numpy as np
from skimage.io import imread, imsave
from skimage.transform import estimate_transform, warp, resize, rescale
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import face_alignment

In [None]:
# connect to drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Cuda availability : " + str(device))

Cuda availability : cuda


Repository: https://github.com/srinath2022/DECA_FFHQ  

Datasets: https://drive.google.com/drive/folders/1VMB9SdgJmfYxaMlGzKC5P2Jojmc37Rvb?usp=sharing 

# DataSet Setup Trials

In [None]:
# Constants
dataset_path = '/content/FFHQ-test-10K'
FAN_landmarks_path = '/content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-10K-landmarks'

In [None]:
!cp /content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-10K.zip /content/

In [None]:
!ls

drive  FFHQ-10K.zip  FFHQ-test-10K  FFHQ-train-10K  sample_data


In [None]:
!unzip FFHQ-10K.zip

## FAN - Face Alignment Landmarks

In [None]:
def FAN_Landmarks(image):
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
  preds = fa.get_landmarks(image)
  return preds

def Extract_FAN_Landmarks(path, dst_path):
  i = 0;
  for imgname in os.listdir(path):
    i = i+1
    outfile = dst_path+"/"+imgname.split('.')[0]+'.npy'
    if os.path.exists(outfile):
      print(i, " exists")
      continue
    image = imread(path+"/"+imgname)
    landmarks = FAN_Landmarks(image)
    # save landmarks
    np.save(outfile, landmarks)
    print(i, " saved")

In [None]:
Extract_FAN_Landmarks(dataset_path, FAN_landmarks_path)

1  saved
2  saved
3  saved
4  saved
5  saved
6  saved
7  saved
8  saved
9  saved
10  saved
11  saved
12  saved
13  saved
14  saved
15  saved
16  saved
17  saved
18  saved
19  saved
20  saved
21  saved
22  saved
23  saved
24  saved
25  saved
26  saved
27  saved
28  saved
29  saved
30  saved
31  saved
32  saved
33  saved
34  saved
35  saved
36  saved
37  saved
38  saved
39  saved
40  saved
41  saved
42  saved
43  saved
44  saved
45  saved
46  saved
47  saved
48  saved
49  saved
50  saved
51  saved
52  saved
53  saved
54  saved
55  saved
56  saved
57  saved
58  saved
59  saved
60  saved
61  saved
62  saved
63  saved
64  saved
65  saved
66  saved
67  saved
68  saved
69  saved
70  saved
71  saved
72  saved
73  saved
74  saved
75  saved
76  saved
77  saved
78  saved
79  saved
80  saved
81  saved
82  saved
83  saved
84  saved
85  saved
86  saved
87  saved
88  saved
89  saved
90  saved
91  saved
92  saved
93  saved
94  saved
95  saved
96  saved
97  saved
98  saved
99  saved
100  saved
101  sav

# Dataset

In [None]:
# FFHQ Dataset
class FFHQDataset(Dataset):
    def __init__(self, image_size, scale, trans_scale = 0, isEval=False):
        self.image_size  = image_size
        self.imagefolder = '/content/resized'
        self.images_list = os.listdir(self.imagefolder)
        self.scale = scale #[scale_min, scale_max]
        self.trans_scale = trans_scale # 0.5?
        self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False)

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        while(10):
            imgname = self.images_list[idx]
            image_path = os.path.join(self.imagefolder, imgname)
            image = imread(image_path)
            kpt = self.fan_landmarks(image)
            if len(kpt.shape) != 2:
                idx = np.random.randint(low=0, high=len(self.images_list))
                continue
            # print(kpt_path, kpt.shape)
            # kpt = kpt[:,:2]

            image = image/255.
            if len(image.shape) < 3:
                image = np.tile(image[:,:,None], 3)
            ### crop information
            tform = self.crop(image, kpt)
            ## crop 
            cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
            cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)

            # normalized kpt
            cropped_kpt[:,:2] = cropped_kpt[:,:2]/self.image_size * 2  - 1

            ###
            images_array = torch.from_numpy(cropped_image.transpose(2,0,1)).type(dtype = torch.float32) #224,224,3
            kpt_array = torch.from_numpy(cropped_kpt).type(dtype = torch.float32) #224,224,3
                        
            data_dict = {
                'image': images_array,
                'landmark': kpt_array,
                # 'mask': mask_array
            }
            
            return data_dict
        
    def crop(self, image, kpt):
        left = np.min(kpt[:,0]); right = np.max(kpt[:,0]); 
        top = np.min(kpt[:,1]); bottom = np.max(kpt[:,1])

        h, w, _ = image.shape
        old_size = (right - left + bottom - top)/2
        center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 ])#+ old_size*0.1])
        trans_scale = (np.random.rand(2)*2 -1) * self.trans_scale
        center = center + trans_scale*old_size # 0.5
        
        scale = np.random.rand() * (self.scale[1] - self.scale[0]) + self.scale[0]

        size = int(old_size*scale)

        # crop image
        # src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
        src_pts = np.array([[0,0], [0,h - 1], [w - 1, 0]])
        DST_PTS = np.array([[0,0], [0,self.image_size - 1], [self.image_size - 1, 0]])
        tform = estimate_transform('similarity', src_pts, DST_PTS)
        
        # cropped_image = warp(image, tform.inverse, output_shape=(self.image_size, self.image_size))
        # # change kpt accordingly
        # cropped_kpt = np.dot(tform.params, np.hstack([kpt, np.ones([kpt.shape[0],1])]).T).T # np.linalg.inv(tform.params)
        return tform
    
    def load_mask(self, maskpath, h, w):
        # print(maskpath)
        if os.path.isfile(maskpath):
            vis_parsing_anno = np.load(maskpath)
            # atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
            #     'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
            mask = np.zeros_like(vis_parsing_anno)
            # for i in range(1, 16):
            mask[vis_parsing_anno>0.5] = 1.
        else:
            mask = np.ones((h, w))
        return mask

    def fan_landmarks(self, image):
        preds = self.fa.get_landmarks(image)
        return preds

# Dataset Split

In [None]:
dset_folder   = '/content/resized'
train_folder  = '/content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-train'
test_folder   = '/content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-test'

In [None]:
import random
import shutil
images = os.listdir(dset_folder)
print(len(images))
# NUM = 10000
NUM = len(images)
imgs_10k = random.sample(images, NUM)
for i in range(int(NUM*0.7)):
  src = dset_folder+'/'+imgs_10k[i]
  dst = train_folder+'/'+imgs_10k[i]
  shutil.copyfile(src, dst)

for i in range(int(NUM*0.7), NUM):
  src = dset_folder+'/'+imgs_10k[i]
  dst = test_folder+'/'+imgs_10k[i]
  shutil.copyfile(src, dst)

70000


In [None]:
!ls /content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-train

# Training

In [None]:
!git clone https://github.com/srinath2022/DECA_FFHQ.git

Cloning into 'DECA_FFHQ'...
remote: Enumerating objects: 499, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (74/74), done.[K
remote: Total 499 (delta 69), reused 69 (delta 37), pack-reused 385[K
Receiving objects: 100% (499/499), 23.01 MiB | 33.81 MiB/s, done.
Resolving deltas: 100% (222/222), done.


In [None]:
!ls

DECA_FFHQ  drive  FFHQ-10K.zip	FFHQ-test-10K  FFHQ-train-10K  sample_data


In [None]:
!cd DECA_FFHQ && git pull

remote: Enumerating objects: 5, done.[K
remote: Counting objects:  20% (1/5)[Kremote: Counting objects:  40% (2/5)[Kremote: Counting objects:  60% (3/5)[Kremote: Counting objects:  80% (4/5)[Kremote: Counting objects: 100% (5/5)[Kremote: Counting objects: 100% (5/5), done.[K
remote: Total 5 (delta 4), reused 5 (delta 4), pack-reused 0[K
Unpacking objects:  20% (1/5)   Unpacking objects:  40% (2/5)   Unpacking objects:  60% (3/5)   Unpacking objects:  80% (4/5)   Unpacking objects: 100% (5/5)   Unpacking objects: 100% (5/5), done.
From https://github.com/srinath2022/DECA_FFHQ
   7eff14d..e5642b7  master     -> origin/master
Updating 7eff14d..e5642b7
Fast-forward
 decalib/datasets/ffhq.py | 2 [32m+[m[31m-[m
 1 file changed, 1 insertion(+), 1 deletion(-)


In [None]:
!cp /content/drive/MyDrive/CS275_Graphics/requirements/generic_model.pkl /content/DECA_FFHQ/data/

In [None]:
!cp /content/drive/MyDrive/CS275_Graphics/requirements/FLAME_albedo_from_BFM.npz /content/DECA_FFHQ/data/

In [None]:
!cp /content/drive/MyDrive/CS275_Graphics/requirements/resnet50_ft_weight.pkl /content/DECA_FFHQ/data/

In [None]:
!cp -r /content/drive/MyDrive/CS275_Graphics/FFHQ-Datasets/FFHQ-10K-landmarks/ /content/FFHQ-10K-landmarks

In [None]:
!ls

DECA_FFHQ  FFHQ-10K-landmarks  FFHQ-test-10K   sample_data
drive	   FFHQ-10K.zip        FFHQ-train-10K


In [None]:
!pip install -r DECA_FFHQ/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting chumpy>=0.69
  Downloading chumpy-0.70.tar.gz (50 kB)
[K     |████████████████████████████████| 50 kB 3.2 MB/s 
Collecting PyYAML==5.1.1
  Downloading PyYAML-5.1.1.tar.gz (274 kB)
[K     |████████████████████████████████| 274 kB 9.8 MB/s 
[?25hCollecting torch==1.6.0
  Downloading torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (748.8 MB)
[K     |████████████████████████████████| 748.8 MB 17 kB/s 
[?25hCollecting torchvision==0.7.0
  Downloading torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9 MB)
[K     |████████████████████████████████| 5.9 MB 22.9 MB/s 
Collecting yacs==0.1.8
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting kornia==0.4.0
  Downloading kornia-0.4.0-py2.py3-none-any.whl (195 kB)
[K     |████████████████████████████████| 195 kB 73.6 MB/s 
[?25hCollecting ninja
  Downloading ninja-1.10.2.3-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_

In [None]:
!pip install loguru torchfile pytorch3d

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting loguru
  Downloading loguru-0.6.0-py3-none-any.whl (58 kB)
[?25l[K     |█████▋                          | 10 kB 27.7 MB/s eta 0:00:01[K     |███████████▎                    | 20 kB 19.0 MB/s eta 0:00:01[K     |████████████████▉               | 30 kB 10.5 MB/s eta 0:00:01[K     |██████████████████████▌         | 40 kB 8.6 MB/s eta 0:00:01[K     |████████████████████████████    | 51 kB 5.4 MB/s eta 0:00:01[K     |████████████████████████████████| 58 kB 3.5 MB/s 
[?25hCollecting torchfile
  Downloading torchfile-0.1.0.tar.gz (5.2 kB)
Collecting pytorch3d
  Downloading pytorch3d-0.3.0-cp37-cp37m-manylinux1_x86_64.whl (30.0 MB)
[K     |████████████████████████████████| 30.0 MB 95.5 MB/s 
Building wheels for collected packages: torchfile
  Building wheel for torchfile (setup.py) ... [?25l[?25hdone
  Created wheel for torchfile: filename=torchfile-0.1.0-py3-none-any.w

In [None]:
!python DECA_FFHQ/main_train.py --cfg DECA_FFHQ/configs/release_version/deca_pretrain.yml 

Namespace(cfg='DECA_FFHQ/configs/release_version/deca_pretrain.yml', mode='train')

creating the FLAME Decoder
tcmalloc: large alloc 1251999744 bytes == 0x6229e000 @  0x7f12ec23b1e7 0x7f12e9bcc0ce 0x7f12e9c22cf5 0x7f12e9bcf948 0x5947d6 0x548cc1 0x5127f1 0x549576 0x593fce 0x548ae9 0x5127f1 0x4bc98a 0x533274 0x4d3969 0x512147 0x549e0e 0x4bcb19 0x532b86 0x594a96 0x515600 0x549e0e 0x593fce 0x548ae9 0x51566f 0x549e0e 0x4bcb19 0x532b86 0x594a96 0x515600 0x593dd7 0x5118f8
please check model path: 
[32m2022-06-18 05:49:42.542[0m | [1mINFO    [0m | [36mdecalib.trainer[0m:[36mload_checkpoint[0m:[36m107[0m - [1mmodel path not found, start training from scratch[0m
Configuration  K: 4
batch_size: 2
eval_data: ['ffhq']
image_size: 224
isSingle: False
num_workers: 2
scale_max: 1.8
scale_min: 1.4
test_data: ['']
training_data: ['ffhq']
trans_scale: 0.0
[32m2022-06-18 05:49:46.414[0m | [1mINFO    [0m | [36mdecalib.trainer[0m:[36mprepare_data[0m:[36m365[0m - [1m---- training data 