# pytorchでGANしてみる

## torchのインストール

In [None]:
%%bash
pip3 install torch torchvision

## なんか画像保存にエラーがあったので

In [None]:
%%bash
pip install Pillow==4.0.0
pip install PIL
pip install image

## import群

In [21]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import pickle
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

import matplotlib.pyplot as plt
%matplotlib inline

cuda = torch.cuda.is_available()
if cuda :
    print("cuda available" )

## pytorchのクラスについて
- 畳み込み層
```py
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
```
- 正規化
```py
class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
```
- 拡張 (畳み込み層の逆)
```py
class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
```

## ネットワークの重みを初期化

In [3]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
            

## Generatorクラス

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(62, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU()
        )
        # 逆畳み込み層
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid())
        
        initialize_weights(self)
        
    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x

## Discriminatorクラス

In [11]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
        initialize_weights(self)
        
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * 7 * 7)
        x = self.fc(x)
        return x

## 確認

In [12]:
print(Generator())
print(Discriminator())

Generator(
  (fc): Sequential(
    (0): Linear(in_features=62, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=6272, bias=True)
    (4): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (deconv): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Sigmoid()
  )
)
Discriminator(
  (conv): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running

## ハイパーパラメータ

In [15]:
batch_size = 128
lr = 0.0002
z_dim = 62
num_epochs  = 25
sample_num = 16
log_dir = './logs'

## ネットワークの生成
Generator, Discriminatorで異なる最適化器を異なることに注意
```py
G.cuda()
D.cuda()
```
でデータをGPUに転送する

In [16]:
G = Generator()
D = Discriminator()

if cuda:
    G.cuda()
    D.cuda()

# 最適化
G_optim = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optim = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# 損失
criterion = nn.BCELoss()

## データのロード

In [17]:
transform = transforms.Compose([
    transforms.ToTensor()
])
dataset = datasets.MNIST('data/mnist', train=True, download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


## 損失関数
### Discriminatorの目的関数
\begin{align*}
L_D = E[ \log{D(x)}] + E[\log\left(1 - D(G(z))\right) ] \rightarrow \max
\end{align*}
- $D(\cdot) \in [0,1]$は$1$に近いほど本物と判定したことを意味する →Binary Cross Entropy Loss
- $G(z)$は乱数$z$から生成した偽物画像
- 第1項は本物$x$をちゃんと本物だと認識できれば大きくなる
- 第２項は偽物$G(z)$がちゃんと偽物だと分かれば大きくなる
### Generatorの目的関数
\begin{align*}
L_G = E[ \log(D(G(z)) ] \rightarrow \max
\end{align*}
- Gの生成する偽物画像$G(z)$をDが本物だと思うほど値が大きくなる

In [20]:
def train(D, G, criterion, D_optim, G_optim, data_loader):
    # 訓練モードに
    D.train()
    G.train()
    
    #本物のラベルは1, 偽物は0
    y_real = Variable(torch.ones(batch_size, 1))
    y_fake = Variable(torch.zeros(batch_size, 1))
    
    if cuda:
        y_real = y_real.cuda()
        y_fake = y_fake.cuda()
    
    D_running_loss = G_running_loss = 0
    for batch_idx, (real_images, _) in enumerate(data_loader):
        # バッチサイズに足りないなら無視
        if real_images.size()[0] != batch_size:
            break
            
        z = torch.rand((batch_size, z_dim))
        # GPUに転送
        if cuda:
            real_images, z = real_images.cuda(), z.cuda()
        real_images, z = Variable(real_images), Variable(z)
        
        #---------------------
        # ●Discriminatorの更新
        #---------------------
        D_optim.zero_grad()
        
        # E[log(D(x))]
        D_real = D(real_images)
        D_real_loss = criterion(D_real, y_real)
        
        # E[ log(1 - D(G(z))) ] 
        fake_images = G(z)
        D_fake = D(fake_images.detach()) # detachで勾配がGに伝わらないようにする
        D_fake_loss = criterion(D_fake, y_fake)
        
        # Dの損失を計算し、更新
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        D_optim.step()  # ここでGのパラメータが更新されることはない
        D_running_loss += D_loss.data[0]
        
        #------------------
        # ●Generatorの更新
        #------------------
        z = torch.rand((batch_size, z_dim))
        if cuda:
            z = z.cuda()
        z = Variable(z)
        
        G_optim.zero_grad()
        
        # E[log(D(G(z)))]
        fake_images = G(z)   # さっきdetachしてしまったので再利用できない
        D_fake = D(fake_images)
        G_loss = criterion(D_fake, y_real)
        G_loss.backward()
        G_optim.step()
        G_running_loss += G_loss.data[0]
        
    D_running_loss /= len(data_loader)
    G_running_loss /= len(data_loader)
    
    return D_running_loss, G_running_loss
        

## 画像生成関数

In [23]:
def generate(epoch, G, log_dir="logs"):
    G.eval()
    
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
        
    sample_z = torch.rand((64, z_dim))
    if cuda:
        sample_z = sample_z.cuda()    
    sample_z = Variable(sample_z, volatile=True)
    
    # Generatorでサンプル生成
    samples = G(sample_z).data.cpu()
    save_image(samples, os.path.join(log_dir, "epoch_%03d.png" % epoch))

## main

In [None]:
history = {}
history['D_loss'] = []
history['G_loss'] = []
from tqdm import tqdm

for epoch in tqdm(range(num_epochs)):
    D_loss, G_loss = train(D, G, criterion, D_optim, G_optim, data_loader)
    print('epoch %d, D_loss: %.4f G_loss: %.4f' % (epoch + 1, D_loss, G_loss))
    history['D_loss'].append(D_loss)
    history['G_loss'].append(G_loss)
    
    # 特定のエポックでGeneratorから画像を生成してモデルも保存
    if epoch == 0 or epoch == 9 or epoch == 24:
        generate(epoch + 1, G, log_dir)
        torch.save(G.state_dict(), os.path.join(log_dir, 'G_%03d.pth' % (epoch + 1)))
        torch.save(D.state_dict(), os.path.join(log_dir, 'D_%03d.pth' % (epoch + 1)))

# 学習履歴を保存
with open(os.path.join(log_dir, 'history.pkl'), 'wb') as f:
    pickle.dump(history, f)    

## 画像の表示

In [None]:
from IPython.display import Image
Image('logs/epoch_001.png')

In [None]:
Image('logs/epoch_010.png')

In [None]:
Image('logs/epoch_025.png')