## config

In [None]:
# banknote classification config

# 超参配置
# yaml
class Hyperparameter:
    # ################################################################
    #                             Data
    # ################################################################
    device = 'cuda'
    data_root = '../input/celeba-dataset/img_align_celeba'

    image_size = 64
    seed = 1234  # random seed

    # ################################################################
    #                             Model Structure
    # ################################################################
    z_dim = 100
    data_channels = 3

    # ################################################################
    #                             Experiment
    # ################################################################
    batch_size = 64
    n_workers = 4
    beta = 0.5
    init_lr = 0.0002
    epochs = 30
    verbose_step = 250
    save_step = 1000


HP = Hyperparameter()


## 生成必要的目录

In [None]:
import os
import random

random.seed(HP.seed)

for foldername in ['data', 'log', 'model_save', 'model_save/Generator', 'model_save/Discriminator']:
    if not os.path.exists(foldername):
        os.mkdir(foldername)


## dataset_face

In [None]:
from torchvision import transforms as T
import torchvision.datasets as TD
from torch.utils.data import DataLoader
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

data_face = TD.ImageFolder(
    root=HP.data_root,
    transform=T.Compose([
        T.Resize(HP.image_size),
        T.CenterCrop(HP.image_size),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
)

face_loader = DataLoader(data_face, batch_size=HP.batch_size, shuffle=True, num_workers=HP.n_workers)

invTrans = T.Compose([
    T.Normalize((0., 0., 0.), (2, 2, 2)),
    T.Normalize((-0.5, -0.5, -0.5), (1., 1., 1.))
])


## generator

In [None]:
import torch
from torch import nn


class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.projection_layer = nn.Linear(HP.z_dim, 4 * 4 * 1024)

        self.generator = nn.Sequential(

            nn.ConvTranspose2d(
                in_channels=1024,
                out_channels=512,
                kernel_size=(4, 4),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.ConvTranspose2d(
                in_channels=512,
                out_channels=256,
                kernel_size=(4, 4),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(
                in_channels=256,
                out_channels=128,
                kernel_size=(4, 4),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(
                in_channels=128,
                out_channels=HP.data_channels,
                kernel_size=(4, 4),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.Tanh()
        )

    def forward(self, latent_z):
        z = self.projection_layer(latent_z)
        z_projected = z.view(-1, 1024, 4, 4)
        return self.generator(z_projected)

    @staticmethod
    def weights_init(layer):
        layer_class_name = layer.__class__.__name__
        if 'Conv' in layer_class_name:
            nn.init.normal_(layer.weight.data, 0.0, 0.02)
        elif 'BatchNorm' in layer_class_name:
            nn.init.normal_(layer.weight.data, 1.0, 0.02)
            nn.init.normal_(layer.bias.data, 0.0)


## discriminator

In [None]:
import torch
from torch import nn


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.discriminator = nn.Sequential(

            nn.Conv2d(
                in_channels=HP.data_channels,
                out_channels=16,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.LeakyReLU(0.2),

            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),

            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(
                in_channels=64,
                out_channels=128,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(
                in_channels=128,
                out_channels=256,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=(1, 1),
                bias=False
            ),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )

        self.linear_layer = nn.Linear(256 * 2 * 2, 1)
        self.out_ac = nn.Sigmoid()

    def forward(self, image):
        out_d = self.discriminator(image)
        out_d = out_d.view(-1, 256 * 2 * 2)
        return self.out_ac(self.linear_layer(out_d))

    @staticmethod
    def weights_init(layer):
        layer_class_name = layer.__class__.__name__
        if 'Conv' in layer_class_name:
            nn.init.normal_(layer.weight.data, 0.0, 0.02)
        elif 'BatchNorm' in layer_class_name:
            nn.init.normal_(layer.weight.data, 1.0, 0.02)
            nn.init.normal_(layer.bias.data, 0.0)


## trainer

In [None]:
import os
from torch import optim
import torch
import random
import numpy as np
from torch import nn
from tensorboardX import SummaryWriter
import torchvision.utils as vutils

logger = SummaryWriter('./log')

# seed init: 保证模型的可复现性
torch.manual_seed(HP.seed)
torch.random.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)
torch.cuda.manual_seed(HP.seed)


def save_checkpoint(model, epoch, opt, save_path):
    save_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict()
    }
    torch.save(save_dict, save_path)


def train():

    # 实例化Generator和Discriminator，并进行参数的初始化
    G = Generator()
    G.apply(G.weights_init)
    D = Discriminator()
    D.apply(D.weights_init)
    G.to(HP.device)
    D.to(HP.device)

    # loss criterion
    criterion = nn.BCELoss()

    # optimizer
    optimizer_g = optim.Adam(G.parameters(), lr=HP.init_lr, betas=(HP.beta, 0.999))
    optimizer_d = optim.Adam(D.parameters(), lr=HP.init_lr, betas=(HP.beta, 0.999))

    start_epoch, step = 0, 0

    G.train()
    D.train()

    fixed_latent_z = torch.randn(size=(64, 100), device=HP.device)

    for epoch in range(start_epoch, HP.epochs):
        print('Start Epoch: %d, Step: %d' % (epoch, len(face_loader)))
        for batch, _ in face_loader:
            b_size = batch.size(0)

            # 训练Discriminator
            optimizer_d.zero_grad()

            labels_gt = torch.full(size=(b_size,), fill_value=0.9, dtype=torch.float32, device=HP.device)
            predict_labels_gt = D(batch.to(HP.device)).squeeze()
            loss_d_of_gt = criterion(predict_labels_gt, labels_gt)

            labels_fake = torch.full(size=(b_size,), fill_value=0.1, dtype=torch.float32, device=HP.device)
            latent_z = torch.randn(size=(b_size, HP.z_dim), device=HP.device)
            predict_labels_fake = D(G(latent_z)).squeeze()
            loss_d_of_fake = criterion(predict_labels_fake, labels_fake)

            loss_D = loss_d_of_gt + loss_d_of_fake
            loss_D.backward()
            optimizer_d.step()
            logger.add_scalar('Loss/Discriminator', loss_D.mean().item(), step)

            # 训练Generator
            optimizer_g.zero_grad()

            latent_z = torch.randn(size=(b_size, HP.z_dim), device=HP.device)
            labels_for_g = torch.full(size=(b_size,), fill_value=0.9, dtype=torch.float32, device=HP.device)
            predict_labels_from_g = D(G(latent_z)).squeeze()

            loss_G = criterion(predict_labels_from_g, labels_for_g)
            loss_G.backward()
            optimizer_g.step()
            logger.add_scalar('Loss/Generator', loss_G.mean().item(), step)

            if not step % HP.verbose_step:
                with torch.no_grad():
                    fake_image_dev = G(fixed_latent_z)
                    logger.add_image('Generator Face', invTrans(vutils.make_grid(fake_image_dev.detach().cpu(), nrow=8)), step)

            if not step % HP.save_step:
                model_path_g = 'model_g_%d_%d.pth' % (epoch, step)
                save_checkpoint(G, epoch, optimizer_g, os.path.join('model_save','Generator', model_path_g))
                model_path_d = 'model_d_%d_%d.pth' % (epoch, step)
                save_checkpoint(D, epoch, optimizer_d, os.path.join('model_save','Discriminator', model_path_d))

            step += 1
            logger.flush()
            print('Epoch:[%d/%d], step:%d, Discriminator Loss:%.5f, Generator Loss:%.5f' % (
                epoch, HP.epochs, step, loss_D.mean().item(), loss_G.mean().item()))

    logger.close()


## 训练

In [None]:
train()