In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np
import time


In [2]:
n_epochs=200
batch_size=100
lr=0.0002
latent_size=128
hidden_size=256
n_classes=10
img_size=28
channels=1
sample_dir = './samples/cgan/'
saved_dir='./saved_data/cgan/'
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("folders created!")
else:
    print("folders not created!")
if not os.path.exists(saved_dir):
    os.makedirs(saved_dir)
    print("folders created!")
else:
    print("folders not created!")
#Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])
# MNIST dataset (images and labels)
train_dataset = torchvision.datasets.MNIST(root='../../data', 
                                           train=True, 
                                           transform=transform,
                                           download=True)

# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)


folders not created!
folders not created!


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(latent_size+10, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, img_size*img_size),
            nn.Tanh())

    def forward(self, x,labels):
#         print("x.shape : ",x.shape,", labels.shape : ",labels.shape)
        gen_input=torch.cat((x,labels),-1)
#         print("gen_input.shape : ",gen_input.shape)
        out = self.layer1(gen_input)
        out=out.view(batch_size,-1,28,28)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(img_size*img_size+10, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid())

    def forward(self, x, labels):
        x=x.view(batch_size,-1)
#         print("x.shape:",x.shape,"labels.shape : ",labels.shape)
        dis_input=torch.cat((x,labels),-1)
        out = self.layer1(dis_input)
        return out

In [4]:
generator=Generator().cuda()
discriminator=Discriminator().cuda()

In [5]:
loss_fun=nn.BCELoss()
d_optimizer=torch.optim.Adam(discriminator.parameters(),lr=lr)
g_optimizer=torch.optim.Adam(generator.parameters(),lr=lr)

In [6]:
start = time.time()
# g_optimizer.load_state_dict(torch.load('./saved_data/cgan/G_cgan_151.ckpt'))
# d_optimizer.load_state_dict(torch.load('./saved_data/cgan/D_cgan_151.ckpt'))
total_step=len(train_loader)
n_epochs=200
for epochs in range(n_epochs):
    for i,[images,labels] in enumerate(train_loader):
        true_labels=torch.ones(batch_size,1).cuda()#true 마킹
        false_labels=torch.zeros(batch_size,1).cuda()# false 마킹
        #GT label 을 one hot label 로 변환
        one_hot_labels = np.eye(n_classes)[labels]
        one_hot_labels=np.int_(one_hot_labels)
    #     print("one_hot_labels.shape : ",one_hot_labels.shape)

        images=Variable(images).cuda()
        one_hot_labels=torch.from_numpy(one_hot_labels)
        one_hot_labels=one_hot_labels.type(torch.cuda.FloatTensor)
        one_hot_labels=Variable(one_hot_labels).cuda()
        #------------------------
        #training generator phase
        #------------------------
        g_optimizer.zero_grad()
        #noise 생성
        z=Variable(torch.FloatTensor(torch.randn(batch_size,latent_size))).cuda()

        #random gen label 생성
        random_labels=np.random.randint(0, n_classes, batch_size)
        gen_one_hot_labels = np.eye(n_classes)[random_labels]
        gen_one_hot_labels=torch.from_numpy(gen_one_hot_labels)
        gen_one_hot_labels=gen_one_hot_labels.type(torch.cuda.FloatTensor)
        gen_one_hot_labels=Variable(gen_one_hot_labels).cuda()
    #     print("z.type : ",torch.typename(z),"gen_one_hot_labels.type : ",
    #           torch.typename(gen_one_hot_labels))

        generated_images=generator.forward(z,gen_one_hot_labels)
#         print("generated_images.shape : ", generated_images.shape)
        gen_valid=discriminator(generated_images,gen_one_hot_labels)

        g_loss=loss_fun(gen_valid,true_labels)
        g_loss.backward(retain_graph=True)
        g_optimizer.step()

        #------------------------
        #training discriminator phase
        #------------------------    
        d_optimizer.zero_grad()
        real_valid=discriminator(images,one_hot_labels)
        d_real_loss=loss_fun(real_valid,true_labels)

        fake_valid=discriminator(generated_images,gen_one_hot_labels)
        d_fake_loss=loss_fun(fake_valid,false_labels)

        d_loss=(d_real_loss+d_fake_loss)/2
        d_loss.backward()
        d_optimizer.step()
    print ("[Epoch %d/%d] [D loss: %f] [G loss: %f]" % (epochs, n_epochs,
                                                            d_loss.item(), g_loss.item()))

    save_image(generated_images, os.path.join(
        sample_dir, 'fake_images-{}.png'.format(epochs+1)))
    if(epochs%50==0):
        # Save the model checkpoints
        torch.save(g_optimizer.state_dict(), saved_dir +
                   '/G_cgan_{}.ckpt'.format(epochs+1))
        torch.save(d_optimizer.state_dict(), saved_dir +
                   '/D_cgan_{}.ckpt'.format(epochs+1))
finished = time.time()
hours = finished-start
print("training finished! %d minutes" % hours)       

[Epoch 0/200] [D loss: 0.185662] [G loss: 2.077805]
[Epoch 1/200] [D loss: 0.122951] [G loss: 1.906410]
[Epoch 2/200] [D loss: 0.264744] [G loss: 2.089813]
[Epoch 3/200] [D loss: 0.414117] [G loss: 1.553129]
[Epoch 4/200] [D loss: 0.194032] [G loss: 3.497908]
[Epoch 5/200] [D loss: 0.148088] [G loss: 2.855737]
[Epoch 6/200] [D loss: 0.219789] [G loss: 2.454251]
[Epoch 7/200] [D loss: 0.454168] [G loss: 1.734352]
[Epoch 8/200] [D loss: 0.175178] [G loss: 3.392596]
[Epoch 9/200] [D loss: 0.070911] [G loss: 3.957847]
[Epoch 10/200] [D loss: 0.077189] [G loss: 3.348700]
[Epoch 11/200] [D loss: 0.169894] [G loss: 4.327690]
[Epoch 12/200] [D loss: 0.109357] [G loss: 3.486545]
[Epoch 13/200] [D loss: 0.117750] [G loss: 5.352796]
[Epoch 14/200] [D loss: 0.135896] [G loss: 3.581517]
[Epoch 15/200] [D loss: 0.038170] [G loss: 4.653923]
[Epoch 16/200] [D loss: 0.100294] [G loss: 3.034500]
[Epoch 17/200] [D loss: 0.091170] [G loss: 3.887155]
[Epoch 18/200] [D loss: 0.177643] [G loss: 4.012461]
[Ep

[Epoch 154/200] [D loss: 0.489371] [G loss: 1.241156]
[Epoch 155/200] [D loss: 0.599540] [G loss: 0.902987]
[Epoch 156/200] [D loss: 0.547441] [G loss: 1.068295]
[Epoch 157/200] [D loss: 0.743011] [G loss: 0.911170]
[Epoch 158/200] [D loss: 0.593401] [G loss: 1.153015]
[Epoch 159/200] [D loss: 0.468596] [G loss: 1.253079]
[Epoch 160/200] [D loss: 0.589373] [G loss: 1.159899]
[Epoch 161/200] [D loss: 0.527665] [G loss: 1.091381]
[Epoch 162/200] [D loss: 0.637500] [G loss: 1.146831]
[Epoch 163/200] [D loss: 0.501745] [G loss: 1.135080]
[Epoch 164/200] [D loss: 0.784317] [G loss: 1.034743]
[Epoch 165/200] [D loss: 0.630397] [G loss: 0.995328]
[Epoch 166/200] [D loss: 0.571397] [G loss: 0.910185]
[Epoch 167/200] [D loss: 0.563520] [G loss: 1.047877]
[Epoch 168/200] [D loss: 0.637534] [G loss: 0.992945]
[Epoch 169/200] [D loss: 0.566276] [G loss: 1.124361]
[Epoch 170/200] [D loss: 0.606888] [G loss: 0.931558]
[Epoch 171/200] [D loss: 0.566181] [G loss: 1.094968]
[Epoch 172/200] [D loss: 0.4

In [7]:

# #numpy one hot label
# targets=np.random.randint(0, nb_classes, batch_size)

# one_hot_targets = np.eye(n_classes)[targets]
# int_one_hot=np.int_(one_hot_targets)
# print(int_one_hot.shape)
# tor_one_hot=Variable(torch.LongTensor(int_one_hot))
# print(tor_one_hot.shape,tor_one_hot.type)