In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable


In [2]:
# Hyper-parameters
latent_size = 100
number_g_size =64
number_d_size=64
num_epochs = 200
batch_size = 200
sample_dir = 'samples/SGAN/'
saver_dir = 'saved_data/SGAN/'


# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("created folder")
if not os.path.exists(saver_dir):
    os.makedirs(saver_dir)
    print("created folder")

In [3]:
#Image processing
transform = transforms.Compose([
                transforms.Resize((32,32)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])
train_dataset=torchvision.datasets.MNIST(root='../../data',
                                         train=True,
                                         transform=transform)
# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
print(len(train_loader))

300


In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.layer1=nn.Sequential(
            #layer1 32*32
            nn.Conv2d(1,number_d_size,
                      4, 2, 1,bias=False),   
#             nn.BatchNorm2d(number_d_size),
            nn.LeakyReLU(0.2,inplace=True),
            
            
            #layer3 16*16
            nn.Conv2d(number_d_size,number_d_size*2,
                      4, 2, 1, bias=False),   
            nn.BatchNorm2d(number_d_size*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            #layer4 8*8
            nn.Conv2d(number_d_size*2,number_d_size*4,
                      4, 2, 1, bias=False),   
            nn.BatchNorm2d(number_d_size*4),
            nn.LeakyReLU(0.2,inplace=True),
            
        )
        self.adv_layer=nn.Sequential(
            
            nn.Linear(number_d_size*4*4*4,1),
            nn.Sigmoid())
        self.aux_layer=nn.Sequential(
            nn.Linear(number_d_size*4*4*4,11),
            nn.Softmax())
    def forward(self,x):
        out=self.layer1(x)
        out=out.view(batch_size,-1)
        validity=self.adv_layer(out)
        label=self.aux_layer(out)
        return validity,label
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.layer1=nn.Sequential(
            #4*4
            nn.ConvTranspose2d(latent_size,number_g_size*4,
                               4, 1, 0, bias=False),
            nn.BatchNorm2d(number_g_size*4),
            nn.ReLU(),
            #8*8
            nn.ConvTranspose2d(number_g_size*4,number_g_size*2,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(number_g_size*2),
            nn.ReLU(),
            #16*16
            nn.ConvTranspose2d(number_g_size*2,number_g_size,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(number_g_size),
            nn.ReLU(),
            #32*32
            nn.ConvTranspose2d(number_g_size,1,
                               4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self,x):
        out=self.layer1(x)
        return out


In [5]:
discriminator=Discriminator().cuda()
generator=Generator().cuda()
#loss function
adversarial_loss = torch.nn.BCELoss().cuda()
auxiliary_loss = torch.nn.CrossEntropyLoss().cuda()
# setup optimizer
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=2e-4)
optimizerG = torch.optim.Adam(generator.parameters(), lr=2e-4)

In [6]:
def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)
def reset_grad():
    optimizerD.zero_grad()
    optimizerG.zero_grad()


In [7]:
import numpy as np
import time
# optimizerD.load_state_dict(torch.load('./saved_data/cat_dcgan/D_cifar_151.ckpt'))
# optimizerG.load_state_dict(torch.load('./saved_data/cat_dcgan/G_cifar_151.ckpt'))
start=time.time()
total_step=len(train_loader)
for epochs in range(num_epochs):
    for i, (images,labels)in enumerate(train_loader):
        images=Variable(images).cuda()
        labels=Variable(labels.type(torch.LongTensor)).cuda()
        
        real_labels =Variable(torch.FloatTensor(batch_size, 1).fill_(1.0)).cuda()
        fake_labels =Variable(torch.FloatTensor(batch_size, 1).fill_(0.0)).cuda()
        fake_aux_gt =Variable(torch.LongTensor(batch_size).fill_(10)).cuda()
        
        ##generator##
        optimizerG.zero_grad()
        fixed_noise = torch.randn(batch_size,latent_size , 1, 1).cuda()
        fake_images = generator.forward(fixed_noise)
        validity,_ = discriminator.forward(fake_images)
#         print(validity.shape,real_labels.shape)

        #생성한 이미지를 참으로 받아들이도록 학습
        g_loss = adversarial_loss(validity, real_labels)
        
        # Backprop and optimize
        g_loss.backward(retain_graph=True )
        optimizerG.step()
        
        
        ##discriminator##
        optimizerD.zero_grad()
        
        real_pred,real_aux=discriminator(images)
#         print("real image : ",real_aux.type(),labels.type())
        
        d_real_loss =  (adversarial_loss(real_pred, real_labels) +auxiliary_loss(real_aux, labels)) / 2
        # Loss for fake images
        fake_pred, fake_aux = discriminator(fake_images)
#         print("fake image : ",fake_pred.shape,fake_aux.shape,fake_aux_gt.shape)
        d_fake_loss =  (adversarial_loss(fake_pred, fake_labels) +auxiliary_loss(fake_aux, fake_aux_gt)) / 2
        
        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward(retain_graph=True)
        optimizerD.step()
        
        
        
        
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
                  .format(epochs, num_epochs, i+1, total_step, d_loss.item(), g_loss.item()))
    
    # Save real images
    if(epochs+1) == 1:
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
#     fake_images = fake_images.reshape(fake_images.size(0), 3, 32, 32)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epochs+1)))
    if epochs%50==0:        
        # Save the model checkpoints 
        torch.save(optimizerG.state_dict(), saver_dir+'/G_sgan_{}.ckpt'.format(epochs+1))
        torch.save(optimizerD.state_dict(), saver_dir+'/D_sgan_{}.ckpt'.format(epochs+1)) 
finished=time.time()
hours=finished-start
print("training finished! %d minutes"%hours)

  input = module(input)


Epoch [0/200], Step [200/300], d_loss: 0.8456, g_loss: 4.2643
Epoch [1/200], Step [200/300], d_loss: 0.7939, g_loss: 5.7963
Epoch [2/200], Step [200/300], d_loss: 0.7901, g_loss: 6.0817
Epoch [3/200], Step [200/300], d_loss: 0.7953, g_loss: 4.6891
Epoch [4/200], Step [200/300], d_loss: 0.7829, g_loss: 5.3314
Epoch [5/200], Step [200/300], d_loss: 0.7857, g_loss: 4.8865
Epoch [6/200], Step [200/300], d_loss: 0.7792, g_loss: 6.7758
Epoch [7/200], Step [200/300], d_loss: 0.7956, g_loss: 4.3497
Epoch [8/200], Step [200/300], d_loss: 0.7882, g_loss: 5.2386
Epoch [9/200], Step [200/300], d_loss: 0.7984, g_loss: 3.5530
Epoch [10/200], Step [200/300], d_loss: 0.7927, g_loss: 5.2008
Epoch [11/200], Step [200/300], d_loss: 0.8236, g_loss: 2.8420
Epoch [12/200], Step [200/300], d_loss: 0.8040, g_loss: 4.7949
Epoch [13/200], Step [200/300], d_loss: 0.8302, g_loss: 2.6568
Epoch [14/200], Step [200/300], d_loss: 0.8965, g_loss: 4.4326
Epoch [15/200], Step [200/300], d_loss: 0.8118, g_loss: 4.5898
Ep

Epoch [130/200], Step [200/300], d_loss: 7.9209, g_loss: 0.0000
Epoch [131/200], Step [200/300], d_loss: 7.9206, g_loss: 0.0000
Epoch [132/200], Step [200/300], d_loss: 7.9193, g_loss: 0.0000
Epoch [133/200], Step [200/300], d_loss: 7.9186, g_loss: 0.0000
Epoch [134/200], Step [200/300], d_loss: 7.9187, g_loss: 0.0000
Epoch [135/200], Step [200/300], d_loss: 7.9184, g_loss: 0.0000
Epoch [136/200], Step [200/300], d_loss: 7.9180, g_loss: 0.0000
Epoch [137/200], Step [200/300], d_loss: 7.9183, g_loss: 0.0000
Epoch [138/200], Step [200/300], d_loss: 7.9192, g_loss: 0.0000
Epoch [139/200], Step [200/300], d_loss: 0.8308, g_loss: 4.7578
Epoch [140/200], Step [200/300], d_loss: 0.7879, g_loss: 4.4261
Epoch [141/200], Step [200/300], d_loss: 0.7884, g_loss: 4.0265
Epoch [142/200], Step [200/300], d_loss: 0.7839, g_loss: 5.3882
Epoch [143/200], Step [200/300], d_loss: 0.7779, g_loss: 5.5036
Epoch [144/200], Step [200/300], d_loss: 0.7787, g_loss: 6.0919
Epoch [145/200], Step [200/300], d_loss: