In [None]:
%matplotlib inline
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torch import autograd
from torch.autograd import Variable
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
# https://github.com/hindupuravinash/the-gan-zoo

# 데이터 https://www.kaggle.com/zalando-research/fashionmnist

class FashionMNIST(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        fashion_df = pd.read_csv('data/fashionmnist/fashion-mnist_train.csv')
        self.labels = fashion_df.label.values
        self.images = fashion_df.iloc[:, 1:].values.astype('uint8').reshape(-1, 28, 28)
        # reshape (28 * 28)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        label = self.labels[idx]
        img = Image.fromarray(self.images[idx])
        
        if self.transform:
            img = self.transform(img)
            
        return img, label

In [None]:
dataset = FashionMNIST()
dataset[0][0]

In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
dataset = FashionMNIST(transform=transform)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.label_emb = nn.Embedding(10, 10) # CGAN의 경우 Label을 함께 학습시켜야 함.
        
        self.model = nn.Sequential(
            nn.Linear(794, 1024), # image 784 (28 * 28) + label 10.
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, labels): # D의 input은 img(x)와 label.
        x = x.view(x.size(0), 784) # img
        c = self.label_emb(labels) # label
        x = torch.cat([x, c], 1) # img와 label값을 concat.
        out = self.model(x) # concat한 x를 model에 할당.
        return out.squeeze()

In [None]:
class Generator(nn.Module): # G의 input은 noise vector.
    def __init__(self):
        super().__init__()
        
        self.label_emb = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(110, 256), # noise vector 100 + label 10.
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
        
    def forward(self, z, labels):
        z = z.view(z.size(0), 100) # noise vector.
        c = self.label_emb(labels) # label.
        x = torch.cat([z, c], 1) # noise 와 label을 concat.
        out = self.model(x)
        return out.view(x.size(0), 28, 28)

In [None]:
#generator = Generator().cuda()
#discriminator = Discriminator().cuda()

generator = Generator().cpu()
discriminator = Discriminator().cpu()

In [None]:
criterion = nn.MSELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

In [None]:
def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion):
    g_optimizer.zero_grad()
    # z = Variable(torch.randn(batch_size, 100)).cuda()
    z = Variable(torch.randn(batch_size, 100)).cpu() # noise vector 생성.
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cpu()
    # noise voctor에 대해 fake_label 선언. (0부터 10까지 랜덤하게 선언)
    fake_images = generator(z, fake_labels) # generator에 넣어 fake_image 생성.
    
    validity = discriminator(fake_images, fake_labels) # Discriminator에 넣어 확률값 출력.
    g_loss = criterion(validity, Variable(torch.ones(batch_size)).cpu()) # 위 확률값과 MSE Loss.
    g_loss.backward()
    g_optimizer.step()
    
    return g_loss.data.cpu()

In [None]:
def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion):
    d_optimizer.zero_grad()
    
    # train with real images
    real_validity = discriminator(real_images, labels) # real_image를 진짜로 판별할 확률값.
    real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).cpu()) # Loss.
    
    # train with fake images (G와 동일)
    z = Variable(torch.randn(batch_size, 100)).cpu()
    fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).cpu()
    fake_images = generator(z, fake_labels)
    fake_validity = discriminator(fake_images, fake_labels)
    fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).cpu())
    
    d_loss = real_loss + fake_loss # Loss 둘을 합함.
    d_loss.backward()
    d_optimizer.step()
    
    return d_loss.data.cpu()

In [None]:
num_epochs = 30
n_critic = 5
display_step = 300
for epoch in range(num_epochs):
    print('Starting epoch {}...'.format(epoch))
    for i, (images, labels) in enumerate(data_loader):
        real_images = Variable(images).cpu()
        labels = Variable(labels).cpu()
        generator.train()
        batch_size = real_images_size(0)
        d_loss = discriminator_train_step(len(real_images), discriminator,
                                          generator, d_optimizer, criterion,
                                          real_images, labels)
        
        g_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion)
        
    generator.eval()
    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    z = Variable(torch.randn(9, 100)).cpu()
    labels = Variable(torch.LongTensor(np.arange(9))).cpu()
    sample_images = generator(z, labels).unsqueeze(1).data.cpu()
    grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy()
    plt.imshow(grid)
    plt.show()

In [None]:
z = Variable(torch.randn(9, 100)).cpu()
z

labels = Variable(torch.LongTensor(np.arange(9))).cpu()
labels

sample_images = generator(z, labels).unsqueeze(1).data.cpu()
grid = make_grid(sample_images, nrow=3, normalize=True).permute(1,2,0).numpy()
plt.imshow(grid)
plt.show()

In [None]:
print(torch.__version__)