<a href="https://colab.research.google.com/github/Richardhzj/ML/blob/main/GAN/GAN_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import cv2
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from keras.datasets import mnist
from torchvision.utils import save_image

In [None]:
class Generator(nn.Module):
    def __init__(self, input_dim):
        super(Generator, self).__init__()
        # enter a low dimension vector first 128x100, then convert to image size 56x56
        self.fc1 = nn.Linear(input_dim, 56 * 56)
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True) # inplace设为True，让操作在原地进行
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc1(x)
        x = x.view(-1, 1, 56, 56)
        x = self.br(x)
        x = self.conv1(x)
        x = self.conv2(x)
        output = self.conv3(x)
        return output # [128,1,28,28]

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, stride=1, padding=2),
            nn.LeakyReLU(0.2,True)
        )
        self.pl1 = nn.AvgPool2d(2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, stride=1, padding=2),
            nn.LeakyReLU(0.2,True)
        )
        self.pl2 = nn.AvgPool2d(2, stride=2)
        self.fc1 = nn.Sequential(
            nn.Linear(64 * 7 * 7, 1024),
            nn.LeakyReLU(0.2,True)
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.pl1(x)
        x = self.conv2(x)
        x = self.pl2(x)
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        output = self.fc2(x)
        return output


In [None]:
def G_train(input_dim):
    G_optimizer.zero_grad()

    noise = torch.randn(batch_size, input_dim).to(device)
    # we want to generate real image, so labels are 1
    real_label = torch.ones(batch_size,1).to(device)
    fake_img = G(noise)
    D_output = D(fake_img)
    G_loss = criterion(D_output, real_label)

    G_loss.backward()
    G_optimizer.step()

    return G_loss.data.item()

In [None]:
def D_train(real_img, input_dim):
    # torch.Size([100, 1, 28, 28])
    D_optimizer.zero_grad()

    real_label = torch.ones(real_img.shape[0],1).to(device)
    D_output = D(real_img)
    # print("D output",D_output.shape)
    # print("real label", real_label.shape)
    D_real_loss = criterion(D_output, real_label)

    noise = torch.randn(batch_size, input_dim, requires_grad=False).to(device)
    fake_label = torch.zeros(batch_size,1).to(device)
    fake_img = G(noise)
    D_output = D(fake_img.detach()) # 要不要用fake_img.detach()。 参与计算，变量中涉及的梯度就会自动计算？
    D_fake_loss = criterion(D_output, fake_label)

    D_loss = D_real_loss + D_fake_loss

    D_loss.backward()
    D_optimizer.step()

    return D_loss.data.item()

In [None]:
def save_img(img, img_name):

    # when generating GAN, we normalize pixel to range[-1,1]. this will help rescale to [0,1]
    img = 0.5 * (img + 1)
    img = img.clamp(0, 1)
    save_image(img, "./imgs/" + img_name)
    # print("image has saved.")


In [None]:
batch_size = 100
epoch_num = 30
lr = 0.0002
input_dim = 100

In [None]:
# for batch, (x, _) in enumerate(train_loader):
#   print(x.shape)
#   print(batch)
#   break
# print(len(train_loader))

In [None]:
# fake_img = torch.randn(128, input_dim)
# fake_img = G(fake_img.to(device))
# print(fake_img.shape)

# a = fake_img = 0.5 * (fake_img + 1)
# print(a.shape)

In [None]:
if __name__ == "__main__":

    if not os.path.exists("./checkpoint"):
        os.makedirs("./checkpoint")

    if not os.path.exists("./imgs"):
        os.makedirs("./imgs")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 加载数据
    train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=torchvision.transforms.ToTensor(),
                                   download=True)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    # 构建生成器和判别器网络
    if os.path.exists('./checkpoint/Generator.pkl') and os.path.exists('./checkpoint/Discriminator.pkl'):
        G=torch.load("./checkpoint/Generator.pkl").to(device)
        D=torch.load("./checkpoint/Discriminator.pkl").to(device)
    else:
        G = Generator(input_dim).to(device)
        D = Discriminator().to(device)

    # 指明损失函数和优化器
    criterion = nn.BCELoss()
    G_optimizer = optim.Adam(G.parameters(), lr=lr)
    D_optimizer = optim.Adam(D.parameters(), lr=lr)

    epoch_D_loss = []
    epoch_G_loss = []
    
    print("Training...........")
    for epoch in range(1, epoch_num + 1):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(train_loader)
        print("epoch: ", epoch)
        for batch, (x, _) in enumerate(train_loader):
            # 对判别器和生成器分别进行训练，注意顺序不能反
            D_loss=D_train(x.to(device), input_dim)
            G_loss=G_train(input_dim)

            #if batch % 20 == 0:

            print("[ %d / %d ]  g_loss: %.6f  d_loss: %.6f" % (batch, 600, float(G_loss), float(D_loss)))

            with torch.no_grad():
                d_epoch_loss+=D_loss
                g_epoch_loss+=G_loss


            if batch % 50 == 0:
                fake_img = torch.randn(128, input_dim)
                fake_img = G(fake_img.to(device))
                
                save_img(fake_img, "img_" + str(epoch) + "_" + str(batch) + ".png")
                # 保存模型
                torch.save(G, "./checkpoint/Generator.pkl")
                torch.save(D, "./checkpoint/Discriminator.pkl")
    with torch.no_grad():
        d_epoch_loss/=count
        g_epoch_loss/=count
        epoch_D_loss.append(d_epoch_loss)
        epoch_G_loss.append(g_epoch_loss)

Training...........
epoch:  1
[ 0 / 600 ]  g_loss: 5.261096  d_loss: 0.432289
[ 1 / 600 ]  g_loss: 1.316787  d_loss: 0.758210
[ 2 / 600 ]  g_loss: 1.608086  d_loss: 0.629391
[ 3 / 600 ]  g_loss: 3.068643  d_loss: 0.623034
[ 4 / 600 ]  g_loss: 3.878047  d_loss: 0.430954
[ 5 / 600 ]  g_loss: 4.065329  d_loss: 0.455633
[ 6 / 600 ]  g_loss: 2.880516  d_loss: 0.461076
[ 7 / 600 ]  g_loss: 1.947001  d_loss: 0.338868
[ 8 / 600 ]  g_loss: 1.991178  d_loss: 0.480254
[ 9 / 600 ]  g_loss: 2.413177  d_loss: 0.478868
[ 10 / 600 ]  g_loss: 2.997023  d_loss: 0.512879
[ 11 / 600 ]  g_loss: 3.250618  d_loss: 0.525221
[ 12 / 600 ]  g_loss: 2.680915  d_loss: 0.428168
[ 13 / 600 ]  g_loss: 2.276666  d_loss: 0.391689
[ 14 / 600 ]  g_loss: 2.456187  d_loss: 0.480449
[ 15 / 600 ]  g_loss: 2.726231  d_loss: 0.316564
[ 16 / 600 ]  g_loss: 2.594761  d_loss: 0.419789
[ 17 / 600 ]  g_loss: 2.908499  d_loss: 0.334996
[ 18 / 600 ]  g_loss: 2.908427  d_loss: 0.388803
[ 19 / 600 ]  g_loss: 2.418208  d_loss: 0.410772
