#Training Demo
This is a simple example for training the SimSwap 224*224 with VGGFace2-224.

Code path: https://github.com/neuralchen/SimSwap
If you like the SimSwap project, please star it!
Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630

Installation
All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want.

#Get Scripts

# Install Blocks

#Download the Training Dataset
We employ the cropped VGGFace2-224 dataset for this toy training demo.

You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing

***Please check the dataset in dir /content/TrainingData***

***If dataset already exists in /content/TrainingData, please do not run blow scripts!***


#Trainig
Batch size must larger than 1!

In [1]:
import numpy as np
import torch
from torch import nn
from models.custom_network import DeformConv
from models.fs_networks_fix import ResnetBlock_Adain
# from models.fs_networks_fix import Generator_Adain_Upsample
batch_size = 4
from pg_modules.projected_discriminator import ProjectedDiscriminator
D = ProjectedDiscriminator(diffaug=False,interp224=False)

In [8]:
class IdDeformConv(nn.Module):
    def __init__(self, latent_size, input_channels, output_channels, kernel_size, stride=1, padding=0, bias=False) -> None:
        super(IdDeformConv,self).__init__()
        # self.latent_size = latent_size
        self.dconv = DeformConv(input_channels, output_channels, kernel_size, stride, padding, bias)
        self.latent_injection = nn.Linear(latent_size, output_channels)
        # self.res = 
    def forward(self, input, latent):
        latent = self.latent_injection(latent)
        return self.dconv(input) + latent.view(latent.size(0), latent.size(1), 1, 1)
    

class DeformConvDownSample(nn.Module):
    def __init__(self, latent_size, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(DeformConvDownSample, self).__init__()
        self.dconv = IdDeformConv(latent_size,in_channels, out_channels, kernel_size, stride, padding, bias)
        self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU(0.2,inplace=False)

    def forward(self, x, latent_id):
        x = self.dconv(x, latent_id)
        x = self.conv(x)
        x = self.norm(x)
        x = self.relu(x)
        return x
class DeformConvUpSample(nn.Module):
    def __init__(self, scaleFactor,latent_size, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.upsample = nn.Upsample(scale_factor=scaleFactor, mode='bilinear',align_corners=False)
        self.IdDeformConv = IdDeformConv(latent_size,in_channels, out_channels, kernel_size, stride, padding, bias)
        self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.norm = nn.BatchNorm2d(out_channels)
        self.rl = nn.LeakyReLU(0.2,inplace=False)
    def forward(self, x, latent_id):
        x = self.upsample(x)
        x = self.IdDeformConv(x, latent_id)
        x = self.conv(x)
        x = self.norm(x)
        x = self.rl(x)
        return x
class IDBlocks(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

In [7]:

class Generator_Adain_Upsample(nn.Module):
    def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
                 norm_layer=nn.BatchNorm2d,
                 padding_type='reflect'):
        assert (n_blocks >= 0)
        super(Generator_Adain_Upsample, self).__init__()

        activation = nn.ReLU(True)
        
        self.deep = deep
        
        self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(input_nc, 64, kernel_size=7, padding=0),
                                         norm_layer(64), activation)
        ### downsample
        self.down1 = nn.Sequential(DeformConv(64, 128, kernel_size=3, stride=2, padding=1),
                                   norm_layer(128), activation)
        self.down2 = nn.Sequential(DeformConv(128, 256, kernel_size=3, stride=2, padding=1),
                                   norm_layer(256), activation)
        self.down3 = nn.Sequential(DeformConv(256, 512, kernel_size=3, stride=2, padding=1),
                                   norm_layer(512), activation)
                                   
                                   
#####################################################################
# for 512
        # self.down = nn.Sequential(DeformConv(512, 1024, kernel_size=3, stride=2, padding=1),
        #                                 norm_layer(1024), activation,
        # DeformConv(1024,2048, kernel_size=3, stride=2, padding=1),
        #                                 norm_layer(2048), activation)


#####################################################################
        if self.deep:
            self.down4 = nn.Sequential(DeformConv(512, 512, kernel_size=3, stride=2, padding=1),
                                       norm_layer(512), activation)

        ### resnet blocks
        BN = []
        for i in range(n_blocks):
            BN += [
                ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
        self.BottleNeck = nn.Sequential(*BN)

        if self.deep:
            self.up4 = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
                DeformConv(512, 512, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(512), activation
            )
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256), activation
        )
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128), activation
        )
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
            DeformConv(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64), activation
        )
#####################################################################
# for 1024 input


#####################################################################


        self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), DeformConv(64, output_nc, kernel_size=7, padding=0))

    def forward(self, input, dlatents):
        x = input  # 3*224*224
        print(x.shape)
        skip1 = self.first_layer(x)
        print("x after first layer",skip1.shape)
        skip2 = self.down1(skip1)
        print("x after down1",skip2.shape)
        skip3 = self.down2(skip2)
        print("x after down2",skip3.shape)
        if self.deep:
            skip4 = self.down3(skip3)
            x = self.down4(skip4)
        else:
            x = self.down3(skip3)
            print("x after down3",x.shape)
        bot = []
        bot.append(x)
        features = []
        for i in range(len(self.BottleNeck)):
            x = self.BottleNeck[i](x, dlatents)
            bot.append(x)
        print("x after bottleneck",x.shape)
        if self.deep:
            x = self.up4(x)
            features.append(x)
        
        x = self.up3(x)
        print("x after up3",x.shape)
        features.append(x)
        x = self.up2(x)
        print("x after up2",x.shape)
        features.append(x)
        x = self.up1(x)
        print("x after up1",x.shape)
        features.append(x)
        x = self.last_layer(x)
        print("x after last layer",x.shape)
        # x = (x + 1) / 2

        # return x, bot, features, dlatents
        return x


In [11]:
# G = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=True)

latent_id = torch.randn(batch_size, 512)
src_image = torch.randn(batch_size, 256, 56, 56)
# down = DeformConvDownSample(512, 3, 64, kernel_size=3, stride=2, padding=1)
# down(src_image,latent_id).shape
up = DeformConvUpSample(2,512, 256, 128, kernel_size=3, stride=1, padding=1)
up(src_image,latent_id).shape
# G(src_image, latent_id)
# D(src_image,None)[0].shape

torch.Size([4, 128, 112, 112])