In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File              : test_gan.py
# Author            : none <none>
# Date              : 14.04.2022
# Last Modified Date: 15.04.2022
# Last Modified By  : none <none>
""" 基于MNIST 实现对抗生成网络 (GAN) """

import torch
import torchvision
import torch.nn as nn
import numpy as np

image_size = [1, 28, 28]
latent_dim = 96
batch_size = 64
use_gpu = torch.cuda.is_available()

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.GELU(),

            nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256),
            torch.nn.GELU(),
            nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.GELU(),
            nn.Linear(512, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.GELU(),
            nn.Linear(1024, np.prod(image_size, dtype=np.int32)),
            #  nn.Tanh(),
            nn.Sigmoid(),
        )

    def forward(self, z):
        # shape of z: [batchsize, latent_dim]
        # image shape: [batchsize, dims, h, w]
        output = self.model(z)
        image = output.reshape(z.shape[0], *image_size)

        return image


class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(np.prod(image_size, dtype=np.int32), 512),
            torch.nn.GELU(),
            nn.Linear(512, 256),
            torch.nn.GELU(),
            nn.Linear(256, 128),
            torch.nn.GELU(),
            nn.Linear(128, 64),
            torch.nn.GELU(),
            nn.Linear(64, 32),
            torch.nn.GELU(),
            nn.Linear(32, 1),
            nn.Sigmoid(),
        )

    def forward(self, image):
        # shape of image: [batchsize, 1, 28, 28]
        # prob shape: [batchsize, 1]
        prob = self.model(image.reshape(image.shape[0], -1))

        return prob

# Training
dataset = torchvision.datasets.MNIST("./mnist_data", train=True, download=False,
                                     transform=torchvision.transforms.Compose(
                                         [
                                             torchvision.transforms.Resize(28),
                                             torchvision.transforms.ToTensor(),
                                             #  torchvision.transforms.Normalize([0.5], [0.5]),
                                         ]
                                                                             )
                                     )
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

generator = Generator()
discriminator = Discriminator()


g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

loss_fn = nn.BCELoss()
labels_one = torch.ones(batch_size, 1)
labels_zero = torch.zeros(batch_size, 1)

if use_gpu:
    print("use gpu for training")
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    loss_fn = loss_fn.cuda()
    labels_one = labels_one.to("cuda")
    labels_zero = labels_zero.to("cuda")

num_epoch = 10
for epoch in range(num_epoch):
    for i, mini_batch in enumerate(dataloader):
        # 只要X, 不要y(label)
        # gt_images shape = (batch,dims,h,w)
        gt_images, _ = mini_batch
        # torch.randn()则其中每个元素都是从N(0,1)分布取出，并不意味着取出的这些元素组成 N(0,1)分布
        z = torch.randn(batch_size, latent_dim)

        if use_gpu:
            gt_images = gt_images.to("cuda")
            z = z.to("cuda")
        
        pred_images = generator(z)
        g_optimizer.zero_grad()
        # 更新 Generator
        # 绝对值可微分吗？
        recons_loss = torch.abs(pred_images-gt_images).mean()
        
        g_loss = recons_loss*0.05 + loss_fn(discriminator(pred_images), labels_one)

        g_loss.backward()
        g_optimizer.step()
        # 更新 Discriminator
        d_optimizer.zero_grad()
        
        real_loss = loss_fn(discriminator(gt_images), labels_one)
        fake_loss = loss_fn(discriminator(pred_images.detach()), labels_zero)
        d_loss = (real_loss + fake_loss)

        # 观察real_loss与fake_loss，同时下降同时达到最小值，并且差不多大，说明D已经稳定了

        d_loss.backward()
        d_optimizer.step()

        if i % 50 == 0:
            print(f"step:{len(dataloader)*epoch+i}, recons_loss:{recons_loss.item()}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()}, real_loss:{real_loss.item()}, fake_loss:{fake_loss.item()}")

        if i % 400 == 0:
            image = pred_images[:16].data
            torchvision.utils.save_image(image, f"image_{len(dataloader)*epoch+i}.png", nrow=4)



use gpu for training
step:0, recons_loss:0.47394564747810364, g_loss:0.6707997918128967, d_loss:1.3890595436096191, real_loss:0.6476448774337769, fake_loss:0.7414146661758423
step:50, recons_loss:0.14291170239448547, g_loss:1.244321584701538, d_loss:0.9287236928939819, real_loss:0.5431104898452759, fake_loss:0.38561323285102844
step:100, recons_loss:0.16045722365379333, g_loss:0.5868602991104126, d_loss:1.497227668762207, real_loss:0.6713259220123291, fake_loss:0.8259017467498779
step:150, recons_loss:0.1554144024848938, g_loss:0.8444570899009705, d_loss:1.3537132740020752, real_loss:0.7647496461868286, fake_loss:0.5889636278152466
step:200, recons_loss:0.15453919768333435, g_loss:0.748405933380127, d_loss:1.3231477737426758, real_loss:0.6702013611793518, fake_loss:0.6529464721679688
step:250, recons_loss:0.147323340177536, g_loss:0.6857427358627319, d_loss:1.374295949935913, real_loss:0.6647493839263916, fake_loss:0.7095465064048767
step:300, recons_loss:0.1505972146987915, g_loss:0.7