In [74]:
##### import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from stylegan import get_style_gan
from torchvision import transforms
from torch.nn.functional import interpolate

import facenet_pytorch as fp
import cv2, dlib
import math

from torchvision import models
from torchvision import transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

def build_resnet_model(latent_space=512):
    resnet = models.resnet18(pretrained=True)
    resnet.fc = nn.Sequential(*[
        nn.Linear(512, 1024, bias=True),
        nn.ReLU(),
        nn.Linear(1024, latent_space, bias=True),
#         nn.ReLU(),
#         nn.Linear(1024, latent_space, bias=True)
    ])
    for param in resnet.fc.parameters():
        param.requires_grad = True
    resnet = resnet.cuda()

    return resnet

def crop_image(image, det):
    left, top, right, bottom = rect_to_tuple(det)
    return image[top:bottom, left:right]

def rect_to_tuple(rect):
    left = rect.left()
    right = rect.right()
    top = rect.top()
    bottom = rect.bottom()
    return left, top, right, bottom

In [None]:
# Build face swap model
import face_swap_py as fspy
import numpy as np
import cv2

# Initialize face swap
landmarks_path = 'data/shape_predictor_68_face_landmarks.dat'
model_3dmm_h5_path = 'data/BaselFaceModel_mod_wForehead_noEars.h5'
model_3dmm_dat_path = 'data/BaselFace.dat'
reg_model_path = 'data/3dmm_cnn_resnet_101.caffemodel'
reg_deploy_path = 'data/3dmm_cnn_resnet_101_deploy.prototxt'
reg_mean_path = 'data/3dmm_cnn_resnet_101_mean.binaryproto'
seg_model_path = 'data/face_seg_fcn8s.caffemodel'          # or 'data/face_seg_fcn8s_300.caffemodel' for lower resolution
seg_deploy_path = 'data/face_seg_fcn8s_deploy.prototxt'    # or 'data/face_seg_fcn8s_300_deploy.prototxt' for lower resolution
generic = False
with_expr = True
with_gpu = False
gpu_device_id = 0
fs = fspy.FaceSwap(landmarks_path, model_3dmm_h5_path,
            model_3dmm_dat_path, reg_model_path,
            reg_deploy_path, reg_mean_path,
            seg_model_path, seg_deploy_path,
            generic, with_expr, with_gpu, gpu_device_id)

In [75]:
# scale = 4
# detector = dlib.get_frontal_face_detector()

num_eval = 5
batch_size = 32
fc_lr = 0.00005

resnet = build_resnet_model()
anonymizer = get_style_gan()
loss_fn = torch.nn.MSELoss()

# Now freeze the full model and then train only the fc layer
for param in resnet.parameters(): # Unfreeze the full model
    param.requires_grad = False

for param in resnet.fc.parameters():
    param.requires_grad = True
    
resnet.eval()
resnet.fc.train()
anonymizer.eval()

optimizer = optim.Adam(list(filter(lambda x: x.requires_grad, resnet.parameters())) , lr=fc_lr)

In [78]:
for j in range(0, 100000, 1):
    resnet.train()
    latents = torch.randn(batch_size, 512).cuda()
    generated_image = anonymizer(latents)
    generated_image = (generated_image.clamp(-1, 1) + 1) / 2.0
    
    generated_image = interpolate(generated_image, size=(224, 224)).cpu()
    generated_image = torch.stack([normalize(x).cpu() for x in generated_image]).detach().cuda()
    
    predicted_features = resnet(generated_image)
    
    loss = loss_fn(predicted_features, latents) # we wanna make the latent features representative
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if j % 20 == 0:
        print(f"Iteration: {j} \t\t Loss {loss.item()}")
        facenet.eval()
        fc.eval()
        with torch.no_grad():
            latents = torch.randn(5, 512).cuda()
            generated_image = anonymizer(latents)
            generated_image = (generated_image.clamp(-1, 1) + 1) / 2.0
            generated_image = interpolate(generated_image, size=(224, 224)).cpu()
            
            generated_image_normalized = torch.stack([normalize(x).cpu() for x in generated_image]).detach().cuda()
            predicted_features = resnet(generated_image_normalized)

            resnet_based_images = anonymizer(predicted_features)
            resnet_based_images = (resnet_based_images.clamp(-1, 1) + 1) / 2.0
            resnet_based_images = interpolate(resnet_based_images, size=(224, 224)).cpu()
            
            uniform_rand = anonymizer(predicted_features + (torch.rand(5,512)/10.0).cuda())
            uniform_rand = (uniform_rand.clamp(-1, 1) + 1) / 2.0
            uniform_rand = interpolate(uniform_rand, size=(224, 224)).cpu()
            
            noise = torch.randn(5,512).cuda()
            
            lightly_p = anonymizer(predicted_features + (noise / 100.0))
            lightly_p = (lightly_p.clamp(-1, 1) + 1) / 2.0
            lightly_p = interpolate(lightly_p, size=(224, 224)).cpu()
                
            fd1 = fspy.FaceData(img1)
            fd2 = fspy.FaceData(img2)
            result_img = fs.swap(fd1, fd2)
            
            highly_p = anonymizer(predicted_features + (noise / 15.0))
            highly_p = (highly_p.clamp(-1, 1) + 1) / 2.0
            highly_p = interpolate(highly_p, size=(224, 224)).cpu()
            
            #### Swap faces in image ######
            swapped_img = []
            for source_img, anon_img in zip(generated_image, lightly_p):
                anon_img = cv2.cvtColor(anon_img.numpy(), cv2.COLOR_RGB2BGR)
                source_img = cv2.cvtColor(source_img.numpy(), cv2.COLOR_RGB2BGR)
                fd1 = fspy.FaceData(img1)
                fd2 = fspy.FaceData(img2)
                result_img = fs.swap(fd1, fd2)
                swapped_img.append(result_img)
            swapped_imgs = torch.stack([torch.tensor(x) for x in swapped_img])
                

            images = torchvision.utils.make_grid(torch.cat([generated_image, uniform_rand, resnet_based_images, lightly_p, highly_p, swapped_imgs]), nrow=5)
            torchvision.utils.save_image(images, "input_output/" + str(
                j)  + ".png", nrow=5, range=(-1, 1))

Iteration: 0 		 Loss 1.0115655660629272
Iteration: 20 		 Loss 0.9995635747909546
Iteration: 40 		 Loss 1.0077414512634277
Iteration: 60 		 Loss 1.0188370943069458
Iteration: 80 		 Loss 1.0001838207244873
Iteration: 100 		 Loss 0.9952874779701233
Iteration: 120 		 Loss 1.0121955871582031
Iteration: 140 		 Loss 1.0079536437988281
Iteration: 160 		 Loss 0.9994035363197327
Iteration: 180 		 Loss 1.0184910297393799
Iteration: 200 		 Loss 0.9860864877700806
Iteration: 220 		 Loss 1.01272714138031
Iteration: 240 		 Loss 0.9952960014343262
Iteration: 260 		 Loss 1.0131299495697021
Iteration: 280 		 Loss 1.0127089023590088
Iteration: 300 		 Loss 0.9963364601135254
Iteration: 320 		 Loss 0.9804158210754395
Iteration: 340 		 Loss 0.9991058111190796
Iteration: 360 		 Loss 1.003040075302124
Iteration: 380 		 Loss 0.9932920336723328
Iteration: 400 		 Loss 1.0018343925476074
Iteration: 420 		 Loss 1.0003981590270996
Iteration: 440 		 Loss 1.0013539791107178
Iteration: 460 		 Loss 0.9833095073699951
I

KeyboardInterrupt: 

In [None]:
# def build_facenet_model(latent_space=512):
#     mtcnn = fp.MTCNN(device=torch.device("cuda"))
#     facenet = fp.InceptionResnetV1(pretrained='vggface2').eval().cuda()
#     for p in facenet.parameters():
#         p.requires_grad = False


#     fc = nn.Sequential(*[
#        nn.Linear(512, 512, bias=True),
#        nn.ReLU(),
#     ])

#     fc.train()
#     fc = fc.cuda()

#     return mtcnn, facenet, fc