In [1]:
from __future__ import print_function
import argparse
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from utils import weights_init, compute_acc
from network import _netG, _netD, _netD_CIFAR10, _netG_CIFAR10
from folder import ImageFolder
import matplotlib.pyplot as plt
import matplotlib.colors as colors

In [2]:
dataset = 'cifar10'
dataroot = 'data'
batchSize = 64
workers = 2
imageSize = 32
nz = 110
ngf = 64
ndf = 64
niter = 200
lr = 0.0002
beta1 = 0.5
cuda = True
netG_ckpt = './output2/netG.pth'
netD_ckpt = './output2/netD.pth'
num_classes = 10
outf = './output2'
manualSeed = None
ngpu = 1

manualSeed = None

try:
    os.makedirs(outf)
except OSError:
    pass

if manualSeed is None:
    manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if cuda:
    torch.cuda.manual_seed_all(manualSeed)

cudnn.benchmark = True


dataset = dset.CIFAR10(root=dataroot, download=True,
                           transform=transforms.Compose([
                               transforms.Scale(imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))


assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                         shuffle=True, num_workers=int(workers))


Random Seed:  6747


  "please use transforms.Resize instead.")


Files already downloaded and verified


In [3]:
netG = _netG_CIFAR10(ngpu, nz)
netG.apply(weights_init)
if netG_ckpt != '':
    netG.load_state_dict(torch.load(netG_ckpt))

netD = _netD_CIFAR10(ngpu, num_classes)
netD.apply(weights_init)
if netD_ckpt != '':
    netD.load_state_dict(torch.load(netD_ckpt))


In [4]:
# loss functions
dis_criterion = nn.BCELoss()
aux_criterion = nn.NLLLoss()

# tensor placeholders
input = torch.FloatTensor(batchSize, 3, imageSize, imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
eval_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
dis_label = torch.FloatTensor(batchSize)
aux_label = torch.LongTensor(batchSize)
real_label = 1
fake_label = 0

In [5]:
netD.cuda()
netG.cuda()
dis_criterion.cuda()
aux_criterion.cuda()
input, dis_label, aux_label = input.cuda(), dis_label.cuda(), aux_label.cuda()
noise, eval_noise = noise.cuda(), eval_noise.cuda()

In [6]:
# define variables
input = Variable(input)
noise = Variable(noise)
eval_noise = Variable(eval_noise)
dis_label = Variable(dis_label)
aux_label = Variable(aux_label)
# noise for evaluation
eval_noise_ = np.random.normal(0, 1, (batchSize, nz))
eval_label = np.random.randint(0, num_classes, batchSize)
eval_onehot = np.zeros((batchSize, num_classes))
eval_onehot[np.arange(batchSize), eval_label] = 1
eval_noise_[np.arange(batchSize), :num_classes] = eval_onehot[np.arange(batchSize)]
eval_noise_ = (torch.from_numpy(eval_noise_))
with torch.no_grad():
    eval_noise.data.copy_(eval_noise_.view(batchSize, nz, 1, 1))
    
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

avg_loss_D = 0.0
avg_loss_G = 0.0
avg_loss_A = 0.0

In [7]:
for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu, label = data
        batch_size = real_cpu.size(0)
        if cuda:
            real_cpu = real_cpu.cuda()
        with torch.no_grad():
            input.resize_as_(real_cpu).copy_(real_cpu)
            dis_label.resize_(batch_size).fill_(real_label)
            aux_label.resize_(batch_size).copy_(label)
        dis_output, aux_output = netD(input)

        dis_errD_real = dis_criterion(dis_output, dis_label)
        aux_errD_real = aux_criterion(aux_output, aux_label)
        errD_real = dis_errD_real + aux_errD_real
        errD_real.backward()
        D_x = dis_output.data.mean()

        # compute the current classification accuracy
        accuracy = compute_acc(aux_output, aux_label)

        # train with fake
        with torch.no_grad() :
            noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
        label = np.random.randint(0, num_classes, batch_size)
        noise_ = np.random.normal(0, 1, (batch_size, nz))
        class_onehot = np.zeros((batch_size, num_classes))
        class_onehot[np.arange(batch_size), label] = 1
        noise_[np.arange(batch_size), :num_classes] = class_onehot[np.arange(batch_size)]
        noise_ = (torch.from_numpy(noise_))
        with torch.no_grad():
            noise.data.copy_(noise_.view(batch_size, nz, 1, 1))
            aux_label.resize_(batch_size).copy_(torch.from_numpy(label))
            dis_label.fill_(fake_label)


        fake = netG(noise)
        dis_output, aux_output = netD(fake.detach())
        dis_errD_fake = dis_criterion(dis_output, dis_label)
        aux_errD_fake = aux_criterion(aux_output, aux_label)
        errD_fake = dis_errD_fake + aux_errD_fake
        errD_fake.backward()
        D_G_z1 = dis_output.data.mean()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        dis_label.data.fill_(real_label)  # fake labels are real for generator cost
        dis_output, aux_output = netD(fake)
        dis_errG = dis_criterion(dis_output, dis_label)
        aux_errG = aux_criterion(aux_output, aux_label)
        errG = dis_errG + aux_errG
        errG.backward()
        D_G_z2 = dis_output.data.mean()
        optimizerG.step()

        # compute the average loss
        curr_iter = epoch * len(dataloader) + i
        all_loss_G = avg_loss_G * curr_iter
        all_loss_D = avg_loss_D * curr_iter
        all_loss_A = avg_loss_A * curr_iter
        all_loss_G += errG.data
        all_loss_D += errD.data
        all_loss_A += accuracy
        avg_loss_G = all_loss_G / (curr_iter + 1)
        avg_loss_D = all_loss_D / (curr_iter + 1)
        avg_loss_A = all_loss_A / (curr_iter + 1)

        print('[%d/%d][%d/%d] Loss_D: %.4f (%.4f) Loss_G: %.4f (%.4f) D(x): %.4f D(G(z)): %.4f / %.4f Acc: %.4f (%.4f)'
              % (epoch, niter, i, len(dataloader),
                 errD.data, avg_loss_D, errG.data, avg_loss_G, D_x, D_G_z1, D_G_z2, accuracy, avg_loss_A))
        if i % 100 == 0:
            vutils.save_image(
                real_cpu, '%s/real_samples.png' % outf)
            print('Label for eval = {}'.format(eval_label))
            fake = netG(eval_noise)
            vutils.save_image(
                fake.data,
                '%s/fake_samples_epoch_%03d.png' % (outf, epoch)
            )

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG.pth' % (outf))
    torch.save(netD.state_dict(), '%s/netD.pth' % (outf))

  classes = self.softmax(fc_aux)


[0/200][0/782] Loss_D: 1.2223 (1.2223) Loss_G: 0.6493 (0.6493) D(x): 0.4839 D(G(z)): 0.4740 / 0.4854 Acc: 10.9375 (10.9375)
Label for eval = [8 0 5 7 3 1 1 3 6 8 1 5 2 5 4 7 8 4 7 1 6 1 8 4 0 0 4 9 9 6 9 9 9 9 1 6 8
 4 7 7 0 1 6 7 8 5 8 9 7 3 5 7 9 0 8 0 4 0 4 9 1 1 0 2]
[0/200][1/782] Loss_D: 1.3637 (1.2930) Loss_G: 0.6573 (0.6533) D(x): 0.4432 D(G(z)): 0.4933 / 0.4897 Acc: 14.0625 (12.5000)
[0/200][2/782] Loss_D: 1.2296 (1.2718) Loss_G: 0.5942 (0.6336) D(x): 0.4999 D(G(z)): 0.4706 / 0.5185 Acc: 6.2500 (10.4167)
[0/200][3/782] Loss_D: 1.2018 (1.2543) Loss_G: 0.6878 (0.6471) D(x): 0.5381 D(G(z)): 0.5113 / 0.4740 Acc: 10.9375 (10.5469)
[0/200][4/782] Loss_D: 1.1671 (1.2369) Loss_G: 0.7248 (0.6627) D(x): 0.5110 D(G(z)): 0.4725 / 0.4597 Acc: 7.8125 (10.0000)
[0/200][5/782] Loss_D: 1.1239 (1.2181) Loss_G: 0.7567 (0.6783) D(x): 0.5155 D(G(z)): 0.4570 / 0.4460 Acc: 10.9375 (10.1562)
[0/200][6/782] Loss_D: 1.2276 (1.2194) Loss_G: 0.7516 (0.6888) D(x): 0.4924 D(G(z)): 0.4702 / 0.4505 Acc: 6.25

[0/200][65/782] Loss_D: 0.9368 (1.1310) Loss_G: 1.0523 (0.8118) D(x): 0.5278 D(G(z)): 0.3662 / 0.3514 Acc: 18.7500 (11.3400)
[0/200][66/782] Loss_D: 0.8762 (1.1272) Loss_G: 1.1791 (0.8173) D(x): 0.5941 D(G(z)): 0.3805 / 0.3176 Acc: 9.3750 (11.3106)
[0/200][67/782] Loss_D: 0.8134 (1.1226) Loss_G: 1.0396 (0.8206) D(x): 0.5865 D(G(z)): 0.3730 / 0.3596 Acc: 20.3125 (11.4430)
[0/200][68/782] Loss_D: 0.9311 (1.1198) Loss_G: 1.1486 (0.8253) D(x): 0.6144 D(G(z)): 0.4271 / 0.3153 Acc: 10.9375 (11.4357)
[0/200][69/782] Loss_D: 1.1190 (1.1198) Loss_G: 1.1575 (0.8301) D(x): 0.5186 D(G(z)): 0.3948 / 0.3216 Acc: 9.3750 (11.4063)
[0/200][70/782] Loss_D: 1.0520 (1.1188) Loss_G: 1.1093 (0.8340) D(x): 0.5210 D(G(z)): 0.3927 / 0.3423 Acc: 12.5000 (11.4217)
[0/200][71/782] Loss_D: 0.8930 (1.1157) Loss_G: 1.0759 (0.8374) D(x): 0.5709 D(G(z)): 0.3995 / 0.3490 Acc: 15.6250 (11.4800)
[0/200][72/782] Loss_D: 1.0797 (1.1152) Loss_G: 1.1654 (0.8419) D(x): 0.5296 D(G(z)): 0.4170 / 0.3152 Acc: 12.5000 (11.4940)
[0

[0/200][133/782] Loss_D: 0.9499 (1.0065) Loss_G: 0.8634 (1.1078) D(x): 0.6602 D(G(z)): 0.4886 / 0.4373 Acc: 10.9375 (13.8643)
[0/200][134/782] Loss_D: 1.2547 (1.0083) Loss_G: 1.0042 (1.1071) D(x): 0.5721 D(G(z)): 0.4997 / 0.4095 Acc: 10.9375 (13.8426)
[0/200][135/782] Loss_D: 1.3457 (1.0108) Loss_G: 0.8752 (1.1054) D(x): 0.4974 D(G(z)): 0.4968 / 0.4416 Acc: 23.4375 (13.9131)
[0/200][136/782] Loss_D: 1.2201 (1.0123) Loss_G: 0.5420 (1.1013) D(x): 0.5746 D(G(z)): 0.5308 / 0.5318 Acc: 15.6250 (13.9256)
[0/200][137/782] Loss_D: 1.2752 (1.0142) Loss_G: 0.5966 (1.0976) D(x): 0.5112 D(G(z)): 0.4829 / 0.5094 Acc: 12.5000 (13.9153)
[0/200][138/782] Loss_D: 1.2931 (1.0162) Loss_G: 0.7624 (1.0952) D(x): 0.5913 D(G(z)): 0.5463 / 0.4796 Acc: 10.9375 (13.8939)
[0/200][139/782] Loss_D: 1.3378 (1.0185) Loss_G: 0.6210 (1.0918) D(x): 0.5109 D(G(z)): 0.5345 / 0.4879 Acc: 17.1875 (13.9174)
[0/200][140/782] Loss_D: 1.3351 (1.0208) Loss_G: 0.6462 (1.0886) D(x): 0.5119 D(G(z)): 0.5329 / 0.4891 Acc: 17.1875 (1

[0/200][201/782] Loss_D: 0.9540 (1.0342) Loss_G: 0.7833 (0.9603) D(x): 0.6027 D(G(z)): 0.5008 / 0.4284 Acc: 21.8750 (15.9189)
[0/200][202/782] Loss_D: 0.9998 (1.0341) Loss_G: 0.9304 (0.9602) D(x): 0.6214 D(G(z)): 0.4747 / 0.4234 Acc: 17.1875 (15.9252)
[0/200][203/782] Loss_D: 0.9639 (1.0337) Loss_G: 0.6728 (0.9588) D(x): 0.5616 D(G(z)): 0.4483 / 0.4688 Acc: 21.8750 (15.9544)
[0/200][204/782] Loss_D: 0.8140 (1.0327) Loss_G: 0.8668 (0.9583) D(x): 0.6520 D(G(z)): 0.4381 / 0.4152 Acc: 20.3125 (15.9756)
[0/200][205/782] Loss_D: 0.9259 (1.0321) Loss_G: 0.7333 (0.9572) D(x): 0.5579 D(G(z)): 0.4241 / 0.4519 Acc: 18.7500 (15.9891)
[0/200][206/782] Loss_D: 0.9504 (1.0317) Loss_G: 0.6289 (0.9556) D(x): 0.5688 D(G(z)): 0.4294 / 0.4809 Acc: 15.6250 (15.9873)
[0/200][207/782] Loss_D: 1.0739 (1.0320) Loss_G: 0.8313 (0.9550) D(x): 0.5996 D(G(z)): 0.4605 / 0.4360 Acc: 15.6250 (15.9856)
[0/200][208/782] Loss_D: 0.8556 (1.0311) Loss_G: 0.7899 (0.9542) D(x): 0.6231 D(G(z)): 0.4638 / 0.4326 Acc: 23.4375 (1

[0/200][270/782] Loss_D: 1.1297 (1.0297) Loss_G: 1.0583 (0.9277) D(x): 0.5590 D(G(z)): 0.4728 / 0.3581 Acc: 12.5000 (16.7205)
[0/200][271/782] Loss_D: 1.2658 (1.0306) Loss_G: 0.8841 (0.9275) D(x): 0.4860 D(G(z)): 0.4933 / 0.4019 Acc: 20.3125 (16.7337)
[0/200][272/782] Loss_D: 1.0554 (1.0307) Loss_G: 0.9756 (0.9277) D(x): 0.5265 D(G(z)): 0.4508 / 0.3830 Acc: 18.7500 (16.7411)
[0/200][273/782] Loss_D: 1.0814 (1.0309) Loss_G: 1.0348 (0.9281) D(x): 0.6258 D(G(z)): 0.5240 / 0.3615 Acc: 14.0625 (16.7313)
[0/200][274/782] Loss_D: 1.0391 (1.0309) Loss_G: 1.1321 (0.9288) D(x): 0.5452 D(G(z)): 0.4597 / 0.3222 Acc: 15.6250 (16.7273)
[0/200][275/782] Loss_D: 1.1359 (1.0313) Loss_G: 0.9205 (0.9288) D(x): 0.4697 D(G(z)): 0.4229 / 0.3841 Acc: 25.0000 (16.7572)
[0/200][276/782] Loss_D: 1.0706 (1.0314) Loss_G: 0.9288 (0.9288) D(x): 0.5220 D(G(z)): 0.4442 / 0.3806 Acc: 17.1875 (16.7588)
[0/200][277/782] Loss_D: 1.0727 (1.0316) Loss_G: 1.0560 (0.9293) D(x): 0.5949 D(G(z)): 0.5251 / 0.3472 Acc: 21.8750 (1

[0/200][337/782] Loss_D: 1.0637 (1.0428) Loss_G: 0.7468 (0.9011) D(x): 0.5171 D(G(z)): 0.4899 / 0.4261 Acc: 20.3125 (17.6359)
[0/200][338/782] Loss_D: 1.0847 (1.0429) Loss_G: 0.7634 (0.9007) D(x): 0.5133 D(G(z)): 0.4924 / 0.4285 Acc: 21.8750 (17.6484)
[0/200][339/782] Loss_D: 0.9119 (1.0425) Loss_G: 0.8033 (0.9004) D(x): 0.5219 D(G(z)): 0.4494 / 0.4070 Acc: 26.5625 (17.6746)
[0/200][340/782] Loss_D: 0.9922 (1.0424) Loss_G: 0.6862 (0.8998) D(x): 0.5209 D(G(z)): 0.4579 / 0.4569 Acc: 21.8750 (17.6870)
[0/200][341/782] Loss_D: 1.2961 (1.0431) Loss_G: 0.8276 (0.8996) D(x): 0.5114 D(G(z)): 0.4899 / 0.4234 Acc: 9.3750 (17.6626)
[0/200][342/782] Loss_D: 1.0706 (1.0432) Loss_G: 0.7802 (0.8992) D(x): 0.5214 D(G(z)): 0.4490 / 0.4549 Acc: 21.8750 (17.6749)
[0/200][343/782] Loss_D: 1.1115 (1.0434) Loss_G: 0.6864 (0.8986) D(x): 0.4679 D(G(z)): 0.4485 / 0.4603 Acc: 20.3125 (17.6826)
[0/200][344/782] Loss_D: 1.0580 (1.0435) Loss_G: 0.8933 (0.8986) D(x): 0.5332 D(G(z)): 0.4581 / 0.4032 Acc: 21.8750 (17

[0/200][402/782] Loss_D: 1.0677 (1.0289) Loss_G: 1.3026 (0.9289) D(x): 0.6047 D(G(z)): 0.5131 / 0.2962 Acc: 21.8750 (17.7419)
[0/200][403/782] Loss_D: 0.9985 (1.0288) Loss_G: 1.3212 (0.9299) D(x): 0.5657 D(G(z)): 0.4051 / 0.2983 Acc: 18.7500 (17.7444)
[0/200][404/782] Loss_D: 0.8739 (1.0284) Loss_G: 1.2066 (0.9306) D(x): 0.6105 D(G(z)): 0.4385 / 0.3092 Acc: 25.0000 (17.7623)
[0/200][405/782] Loss_D: 1.1820 (1.0288) Loss_G: 0.8995 (0.9305) D(x): 0.5287 D(G(z)): 0.4762 / 0.3942 Acc: 15.6250 (17.7571)
[0/200][406/782] Loss_D: 0.9898 (1.0287) Loss_G: 1.0172 (0.9307) D(x): 0.6412 D(G(z)): 0.4830 / 0.3801 Acc: 14.0625 (17.7480)
[0/200][407/782] Loss_D: 0.9087 (1.0284) Loss_G: 1.1723 (0.9313) D(x): 0.6571 D(G(z)): 0.4569 / 0.3262 Acc: 10.9375 (17.7313)
[0/200][408/782] Loss_D: 0.9814 (1.0283) Loss_G: 1.0709 (0.9316) D(x): 0.5617 D(G(z)): 0.4228 / 0.3456 Acc: 10.9375 (17.7147)
[0/200][409/782] Loss_D: 1.0015 (1.0282) Loss_G: 1.0737 (0.9320) D(x): 0.5769 D(G(z)): 0.3617 / 0.3773 Acc: 7.8125 (17

[0/200][470/782] Loss_D: 0.7486 (1.0213) Loss_G: 1.0125 (0.9458) D(x): 0.5587 D(G(z)): 0.3339 / 0.3367 Acc: 17.1875 (17.8377)
[0/200][471/782] Loss_D: 0.7171 (1.0206) Loss_G: 1.4199 (0.9468) D(x): 0.6383 D(G(z)): 0.4301 / 0.2517 Acc: 25.0000 (17.8529)
[0/200][472/782] Loss_D: 0.8043 (1.0202) Loss_G: 1.3848 (0.9477) D(x): 0.5864 D(G(z)): 0.3246 / 0.2688 Acc: 9.3750 (17.8350)
[0/200][473/782] Loss_D: 0.7519 (1.0196) Loss_G: 1.1448 (0.9481) D(x): 0.6461 D(G(z)): 0.3720 / 0.3064 Acc: 10.9375 (17.8204)
[0/200][474/782] Loss_D: 0.7393 (1.0190) Loss_G: 1.1287 (0.9485) D(x): 0.5811 D(G(z)): 0.3281 / 0.3443 Acc: 17.1875 (17.8191)
[0/200][475/782] Loss_D: 0.7835 (1.0185) Loss_G: 1.2553 (0.9491) D(x): 0.7041 D(G(z)): 0.3885 / 0.2958 Acc: 6.2500 (17.7948)
[0/200][476/782] Loss_D: 0.6452 (1.0177) Loss_G: 1.3120 (0.9499) D(x): 0.6701 D(G(z)): 0.3835 / 0.2639 Acc: 10.9375 (17.7804)
[0/200][477/782] Loss_D: 0.7541 (1.0172) Loss_G: 1.0918 (0.9502) D(x): 0.5349 D(G(z)): 0.3201 / 0.2997 Acc: 15.6250 (17.

[0/200][535/782] Loss_D: 1.8989 (1.0082) Loss_G: 0.9642 (0.9669) D(x): 0.3016 D(G(z)): 0.5175 / 0.3705 Acc: 23.4375 (17.8434)
[0/200][536/782] Loss_D: 1.5620 (1.0092) Loss_G: 0.4927 (0.9660) D(x): 0.3550 D(G(z)): 0.5006 / 0.5330 Acc: 15.6250 (17.8393)
[0/200][537/782] Loss_D: 1.4134 (1.0100) Loss_G: 0.7351 (0.9655) D(x): 0.5638 D(G(z)): 0.6032 / 0.4685 Acc: 17.1875 (17.8381)
[0/200][538/782] Loss_D: 1.1039 (1.0102) Loss_G: 1.3891 (0.9663) D(x): 0.5747 D(G(z)): 0.4898 / 0.2563 Acc: 15.6250 (17.8340)
[0/200][539/782] Loss_D: 1.2805 (1.0107) Loss_G: 1.0877 (0.9665) D(x): 0.4005 D(G(z)): 0.3934 / 0.3465 Acc: 17.1875 (17.8328)
[0/200][540/782] Loss_D: 0.9771 (1.0106) Loss_G: 0.9491 (0.9665) D(x): 0.5446 D(G(z)): 0.4419 / 0.3526 Acc: 15.6250 (17.8287)
[0/200][541/782] Loss_D: 0.8887 (1.0104) Loss_G: 1.1333 (0.9668) D(x): 0.5799 D(G(z)): 0.4264 / 0.3353 Acc: 17.1875 (17.8275)
[0/200][542/782] Loss_D: 0.8954 (1.0102) Loss_G: 0.9546 (0.9668) D(x): 0.5617 D(G(z)): 0.4331 / 0.3708 Acc: 20.3125 (1

Label for eval = [8 0 5 7 3 1 1 3 6 8 1 5 2 5 4 7 8 4 7 1 6 1 8 4 0 0 4 9 9 6 9 9 9 9 1 6 8
 4 7 7 0 1 6 7 8 5 8 9 7 3 5 7 9 0 8 0 4 0 4 9 1 1 0 2]
[0/200][601/782] Loss_D: 0.8917 (1.0017) Loss_G: 0.8406 (0.9624) D(x): 0.5983 D(G(z)): 0.4728 / 0.3952 Acc: 21.8750 (18.2283)
[0/200][602/782] Loss_D: 0.7815 (1.0013) Loss_G: 0.9335 (0.9624) D(x): 0.5742 D(G(z)): 0.4310 / 0.3602 Acc: 29.6875 (18.2473)
[0/200][603/782] Loss_D: 0.7823 (1.0010) Loss_G: 1.0680 (0.9625) D(x): 0.5850 D(G(z)): 0.4549 / 0.3393 Acc: 32.8125 (18.2714)
[0/200][604/782] Loss_D: 0.8910 (1.0008) Loss_G: 0.8755 (0.9624) D(x): 0.5606 D(G(z)): 0.4170 / 0.4255 Acc: 21.8750 (18.2774)
[0/200][605/782] Loss_D: 0.7583 (1.0004) Loss_G: 1.4049 (0.9631) D(x): 0.6030 D(G(z)): 0.4000 / 0.2474 Acc: 20.3125 (18.2807)
[0/200][606/782] Loss_D: 0.9165 (1.0003) Loss_G: 1.3153 (0.9637) D(x): 0.5217 D(G(z)): 0.3789 / 0.2751 Acc: 20.3125 (18.2841)
[0/200][607/782] Loss_D: 0.9487 (1.0002) Loss_G: 1.1429 (0.9640) D(x): 0.5574 D(G(z)): 0.3923 / 

[0/200][668/782] Loss_D: 1.0136 (0.9878) Loss_G: 0.7976 (0.9738) D(x): 0.4582 D(G(z)): 0.4461 / 0.3703 Acc: 20.3125 (18.4510)
[0/200][669/782] Loss_D: 0.7961 (0.9876) Loss_G: 1.1603 (0.9740) D(x): 0.6111 D(G(z)): 0.4320 / 0.3061 Acc: 20.3125 (18.4538)
[0/200][670/782] Loss_D: 0.6854 (0.9871) Loss_G: 1.2252 (0.9744) D(x): 0.5970 D(G(z)): 0.4130 / 0.2731 Acc: 23.4375 (18.4613)
[0/200][671/782] Loss_D: 0.5253 (0.9864) Loss_G: 1.3748 (0.9750) D(x): 0.6491 D(G(z)): 0.3396 / 0.2698 Acc: 25.0000 (18.4710)
[0/200][672/782] Loss_D: 0.6459 (0.9859) Loss_G: 1.3277 (0.9755) D(x): 0.6324 D(G(z)): 0.3886 / 0.2490 Acc: 21.8750 (18.4760)
[0/200][673/782] Loss_D: 0.5274 (0.9852) Loss_G: 1.2222 (0.9759) D(x): 0.6389 D(G(z)): 0.3216 / 0.2885 Acc: 21.8750 (18.4811)
[0/200][674/782] Loss_D: 0.8635 (0.9850) Loss_G: 1.1088 (0.9761) D(x): 0.6156 D(G(z)): 0.4130 / 0.3395 Acc: 18.7500 (18.4815)
[0/200][675/782] Loss_D: 0.9010 (0.9849) Loss_G: 1.1172 (0.9763) D(x): 0.6823 D(G(z)): 0.5221 / 0.2946 Acc: 18.7500 (1

[0/200][734/782] Loss_D: 0.8575 (0.9792) Loss_G: 1.0596 (0.9680) D(x): 0.5738 D(G(z)): 0.4563 / 0.3010 Acc: 20.3125 (18.8116)
[0/200][735/782] Loss_D: 0.9222 (0.9791) Loss_G: 0.8637 (0.9678) D(x): 0.4318 D(G(z)): 0.3061 / 0.3610 Acc: 20.3125 (18.8137)
[0/200][736/782] Loss_D: 0.6538 (0.9787) Loss_G: 0.6170 (0.9674) D(x): 0.6325 D(G(z)): 0.4429 / 0.4393 Acc: 21.8750 (18.8178)
[0/200][737/782] Loss_D: 0.6295 (0.9782) Loss_G: 0.8030 (0.9671) D(x): 0.5964 D(G(z)): 0.4099 / 0.3744 Acc: 17.1875 (18.8156)
[0/200][738/782] Loss_D: 0.7126 (0.9779) Loss_G: 1.2450 (0.9675) D(x): 0.6639 D(G(z)): 0.4583 / 0.2636 Acc: 20.3125 (18.8177)
[0/200][739/782] Loss_D: 0.7434 (0.9775) Loss_G: 1.0331 (0.9676) D(x): 0.5322 D(G(z)): 0.3338 / 0.3512 Acc: 25.0000 (18.8260)
[0/200][740/782] Loss_D: 0.8087 (0.9773) Loss_G: 0.8114 (0.9674) D(x): 0.5670 D(G(z)): 0.4462 / 0.3811 Acc: 23.4375 (18.8322)
[0/200][741/782] Loss_D: 0.5197 (0.9767) Loss_G: 0.8280 (0.9672) D(x): 0.6397 D(G(z)): 0.3717 / 0.3779 Acc: 20.3125 (1

[1/200][21/782] Loss_D: 0.9380 (0.9670) Loss_G: 0.6686 (0.9598) D(x): 0.6250 D(G(z)): 0.5469 / 0.4270 Acc: 20.3125 (19.1484)
[1/200][22/782] Loss_D: 1.0540 (0.9672) Loss_G: 0.7203 (0.9595) D(x): 0.4488 D(G(z)): 0.4235 / 0.4235 Acc: 20.3125 (19.1498)
[1/200][23/782] Loss_D: 0.8453 (0.9670) Loss_G: 0.5723 (0.9590) D(x): 0.5467 D(G(z)): 0.4790 / 0.4806 Acc: 25.0000 (19.1571)
[1/200][24/782] Loss_D: 0.9374 (0.9670) Loss_G: 0.4873 (0.9584) D(x): 0.5795 D(G(z)): 0.5772 / 0.4836 Acc: 28.1250 (19.1682)
[1/200][25/782] Loss_D: 0.7582 (0.9667) Loss_G: 0.8685 (0.9583) D(x): 0.5168 D(G(z)): 0.4390 / 0.3457 Acc: 26.5625 (19.1774)
[1/200][26/782] Loss_D: 0.8943 (0.9666) Loss_G: 0.7232 (0.9580) D(x): 0.4892 D(G(z)): 0.4594 / 0.3825 Acc: 26.5625 (19.1865)
[1/200][27/782] Loss_D: 1.0436 (0.9667) Loss_G: 0.7312 (0.9577) D(x): 0.4740 D(G(z)): 0.4762 / 0.4126 Acc: 23.4375 (19.1917)
[1/200][28/782] Loss_D: 0.8203 (0.9665) Loss_G: 0.7086 (0.9574) D(x): 0.5815 D(G(z)): 0.5113 / 0.4194 Acc: 29.6875 (19.2047)


[1/200][88/782] Loss_D: 0.6100 (0.9594) Loss_G: 0.9032 (0.9426) D(x): 0.6205 D(G(z)): 0.3906 / 0.3562 Acc: 21.8750 (19.4209)
[1/200][89/782] Loss_D: 0.7604 (0.9592) Loss_G: 0.7353 (0.9424) D(x): 0.5209 D(G(z)): 0.4259 / 0.4105 Acc: 31.2500 (19.4345)
[1/200][90/782] Loss_D: 0.7012 (0.9589) Loss_G: 0.6170 (0.9420) D(x): 0.5658 D(G(z)): 0.4333 / 0.4366 Acc: 26.5625 (19.4427)
[1/200][91/782] Loss_D: 0.6186 (0.9585) Loss_G: 0.7280 (0.9418) D(x): 0.6342 D(G(z)): 0.4136 / 0.4098 Acc: 18.7500 (19.4419)
[1/200][92/782] Loss_D: 1.0329 (0.9586) Loss_G: 0.5764 (0.9413) D(x): 0.4919 D(G(z)): 0.4841 / 0.4400 Acc: 20.3125 (19.4429)
[1/200][93/782] Loss_D: 0.8274 (0.9584) Loss_G: 0.5015 (0.9408) D(x): 0.4938 D(G(z)): 0.4306 / 0.4707 Acc: 26.5625 (19.4510)
[1/200][94/782] Loss_D: 0.8005 (0.9582) Loss_G: 0.9182 (0.9408) D(x): 0.6084 D(G(z)): 0.4972 / 0.3447 Acc: 21.8750 (19.4537)
[1/200][95/782] Loss_D: 0.6358 (0.9579) Loss_G: 0.8663 (0.9407) D(x): 0.6054 D(G(z)): 0.4499 / 0.3350 Acc: 26.5625 (19.4618)


[1/200][154/782] Loss_D: 1.1880 (0.9489) Loss_G: 0.4504 (0.9258) D(x): 0.4870 D(G(z)): 0.5814 / 0.4999 Acc: 21.8750 (19.6838)
[1/200][155/782] Loss_D: 1.0359 (0.9490) Loss_G: 0.6082 (0.9255) D(x): 0.5427 D(G(z)): 0.5347 / 0.4358 Acc: 18.7500 (19.6828)
[1/200][156/782] Loss_D: 0.9781 (0.9490) Loss_G: 0.7910 (0.9253) D(x): 0.4726 D(G(z)): 0.4602 / 0.3957 Acc: 23.4375 (19.6868)
[1/200][157/782] Loss_D: 1.1889 (0.9493) Loss_G: 0.5578 (0.9249) D(x): 0.4867 D(G(z)): 0.4773 / 0.4820 Acc: 9.3750 (19.6759)
[1/200][158/782] Loss_D: 0.9651 (0.9493) Loss_G: 0.4238 (0.9244) D(x): 0.5279 D(G(z)): 0.5383 / 0.4824 Acc: 15.6250 (19.6716)
[1/200][159/782] Loss_D: 1.0187 (0.9493) Loss_G: 0.6062 (0.9241) D(x): 0.5039 D(G(z)): 0.4798 / 0.4404 Acc: 20.3125 (19.6722)
[1/200][160/782] Loss_D: 0.6281 (0.9490) Loss_G: 0.4149 (0.9235) D(x): 0.5171 D(G(z)): 0.4024 / 0.4690 Acc: 23.4375 (19.6762)
[1/200][161/782] Loss_D: 0.6723 (0.9487) Loss_G: 0.7330 (0.9233) D(x): 0.6082 D(G(z)): 0.5070 / 0.3684 Acc: 26.5625 (19

[1/200][218/782] Loss_D: 0.6224 (0.9352) Loss_G: 0.6971 (0.9055) D(x): 0.6305 D(G(z)): 0.4694 / 0.3807 Acc: 18.7500 (19.8973)
[1/200][219/782] Loss_D: 0.6275 (0.9349) Loss_G: 0.7472 (0.9053) D(x): 0.5323 D(G(z)): 0.4096 / 0.3676 Acc: 28.1250 (19.9055)
[1/200][220/782] Loss_D: 0.9234 (0.9349) Loss_G: 0.6329 (0.9051) D(x): 0.5384 D(G(z)): 0.5026 / 0.4089 Acc: 21.8750 (19.9075)
[1/200][221/782] Loss_D: 0.8493 (0.9348) Loss_G: 0.8294 (0.9050) D(x): 0.5875 D(G(z)): 0.4777 / 0.3864 Acc: 15.6250 (19.9032)
[1/200][222/782] Loss_D: 0.8516 (0.9347) Loss_G: 0.6145 (0.9047) D(x): 0.4690 D(G(z)): 0.4428 / 0.4075 Acc: 21.8750 (19.9052)
[1/200][223/782] Loss_D: 0.9829 (0.9348) Loss_G: 0.6954 (0.9045) D(x): 0.5239 D(G(z)): 0.4702 / 0.4480 Acc: 23.4375 (19.9087)
[1/200][224/782] Loss_D: 0.6550 (0.9345) Loss_G: 0.6749 (0.9043) D(x): 0.5405 D(G(z)): 0.4195 / 0.3546 Acc: 18.7500 (19.9075)
[1/200][225/782] Loss_D: 1.0293 (0.9346) Loss_G: 0.5813 (0.9039) D(x): 0.4670 D(G(z)): 0.5175 / 0.4289 Acc: 23.4375 (1

KeyboardInterrupt: 

In [7]:
def generate_img(G, num_samples) :
    with torch.no_grad() :
        noise = torch.FloatTensor(num_samples, nz, 1, 1).normal_(0, 1)
        noise = noise.cuda()
        noise = Variable(noise)
        noise_ = np.random.normal(0, 1, (num_samples, nz))
        random_label = np.random.randint(0, 10, num_samples)
    #     label = np.random.randint(0, nb_label, batchSize)
        label_onehot = np.zeros((num_samples, num_classes))
        label_onehot[np.arange(num_samples), random_label] = 1
        noise_[np.arange(num_samples), :num_classes] = label_onehot[np.arange(num_samples)]
        noise_ = (torch.from_numpy(noise_))
        noise_ = noise_.resize_(num_samples, nz, 1, 1)
        noise.copy_(noise_)
        imgs = G(noise)
        
        
        return imgs, random_label
    
def generate_single_img(G, label) :
    with torch.no_grad() :
        noise = torch.FloatTensor(1, nz, 1, 1).normal_(0, 1)
        noise = noise.cuda()
        noise = Variable(noise)
        noise_ = np.random.normal(0, 1, (1, nz))
        label_onehot = np.zeros(num_classes)
        label_onehot[label] = 1
        noise_[:,:num_classes] = label_onehot
        noise_ = (torch.from_numpy(noise_))
        noise_ = noise_.resize_(1, nz, 1, 1)
        noise.copy_(noise_)
        imgs = G(noise)
        
        
        return imgs
    
def plt_images(images, rows, cols):
    fig = plt.figure(figsize=(cols,rows))
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.05, hspace=0.05)
    for i, x in enumerate(images[:cols * rows]):
        plt.subplot(rows, cols, i + 1)
        plt.axis('off')
        plt.imshow(np.rollaxis(x, 0, 3), norm=colors.NoNorm())
    return fig

def plt_single_gen_with_fixed_noise(G, img_num, class_gen):
    plt_labels_sing = Variable(torch.zeros(100, 10).scatter_(1, torch.LongTensor([[class_gen]*10 for i in range(10)]).view(-1, 1), 1), volatile = True)
    plt_z_sing = Variable(torch.randn(100, 100), volatile = True)
    imgs = generate_single_img(G, class_gen)
    fig = plt_images(imgs.data.cpu().numpy(), 1, 1)
    if not os.path.exists(outf + '/img_gen'):
        os.makedirs(outf + '/img_gen')
    fig.savefig(outf + '/img_gen/%s/%s.png' % ((str(class_gen).zfill(3)), str(img_num).zfill(5)) )
    plt.close(fig)

In [8]:
# imgs, labels = generate_img(netG, 2)
# print(imgs.shape)
# fig = plt_images(imgs.data.cpu().numpy(), 2,2)
# fig = plt_images(G(plt_z_sing, plt_labels_sing).data.cpu().numpy(), 1, 1)
# if not os.path.exists(outf + '/img_gen'):
#     os.makedirs(outf + '/img_gen')
# fig.savefig(prefix + '/img_gen/%s/%s.png' % ((str(class_gen).zfill(3)), str(img_num).zfill(5)) )
# plt.close(fig)
generate_single_img(netG, 1).shape

torch.Size([1, 3, 32, 32])

In [10]:
num_images = 1000
if not os.path.exists(outf + '/img_gen'):
    os.makedirs(outf + '/img_gen')
for j in range(10):
    os.mkdir(outf + '/img_gen/' + str(j).zfill(3))
    for i in range(num_images):
        plt_single_gen_with_fixed_noise(netG,i, j)

