# `GAN`原理

- `generative advercarial networks`，对抗生成网络：让神经网络彼此竞争，提升性能；
![](../images/gan_schema.png)
- 一个`GAN`由两部分组成：
    - `Generator`：以随机分布作为输入，如高斯分布，输入一些数据，如图片；随机的输入可看作待生成图片的编码或潜在表征`(latent representations)`
    - `Discriminator`，以生成器输出的假图片和真实图片作为输入，判断输入是真图片还是假图片
    
      
- 训练时，`generator`和`discriminator`的目标相反，`generator`生成尽可能真的图片以愚弄`discriminator`，`discriminator`试着区分真实图片和假图片；因为两部分的目标不同，训练方式与通常的网络不同
    - 首先训练`discriminator`，以真图片和假图片为输入，标签为1和0，二元交叉熵为损失函数。此阶段，反向传播仅仅优化`discriminator`的权重
    - 然后训练`generator`，利用`generator`产生假图片作为`discriminator`的输入，此时不再有真实图片为输入，标签为1，即训练`generator`产生尽可能真的图片。此阶段，`discriminator`的权重被冻结，反向传播仅仅影响`generator`的权重
    
    
- `generator`从未见过真图片，仅仅接受来自`discriminator`的梯度传播，却可逐渐学习产生令人信服的假照片

**训练`gan`的难点:**
- 理论上，只要训练足够长时间，`gan`最终会达到平衡：`generator`生成完美的图片，迫使`discriminator`随机猜测结果(50%真，50%假)
    
    
- 但是训练`gan`最大的难点为`mode collapse`：当生成器的输出越来越单一，如生成器产生的“鞋子图片”比其它类别图片更真实，`discriminator`更可能被“鞋子图片”愚弄，反过来鼓励生成器生成更多的“鞋子”；最终生成器只能生成鞋子，而`discriminator`见到的假图片只有“鞋子”，将忘记如何区分其它类型的假图片
- 此外，最终两者的参数可能同时产生震荡，而变得不稳定；训练可能突然变得离散；因此`gan`对超参数非常敏感，需要花大量时间进行调参

   
- 目前的解决方案：
    - `experience replay`，储存每次迭代产生的假图片，然后从中选择图片加上真实图片训练`discriminator`，而不是随机产生假图片
    - `mini-batch discrimination`：度量批量图片之间的相似度，然后将统计信息提供给`discriminator`，让其可以拒绝缺乏多样性的输入假图片，鼓励生成器生成不同的图片
    - 一些碰巧运行良好的特殊结构

# `PyTorch`实现

In [18]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

# 指定 gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 参数
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'images'

# 生成图片保存目录
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# 图片归一化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

# 数据集
mnist = torchvision.datasets.MNIST(
    root='datasets',
    train=True,
    transform=transform,
    download=True,
)
# 数据通道
data_loader = torch.utils.data.DataLoader(
    dataset=mnist,
    batch_size=batch_size,
    shuffle=True,
)

# discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid(),
)

# generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh(),
)

# 使用 gpu
D = D.to(device)
G = G.to(device)

# 二元交叉熵损失函数
criterion = nn.BCELoss()

# 两个优化器对应两个训练阶段，两组不同参数的更新
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.002)


# 便于生成的假图片保存的预先处理
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# 两组参数的梯度归零
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

    
    
# 训练模型    
total_step = len(data_loader)

for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        ########## 训练 discriminator ##########
        #######################################
        # 真图片进入 discriminator 后的输出和损失
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # generator 的随机输入和生成的假照片
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        
        # 假图片进入 discriminator 后的输出和损失
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # 总损失
        d_loss = d_loss_real + d_loss_fake
        
        # 更新 discriminator 的参数
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        ########## 训练 generator ##############
        #######################################
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)

        # 反向传播更新 generator 的参数
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 200 == 0:
            print(
                'Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                .format(
                    epoch,
                    num_epochs,
                    i + 1,
                    total_step,
                    d_loss.item(),
                    g_loss.item(),
                    real_score.mean().item(),
                    fake_score.mean().item(),
                ))

    # 保存真图片
    if (epoch + 1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    # 保存生成图片
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(
        denorm(fake_images),
        os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))
    

# # 保存模型
# torch.save(G.state_dict(), 'G.ckpt')
# torch.save(D.state_dict(), 'D.ckpt')    

Epoch [0/200], Step [200/600], d_loss: 0.0000, g_loss: 68.1338, D(x): 1.00, D(G(z)): 0.00
Epoch [0/200], Step [400/600], d_loss: 0.0000, g_loss: 68.4252, D(x): 1.00, D(G(z)): 0.00
Epoch [0/200], Step [600/600], d_loss: 0.0000, g_loss: 69.4815, D(x): 1.00, D(G(z)): 0.00
Epoch [1/200], Step [200/600], d_loss: 0.0000, g_loss: 64.8541, D(x): 1.00, D(G(z)): 0.00
Epoch [1/200], Step [400/600], d_loss: 0.0000, g_loss: 65.4295, D(x): 1.00, D(G(z)): 0.00
Epoch [1/200], Step [600/600], d_loss: 0.0000, g_loss: 68.2684, D(x): 1.00, D(G(z)): 0.00
Epoch [2/200], Step [200/600], d_loss: 0.0000, g_loss: 69.2777, D(x): 1.00, D(G(z)): 0.00
Epoch [2/200], Step [400/600], d_loss: 0.0000, g_loss: 64.2088, D(x): 1.00, D(G(z)): 0.00
Epoch [2/200], Step [600/600], d_loss: 0.0000, g_loss: 67.7681, D(x): 1.00, D(G(z)): 0.00
Epoch [3/200], Step [200/600], d_loss: 0.0000, g_loss: 66.1722, D(x): 1.00, D(G(z)): 0.00
Epoch [3/200], Step [400/600], d_loss: 0.0000, g_loss: 70.4333, D(x): 1.00, D(G(z)): 0.00
Epoch [3/2

Epoch [30/200], Step [400/600], d_loss: 0.0000, g_loss: 63.6769, D(x): 1.00, D(G(z)): 0.00
Epoch [30/200], Step [600/600], d_loss: 0.0000, g_loss: 66.3915, D(x): 1.00, D(G(z)): 0.00
Epoch [31/200], Step [200/600], d_loss: 0.0000, g_loss: 68.7309, D(x): 1.00, D(G(z)): 0.00
Epoch [31/200], Step [400/600], d_loss: 0.0000, g_loss: 65.2301, D(x): 1.00, D(G(z)): 0.00
Epoch [31/200], Step [600/600], d_loss: 0.0000, g_loss: 66.3003, D(x): 1.00, D(G(z)): 0.00
Epoch [32/200], Step [200/600], d_loss: 0.0000, g_loss: 66.3971, D(x): 1.00, D(G(z)): 0.00
Epoch [32/200], Step [400/600], d_loss: 0.0000, g_loss: 65.7016, D(x): 1.00, D(G(z)): 0.00
Epoch [32/200], Step [600/600], d_loss: 0.0000, g_loss: 67.0278, D(x): 1.00, D(G(z)): 0.00
Epoch [33/200], Step [200/600], d_loss: 0.0000, g_loss: 65.3028, D(x): 1.00, D(G(z)): 0.00
Epoch [33/200], Step [400/600], d_loss: 0.0000, g_loss: 69.7041, D(x): 1.00, D(G(z)): 0.00
Epoch [33/200], Step [600/600], d_loss: 0.0000, g_loss: 70.3848, D(x): 1.00, D(G(z)): 0.00

Epoch [60/200], Step [600/600], d_loss: 0.0000, g_loss: 65.9708, D(x): 1.00, D(G(z)): 0.00
Epoch [61/200], Step [200/600], d_loss: 0.0000, g_loss: 67.8512, D(x): 1.00, D(G(z)): 0.00
Epoch [61/200], Step [400/600], d_loss: 0.0000, g_loss: 68.0023, D(x): 1.00, D(G(z)): 0.00
Epoch [61/200], Step [600/600], d_loss: 0.0000, g_loss: 66.3177, D(x): 1.00, D(G(z)): 0.00
Epoch [62/200], Step [200/600], d_loss: 0.0000, g_loss: 66.1169, D(x): 1.00, D(G(z)): 0.00
Epoch [62/200], Step [400/600], d_loss: 0.0000, g_loss: 59.9455, D(x): 1.00, D(G(z)): 0.00
Epoch [62/200], Step [600/600], d_loss: 0.0000, g_loss: 65.6878, D(x): 1.00, D(G(z)): 0.00
Epoch [63/200], Step [200/600], d_loss: 0.0000, g_loss: 61.3779, D(x): 1.00, D(G(z)): 0.00
Epoch [63/200], Step [400/600], d_loss: 0.0000, g_loss: 66.8400, D(x): 1.00, D(G(z)): 0.00
Epoch [63/200], Step [600/600], d_loss: 0.0000, g_loss: 63.1457, D(x): 1.00, D(G(z)): 0.00
Epoch [64/200], Step [200/600], d_loss: 0.0000, g_loss: 67.1520, D(x): 1.00, D(G(z)): 0.00

Epoch [91/200], Step [200/600], d_loss: 0.0000, g_loss: 65.1562, D(x): 1.00, D(G(z)): 0.00
Epoch [91/200], Step [400/600], d_loss: 0.0000, g_loss: 67.1012, D(x): 1.00, D(G(z)): 0.00
Epoch [91/200], Step [600/600], d_loss: 0.0000, g_loss: 65.0608, D(x): 1.00, D(G(z)): 0.00
Epoch [92/200], Step [200/600], d_loss: 0.0000, g_loss: 66.7409, D(x): 1.00, D(G(z)): 0.00
Epoch [92/200], Step [400/600], d_loss: 0.0000, g_loss: 66.4711, D(x): 1.00, D(G(z)): 0.00
Epoch [92/200], Step [600/600], d_loss: 0.0000, g_loss: 63.5650, D(x): 1.00, D(G(z)): 0.00
Epoch [93/200], Step [200/600], d_loss: 0.0000, g_loss: 60.7452, D(x): 1.00, D(G(z)): 0.00
Epoch [93/200], Step [400/600], d_loss: 0.0000, g_loss: 67.6977, D(x): 1.00, D(G(z)): 0.00
Epoch [93/200], Step [600/600], d_loss: 0.0000, g_loss: 64.9749, D(x): 1.00, D(G(z)): 0.00
Epoch [94/200], Step [200/600], d_loss: 0.0000, g_loss: 66.6016, D(x): 1.00, D(G(z)): 0.00
Epoch [94/200], Step [400/600], d_loss: 0.0000, g_loss: 65.8730, D(x): 1.00, D(G(z)): 0.00

Epoch [121/200], Step [200/600], d_loss: 0.0000, g_loss: 59.4375, D(x): 1.00, D(G(z)): 0.00
Epoch [121/200], Step [400/600], d_loss: 0.0000, g_loss: 64.5835, D(x): 1.00, D(G(z)): 0.00
Epoch [121/200], Step [600/600], d_loss: 0.0000, g_loss: 69.1081, D(x): 1.00, D(G(z)): 0.00
Epoch [122/200], Step [200/600], d_loss: 0.0000, g_loss: 63.9074, D(x): 1.00, D(G(z)): 0.00
Epoch [122/200], Step [400/600], d_loss: 0.0000, g_loss: 64.6445, D(x): 1.00, D(G(z)): 0.00
Epoch [122/200], Step [600/600], d_loss: 0.0000, g_loss: 67.3396, D(x): 1.00, D(G(z)): 0.00
Epoch [123/200], Step [200/600], d_loss: 0.0000, g_loss: 64.9800, D(x): 1.00, D(G(z)): 0.00
Epoch [123/200], Step [400/600], d_loss: 0.0000, g_loss: 65.2558, D(x): 1.00, D(G(z)): 0.00
Epoch [123/200], Step [600/600], d_loss: 0.0000, g_loss: 62.5886, D(x): 1.00, D(G(z)): 0.00
Epoch [124/200], Step [200/600], d_loss: 0.0000, g_loss: 64.8200, D(x): 1.00, D(G(z)): 0.00
Epoch [124/200], Step [400/600], d_loss: 0.0000, g_loss: 69.7756, D(x): 1.00, D(

Epoch [151/200], Step [200/600], d_loss: 0.0000, g_loss: 65.3152, D(x): 1.00, D(G(z)): 0.00
Epoch [151/200], Step [400/600], d_loss: 0.0000, g_loss: 62.7729, D(x): 1.00, D(G(z)): 0.00
Epoch [151/200], Step [600/600], d_loss: 0.0000, g_loss: 65.9539, D(x): 1.00, D(G(z)): 0.00
Epoch [152/200], Step [200/600], d_loss: 0.0000, g_loss: 65.8359, D(x): 1.00, D(G(z)): 0.00
Epoch [152/200], Step [400/600], d_loss: 0.0000, g_loss: 65.6470, D(x): 1.00, D(G(z)): 0.00
Epoch [152/200], Step [600/600], d_loss: 0.0000, g_loss: 66.5141, D(x): 1.00, D(G(z)): 0.00
Epoch [153/200], Step [200/600], d_loss: 0.0000, g_loss: 61.7989, D(x): 1.00, D(G(z)): 0.00
Epoch [153/200], Step [400/600], d_loss: 0.0000, g_loss: 66.3275, D(x): 1.00, D(G(z)): 0.00
Epoch [153/200], Step [600/600], d_loss: 0.0000, g_loss: 68.1462, D(x): 1.00, D(G(z)): 0.00
Epoch [154/200], Step [200/600], d_loss: 0.0000, g_loss: 68.1567, D(x): 1.00, D(G(z)): 0.00
Epoch [154/200], Step [400/600], d_loss: 0.0000, g_loss: 63.8428, D(x): 1.00, D(

Epoch [181/200], Step [200/600], d_loss: 0.0000, g_loss: 68.4217, D(x): 1.00, D(G(z)): 0.00
Epoch [181/200], Step [400/600], d_loss: 0.0000, g_loss: 62.4892, D(x): 1.00, D(G(z)): 0.00
Epoch [181/200], Step [600/600], d_loss: 0.0000, g_loss: 65.3117, D(x): 1.00, D(G(z)): 0.00
Epoch [182/200], Step [200/600], d_loss: 0.0000, g_loss: 65.4091, D(x): 1.00, D(G(z)): 0.00
Epoch [182/200], Step [400/600], d_loss: 0.0000, g_loss: 64.6535, D(x): 1.00, D(G(z)): 0.00
Epoch [182/200], Step [600/600], d_loss: 0.0000, g_loss: 65.2029, D(x): 1.00, D(G(z)): 0.00
Epoch [183/200], Step [200/600], d_loss: 0.0000, g_loss: 64.9353, D(x): 1.00, D(G(z)): 0.00
Epoch [183/200], Step [400/600], d_loss: 0.0000, g_loss: 67.2258, D(x): 1.00, D(G(z)): 0.00
Epoch [183/200], Step [600/600], d_loss: 0.0000, g_loss: 63.6162, D(x): 1.00, D(G(z)): 0.00
Epoch [184/200], Step [200/600], d_loss: 0.0000, g_loss: 67.8352, D(x): 1.00, D(G(z)): 0.00
Epoch [184/200], Step [400/600], d_loss: 0.0000, g_loss: 61.7676, D(x): 1.00, D(

# `TensorFlow`实现

In [29]:
import tensorflow as tf

In [None]:
################# 定义模型 ###################
#############################################

coding_size = 30
generator = tf.keras.models.Sequential([
    tf.keras.layers.Dense(100, activation='selu', input_shape=[coding_size]),
    tf.keras.layers.Dense(150, activation='selu'),
    tf.keras.layers.Dense(28 * 28, activation='sigmoid'),
    tf.keras.layers.Reshape([28, 28]),
])
discriminator = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=[28, 28]),
    tf.keras.layers.Dense(150, activation='selu'),
    tf.keras.layers.Dense(100, activation='selu'),
    tf.keras.layers.Dense(1, activation='sigmoid'),
])

gan = tf.keras.models.Sequential([generator, discriminator])

discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')
discriminator.trainable = False

# 先冻结 discriminator 的参数
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')

################# 数据管道 ###################
#############################################
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)


################# 训练模型 ###################
#############################################
def train_gan(gan, dataset, batch_size, coding_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        for x_batch in dataset:
            # 阶段一：训练 discriminator
            noise = tf.random.normal(shape=[batch_size, coding_size])
            generated_images = generator(noise)

            x_fake_and_real = tf.concat([generated_images, x_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)

            discriminator.trainable = True
            discriminator.train_on_batch(x_fake_and_real, y1)

            # 阶段二：训练 generator
            noise = tf.random.normal(shape=[batch_size, coding_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False  # 冻结 discriminator 的参数
            gan.train_on_batch(noise, y2)

# `GAN`架构 

## `Deep Convolutional GANs(DCGANs)`
- 将`CNN`的`pooling`层替换成`strided convolutions`(在`discriminator`中)和`Transposed convolutions`(在`generator`中)
- 除了生成器的输出层和识别器的输入层外，使用`Batch Normalization`
- 在更深的网络中，删除全连接隐藏层
- 生成器使用`ReLU`激活函数，除了输出层使用`tanh`
- 识别器使用`Leaky ReLU`激活函数

In [None]:
codings_size = 100
generator = tf.keras.models.Sequential([
    tf.keras.layers.Dense(7 * 7 * 128, input_shape=[codings_size]),
    tf.keras.layers.Reshape([7, 7, 128]),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(64,
                                    kernel_size=5,
                                    strides=2,
                                    padding='same',
                                    activation='relu'),
    tf.keras.layers.Conv2DTranspose(1,
                                    kernel_size=5,
                                    strides=2,
                                    padding='same',
                                    activation='tanh')
])
discriminator = tf.keras.models.Sequentials([
    tf.keras.layers.Conv2D(64,
                           kernel_size=5,
                           strides=2,
                           padding='same',
                           activation=tf.keras.layers.LeakyReLU(0.2),
                           input_shape=[28, 28, 1]),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Conv2D(128,
                           kernel_size=5,
                           strides=2,
                           padding='same',
                           activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1, activation='sigmoid')
])
gan = tf.keras.models.Sequential([generator, discriminator])

X_train = X_train.reshape(-1, 28, 28, 1) * 2. - 1.

## `Progressive Growing of GANs`
- 在训练开始时，生成小的图片，然后逐渐添加卷积层到生成器的末尾和识别器的开始，以生成更大的图片；之前训练的层仍可以继续训练

## `StyleGANs`

## `CycleGAN`