In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

from sklearn.model_selection import train_test_split

import pandas as pd

import matplotlib.pyplot as plt
import numpy as np
import time

In [2]:
def get_noise(batch_size, n_noise):
    return np.random.uniform(-1, 1, (batch_size, n_noise))

In [3]:
def get_one_hot(targets, nb_classes):
    res = np.eye(nb_classes)[np.array(targets.cpu().int()).reshape(-1)]
    return res.reshape(list(targets.shape)+[nb_classes])

In [4]:
n_noise = 100
n_channel = 1
batch_size = 100
condition = []

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
    ])

In [6]:
train_data = datasets.MNIST(root='../../datasets/MNIST', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='../../datasets/MNIST', train=False, download=True, transform=transform)

In [7]:
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size = batch_size, shuffle=False)

In [8]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        # 1, 100
        self.fc1 = nn.Linear(n_noise+10, 128*7*7)
            
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            nn.ConvTranspose2d(32, n_channel, 3, 1, 1)
        )
        
    def forward(self, x):
        x = self.fc1(x)
        x = x.reshape(-1, 128, 7, 7)
        x = self.conv(x)
        return F.tanh(x)

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.fc1= nn.Linear(28*28+10, 28*28)
        
        self.conv = nn.Sequential(
            nn.Conv2d(n_channel, 16, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2,2),
            

            nn.Conv2d(128, 128, 3, 1, 1)       
        )
        
        self.fc2 = nn.Linear(128*7*7, 1)
        
    def forward(self, x, y):
        x = torch.cat((x.float().reshape(-1, 28*28), y.float()), 1)
        x = self.fc1(x).reshape(-1, 1, 28 ,28)
        x = self.conv(x).reshape(-1, 128*7*7)
        x = self.fc2(x)
        
        return F.sigmoid(x)


In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

G = Generator()
D = Discriminator()

G = G.to(device)
D = D.to(device)

criterion = nn.BCELoss()
optimizerG = optim.Adam(G.parameters(), lr=0.0002)
optimizerD = optim.Adam(D.parameters(), lr=0.0002)

In [None]:
print("Training start")
start = time.time()

test_label = torch.from_numpy(get_one_hot(test_data.test_labels[0:10], 10)).to(device).float()


for epoch in range(200):
    for data, target in train_loader:
        G.train()
        train_x, train_y = data.reshape(-1, 1, 28, 28).to(device), target.to(device)
        train_y = torch.from_numpy(get_one_hot(train_y, 10)).to(device)
        
        #training Discriminator
        optimizerD.zero_grad()
        d_real_data = train_x.to(device)
        d_real_output = D(d_real_data, train_y)
        d_real_error = criterion(d_real_output, torch.ones(batch_size).reshape(batch_size,1).to(device))
        
        d_gen_input = torch.from_numpy(get_noise(batch_size, n_noise)).float().to(device)
        d_gen_input = torch.cat((d_gen_input, train_y.float()), 1)
        d_fake_data = G(d_gen_input.float()).detach()
        d_fake_output = D(d_fake_data, train_y)
        d_fake_error = criterion(d_fake_output, torch.zeros(batch_size).reshape(batch_size,1).to(device))

        d_train_loss = d_real_error + d_fake_error
        d_train_loss.backward()
        optimizerD.step()
        
        
        #training Generator
        optimizerG.zero_grad()
#         g_data = torch.from_numpy(get_noise(batch_size, n_noise)).float().to(device)
        
        g_fake_output = G(d_gen_input)
        gd_fake_output = D(g_fake_output, train_y)
        g_error = criterion(gd_fake_output, torch.ones(batch_size).reshape(batch_size,1).to(device))
        g_error.backward()
        
        optimizerG.step()
    
    
        lossD = d_train_loss
        lossG = g_error
    G.eval()
    print('epoch %d'%epoch, 'lossD',lossD.data,'lossG',lossG.data)
    
    test_noise = torch.from_numpy(get_noise(10, n_noise)).to(device).float()
    test_noise = torch.cat((test_noise.float(), test_label), 1)
    
    image = G(test_noise)
    fig, ax = plt.subplots(1, 10, figsize=(15,4))
    for i in range(10):
        ax[i].set_axis_off()
        ax[i].imshow(image[i].cpu().data.numpy().reshape(28, 28))
    plt.savefig('./cGAN_pytorch_MNIST_samples/{}.png'.format(str(epoch+200).zfill(3)), bbox_inches='tight')
    plt.close()
        
end = time.time()

print('Elapsed Time : %.3f sec'%(end-start))