In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable


In [4]:
# Hyper-parameters
latent_size = 100
number_g_size = 64
number_d_size=64
num_epochs = 200
batch_size = 100
sample_dir = 'samples/mycifar10_cnn'
saver_dir = 'saved_data/mycifar10_cnn'


# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    print("created folder")
if not os.path.exists(saver_dir):
    os.makedirs(saver_dir)
    print("created folder")
    
#Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])
# # MNIST dataset (images and labels)
train_dataset = torchvision.datasets.CIFAR10(root='../../data', 
                                           train=True, 
                                           transform=transform,
                                           download=True)

test_dataset = torchvision.datasets.CIFAR10(root='../../data', 
                                          train=False, 
                                          transform=transform)

# Data loader (input pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
# Data loader (input pipeline)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)

# #lsun dataset

# train_dataset=torchvision.datasets.LSUN(root='../../data',
#                                         transform=transform)
# train_loader=torch.utils.data.DataLoader(train_dataset,
#                                         batch_size=batch_size,
#                                         shuffle=True)

Files already downloaded and verified


In [3]:
class DCGAN_discriminator(nn.Module):
    def __init__(self):
        super(DCGAN_discriminator,self).__init__()
        self.layer1=nn.Sequential(
#             #layer1 64*64
#             nn.Conv2d(3,number_d_size,
#                       4, 2, 1, bias=False),   # batch x 16 x 28 x 28
#             nn.BatchNorm2d(number_d_size),
#             nn.LeakyReLU(0.2,inplace=True),
            
            #layer2 32*32
            nn.Conv2d(3,number_d_size,
                      4, 2, 1,bias=False),   # batch x 16 x 28 x 28
#             nn.BatchNorm2d(number_d_size),
            nn.LeakyReLU(0.2,inplace=True),
            
            
            #layer3 16*16
            nn.Conv2d(number_d_size,number_d_size*2,
                      4, 2, 1, bias=False),   # batch x 16 x 28 x 28
            nn.BatchNorm2d(number_d_size*2),
            nn.LeakyReLU(0.2,inplace=True),
            
            #layer4 8*8
            nn.Conv2d(number_d_size*2,number_d_size*4,
                      4, 2, 1, bias=False),   # batch x 16 x 28 x 28
            nn.BatchNorm2d(number_d_size*4),
            nn.LeakyReLU(0.2,inplace=True),
            
            #4*4
            nn.Conv2d(number_d_size*4,1,
                      4, 1, 0, bias=False),
            nn.Sigmoid()
#             nn.LeakyReLU(0.2,inplace=True)

        )
    def forward(self,x):
        out=self.layer1(x)
        return out
class DCGAN_generator(nn.Module):
    def __init__(self):
        super(DCGAN_generator,self).__init__()
        self.layer1=nn.Sequential(
            #
#             nn.ConvTranspose2d(latent_size,number_g_size*8,
#                                4, 1, 0, bias=False),
#             nn.BatchNorm2d(number_g_size*8),
#             nn.ReLU(),
            #4*4
            nn.ConvTranspose2d(latent_size,number_g_size*4,
                               4, 1, 0, bias=False),
            nn.BatchNorm2d(number_g_size*4),
            nn.ReLU(),
            #8*8
            nn.ConvTranspose2d(number_g_size*4,number_g_size*2,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(number_g_size*2),
            nn.ReLU(),
            #16*16
            nn.ConvTranspose2d(number_g_size*2,number_g_size,
                               4, 2, 1, bias=False),
            nn.BatchNorm2d(number_g_size),
            nn.ReLU(),
            #32*32
            nn.ConvTranspose2d(number_g_size,3,
                               4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self,x):
        out=self.layer1(x)
        return out


In [4]:
discriminator=DCGAN_discriminator().cuda()
generator=DCGAN_generator().cuda()

In [5]:
criterion = nn.BCELoss()
fixed_noise = torch.randn(batch_size,latent_size , 1, 1)

In [6]:
# setup optimizer
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=2e-4)
optimizerG = torch.optim.Adam(generator.parameters(), lr=2e-4)


In [7]:
print(fixed_noise.shape)

torch.Size([100, 100, 1, 1])


In [8]:
def denorm(x):
    out=(x+1)/2
    return out.clamp(0,1)
def reset_grad():
    optimizerD.zero_grad()
    optimizerG.zero_grad()


In [9]:
total_step=len(train_loader)
for epochs in range(num_epochs):
    for i, (images,_)in enumerate(train_loader):
        images=Variable(images).cuda()
#         print("images shape : ",images.shape)
        
#         images=Variable(images).cuda()
        
        real_labels = torch.ones(batch_size, 1).cuda()
        fake_labels = torch.zeros(batch_size, 1).cuda()
        ##discriminator##
        outputs=discriminator.forward(images)
        d_loss_real=criterion(outputs,real_labels)
        real_score=outputs
#         print("discriminator_size : ",outputs.shape)
        
        fixed_noise = torch.randn(batch_size,latent_size , 1, 1).cuda()
#         print("z size : ",fixed_noise.shape)
        fake_images=generator.forward(fixed_noise)
        outputs=discriminator.forward(fake_images)
        d_loss_fake=criterion(outputs,fake_labels)
        fake_score=outputs
#         print("generator size : ",outputs.shape)
        
        d_loss=d_loss_real+d_loss_fake
        reset_grad()
        d_loss.backward()
        optimizerD.step()
        
        
        ##generator##
#         fixed_noise = torch.randn(batch_size,latent_size , 1, 1).cuda()
        fake_images = generator.forward(fixed_noise)
#         print("generator size : ",fake_images.shape)
        outputs = discriminator.forward(fake_images)
#         print("generator size : ",outputs.shape)
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        optimizerG.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epochs, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epochs+1) == 1:
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
#     fake_images = fake_images.reshape(fake_images.size(0), 3, 32, 32)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epochs+1)))
    if epochs%50==0:        
        # Save the model checkpoints 
        torch.save(optimizerG.state_dict(), saver_dir+'/G_cifar_{}.ckpt'.format(epochs+1))
        torch.save(optimizerD.state_dict(), saver_dir+'/D_cifar_{}.ckpt'.format(epochs+1))    
print("training finished!")

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/400], Step [200/500], d_loss: 0.0410, g_loss: 4.7330, D(x): 0.98, D(G(z)): 0.02
Epoch [0/400], Step [400/500], d_loss: 0.0585, g_loss: 5.5828, D(x): 0.96, D(G(z)): 0.01
Epoch [1/400], Step [200/500], d_loss: 0.0104, g_loss: 6.5475, D(x): 1.00, D(G(z)): 0.01
Epoch [1/400], Step [400/500], d_loss: 0.0798, g_loss: 6.1625, D(x): 0.98, D(G(z)): 0.04
Epoch [2/400], Step [200/500], d_loss: 0.0369, g_loss: 7.0229, D(x): 0.99, D(G(z)): 0.02
Epoch [2/400], Step [400/500], d_loss: 0.0394, g_loss: 7.9474, D(x): 0.99, D(G(z)): 0.02
Epoch [3/400], Step [200/500], d_loss: 0.0220, g_loss: 7.5327, D(x): 0.99, D(G(z)): 0.01
Epoch [3/400], Step [400/500], d_loss: 0.0600, g_loss: 7.2458, D(x): 0.97, D(G(z)): 0.02
Epoch [4/400], Step [200/500], d_loss: 0.0815, g_loss: 6.5677, D(x): 0.97, D(G(z)): 0.04
Epoch [4/400], Step [400/500], d_loss: 0.0258, g_loss: 6.7934, D(x): 0.99, D(G(z)): 0.01
Epoch [5/400], Step [200/500], d_loss: 0.0565, g_loss: 6.5097, D(x): 0.99, D(G(z)): 0.04
Epoch [5/400], Step [

Epoch [46/400], Step [200/500], d_loss: 0.2461, g_loss: 4.5910, D(x): 0.91, D(G(z)): 0.10
Epoch [46/400], Step [400/500], d_loss: 0.0616, g_loss: 4.5441, D(x): 0.99, D(G(z)): 0.04
Epoch [47/400], Step [200/500], d_loss: 0.1112, g_loss: 4.6634, D(x): 0.98, D(G(z)): 0.07
Epoch [47/400], Step [400/500], d_loss: 0.1402, g_loss: 5.5572, D(x): 0.93, D(G(z)): 0.03
Epoch [48/400], Step [200/500], d_loss: 0.3327, g_loss: 5.0777, D(x): 0.82, D(G(z)): 0.04
Epoch [48/400], Step [400/500], d_loss: 0.1070, g_loss: 4.9599, D(x): 0.97, D(G(z)): 0.07
Epoch [49/400], Step [200/500], d_loss: 0.1298, g_loss: 3.9724, D(x): 0.94, D(G(z)): 0.06
Epoch [49/400], Step [400/500], d_loss: 0.1069, g_loss: 4.2958, D(x): 0.99, D(G(z)): 0.08
Epoch [50/400], Step [200/500], d_loss: 0.1206, g_loss: 4.9305, D(x): 0.97, D(G(z)): 0.06
Epoch [50/400], Step [400/500], d_loss: 0.1429, g_loss: 4.5340, D(x): 0.98, D(G(z)): 0.08
Epoch [51/400], Step [200/500], d_loss: 0.2723, g_loss: 4.3904, D(x): 0.99, D(G(z)): 0.19
Epoch [51/

Epoch [92/400], Step [200/500], d_loss: 0.2179, g_loss: 3.4205, D(x): 0.99, D(G(z)): 0.16
Epoch [92/400], Step [400/500], d_loss: 0.1698, g_loss: 3.0007, D(x): 0.97, D(G(z)): 0.12
Epoch [93/400], Step [200/500], d_loss: 0.7187, g_loss: 4.5652, D(x): 0.59, D(G(z)): 0.01
Epoch [93/400], Step [400/500], d_loss: 0.1228, g_loss: 4.2269, D(x): 0.99, D(G(z)): 0.08
Epoch [94/400], Step [200/500], d_loss: 0.3756, g_loss: 3.8359, D(x): 0.79, D(G(z)): 0.05
Epoch [94/400], Step [400/500], d_loss: 0.3192, g_loss: 3.0379, D(x): 0.85, D(G(z)): 0.10
Epoch [95/400], Step [200/500], d_loss: 0.4445, g_loss: 2.5260, D(x): 0.95, D(G(z)): 0.26
Epoch [95/400], Step [400/500], d_loss: 0.1479, g_loss: 4.4082, D(x): 0.92, D(G(z)): 0.04
Epoch [96/400], Step [200/500], d_loss: 0.0701, g_loss: 6.0685, D(x): 0.94, D(G(z)): 0.01
Epoch [96/400], Step [400/500], d_loss: 0.1707, g_loss: 3.5069, D(x): 0.96, D(G(z)): 0.11
Epoch [97/400], Step [200/500], d_loss: 0.3762, g_loss: 3.0256, D(x): 0.98, D(G(z)): 0.24
Epoch [97/

Epoch [137/400], Step [400/500], d_loss: 0.5307, g_loss: 2.4394, D(x): 0.76, D(G(z)): 0.13
Epoch [138/400], Step [200/500], d_loss: 0.1992, g_loss: 4.2344, D(x): 0.98, D(G(z)): 0.13
Epoch [138/400], Step [400/500], d_loss: 0.4060, g_loss: 3.1724, D(x): 1.00, D(G(z)): 0.25
Epoch [139/400], Step [200/500], d_loss: 0.1545, g_loss: 4.5594, D(x): 0.89, D(G(z)): 0.02
Epoch [139/400], Step [400/500], d_loss: 0.1464, g_loss: 4.1436, D(x): 0.95, D(G(z)): 0.07
Epoch [140/400], Step [200/500], d_loss: 0.0705, g_loss: 6.1425, D(x): 0.99, D(G(z)): 0.05
Epoch [140/400], Step [400/500], d_loss: 0.0502, g_loss: 5.0146, D(x): 0.98, D(G(z)): 0.03
Epoch [141/400], Step [200/500], d_loss: 0.0886, g_loss: 3.8748, D(x): 0.97, D(G(z)): 0.05
Epoch [141/400], Step [400/500], d_loss: 0.1263, g_loss: 2.7786, D(x): 0.99, D(G(z)): 0.10
Epoch [142/400], Step [200/500], d_loss: 0.0563, g_loss: 4.6043, D(x): 0.97, D(G(z)): 0.02
Epoch [142/400], Step [400/500], d_loss: 0.1065, g_loss: 4.4406, D(x): 0.96, D(G(z)): 0.05

Epoch [183/400], Step [200/500], d_loss: 0.1665, g_loss: 4.6590, D(x): 0.91, D(G(z)): 0.05
Epoch [183/400], Step [400/500], d_loss: 0.0645, g_loss: 5.0548, D(x): 0.99, D(G(z)): 0.05
Epoch [184/400], Step [200/500], d_loss: 0.0718, g_loss: 5.2124, D(x): 0.96, D(G(z)): 0.03
Epoch [184/400], Step [400/500], d_loss: 0.1041, g_loss: 3.1353, D(x): 1.00, D(G(z)): 0.08
Epoch [185/400], Step [200/500], d_loss: 0.0811, g_loss: 4.4204, D(x): 0.97, D(G(z)): 0.04
Epoch [185/400], Step [400/500], d_loss: 0.2371, g_loss: 4.8199, D(x): 0.85, D(G(z)): 0.04
Epoch [186/400], Step [200/500], d_loss: 0.1683, g_loss: 3.7665, D(x): 0.98, D(G(z)): 0.12
Epoch [186/400], Step [400/500], d_loss: 0.2283, g_loss: 5.2294, D(x): 0.84, D(G(z)): 0.02
Epoch [187/400], Step [200/500], d_loss: 0.0444, g_loss: 5.0190, D(x): 0.99, D(G(z)): 0.03
Epoch [187/400], Step [400/500], d_loss: 0.1457, g_loss: 3.5495, D(x): 0.97, D(G(z)): 0.09
Epoch [188/400], Step [200/500], d_loss: 0.1677, g_loss: 4.2696, D(x): 0.96, D(G(z)): 0.09

Epoch [228/400], Step [400/500], d_loss: 0.0870, g_loss: 5.0853, D(x): 0.98, D(G(z)): 0.05
Epoch [229/400], Step [200/500], d_loss: 0.0774, g_loss: 6.1350, D(x): 0.98, D(G(z)): 0.04
Epoch [229/400], Step [400/500], d_loss: 0.0698, g_loss: 6.3270, D(x): 0.98, D(G(z)): 0.04
Epoch [230/400], Step [200/500], d_loss: 0.0286, g_loss: 6.1134, D(x): 1.00, D(G(z)): 0.02
Epoch [230/400], Step [400/500], d_loss: 0.3201, g_loss: 4.5353, D(x): 0.82, D(G(z)): 0.05
Epoch [231/400], Step [200/500], d_loss: 0.0719, g_loss: 5.4487, D(x): 1.00, D(G(z)): 0.05
Epoch [231/400], Step [400/500], d_loss: 0.0718, g_loss: 7.0675, D(x): 0.94, D(G(z)): 0.01
Epoch [232/400], Step [200/500], d_loss: 0.0587, g_loss: 5.3347, D(x): 0.97, D(G(z)): 0.03
Epoch [232/400], Step [400/500], d_loss: 0.0216, g_loss: 5.3382, D(x): 1.00, D(G(z)): 0.02
Epoch [233/400], Step [200/500], d_loss: 0.0501, g_loss: 5.4909, D(x): 0.99, D(G(z)): 0.03
Epoch [233/400], Step [400/500], d_loss: 0.0757, g_loss: 6.6057, D(x): 0.94, D(G(z)): 0.01

Epoch [274/400], Step [200/500], d_loss: 0.0782, g_loss: 6.6528, D(x): 0.98, D(G(z)): 0.05
Epoch [274/400], Step [400/500], d_loss: 0.0729, g_loss: 6.1574, D(x): 0.98, D(G(z)): 0.04
Epoch [275/400], Step [200/500], d_loss: 0.1622, g_loss: 5.3277, D(x): 0.99, D(G(z)): 0.09
Epoch [275/400], Step [400/500], d_loss: 0.1697, g_loss: 5.0810, D(x): 0.90, D(G(z)): 0.02
Epoch [276/400], Step [200/500], d_loss: 0.0845, g_loss: 7.2923, D(x): 0.96, D(G(z)): 0.02
Epoch [276/400], Step [400/500], d_loss: 0.0704, g_loss: 4.5869, D(x): 0.97, D(G(z)): 0.03
Epoch [277/400], Step [200/500], d_loss: 0.0557, g_loss: 6.4630, D(x): 0.96, D(G(z)): 0.01
Epoch [277/400], Step [400/500], d_loss: 0.0275, g_loss: 6.6459, D(x): 0.98, D(G(z)): 0.01
Epoch [278/400], Step [200/500], d_loss: 0.0942, g_loss: 5.8897, D(x): 0.93, D(G(z)): 0.01
Epoch [278/400], Step [400/500], d_loss: 0.0746, g_loss: 6.8237, D(x): 0.95, D(G(z)): 0.01
Epoch [279/400], Step [200/500], d_loss: 0.0731, g_loss: 6.8514, D(x): 0.96, D(G(z)): 0.02

Epoch [319/400], Step [400/500], d_loss: 0.0428, g_loss: 6.3658, D(x): 0.99, D(G(z)): 0.03
Epoch [320/400], Step [200/500], d_loss: 0.0661, g_loss: 5.2937, D(x): 0.96, D(G(z)): 0.01
Epoch [320/400], Step [400/500], d_loss: 0.2618, g_loss: 5.0478, D(x): 0.96, D(G(z)): 0.12
Epoch [321/400], Step [200/500], d_loss: 0.0423, g_loss: 6.5748, D(x): 0.99, D(G(z)): 0.03
Epoch [321/400], Step [400/500], d_loss: 0.0367, g_loss: 7.5696, D(x): 0.97, D(G(z)): 0.01
Epoch [322/400], Step [200/500], d_loss: 0.0265, g_loss: 7.2724, D(x): 0.98, D(G(z)): 0.01
Epoch [322/400], Step [400/500], d_loss: 0.0964, g_loss: 6.0697, D(x): 1.00, D(G(z)): 0.06
Epoch [323/400], Step [200/500], d_loss: 0.0964, g_loss: 5.2505, D(x): 0.96, D(G(z)): 0.04
Epoch [323/400], Step [400/500], d_loss: 0.2147, g_loss: 5.3448, D(x): 0.87, D(G(z)): 0.03
Epoch [324/400], Step [200/500], d_loss: 0.0418, g_loss: 7.0376, D(x): 0.99, D(G(z)): 0.02
Epoch [324/400], Step [400/500], d_loss: 0.0707, g_loss: 9.0977, D(x): 0.97, D(G(z)): 0.02

Epoch [365/400], Step [200/500], d_loss: 0.0990, g_loss: 5.0029, D(x): 0.99, D(G(z)): 0.05
Epoch [365/400], Step [400/500], d_loss: 0.0208, g_loss: 7.1473, D(x): 0.98, D(G(z)): 0.00
Epoch [366/400], Step [200/500], d_loss: 0.0249, g_loss: 7.7552, D(x): 0.99, D(G(z)): 0.01
Epoch [366/400], Step [400/500], d_loss: 0.0374, g_loss: 6.0115, D(x): 0.99, D(G(z)): 0.03
Epoch [367/400], Step [200/500], d_loss: 0.0726, g_loss: 5.9502, D(x): 0.98, D(G(z)): 0.05
Epoch [367/400], Step [400/500], d_loss: 0.0152, g_loss: 7.8761, D(x): 0.99, D(G(z)): 0.00
Epoch [368/400], Step [200/500], d_loss: 0.0151, g_loss: 9.1675, D(x): 0.99, D(G(z)): 0.00
Epoch [368/400], Step [400/500], d_loss: 0.0929, g_loss: 6.9993, D(x): 0.96, D(G(z)): 0.03
Epoch [369/400], Step [200/500], d_loss: 0.0788, g_loss: 5.1233, D(x): 0.96, D(G(z)): 0.03
Epoch [369/400], Step [400/500], d_loss: 0.1371, g_loss: 5.5916, D(x): 0.91, D(G(z)): 0.02
Epoch [370/400], Step [200/500], d_loss: 0.0204, g_loss: 6.6696, D(x): 0.99, D(G(z)): 0.01