### 1. 필요한 라이브러리 불러오기

In [1]:
import os
import torch.nn as nn
import torch.utils.data
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

### 2. 하이퍼파라미터 세팅

In [2]:
num_epoch = 200
batch_size = 100
learning_rate = 0.0002
img_size = 28 * 28
num_channel = 1
dir_name = "GAN_results"

noise_size = 100

# Device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Now using {} devices".format(device))

Now using cuda devices


### 3. 데이터 세팅

In [3]:
# Create a directory for saving samples
if not os.path.exists(dir_name):
    os.makedirs(dir_name)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [5]:
MNIST_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True,
                                           transform=transform,
                                           download=True)

data_loader = torch.utils.data.DataLoader(dataset=MNIST_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 221740674.49it/s]

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 115367327.45it/s]


Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 65555337.09it/s]

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 18370808.84it/s]


Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw



### 4. Discriminator(판별자)

In [10]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.linear1 = nn.Linear(img_size, 1024)
    self.linear2 = nn.Linear(1024, 512)
    self.linear3 = nn.Linear(512, 256)
    self.linear4 = nn.Linear(256, 1)
    self.leaky_relu = nn.LeakyReLU(0.2)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.leaky_relu(self.linear1(x))
    x = self.leaky_relu(self.linear2(x))
    x = self.leaky_relu(self.linear3(x))
    x = self.linear4(x)
    x = self.sigmoid(x)
    return x

### 5. Generator(생성자)

In [11]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.linear1 = nn.Linear(noise_size, 256)
    self.linear2 = nn.Linear(256, 512)
    self.linear3 = nn.Linear(512, 1024)
    self.linear4 = nn.Linear(1024, img_size)
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.relu(self.linear3(x))
    x = self.linear4(x)
    x = self.tanh(x)
    return x

### 6. Initialize G, D

In [12]:
discriminator = Discriminator()
generator = Generator()

discriminator = discriminator.to(device)
generator = generator.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)

### 7. Training

In [13]:
for epoch in range(num_epoch):
  for i, (images, label) in enumerate(data_loader):

    #make ground truth (labels) -> 1 for real, 0 for fake
    real_label = torch.full((batch_size, 1), 1, dtype=torch.float32).to(device)
    fake_label = torch.full((batch_size, 1), 0, dtype=torch.float32).to(device)

    #reshape real images from MNIST dataset
    real_images = images.reshape(batch_size, -1).to(device)

    #Train Generator

    g_optimizer.zero_grad()
    d_optimizer.zero_grad()

    #make fake images with generator & noise vector z
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)

    # compare result of discriminator with fake images & real labels
    # 만약에, generator 가 discriminator 를 속이면, g_loss 감소
    g_loss = criterion(discriminator(fake_images), real_label)
    g_loss.backward()
    g_optimizer.step()

    #Train Discriminator

    # Initialize grad
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

    # make fake images with generator & noise vector 'z'
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)

    # Calculate fake & real loss with generated images above & real images
    fake_loss = criterion(discriminator(fake_images), fake_label)
    real_loss = criterion(discriminator(real_images), real_label)
    d_loss = (fake_loss + real_loss) / 2

    # Train discriminator with backpropagation
    # In this part, we don't train generator
    d_loss.backward()
    d_optimizer.step()

    d_performance = discriminator(real_images).mean()
    g_performance = discriminator(fake_images).mean()

    if (i + 1) % 150 == 0:
        print("Epoch [ {}/{} ]  Step [ {}/{} ]  d_loss : {:.5f}  g_loss : {:.5f}"
              .format(epoch, num_epoch, i+1, len(data_loader), d_loss.item(), g_loss.item()))

# print discriminator & generator's performance
print(" Epock {}'s discriminator performance : {:.2f}  generator performance : {:.2f}"
      .format(epoch, d_performance, g_performance))

# Save fake images in each epoch
samples = fake_images.reshape(batch_size, 1, 28, 28)
save_image(samples, os.path.join(dir_name, 'GAN_fake_samples{}.png'.format(epoch + 1)))

Epoch [ 0/200 ]  Step [ 150/600 ]  d_loss : 0.03694  g_loss : 3.06692
Epoch [ 0/200 ]  Step [ 300/600 ]  d_loss : 0.05642  g_loss : 11.85699
Epoch [ 0/200 ]  Step [ 450/600 ]  d_loss : 0.00265  g_loss : 9.06607
Epoch [ 0/200 ]  Step [ 600/600 ]  d_loss : 0.00752  g_loss : 9.71986
Epoch [ 1/200 ]  Step [ 150/600 ]  d_loss : 0.00000  g_loss : 34.41891
Epoch [ 1/200 ]  Step [ 300/600 ]  d_loss : 0.00671  g_loss : 9.69560
Epoch [ 1/200 ]  Step [ 450/600 ]  d_loss : 0.00556  g_loss : 15.82375
Epoch [ 1/200 ]  Step [ 600/600 ]  d_loss : 0.01758  g_loss : 19.78484
Epoch [ 2/200 ]  Step [ 150/600 ]  d_loss : 0.34472  g_loss : 14.08827
Epoch [ 2/200 ]  Step [ 300/600 ]  d_loss : 0.10197  g_loss : 18.07524
Epoch [ 2/200 ]  Step [ 450/600 ]  d_loss : 0.20200  g_loss : 4.56785
Epoch [ 2/200 ]  Step [ 600/600 ]  d_loss : 0.79432  g_loss : 5.08132
Epoch [ 3/200 ]  Step [ 150/600 ]  d_loss : 0.20154  g_loss : 4.36742
Epoch [ 3/200 ]  Step [ 300/600 ]  d_loss : 1.21475  g_loss : 1.92222
Epoch [ 3/200 

KeyboardInterrupt: 