# 深度卷积生成对抗网络

论文地址：https://arxiv.org/abs/1511.06434

在之前我们看到了如何用GAN来生成服从一个分布的假样本，即将它们转换成与某些数据集的分布相匹配的样本。

那么很自然地，深度卷积网络拥有更强大地判别能力，因此，在本章中，我们将讨论深度卷积网络的生成对抗网络，深度卷积网络已经被成功的应用于计算机视觉之中，因此，在本章中，我们将CNN结合GAN，来生成逼真的图片。

我们使用的数据集是LWF人脸数据集：http://vis-www.cs.umass.edu/lfw/

In [1]:
import os

import mxnet as mx
import numpy as np

from mxnet import nd
from mxnet import gluon
from mxnet import image
from mxnet import autograd

%matplotlib inline
import tarfile
import matplotlib as mlt
mlt.rcParams['figure.dpi'] = 120
import matplotlib.pyplot as plt

In [5]:
epochs = 2
batch_size = 64
latent_z_size = 100

use_gpu = True
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5

In [6]:
lfw_url = 'http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz'
data_path = '../data/lfw_dataset/lfw-deepfunneled'
if not os.path.exists(data_path):
    os.makedirs(data_path)
    data_file = gluon.utils.download(lfw_url)
    with tarfile.open(data_file) as tar:
        tar.extractall(path=data_path)

## 数据预处理

In [7]:
target_wd = 64
target_ht = 64
img_list = []

def transform(data, target_wd, target_ht):
    data = image.imresize(data, target_wd, target_ht)
    data = nd.transpose(data, (2, 0, 1)) # channel X height X width
    # normalize to [-1, 1]
    data = data.astype(np.float32) / 127.5 - 1
    # if img is greyscale, repeat 3 times to get RGB img
    if data.shape[0] == 1:
        data = nd.tile(data, (3,1,1))
    return data.reshape((1,) + data.shape) # reshape to (1,3,64,64)

In [None]:
for path, _, fnames in os.walk(data_path):
    for fname in fnames:
        if not fname.endswith('.jpg'):
            continue
        img = os.path.join(path, fname)
        img_arr = image.imread(img)
        img_arr = transform(img_arr, target_wd, target_ht)
        img_list.append(img_arr)
        
train_data = mx.io.NDArrayIter(data=nd.concatenate(img_list), batch_size=batch_size)

## 数据展示

In [None]:
def visualize(img_arr):
    plt.imshow((img_arr.asnumpy().transpose((1, 2, 0)) + 1.0) * 127.5).astype(np.uint8)
    plt.axis('off')
    
for i in range(4):
    plt.subplot(1,4,i+1)
    visualize(img_list[i + 10][0])
    
plt.show()

## DCGAN

DCGAN使用标准的CNN结构构建判别模型。对生成模型来说，卷积层被上卷积层所取代，所以每层的表示实际上都在相继变大，DCGAN的特点如下：


* Replace any pooling layers with strided convolutions (discriminator) and fractional-strided convolutions (generator).
* Use batch normalization in both the generator and the discriminator.
* Remove fully connected hidden layers for deeper architectures.
* Use ReLU activation in generator for all layers except for the output, which uses Tanh.
* Use LeakyReLU activation in the discriminator for all layers.

<img src="http://gluon.mxnet.io/_images/dcgan.png">

## Generator

In [3]:
nc = 3
ngf = 64
generator = gluon.nn.Sequential()
with generator.name_scope():
    # input Z, going into a convolution
    generator.add(gluon.nn.Conv2DTranspose(channels=8 * ngf, kernel_size=4, strides=1, padding=0, use_bias=False))
    generator.add(gluon.nn.BatchNorm(axis=1))
    generator.add(gluon.nn.Activation('relu'))
    # output size (nfg*8) X 4 X 4
    generator.add(gluon.nn.Conv2DTranspose(channels=4 * ngf, kernel_size=4, strides=2, padding=1, use_bias=False))
    generator.add(gluon.nn.BatchNorm(axis=1))
    generator.add(gluon.nn.Activation('relu'))
    # output size (nfg*8) X 8 X 8
    generator.add(gluon.nn.Conv2DTranspose(channels=2 * ngf, kernel_size=4, strides=2, padding=1, use_bias=False))
    generator.add(gluon.nn.BatchNorm(axis=1))
    generator.add(gluon.nn.Activation('relu'))
    # output size (nfg*8) X 16 X 16
    generator.add(gluon.nn.Conv2DTranspose(channels=ngf, kernel_size=4, strides=2, padding=1, use_bias=False))
    generator.add(gluon.nn.BatchNorm(axis=1))
    generator.add(gluon.nn.Activation('relu'))
    # output size (nfg*8) X 32 X 32
    generator.add(gluon.nn.Conv2DTranspose(channels=nc, kernel_size=4, strides=2, padding=1, use_bias=False))
    generator.add(gluon.nn.Activation('tanh'))
    # output size (nc) X 64 X 64

## Discriminator

In [4]:
ndf = 64
discriminator = gluon.nn.Sequential()
with discriminator.name_scope():
    # input size (nc) X 64 X 64
    discriminator.add(gluon.nn.Conv2D(channels=ndf, kernel_size=4, strides=2, padding=1, use_bias=False))
    discriminator.add(gluon.nn.LeakyReLU(0.2))
    # output size (ndf) * 32 * 32
    discriminator.add(gluon.nn.Conv2D(channels=2 * ndf, kernel_size=4, strides=2, padding=1, use_bias=False))
    discriminator.add(gluon.nn.BatchNorm(axis=1))
    discriminator.add(gluon.nn.LeakyReLU(0.2))
    # output size (ndf*2) * 16 * 16
    discriminator.add(gluon.nn.Conv2D(channels=4 * ndf, kernel_size=4, strides=2, padding=1, use_bias=False))
    discriminator.add(gluon.nn.BatchNorm(axis=1))
    discriminator.add(gluon.nn.LeakyReLU(0.2))    
    # output size (ndf*4) * 8 * 8
    discriminator.add(gluon.nn.Conv2D(channels=8 * ndf, kernel_size=4, strides=2, padding=1, use_bias=False))
    discriminator.add(gluon.nn.BatchNorm(axis=1))
    discriminator.add(gluon.nn.LeakyReLU(0.2))    
    # output size (ndf*8) * 4 * 4
    discriminator.add(gluon.nn.Conv2D(channels=1, kernel_size=4, strides=1, padding=0, use_bias=False))

## 定义损失函数和优化器

In [8]:
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

generator.initialize(init=mx.init.Normal(0.02), ctx=ctx)
discriminator.initialize(init=mx.init.Normal(0.02), ctx=ctx)

trainerG = gluon.Trainer(generator.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
trainerD = gluon.Trainer(discriminator.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

## 训练

In [None]:
from datetime import datetime
import time
import logging

real_label = nd.ones((batch_size, ), ctx=ctx)
fake_label = nd.zeros((batch_size, ), ctx=ctx)

def facc(label, pred):
    pred = pred.ravel() # Return a contiguous flattened array.
    label = pred.revel() # Return a contiguous flattened array.
    return ((pred > 0.5) == label).mean()

metric = mx.metric.CustomMetric(facc)

stamp = datetime.now().strftime("%Y_%m_%d-%H_%M")
logging.basicConfig(level=logging.DEBUG)

for epoch in range(epochs):
    btic = time.time()
    train_data.reset()
    iter = 0
    for batch in train_data:
        ####################################################
        # (1) Update D network : maximize log(D(x)) + log(1 - D(G(z)))
        ####################################################
        data = batch.data[0].as_in_context(ctx)
        latent_z = nd.random.normal(0, 1, shape=(batch_size, latent_z_size, 1, 1), ctx=ctx) # noise
        
        with autograd.record():
            real_output = discriminator(data).reshape((-1, 1))
            err_discrim_real = loss(real_output, real_label)
            metric.update([real_label,], [real_output,]) #
            
            fake = generator(noise)
            fake_output = discriminator(fake.detach()).reshape((-1,1))
            err_discrim_fake = loss(fake_output, fake_label)
            
            err_discrim = err_discrim_fake + err_discrim_real
            err_discrim.backward()
            metric.update([fake_label,], [fake_output,]) # 
            
        trainerD.step(batch.data[0].shape[0])

        ####################################################
        # (1) Update G network : maximize log(D(G(z)))
        ####################################################
        with autograd.record():
            fake = generator(latent_z)
            output = discriminator(fake).reshape((-1,1))
            err_generator = loss(output, real_label)
            err_generator.backward()
            
        trainerG.step(batch.data[0].shape[0])
        
        if iter % 10 == 0:
            name, acc = metric.get()
            logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
            logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                     %(nd.mean(errD).asscalar(),
                       nd.mean(errG).asscalar(), acc, iter, epoch))
    
        iter = iter + 1
        btic = time.time()
        
    name, acc = metric.get()
    metric.reset()

## 测试

给定一个生成器，我们可以生成一些关于脸的图片

In [None]:
num_image = 8

for i in range(num_image):
    latent_z = nd.random.normal(0, 1, shape=(1, latent_z_size, 1, 1), ctx=ctx)
    img = generator(latent_z)
    plt.subplot(2, 4, i+1)
    visualize(img[0])
plt.show()