#Training Demo
This is a simple example for training the SimSwap 224*224 with VGGFace2-224.

Code path: https://github.com/neuralchen/SimSwap
If you like the SimSwap project, please star it!
Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630

Installation
All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want.

#Get Scripts

# Install Blocks

#Download the Training Dataset
We employ the cropped VGGFace2-224 dataset for this toy training demo.

You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing

***Please check the dataset in dir /content/TrainingData***

***If dataset already exists in /content/TrainingData, please do not run blow scripts!***


#Trainig
Batch size must larger than 1!

In [1]:
import numpy as np
import torch
from torch import nn
from models.custom_network import DeformConvGenerator
# from models.fs_networks_fix import ApplyStyle as AdaIn
batch_size = 4

model = DeformConvGenerator(4,4,128,3)
t = torch.randn(batch_size,3,224,224)
model(t,torch.randn(batch_size,128)).shape

torch.Size([4, 3, 224, 224])

In [33]:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: fs_model_fix_idnorm_donggp_saveoptim copy.py
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified:  Saturday, 13th May 2023 9:56:35 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################


import torch
import torch.nn as nn
from torch.autograd import Variable
from models.base_model import BaseModel
from models.fs_networks_fix import Generator_Adain_Upsample
from models.custom_network import DeformConvGenerator,DancerGenerator
from pg_modules.projected_discriminator import ProjectedDiscriminator

def compute_grad2(d_out, x_in):
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg

class fsModel(BaseModel):
    def name(self):
        return 'fsModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # if opt.resize_or_crop != 'none' or not opt.isTrain:  # when training at full res this causes OOM
        self.isTrain = opt.isTrain

        if opt.model_name=="simswap":
            model_k = Generator_Adain_Upsample
        elif opt.model_name=="simswap+=+":
            model_k = DeformConvGenerator
        elif opt.model_name=="dancer":
            model_k = DancerGenerator
        else:
            model_k = None
        # Generator network
        if opt.model_name!="dancer":
            self.netG = model_k(input_nc=3, output_nc=3, latent_size=512, n_blocks=opt.n_blocks, deep=opt.Gdeep)
        else:
            self.netG = model_k(input_nc=3, output_nc=3, latent_size=512, n_blocks=opt.n_blocks, n_layers=opt.n_layers, deep=opt.Gdeep,upsample_method = opt.upsample_method,kernel_type=opt.kernel_type)
        self.netG.cuda()

        # Id network
        netArc_checkpoint = opt.Arc_path
        netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
        self.netArc = netArc_checkpoint
        self.netArc = self.netArc.cuda()
        self.netArc.eval()
        self.netArc.requires_grad_(False)
        if not self.isTrain:
            pretrained_path =  opt.checkpoints_dir
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            return
        self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
        # self.netD.feature_network.requires_grad_(False)
        self.netD.cuda()


        if self.isTrain:
            # define loss functions
            self.criterionFeat  = nn.L1Loss()
            self.criterionRec   = nn.L1Loss()


           # initialize optimizers

            # optimizer G
            params = list(self.netG.parameters())
            self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)

            # optimizer D
            params = list(self.netD.parameters())
            self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)

        # load networks
        if opt.continue_train:
            pretrained_path = '' if not self.isTrain else opt.load_pretrain
            # print (pretrained_path)
            self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
            self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
            self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
            self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
        torch.cuda.empty_cache()

    def cosin_metric(self, x1, x2):
        #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
        return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))



    def save(self, which_epoch, overlap = False):
        self.save_network(self.netG, 'G', which_epoch if not overlap else 0)
        self.save_network(self.netD, 'D', which_epoch if not overlap else 0)
        self.save_optim(self.optimizer_G, 'G', which_epoch if not overlap else 0)
        self.save_optim(self.optimizer_D, 'D', which_epoch if not overlap else 0)
        '''if self.gen_features:
            self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''

    def update_fixed_params(self):
        # after fixing the global generator for a number of iterations, also start finetuning it
        params = list(self.netG.parameters())
        if self.gen_features:
            params += list(self.netE.parameters())
        self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        if self.opt.verbose:
            print('------------ Now also finetuning global generator -----------')

    def update_learning_rate(self):
        lrd = self.opt.lr / self.opt.niter_decay
        lr = self.old_lr - lrd
        for param_group in self.optimizer_D.param_groups:
            param_group['lr'] = lr
        for param_group in self.optimizer_G.param_groups:
            param_group['lr'] = lr
        if self.opt.verbose:
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
        self.old_lr = lr
    
    def gradinet_penalty_D(self, netD, img_att, img_fake):
        # interpolate sample
        bs = img_fake.shape[0]
        alpha = torch.rand(bs, 1, 1, 1).expand_as(img_fake)
        interpolated = Variable(alpha * img_att + (1 - alpha) * img_fake, requires_grad=True)
        pred_interpolated = netD.forward(interpolated,None)
        # print(pred_interpolated[1])
        pred_interpolated = pred_interpolated[0]

        # compute gradients
        grad = torch.autograd.grad(outputs=pred_interpolated,
                                   inputs=interpolated,
                                   grad_outputs=torch.ones(pred_interpolated.size()),
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        # penalize gradients
        grad = grad.view(grad.size(0), -1)
        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
        loss_d_gp = torch.mean((grad_l2norm - 1) ** 2)

        return loss_d_gp

m = fsModel()

In [34]:
t = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
# t.discriminator(torch.rand(4,3,224,224),).shape
m.gradinet_penalty_D(t, torch.rand(4,3,224,224), torch.rand(4,3,224,224))
# t.feature_network(torch.rand(4,3,224,224))

tensor(1035895.5000, grad_fn=<MeanBackward0>)

In [2]:
class DancerGenerator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, latent_size=512, n_blocks=6, deep=False,norm_layer=nn.BatchNorm2d,padding_type='reflect') -> None:
        assert (n_blocks >= 0)
        super(DancerGenerator, self).__init__()
        self.deep = deep
        
        self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(input_nc, 64, kernel_size=7, padding=0),
                                         norm_layer(64), nn.ReLU(True))
    
        self.down1 = DeformConvDownSample(latent_size, 64, 128, kernel_size=3, stride=2, padding=1)
        self.down2 = DeformConvDownSample(latent_size, 128, 256, kernel_size=3, stride=2, padding=1)
        self.down3 = DeformConvDownSample(latent_size, 256, 512, kernel_size=3, stride=2, padding=1)

        if self.deep:
            self.down4 = DeformConvDownSample(latent_size, 512, 512, kernel_size=3, stride=2, padding=1)
        BN = []
        activation = nn.LeakyReLU(0.2,True)
        for i in range(n_blocks):
            BN += [
                ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
        self.BottleNeck = nn.Sequential(*BN)
        
        self.transition = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        
        if self.deep:
            self.affa4 = AFFA_RB(512, 512, 512, sample_method="up")
        
        self.affa3 = AFFA_RB(latent_size = 512,in_channels = 512,out_channels= 256, sample_method="up")
        self.affa2 = AFFA_RB(latent_size= 512,in_channels= 256,out_channels= 128, sample_method="up")
        self.affa1 = AFFA_RB(latent_size = 512,in_channels= 128,out_channels= 64, sample_method="up")
        
        self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(64, output_nc, kernel_size=7, padding=0))
    def forward(self, x, latent):
        # x: (batch_size, 3, 224, 224)
        skip = self.first_layer(x)# (batch_size, 64, 224, 224)
        
        skip1 = self.down1(skip,latent) # (batch_size, 128, 112, 112)
        skip2 = self.down2(skip1,latent) # (batch_size, 256, 56, 56)
        skip3 = self.down3(skip2,latent) # (batch_size, 512, 28, 28)
        skip4 = None
        if self.deep:
            skip4 = self.down4(skip3,latent) # (batch_size, 512, 14, 14)
            x = skip4
        else:
            x = skip3

        for i in range(len(self.BottleNeck)):
            x = self.BottleNeck[i](x, latent)
        
        trans = self.transition(x) # (batch_size, 512, 14, 14) if self.deep else (batch_size, 512, 28, 28)
        
        if self.deep:
            x = self.affa4(trans,skip4,latent) # (batch_size, 512, 28, 28)
        x = self.affa3(x,skip3,latent)# (batch_size, 256, 56, 56)
        x = self.affa2(x,skip2,latent) # (batch_size, 128, 112, 112)
        x = self.affa1(x,skip1,latent) # (batch_size, 64, 224, 224)
        x = self.last_layer(x) # (batch_size, 3, 224, 224)
        return x


In [3]:

k = AFFA_RB(latent_size = 512,in_channels = 512,out_channels= 256, sample_method="up")
k.forward(torch.randn(4,512,28,28),torch.randn(4,512,28,28),torch.randn(4,512)).shape

torch.Size([4, 256, 56, 56])

In [5]:
model = DancerGenerator(input_nc=3,output_nc=3,latent_size=512,n_blocks=6,deep=True)
x = torch.randn(batch_size,3,224,224)
latent = torch.randn(batch_size,512)
model.forward(x,latent).shape


torch.Size([4, 256, 56, 56]) torch.Size([4, 256, 56, 56]) torch.Size([4, 512])


KeyboardInterrupt: 

In [None]:
class IdDeformConv(nn.Module):
    def __init__(self, latent_size, input_channels, output_channels, kernel_size, stride=1, padding=0, bias=False) -> None:
        super(IdDeformConv,self).__init__()
        # self.latent_size = latent_size
        self.dconv = DeformConv(input_channels, output_channels, kernel_size, stride, padding, bias)
        self.latent_injection = nn.Linear(latent_size, output_channels)
        # self.res = 
    def forward(self, input, latent):
        latent = self.latent_injection(latent)
        return self.dconv(input) + latent.view(latent.size(0), latent.size(1), 1, 1)
    

class DeformConvDownSample(nn.Module):
    def __init__(self, latent_size, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(DeformConvDownSample, self).__init__()
        self.dconv = IdDeformConv(latent_size,in_channels, out_channels, kernel_size, stride, padding, bias)
        self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2,inplace=False)

    def forward(self, x, latent_id):
        x = self.dconv(x, latent_id)
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x
class DeformConvUpSample(nn.Module):
    def __init__(self, scaleFactor,latent_size, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.upsample = nn.Upsample(scale_factor=scaleFactor, mode='bilinear',align_corners=False)
        self.IdDeformConv = IdDeformConv(latent_size,in_channels, out_channels, kernel_size, stride, padding, bias)
        self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.norm = nn.BatchNorm2d(out_channels)
        self.rl = nn.LeakyReLU(0.2,inplace=False)
    def forward(self, x, latent_id):
        x = self.upsample(x)
        x = self.IdDeformConv(x, latent_id)
        x = self.conv(x)
        x = self.norm(x)
        x = self.rl(x)
        return x
class IDBlocks(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

In [None]:

class Generator_Adain_Upsample(nn.Module):
    def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
                 norm_layer=nn.BatchNorm2d,
                 padding_type='reflect'):
        assert (n_blocks >= 0)
        super(Generator_Adain_Upsample, self).__init__()

        activation = nn.ReLU(True)
        
        self.deep = deep
        
        self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(input_nc, 64, kernel_size=7, padding=0),
                                         norm_layer(64), activation)
        ### downsample
        self.down1 = nn.Sequential(DeformConv(64, 128, kernel_size=3, stride=2, padding=1),
                                   norm_layer(128), activation)
        self.down2 = nn.Sequential(DeformConv(128, 256, kernel_size=3, stride=2, padding=1),
                                   norm_layer(256), activation)
        self.down3 = nn.Sequential(DeformConv(256, 512, kernel_size=3, stride=2, padding=1),
                                   norm_layer(512), activation)

        if self.deep:
            self.down4 = nn.Sequential(DeformConv(512, 512, kernel_size=3, stride=2, padding=1),
                                       norm_layer(512), activation)

        ### resnet blocks
        BN = []
        for i in range(n_blocks):
            BN += [
                ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
        self.BottleNeck = nn.Sequential(*BN)

        if self.deep:
            self.up4 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
                DeformConv(512, 512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512), activation
            )
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256), activation
        )
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128), activation
        )
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64), activation
        )
        self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(64, output_nc, kernel_size=7, padding=0))

    def forward(self, input, dlatents):
        x = input  # 3*224*224
        print(x.shape)
        skip1 = self.first_layer(x)
        print("x after first layer",skip1.shape)
        skip2 = self.down1(skip1)
        print("x after down1",skip2.shape)
        skip3 = self.down2(skip2)
        print("x after down2",skip3.shape)
        if self.deep:
            skip4 = self.down3(skip3)
            x = self.down4(skip4)
        else:
            x = self.down3(skip3)
            print("x after down3",x.shape)
        bot = []
        bot.append(x)
        features = []
        for i in range(len(self.BottleNeck)):
            x = self.BottleNeck[i](x, dlatents)
            bot.append(x)
        print("x after bottleneck",x.shape)
        if self.deep:
            x = self.up4(x)
            features.append(x)
        
        x = self.up3(x)
        print("x after up3",x.shape)
        features.append(x)
        x = self.up2(x)
        print("x after up2",x.shape)
        features.append(x)
        x = self.up1(x)
        print("x after up1",x.shape)
        features.append(x)
        x = self.last_layer(x)
        print("x after last layer",x.shape)
        # x = (x + 1) / 2

        # return x, bot, features, dlatents
        return x


In [None]:
# G = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=True)

latent_id = torch.randn(batch_size, 512)
src_image = torch.randn(batch_size, 256, 56, 56)
# down = DeformConvDownSample(512, 3, 64, kernel_size=3, stride=2, padding=1)
# down(src_image,latent_id).shape
up = DeformConvUpSample(2,512, 256, 128, kernel_size=3, stride=1, padding=1)
up(src_image,latent_id).shape
# G(src_image, latent_id)
# D(src_image,None)[0].shape