# Generative adversarial network

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transform
import matplotlib.pyplot as plt

%matplotlib inline

## 模型定义

### 判别器
利用线性层对输入进行线性变换，激活函数使用 `LeakyReLU`  
辨别器输入：形状为 \[batch_size, infeatures\] 的向量  
辨别器输出：得分

$LeakyRule(x) = \begin {cases} x, if x \ge 0 \\ negative\_slope \times x, otherwise \end{cases}$  
shape:  \[batch_size, in_features\]  $\Rightarrow$  \[batch_size, hidden_size\]  $\Rightarrow$  \[batch_size, hidden_size\]  $\Rightarrow$  \[batch_size, 1\]

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_features, hidden_size):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.disc(x)

### 生成器
利用线性层对输入进行线性变换，激活函数前两个线性层使用 `ReLU`，最后一层使用 `Tanh`    
生成器输入：形状为 \[batch_size, z_dim\] 的随机向量  
输出：形状为 \[batch_size, img_dim\] 的向量

$Tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$  
shape: \[batch_size, z_dim\]  $\Rightarrow$  \[batch_size, hidden_size\]  $\Rightarrow$  \[batch_size, hidden_size\]  $\Rightarrow$  \[batch_size, img_dim\]

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, hidden_size, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, img_dim),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.gen(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
in_features = 784
hidden_size = 256
lr = 2e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 256
num_epochs = 300

In [None]:
disc = Discriminator(in_features, hidden_size).to(device)
gen = Generator(z_dim, hidden_size, img_dim).to(device)

## 数据集准备
本次实验使用 `MNIST` 数据集  
transform 用于将图片转换为 tensor (形状：\[C, H, W\] (1,28,28)), 并对 tensor 进行归一化

In [None]:
fixed_noise = torch.randn((100, z_dim)).to(device)
transform = transform.Compose(
    [transform.ToTensor(), transform.Normalize(mean=(0.5,), std=(0.5,))]
)

dataset = datasets.MNIST(root='dataset/mnist/', transform=transform, download=False)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

## Optimizer 和Criterion
**BCELoss**  
$-\frac{1}{N}\sum^N_{i=1}y_ilog(x_i)+(1-y_i)log(1-x_i)$  
使用 `BCELoss` 的目的是让模型的分布接近于未知的真实分布，模型的优化目标是最大化训练数据的概率。

In [None]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

由于归一化后的图片张量的像素值在 \[-1, 1\] 之间，为了便于图片的可视乎，定义 `denorm()` 函数来实现反归一化。  
`save_fake_img()` 用于保存图片。

In [None]:
import os
from IPython.display import Image
from torchvision.utils import save_image

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    
def denorm(x):
    x = (x + 1) / 2
    return x.clamp(0, 1)

def save_fake_img(index):
    fake_images = gen(fixed_noise)
    fake_images = fake_images.reshape(-1, 1, 28, 28)
    fname = 'fake_images-{0:0=4d}.png'.format(index)
    print('Saving', fname)
    save_image(denorm(fake_images), os.path.join(sample_dir, fname), nrow=10)

## 模型训练
### 判别器训练
1. 我们希望判别器对于真实图片输出 1，对生成的图片输出 0  
2. 首先向判别器输入一批真实图片，并计算 loss，该步的标签设置为 1
3. 接着，生成一批假图片输出到判别器， 并计算 loss， 该步的标签设置为 0
4. 最后，求两个 loss 的均值，利用梯度下降调整判别器的参数  

### 生成器训练
1. 我们使用生成器生成一批假图片，并传入判别器进行打分
2. 该步的标签设置为 1， 根据判别器的输出计算 loss
3. 根据上步的 loss 进行梯度下降调整生成器的参数

In [None]:
import time

d_losses , g_losses = [], []
for epoch in range(num_epochs):
    since = time.time()
    for idx, (real, _) in  enumerate(loader):
        real = real.reshape(-1, 784).to(device)
        train_batch_size  = real.shape[0]
        
        noise = torch.randn((train_batch_size, z_dim)).to(device)
        fake = gen(noise)
        disc_real = disc(real).reshape(-1)
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))
        disc_fake= disc(fake).reshape(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        opt_disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        output = disc(fake).reshape(-1)
        lossG = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        d_losses.append(lossD.item())
        g_losses.append(lossG.item())
        
        if (idx+1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Step [{idx+1}/{len(loader)}], Loss D: {lossD:.4f}, loss G: {lossG:.4f}, time: {time.time()-since:.4f}"
            )
        
    save_fake_img(epoch+1)

In [None]:
torch.save(disc.state_dict(), 'D.ckpt')
torch.save(gen.state_dict(), 'G.ckpt')

In [None]:
plt.plot(d_losses, '-')
plt.plot(g_losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('BCELoss')

In [None]:
from IPython.display import Image

## 加载模型参数

In [None]:
# disc.load_state_dict(torch.load('D.ckpt'))
# disc.eval()
gen.load_state_dict(torch.load('G.ckpt'))
gen.eval()

# print(disc)
# print()
print(gen)

## 测试生成器效果

In [None]:
randn_noise = torch.randn((batch_size, z_dim)).to(device)
fake = gen(randn_noise).reshape(-1, 1, 28, 28)
save_image(denorm(fake),os.path.join(sample_dir, 'test_result.png'), nrow = 20)

Image(os.path.join(sample_dir, 'test_result.png'))

In [None]:
# !pip install opencv-python

In [None]:
import cv2
import os
from IPython.display import FileLink

vid_fname = 'gan_traning.avi'
files = [os.path.join(sample_dir, f) for f in os.listdir(sample_dir) if 'fake_images' in  f]
files.sort()

out = cv2.VideoWriter(vid_fname, cv2.VideoWriter_fourcc(*'MJPG'), 8, (302, 302))
[out.write(cv2.imread(fname)) for fname in  files]
out.release()
FileLink('gan_traning.avi')
