In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.autograd.variable import Variable
from torchvision import transforms
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio

import numpy as np
import matplotlib.pyplot as plt

In [None]:
# 패션 MNIST: 10개의 item 가진 데이터넷(티셔츠, 트라우저, 프로버, 드레스, 코트 ...)
transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5,),(0.5,))
                                 ])
to_image = transforms.ToPILImage()
trainset = FashionMNIST(root='./data', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 18362241.43it/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 360122.19it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 6166942.28it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 5191699.20it/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

cpu


모델 정의

In [None]:
#두 가지 모델: 이미지를 생성하는 generator, 가짜 이미지와 진짜 이미지를 구별하는 discriminator
#VanillaGAN: 일반적 GAN, Linear가 들어가있음(CNN 쓰면 convolutional GAN)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    #128: Latent space에서 나가는 벡터의 크기! 내가 정해주면
    self.n_features=128
    #이미지 크기: 28*28 = 784 고정값!
    self.n_out=784
    self.linear = nn.Sequential(nn.Linear(self.n_features, 256),
                                nn.LeakyReLU(0.2),
                                nn.Linear(256,512),
                                nn.LeakyReLU(0.2),
                                nn.Linear(512,1024),
                                nn.LeakyReLU(0.2),
                                nn.Linear(1024, self.n_out),
                                nn.Tanh())
  def forward(self, x):
    x = self.linear(x)
    x = x.view(-1,1,28,28)
    return x

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.n_in = 784
    self.n_out = 1
    self.linear = nn.Sequential(nn.Linear(self.n_in, 1024),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(1024,512),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(512,256),
                                nn.LeakyReLU(0.2),
                                nn.Dropout(0.3),
                                nn.Linear(256, self.n_out),
                                #이미지가 진짜인지 가짜인지 구별하는 이진 분류
                                nn.Sigmoid())
  def forward(self, x):
    x = x.view(-1,784)
    x = self.linear(x)
    return x

손실함수 및 최적화 방법 정의

In [None]:
#GAN의 loss: MinMax 형태로 이루어짐!
#loss를 업데이트 할 때 두 번 업데이트 해줘야 함 -> 기본적으로 업데이트가 잘 안 되는 구조

generator = Generator().to(device)
discriminator = Discriminator().to(device)

pretrained = False
if pretrained == True:
  discriminator.load_state_dict(torch.load('models/fmnist_disc.pth'))
  generator.load_state_dict(torch.load('models/fmnist_gner.pth'))

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)

g_losses=[]
d_losses=[]
images=[] #이미지를 한장씩 모아 gif 형태를 만들기 위해 선언

criterion = nn.BCELoss()

def noise(n, n_features=128): #Latent space에 있는 벡터. 즉 generator로 들어오는 input 값
  return Variable(torch.randn(n, n_features)).to(device)

def label_ones(size):
  data = Variable(torch.ones(size, 1))
  return data.to(device)

def label_zeros(size):
  data = Variable(torch.zeros(size, 1))
  return data.to(device)

학습 전략 정의      

In [None]:
#loss 구하는 전략: 진짜 이미지를 가지고 loss를 한 번 구하고,
#가짜 이미지를 가지고 loss를 한 번 더 구해서 그 둘을 더한 형태
#이론 상으론 discriminator가 구별을 못 하게 하는 것이 목적! 그래서 로스가 max되게 하는 parameter를 계산하는 것

def train_discriminator(optimizer, real_data, fake_data):
  n=real_data.size(0)
  optimizer.zero_grad()

  #진짜 데이터면 1에 가까워지도록 학습해야함
  prediction_real = discriminator(real_data)
  d_loss = criterion(prediction_real, label_ones(n))
  d_loss.backward()
  #가짜 데이터면 0을 출력하도록 학습해야함
  prediction_fake = discriminator(fake_data)
  g_loss = criterion(prediction_fake, label_zeros(n))
  g_loss.backward()
  optimizer.step()

  return d_loss + g_loss

def train_generator(optimizer, fake_data):
  n = fake_data.size(0)
  optimizer.zero_grad()

  #generator를 업데이트 할 때는 G가 만들어 낸 이미지에다가 진짜라고 해서 loss를 계산하는 것
  #loss가 0: discriminator가 진짜 이미지와 가짜 이미지를 구별을 못한다는 뜻
  prediction = discriminator(fake_data)
  loss = criterion(prediction, label_ones(n))

  loss.backward()
  optimizer.step()

  return loss

학습하기

In [None]:
num_epoch = 201
test_noise = noise(64)

l = len(trainloader)

for epoch in range(num_epoch):
  g_loss = 0.0
  d_loss = 0.0

  for data in trainloader:
    imgs,_ = data
    n = len(imgs)

    fake_data = generator(noise(n)).detach()
    real_data = imgs.to(device)
    d_loss += train_discriminator(d_optim, real_data, fake_data)
    fake_data = generator(noise(n))
    g_loss += train_generator(g_optim, fake_data)
  img = generator(test_noise).cpu().detach()
  img = make_grid(img)
  images.append(img)
  g_losses.append(g_loss/l)
  d_losses.append(d_loss/l)

  if epoch % 10 == 0:
    print("epoch: {}, g_loss: {}, d_loss: {}\r".format(epoch, g_loss/l, d_loss/l))
    torch.save(discriminator.state_dict(), 'models/fmnist_disc.pth')
    torch.save(generator.state_dict(), 'models/fmnist_gner.pth')

print('Training Finished')

epoch: 0, g_loss: 3.403390407562256, d_loss: 0.49047815799713135


KeyboardInterrupt: 