<a href="https://colab.research.google.com/github/StillWork/AIDD-2208-add/blob/main/c_85_12_GAN_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GAN MNIST
- pytorch로 구현
- torchvision이 제공하는 MNIST 사용
- [pytorch lightening 소개](https://baeseongsu.github.io/posts/pytorch-lightning-introduction/)

In [44]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [45]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import pytorch_lightning as pl

class Generator(nn.Module):
  '''
  생성자는 입력으로 크기 100인 노이지를 받고 출력으로 784 크기의 벡터를 생성하고
  이것이 MNIST를 모방하도록 한다. MLP로 구현
  '''
  
  def __init__(self):
    super().__init__()
    self.layer1 = nn.Sequential(nn.Linear(in_features=100, out_features=256),
                                nn.LeakyReLU())
    self.layer2 = nn.Sequential(nn.Linear(in_features=256, out_features=512),
                                nn.LeakyReLU())
    self.layer3 = nn.Sequential(nn.Linear(in_features=512, out_features=1024),
                                nn.LeakyReLU())
    self.output = nn.Sequential(nn.Linear(in_features=1024, out_features=28*28),
                                nn.Tanh())

  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.output(x)
    return x


class Discriminator(nn.Module):
  '''
  판별기는 784 크기의 벡터를 입력으로 받고 이진 분류 확률을 출력한다. 
  생성자의 역과정으로 모델을 구현한다
  '''

  def __init__(self):
    super().__init__()
    self.layer1 = nn.Sequential(nn.Linear(in_features=28*28, out_features=1024),
                                nn.LeakyReLU())
    self.layer2 = nn.Sequential(nn.Linear(in_features=1024, out_features=512),
                                nn.LeakyReLU())
    self.layer3 = nn.Sequential(nn.Linear(in_features=512, out_features=256),
                                nn.LeakyReLU())
    self.output = nn.Sequential(nn.Linear(in_features=256, out_features=1),
                                nn.Sigmoid())
    
  def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.output(x)
    return x


class GAN(pl.LightningModule):

  def __init__(self):
    super().__init__()
    self.generator = Generator()
    self.discriminator = Discriminator()
    self.test_noises = torch.randn(100,1,100, device=device)
    self.test_progression = []

  def forward(self, z):
    """
    노이지 입력 z를 받고 이미지를 생성
    """
    return self.generator(z)

  def generator_step(self, x):
    """
    생성자의 훈련과정
    1. 랜덤 노이지 샘플링
    2. 노이지로부터 이미지 생성
    3. 판별자 구동
    4. Backprop loss
    """
    
    # Sample noise
    # z = torch.randn(x.shape, device=device)
    z = torch.randn(x.shape[0], 1, 100, device=device)

    # Generate images
    generated_imgs = self(z)

    # Classify generated images
    # using the discriminator
    d_output = torch.squeeze(self.discriminator(generated_imgs))

    # Backprop loss
    g_loss = nn.BCELoss()(d_output, torch.ones(x.shape[0], device=device))

    return g_loss

  def discriminator_step(self, x):
    """
    판별자 학습 단계
    1. 실제 이미지를 받는다
    2. 생성자로부터 생성된 가짜 이미지를 받는다
    3. 실제 이미지의 확률을 예측한다
    4. 가짜 이미지의 확률을 예측한다
    5. 손실을 계산하고 역전파한다
    """
    
    # Real images
    d_output = torch.squeeze(self.discriminator(x))
    loss_real = nn.BCELoss()(d_output, torch.ones(x.shape[0], device=device))

    # Fake images
    z = torch.randn(x.shape[0], 1, 100, device=device)
    generated_imgs = self(z)
    d_output = torch.squeeze(self.discriminator(generated_imgs))
    loss_fake = nn.BCELoss()(d_output, torch.zeros(x.shape[0], device=device))

    return loss_real + loss_fake

  def training_step(self, batch, batch_idx, optimizer_idx):
    X, _ = batch

    # train generator
    if optimizer_idx == 0:
      loss = self.generator_step(X)
    
    # train discriminator
    if optimizer_idx == 1:
      loss = self.discriminator_step(X)

    return loss

  def configure_optimizers(self):
    g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=0.0002)
    d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=0.0002)
    return [g_optimizer, d_optimizer], []

  def training_epoch_end(self, training_step_outputs):
    epoch_test_images = self(self.test_noises)
    self.test_progression.append(epoch_test_images)

In [46]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = torchvision.datasets.MNIST(root='../data/MNIST', download=True,
                                  transform=transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5]),
                                    transforms.Lambda(lambda x: x.view(-1, 784))]))

mnist_dataloader = DataLoader(data, batch_size=128, shuffle=True, num_workers=0)

model = GAN()

trainer = pl.Trainer(max_epochs=100, gpus=1 if torch.cuda.is_available() else 0)


  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [47]:
trainer.fit(model, mnist_dataloader)
torch.save(model.state_dict(), './molgan_torch')

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 1.5 M 
1 | discriminator | Discriminator | 1.5 M 
------------------------------------------------
2.9 M     Trainable params
0         Non-trainable params
2.9 M     Total params
11.786    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


# 저장된 모델 사용하기

In [48]:
# from google.colab import drive
# drive.mount('/content/drive')

In [49]:
# model = GAN()
# model.load_state_dict(torch.load('./molgan_torch'))

In [50]:
import numpy as np
from matplotlib import pyplot as plt, gridspec

# Convert images from torch tensor to numpy array
images = [i.detach().cpu().numpy() for i in model.test_progression]

epoch_to_plot = 100
nrow = 3
ncol = 8

# randomly select 10 images for plotting
indexes = np.random.choice(range(100), nrow*ncol, replace=False)

fig = plt.figure(figsize=((ncol+1)*2, (nrow+1)*2)) 
fig.suptitle('Epoch {}'.format(epoch_to_plot), fontsize=30)

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        idx = i*ncol + j
        img = np.reshape(images[epoch_to_plot-1][indexes[idx]], (28,28))
        ax = plt.subplot(gs[i,j])
        ax.imshow(img, cmap='gray')
        ax.axis('off')

IndexError: ignored

<Figure size 1296x576 with 0 Axes>

In [None]:
import numpy as np
from matplotlib import pyplot as plt, gridspec

# Convert images from torch tensor to numpy array
images = [i.detach().cpu().numpy() for i in model.test_progression]

epoch_to_plot = 100
nrow = 3
ncol = 8

# randomly select 10 images for plotting
indexes = np.random.choice(range(100), nrow*ncol, replace=False)

fig = plt.figure(figsize=((ncol+1)*2, (nrow+1)*2)) 
fig.suptitle('Epoch {}'.format(epoch_to_plot), fontsize=30)

gs = gridspec.GridSpec(nrow, ncol,
         wspace=0.0, hspace=0.0, 
         top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
         left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 

for i in range(nrow):
    for j in range(ncol):
        idx = i*ncol + j
        img = np.reshape(images[epoch_to_plot-1][indexes[idx]], (28,28))
        ax = plt.subplot(gs[i,j])
        ax.imshow(img, cmap='gray')
        ax.axis('off')