In [None]:
#base on https://github.com/wmn7/ML_Practice/blob/master/2019_07_15/GAN%20MNIST.ipynb
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 模型超参数

In [None]:
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 28**2
num_epochs = 500
batch_size = 100
sample_dir = 'samples'

In [None]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [None]:
# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])])
                # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) # 3 for RGB channels

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True,
                                          pin_memory=True)

In [None]:
# 查看图片样本
it = iter(data_loader)

In [None]:
x,_ = next(it)
x.shape

## 定义网络

- 问题: generator最后是否要过激活层(tanh)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        x = x.reshape(x.shape[0],-1)
        x = self.leakyrelu(self.map1(x))
        x = self.leakyrelu(self.map2(x))
        x = self.sigmoid(self.map3(x))# 最后生成的是概率
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh() # 激活函数
        
    def forward(self, x):
        x = self.relu(self.map1(x))
        x = self.relu(self.map2(x))
        x = self.tanh(self.map3(x))
        return x

In [None]:
# ----------
# 初始化网络
# ----------
D = Discriminator(input_size=image_size,
                  hidden_size=hidden_size,
                  output_size=1).to(device)


G = Generator(input_size=latent_size,
              hidden_size=hidden_size,
              output_size=image_size).to(device)

summary(D,(1,28,28))

In [None]:
# 定义辅助函数
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [None]:
# 定义损失函数和优化器
learning_rate = 0.0003
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)

d_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(d_optimizer, step_size=50, gamma=0.9)
g_exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(g_optimizer, step_size=50, gamma=0.9)

## 开始训练

- 查看fake_score, real_score的值

In [None]:
total_step = len(data_loader)
# ------------------
# 一开始学习率快一些
# ------------------
for epoch in range(250):
    d_exp_lr_scheduler.step()
    g_exp_lr_scheduler.step()
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        # 创造real label和fake label
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # ---------------------
        # 开始训练discriminator
        # ---------------------
        
        # 首先计算真实的图片
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs # 真实图片的分类结果, 越接近1越好
        
        # 接着使用生成器训练得到图片, 放入判别器
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs # 错误图片的分类结果, 越接近0越好, 最后会趋于1, 生成器生成的判别器判断不了
        
        # 两个loss相加, 反向传播进行优化
        d_loss = d_loss_real + d_loss_fake
        g_optimizer.zero_grad() # 两个优化器梯度都要清0
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # -----------------
        # 开始训练generator
        # -----------------
        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels) # 希望生成器生成的图片判别器可以判别为真
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 600 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}, d_lr={:.6f},g_lr={:.6f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item(),
                         d_optimizer.param_groups[0]['lr'], g_optimizer.param_groups[0]['lr']))
        
        # Save real images
        if (epoch+1) == 1:
            images = images.reshape(images.size(0), 1, 28, 28)
            save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
        # Save sampled images
        if (i+1) %600 == 0:
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# ----------------
# 之后学习率慢一些
# ----------------
d_optimizer.param_groups[0]['initial_lr'] = d_optimizer.param_groups[0]['lr']
g_optimizer.param_groups[0]['initial_lr'] = g_optimizer.param_groups[0]['lr']

d_cos_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(d_optimizer, T_max = 10, eta_min=0.00001)
g_cos_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(g_optimizer, T_max = 10, eta_min=0.00001)
for epoch in range(250, 500):
    d_cos_lr_scheduler.step()
    g_cos_lr_scheduler.step()
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        # 创造real label和fake label
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # ---------------------
        # 开始训练discriminator
        # ---------------------
        
        # 首先计算真实的图片
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs # 真实图片的分类结果, 越接近1越好
        
        # 接着使用生成器训练得到图片, 放入判别器
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs # 错误图片的分类结果, 越接近0越好, 最后会趋于1, 生成器生成的判别器判断不了
        
        # 两个loss相加, 反向传播进行优化
        d_loss = d_loss_real + d_loss_fake
        g_optimizer.zero_grad() # 两个优化器梯度都要清0
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # -----------------
        # 开始训练generator
        # -----------------
        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels) # 希望生成器生成的图片判别器可以判别为真
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 600 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}, d_lr={:.6f},g_lr={:.6f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item(),
                         d_optimizer.param_groups[0]['lr'], g_optimizer.param_groups[0]['lr']))
        
        # Save real images
        if (epoch+1) == 1 and i == 1:
            images = images.reshape(images.size(0), 1, 28, 28)
            save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
        # Save sampled images
#         fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
#         save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
        if (i+1) %600 == 0:
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
        

# Save the model checkpoints 
torch.save(G.state_dict(), './models/G.ckpt')
torch.save(D.state_dict(), './models/D.ckpt')