In [None]:
! git clone "https://github.com/AliJavaheriYekta/MIPGAN1_Pytorch"

In [None]:
%cd MIPGAN1_Pytorch/weights/
! wget https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/karras2019stylegan-ffhq-1024x1024.for_g_all.pt
%cd ..

In [3]:
from __future__ import print_function 
from __future__ import division
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision
from torchvision import models, transforms
import math
import time
import os
import copy
from stylegan_layers import  G_mapping,G_synthesis
from collections import OrderedDict
!pip install piq
from piq import MultiScaleSSIMLoss
from perceptual_model import VGG16_for_Perceptual
from torchvision.utils import save_image

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
#torch.manual_seed(0)


class FullyConnectedSparseLayer(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self, size_in, size_out, coef):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        weights = torch.Tensor(size_out, size_in)
        self.weights = nn.Parameter(weights).to('cuda:0')  # nn.Parameter is a Tensor that's a module parameter.
        inp_section = int(size_in/coef)
        out_section = int(size_out/coef)
        weight_canceler = [[0 for i in range(size_in)] for j in range(size_out)]     
        count = 0
        for i in range(size_out):
            for j in range(count*inp_section,count*inp_section+inp_section):
              weight_canceler[i][j] = 1
            if (i+1)%out_section==0:
              count = count + 1


        self.weight_canceler = torch.Tensor(weight_canceler).to('cuda:0')
        bias = torch.Tensor(size_out)
        self.bias = nn.Parameter(bias).to('cuda:0')

        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init

    def forward(self, x):
        weights = self.weights * self.weight_canceler
        w_times_x= torch.mm(x, weights.t())
        return torch.add(w_times_x, self.bias)  # w times x + b


class SparseLayer(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self, size_in, size_out, steps):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        weights = torch.Tensor(size_out, size_in)
        self.weights = nn.Parameter(weights)  # nn.Parameter is a Tensor that's a module parameter.
        try:
          if int(steps):
            steps = [steps]
        except:
          pass
        weight_canceler = [[0 for i in range(size_in)] for j in range(size_out)]  
        count = 0
        
        if len(steps) ==  2:
          inp_el_counts = int(size_in/steps[0])
          out_el_counts = int(size_out/steps[1])
          steps_ratio = int(steps[0]/steps[1])
          for i in range(steps[1]):
              for j in range(i*out_el_counts,(i+1)*out_el_counts):
                  for k in range((i+1)*steps_ratio):
                      weight_canceler[j][k*inp_el_counts + count]=1
                  count = count + 1
                  if count >= inp_el_counts:
                      count = 0
        else:     
          inp_el_counts = int(size_in/steps[0])   
          for i in range(0,size_out):
              for j in range(steps[0]):
                  weight_canceler[i][j*inp_el_counts + count]=1
              count = count + 1
              if count >= inp_el_counts:
                  count = 0
        # groups_el_counts = int(size_in/step)
        
        # for i in range(0,size_out):
        #     for j in range(step):
        #         weight_canceler[i][j*groups_el_counts + count]=1
        #     count = count + 1
        #     if count >= groups_el_counts:
        #         count = 0   


        self.weight_canceler = torch.Tensor(weight_canceler).to('cuda:0')
        bias = torch.Tensor(size_out)
        self.bias = nn.Parameter(bias).to('cuda:0')

        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init

    def forward(self, x):
        weights = self.weights * self.weight_canceler
        w_times_x= torch.mm(x, weights.t())
        return torch.add(w_times_x, self.bias)  # w times x + b


class TreeConnect(nn.Module):
    def __init__(self, input_dim, hidden_layers_dim , output_dim, div_coefs):
        # percentage_masked, **kwargs
        super(TreeConnect, self).__init__()
        self.output_dim = output_dim
        self.div_coefs = div_coefs 
        self.hidden_layers_dim = hidden_layers_dim
        layers_parts = []
        self.first_Layer = FullyConnectedSparseLayer(input_dim, hidden_layers_dim[0],div_coefs[0]).to('cuda:0')
        #self.first_Layer = nn.Linear(input_dim, hidden_layers_dim[0])
        hidden_layers = []
        if len(hidden_layers_dim)>1:
            for i in range(1,len(hidden_layers_dim)):
                hidden_layers.append(SparseLayer(hidden_layers_dim[i-1], hidden_layers_dim[i], [div_coefs[i-1], div_coefs[i]]).to('cuda:0'))
        self.hidden_layers = hidden_layers   
        self.relu = nn.ReLU()
        self.last_layer = SparseLayer(hidden_layers_dim[-1], output_dim, div_coefs[-1]).to('cuda:0')
        
    def forward(self, x):
        x = self.first_Layer(x)
        x = self.relu(x)
        if len(self.hidden_layers)>0:
            for hidden_layer in self.hidden_layers:
                x = hidden_layer(x)
                x = self.relu(x)
        x = self.last_layer(x)

        return x
        


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB, num_ftrs):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        # Remove last linear layer
        self.modelA.fc = TreeConnect(input_dim=num_ftrs, hidden_layers_dim=[1024,1024],  output_dim=512, div_coefs=[64,32]).to('cuda:0')
        self.modelB.fc = TreeConnect(input_dim=num_ftrs, hidden_layers_dim=[1024,1024],  output_dim=512, div_coefs=[64,32]).to('cuda:0')
        #self.modelA.fc = nn.Sequential(
        #    nn.Linear(in_features=num_ftrs, out_features=512),
        #    nn.ReLU(),
        #    # nn.Linear(in_features=1024, out_features=512),
        #    # nn.ReLU(),
        #    nn.Linear(512,512)
        #)
        #self.modelB.fc = nn.Sequential(
        #    nn.Linear(in_features=num_ftrs, out_features=512),
        #    nn.ReLU(),
        #    # nn.Linear(in_features=1024, out_features=512),
        #    # nn.ReLU(),
        #    nn.Linear(512,512)
        #) 

      
    def forward(self, im1,im2):
        x1 = self.modelA(im1.clone())
        m1 = x1.detach()  # clone to make sure x is not changed by inplace methods
        x2 = self.modelB(im2.clone())
        m2 = x2.detach()
        x = (x1 + x2) / 2.0
        return x, m1, m2


def image_preprocess(img_source):
    img = Image.open(img_source).convert("RGB")
    img = transforms.ToTensor()(img).unsqueeze_(0)
    #upsample2ds = torch.nn.Upsample(scale_factor=2, mode='bilinear')
    #img = upsample2ds(img)
    return img

def adjust_lr(optimizer, lr):
    for param in optimizer.param_groups:
        param['lr'] = lr
    return optimizer

def caluclate_loss(synth_img, images, perceptual_net, img_p, upsample2d):
    
    synth_img_t = (synth_img - torch.min(synth_img))/(torch.max(synth_img)-torch.min(synth_img)).detach()
    #synth_img_t = synth_img
    ms_ssim_loss = MultiScaleSSIMLoss(data_range=1., reduction='none')(images[0], synth_img_t)
    # ms_ssim_loss2 = MultiScaleSSIMLoss(data_range=1., reduction='none')(images[1], synth_img_t)
    ms_ssim_loss = (ms_ssim_loss + MultiScaleSSIMLoss(data_range=1., reduction='none')(images[1], synth_img_t))/2.
    MSE_Loss = nn.MSELoss(reduction="mean")
    #calculate Perceptual Loss
    real1_0,real1_1,real1_2,real1_3=perceptual_net(img_p[0])
    real2_0,real2_1,real2_2,real2_3=perceptual_net(img_p[1])
    synth_p=upsample2d(synth_img) #(1,3,256,256)
    synth_0,synth_1,synth_2,synth_3=perceptual_net(synth_p)

    perceptual_loss=0
    perceptual_loss = perceptual_loss + MSE_Loss(synth_0,real1_0) + MSE_Loss(synth_0,real2_0) 
    perceptual_loss = perceptual_loss + MSE_Loss(synth_1,real1_1) + MSE_Loss(synth_1,real2_1)
    perceptual_loss = perceptual_loss + MSE_Loss(synth_2,real1_2) + MSE_Loss(synth_2,real2_2)
    perceptual_loss = perceptual_loss + MSE_Loss(synth_3,real1_3) + MSE_Loss(synth_3,real2_3)

    return ms_ssim_loss, perceptual_loss

def identity_loss_calc(embedding1, embedding2):
    morph = (embedding1 + embedding2)/2.0
    identity_term1n = torch.mm(embedding1, torch.transpose(morph,0,1))
    identity_term1d = torch.norm(embedding1) * torch.norm(morph)
    identity_term2n = torch.mm(embedding2, torch.transpose(morph,0,1))
    identity_term2d = torch.norm(embedding2) * torch.norm(morph)
    identity_loss = ((1 - identity_term1n/identity_term1d) + (1 - identity_term2n/identity_term2d))/2.0
    identity_diff = torch.abs(((1 - identity_term1n/identity_term1d) + (1 - identity_term2n/identity_term2d)))
    return identity_loss, identity_diff

def train_model(model, perceptual_net, g_synthesis, inputs, optimizer, lr, num_epochs, device):
    global identity_loss
    since = time.time()

    #best_model1_wts = copy.deepcopy(model1.state_dict())
    #best_model2_wts = copy.deepcopy(model2.state_dict())
    image1 = inputs[0]
    image2 = inputs[1]
    loss_list = []

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        model.train()  # Set model to training mode
        

        # zero the parameter gradients
        optimizer.zero_grad()
        
        morph, out1, out2 = model(image1, image2)
        identity_loss, identity_diff = identity_loss_calc(out1, out2)

        morph = morph.unsqueeze(1).repeat(1, 18, 1)
        synth = g_synthesis(morph)
        

        img_p1=image1.clone() #Perceptual loss
        img_p2=image2.clone()
        upsample2d=torch.nn.Upsample(scale_factor=256/1024, mode='bilinear')
        img_p1=upsample2d(img_p1)
        img_p2=upsample2d(img_p2)
        img_p = [img_p1, img_p2]
        
        ms_ssim_loss, perceptual_loss = caluclate_loss(synth, inputs, perceptual_net, img_p, upsample2d)

        loss =  0.0002*perceptual_loss + ms_ssim_loss + 10*identity_loss.to(device) + identity_diff.to(device)

        # backward + optimize only if in training phase
        loss.backward()
        optimizer.step()

        loss_np = loss.detach().cpu().numpy()
        loss_p = perceptual_loss.detach().cpu().numpy()
        loss_m = ms_ssim_loss.detach().cpu().numpy()
        loss_idl = identity_loss.detach().cpu().numpy()
        loss_idd = identity_diff.detach().cpu().numpy()
        loss_list.append(loss_np)
        if epoch%6==0:
             lr = lr*0.95
             optimizer = adjust_lr(optimizer, lr)

        if epoch%10==0 or epoch==num_epochs-1:
             print("iter{}: loss -- {},  ms_ssim --{},  percep_loss --{}, identity_loss --{}, identity_diff --{}".format(epoch,loss_np,loss_m,loss_p,loss_idl,loss_idd))
             synth = (synth - torch.min(synth))/(torch.max(synth)-torch.min(synth))
             save_image(synth.clamp(0,1),"save_result/{}.png".format(epoch))
             #np.save("loss_list.npy",loss_list)
             

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    # model.load_state_dict(best_model_wts)
 
# We use pretrained torchvision models here
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'   
modelA = models.resnet50(pretrained=True).to(device)
modelB = models.resnet50(pretrained=True).to(device)

num_ftrs = modelA.fc.in_features

# Freeze these models
for param in modelA.parameters():
    param.requires_grad_(False)

for param in modelB.parameters():
    param.requires_grad_(False)

# Create ensemble model
model = MyEnsemble(modelA, modelB,num_ftrs).to(device)

img1 = image_preprocess("source_image/3.png").to(device)
img2 = image_preprocess("source_image/4.png").to(device)
inputs = [img1, img2]

prms_to_update = []
for name, param in model.named_parameters():
    if param.requires_grad == True:
        prms_to_update.append(param)

g_synthesis = G_synthesis(resolution=1024)
g_all = nn.Sequential(OrderedDict([
        ('g_mapping', G_mapping()),
        #('truncation', Truncation(avg_latent)),
        ('g_synthesis', G_synthesis(resolution=1024))    
        ]))

g_all.load_state_dict(torch.load("weights/karras2019stylegan-ffhq-1024x1024.for_g_all.pt", map_location=device))
g_all.eval()
g_all.to(device)

perceptual_net = VGG16_for_Perceptual(n_layers=[2,4,9,16]).to(device)

g_synthesis = g_all[1]
g_synthesis.eval()
g_synthesis.to(device)
del g_all
torch.cuda.empty_cache()
# Number of epochs to train for 
num_epochs = 150


# Observe that all parameters are being optimized
lr = 0.03
optimizer = optim.Adam(prms_to_update, lr=lr, betas=(0.9,0.999))

# Train and evaluate
train_model(model, perceptual_net, g_synthesis, inputs, optimizer, lr, num_epochs=num_epochs, device=device)


Collecting piq
[?25l  Downloading https://files.pythonhosted.org/packages/64/a2/c4ef48a8ed230ad1185de933ce466fb1204cba8cda6896a5c304a5e2db84/piq-0.5.4-py3-none-any.whl (102kB)
[K     |████████████████████████████████| 112kB 10.9MB/s 
[?25hCollecting gudhi>=3.2
[?25l  Downloading https://files.pythonhosted.org/packages/cb/71/e70015e0f547debe64901775202aa0e53231907a60850bb8d86ce8a31453/gudhi-3.4.1-cp37-cp37m-manylinux2014_x86_64.whl (28.1MB)
[K     |████████████████████████████████| 28.1MB 101kB/s 
Installing collected packages: gudhi, piq
Successfully installed gudhi-3.4.1 piq-0.5.4
PyTorch Version:  1.8.0+cu101
Torchvision Version:  0.9.0+cu101


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))


Epoch 0/149
----------


  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


iter0: loss -- [[4.471236]],  ms_ssim --[0.5929652],  percep_loss --16.61325454711914, identity_loss --[[0.32291234]], identity_diff --[[0.6458247]]
Epoch 1/149
----------
Epoch 2/149
----------
Epoch 3/149
----------
Epoch 4/149
----------
Epoch 5/149
----------
Epoch 6/149
----------
Epoch 7/149
----------
Epoch 8/149
----------
Epoch 9/149
----------
Epoch 10/149
----------
iter10: loss -- [[1.4309639]],  ms_ssim --[0.45728964],  percep_loss --11.430696487426758, identity_loss --[[0.08094901]], identity_diff --[[0.16189802]]
Epoch 11/149
----------
Epoch 12/149
----------
Epoch 13/149
----------
Epoch 14/149
----------
Epoch 15/149
----------
Epoch 16/149
----------
Epoch 17/149
----------
Epoch 18/149
----------
Epoch 19/149
----------
Epoch 20/149
----------
iter20: loss -- [[0.9556558]],  ms_ssim --[0.40384266],  percep_loss --8.478487968444824, identity_loss --[[0.04584312]], identity_diff --[[0.09168625]]
Epoch 21/149
----------
Epoch 22/149
----------
Epoch 23/149
----------
E

In [None]:
! pip install deepface
from deepface import DeepFace
 
#face verification
obj = DeepFace.verify("save_result/149.png", "source_image/3.png", model_name = 'ArcFace')
print(obj)
#face verification
obj = DeepFace.verify("save_result/149.png", "source_image/4.png", model_name = 'ArcFace')
print(obj)