# 基于MNIST 实现对抗生成网络（GAN）

In [93]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import numpy as np

# 超参数准备

In [94]:
image_size = [1, 28, 28] # 样本的shape
num_epochs = 100
batch_size = 32
latent_dim = 64 # 生成对抗网络里用于生成器使用的维度
use_gpu = torch.cuda.is_available()

## 生成器

In [95]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 64),
            # 引入batchnorm可以提高收敛速度，具体做法是在生成器的Linear层后面添加BatchNorm1d，最后一层除外，判别器不要加
            torch.nn.BatchNorm1d(64),
            # 将激活函数ReLU换成GELU效果更好
            nn.ReLU(inplace=True),
            nn.Linear(64, 128),
            torch.nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            torch.nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            nn.Tanh(),
        )

    def forward(self, z):
        # z 的维度 [batch_size, latent_dim]
        # output's shape [batch_size, 1x28x28]
        output = self.model(z)
        # image's shape [batch_size, 1, 28, 28]
        # 使用*image_size可以得到类似于元组的数据[1, 28, 28] --> (1, 28, 28)
        image = output.reshape(output.shape[0], *image_size)
        return image

## 判别器

In [96]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # 接受一张照片作为输入，输出一个概率值
        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32), 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid(), # 最后输出一个概率
        )
    def forward(self, image):
        # shape of image: [batch_size, 1, 28, 28]
        # 输入的shape为[batch_size, 1x28x28]
        prob = self.model(image.reshape(image.shape[0], -1))
        return prob

## 数据准备

In [97]:
dataset = torchvision.datasets.MNIST("datasets/mnist", download=True, train=True)
len(dataset)

60000

In [98]:
# 查看数据维度
for i in range(len(dataset)):
    if i < 5:
        print(dataset[i])
    else:
        break

(<PIL.Image.Image image mode=L size=28x28 at 0x238E2F56520>, 5)
(<PIL.Image.Image image mode=L size=28x28 at 0x238BE7B7AF0>, 0)
(<PIL.Image.Image image mode=L size=28x28 at 0x238BE7B7AF0>, 4)
(<PIL.Image.Image image mode=L size=28x28 at 0x238BE7B7AF0>, 1)
(<PIL.Image.Image image mode=L size=28x28 at 0x238BE7B7AF0>, 9)


In [99]:
# 这里是一个PIL格式的数据，调用transforms改变数据的形式
dataset = torchvision.datasets.MNIST("datasets/mnist", download=True, train=True, transform=torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(28), # 转换为 28×28的
        torchvision.transforms.ToTensor(),# 转换为Tensor
        torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
    ]
))

In [100]:
for i in range(len(dataset)):
    if i < 5:
        print(dataset[i][0].shape)
    else:
        break

torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([1, 28, 28])


# 训练

In [101]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator = Generator()
discriminator = Discriminator()
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

loss_func = nn.BCELoss()
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)

if use_gpu:
    print("use gpu for training")
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    loss_fn = loss_func.cuda()
    labels_one = labels_one.to("cuda")
    labels_zero = labels_zero.to("cuda")

for epoch in range(num_epochs):
    for i, mini_batch in enumerate(dataloader): # i是索引，mini_batch 是每一个样本，包含数据和标签
        true_images, _ = mini_batch
        z = torch.randn(batch_size, latent_dim) # 符合高斯分布的随机分布
        if use_gpu:
            true_images = true_images.to("cuda")
            z = z.to("cuda")
        pred_images = generator(z)

        '''
            对于D是要最大化，最大化的内容如下：
            1. 判断真实数据为真实数据的概率
            2. 判断来自生成器生成的数据的为虚假数据的概率

            对于G是要最小化，最小化的内容如下：
            1. 判别器判断来自生成器的数据为真实数据的概率最大
        '''
        # ------------------------------------------------------------------------
        # 生成器优化
        g_optimizer.zero_grad()

        recons_loss = torch.abs(pred_images-true_images).mean()

        # 对于G是要最小化，最小化的内容如下：
        # 1. 判别器判断来自生成器的数据为真实数据的概率最大
        g_loss = recons_loss*0.05 + loss_func(discriminator(pred_images), labels_one) # 这里discriminator(pred_images)是输出生成器是真实数据的概率。所以这里就是输出概率与1的差异

        g_loss.backward()
        g_optimizer.step()

        # ------------------------------------------------------------------------
        # 判别器优化
        d_optimizer.zero_grad()

        # 对于D是要最大化，最大化的内容如下：
        #   1. 判断真实数据为真实数据的概率  loss_func(discriminator(true_images), torch.ones(batch_size, 1))
        #   2. 判断来自生成器生成的数据的为虚假数据的概率  loss_func(discriminator(pred_images.detach()), torch.ones(batch_size, 0))
        # pred_images.detach() 因为是在更新判别器的参数，就不需要更新生成器的参数，因此需要将生成器的内容从计算图中剥离出来
        # real_loss基于真实图片，fake_loss基于生成图片
        real_loss = loss_fn(discriminator(true_images), labels_one)
        fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
        d_loss = (real_loss + fake_loss)
        # d_loss = d_loss / 2 # 取一个平均
        d_loss.backward()
        d_optimizer.step()

        # # 每隔1000步打印一次结果
        # if i % 1000 == 0:
        #     for index, image in enumerate(pred_images):
        #         torchvision.utils.save_image(image, f"image_{index}.png")
        if i % 50 == 0:
            print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")

        if i % 400 == 0:
            image = pred_images[:16].data
            torchvision.utils.save_image(image, f"gan_mnist_output/image_{len(dataloader)*epoch+i}.png", nrow=4)




use gpu for training
step:0, recons_loss:0.9491025805473328, g_loss:0.7523645162582397, d_loss:1.3791230916976929, real_loss:0.6976013779640198, fake_loss:0.6815217137336731
step:50, recons_loss:0.5167753100395203, g_loss:0.777081310749054, d_loss:1.0720994472503662, real_loss:0.4336715638637543, fake_loss:0.6384278535842896
step:100, recons_loss:0.47661033272743225, g_loss:1.0258558988571167, d_loss:1.1103706359863281, real_loss:0.6527243852615356, fake_loss:0.4576461911201477
step:150, recons_loss:0.6542694568634033, g_loss:2.138892889022827, d_loss:0.38093554973602295, real_loss:0.250055193901062, fake_loss:0.13088034093379974
step:200, recons_loss:0.5037420988082886, g_loss:0.2634848654270172, d_loss:1.608109474182129, real_loss:0.051162030547857285, fake_loss:1.5569474697113037
step:250, recons_loss:0.6073771119117737, g_loss:2.094127655029297, d_loss:0.2906925678253174, real_loss:0.1546083241701126, fake_loss:0.13608425855636597
step:300, recons_loss:0.6044324040412903, g_loss:3.

KeyboardInterrupt: 

# 注意点

1. 引入batchnorm可以提高收敛速度，具体做法是在生成器的Linear层后面添加BatchNorm1d，最后一层除外，判别器不要加
2. 直接预测【0,1】之间的像素值即可，不做归一化的transform；或者也可以放大，预测【-1,1】之间，用mean=0.5 std=0.5进行归一化transform都可以
3. 将激活函数ReLU换成GELU效果更好
4. real_loss基于真实图片，fake_loss基于生成图片，real_loss = loss_fn(discriminator(gt_images), torch.ones(batch_size, 1))，fake_loss = loss_fn(discriminator(pred_images.detach()), torch.zeros(batch_size, 1))
5. 适当引入重构loss，计算像素值的L1误差
6. 建议引入loss打印语句，如：print(f"step:{len(dataloader)*epoch+i}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")
7. 判别器模型容量不宜过大
8. save_image中的normalize设置成True，目的是将像素值min-max自动归一到【0,1】范围内，如果已经预测了【0,1】之间，则可以不用设置True
9. 判别器的学习率不能太小
10. Adam的一阶平滑系数和二阶平滑系数 betas 适当调小一点，可以帮助学习，设置一定比例的weight decay