#Training Demo
This is a simple example for training the SimSwap 224*224 with VGGFace2-224.

Code path: https://github.com/neuralchen/SimSwap
If you like the SimSwap project, please star it!
Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630

Installation
All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want.

#Get Scripts

# Install Blocks

#Download the Training Dataset
We employ the cropped VGGFace2-224 dataset for this toy training demo.

You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing

***Please check the dataset in dir /content/TrainingData***

***If dataset already exists in /content/TrainingData, please do not run blow scripts!***


#Trainig
Batch size must larger than 1!

In [9]:
import numpy as np
import torch
from torch import nn
# from models.AeiNet import MultilevelAttributesEncoder,AAD,AADGenerator,AADResBlock
# from models.fs_networks_fix import ApplyStyle as AdaIn
batch_size = 4

# model = MultilevelAttributesEncoder()

# t = torch.randn(batch_size,3,224,224)
# model(t)
# model(t,torch.randn(batch_size,128)).shape

In [42]:
class MultilevelAttributesEncoder(nn.Module):
    def __init__(self):
        super(MultilevelAttributesEncoder, self).__init__()
        self.Encoder_channel = [3, 16, 32, 64, 128, 256, 512]
        self.Encoder = nn.ModuleDict({f'layer_{i}' : nn.Sequential(
                nn.Conv2d(self.Encoder_channel[i], self.Encoder_channel[i+1], kernel_size=4 if i<=4 else 3, stride=2 if i<=4 else 1, padding=1),
                nn.InstanceNorm2d(self.Encoder_channel[i+1]),
                nn.LeakyReLU(0.2)
            )for i in range(len(self.Encoder_channel)-1)})
        # self.Encoder.add_module("pad",nn.ReflectionPad2d(1))
        
        

        # self.Decoder_inchannel = [512, 1024, 512, 256, 128, 64]
        self.Decoder_inchannel = [512, 512, 256, 128, 64]
        self.Decoder_outchannel = [256, 128, 64, 32, 16]
        self.Decoder = nn.ModuleDict({f'layer_{i}' : nn.Sequential(
                nn.ConvTranspose2d(self.Decoder_inchannel[i], self.Decoder_outchannel[i], kernel_size=4 if i>0 else 3, stride=2 if i>0 else 1, padding=1),
                nn.InstanceNorm2d(self.Decoder_outchannel[i]),
                nn.LeakyReLU(0.1)
            )for i in range(len(self.Decoder_inchannel))})

        self.Upsample = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x):
        arr_x = []
        for i in range(len(self.Encoder)):
            x = self.Encoder[f'layer_{i}'](x)
            arr_x.append(x)
        # for x in arr_x:
        #     print("X:",x.shape)

        arr_y = []
        arr_y.append(arr_x[-1])
        y = arr_x[-1]
        for i in range(len(self.Decoder)):
            y = self.Decoder[f'layer_{i}'](y)
            y = torch.cat((y, arr_x[len(self.Decoder)-1-i]), 1)
            arr_y.append(y)
        # for y in arr_y:
        #     print("Y:",y.shape)


        arr_y.append(self.Upsample(y))

        return arr_y
model = MultilevelAttributesEncoder()
t = torch.randn(batch_size,3,224,224)
model(t)
# model.Encoder

X: torch.Size([4, 16, 112, 112])
X: torch.Size([4, 32, 56, 56])
X: torch.Size([4, 64, 28, 28])
X: torch.Size([4, 128, 14, 14])
X: torch.Size([4, 256, 7, 7])
X: torch.Size([4, 512, 7, 7])
Y: torch.Size([4, 512, 7, 7])
Y: torch.Size([4, 512, 7, 7])
Y: torch.Size([4, 256, 14, 14])
Y: torch.Size([4, 128, 28, 28])
Y: torch.Size([4, 64, 56, 56])
Y: torch.Size([4, 32, 112, 112])


[tensor([[[[ 6.0822e-01,  2.4267e-01,  7.5740e-01,  ..., -2.4879e-02,
            -2.3917e-01, -1.1832e-01],
           [-2.9212e-01,  2.3444e-02,  1.7451e+00,  ...,  3.9582e-01,
            -1.5404e-01,  6.1847e-02],
           [ 1.4417e+00, -2.8378e-02,  2.0655e-02,  ..., -2.9327e-01,
            -1.1133e-01,  1.9031e-01],
           ...,
           [ 1.1675e+00,  2.5806e-01,  5.6328e-01,  ..., -2.7711e-01,
             2.1132e+00, -3.7891e-02],
           [ 1.8286e+00,  1.9320e+00, -1.1738e-02,  ...,  6.3708e-01,
            -1.6121e-01, -3.5227e-01],
           [-1.2244e-01,  1.1004e+00, -4.0797e-01,  ..., -5.8012e-02,
             3.7254e-01, -6.8588e-02]],
 
          [[ 1.0929e+00,  1.2302e+00,  1.0089e+00,  ...,  8.0722e-01,
            -1.3695e-01, -1.5951e-01],
           [-2.3181e-01, -2.4490e-03,  3.0253e-01,  ...,  7.9889e-01,
             1.1724e+00, -4.5605e-02],
           [ 4.3962e-01, -5.1985e-02, -1.7301e-01,  ..., -8.0042e-02,
            -2.9639e-02,  1.4747e+00],
