The code below was run on a GPU enabled server to produce results.


refernce : https://github.com/pytorch/examples/tree/master/dcgan
Official Pytorch Github page

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Mon Oct 20 17:33:41 2019

@author: Karthik Vikram
"""

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import math
import warnings
import numpy as np
from scipy import linalg
import torch
import torch.nn.functional as F
import torchvision.models as models
#from tqdm import tqdm_notebook as tqdm
import time
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', required=True)
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=128, help='input batch size')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=120, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0005, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')

opt = parser.parse_args()

try:
    os.makedirs(opt.outf)
except OSError:
    pass

def setManualSeed(opt):
    # Create manual seed for the seeding
    if opt.manualSeed is None:
        opt.manualSeed = random.randint(1, 10000)
    print("Random Seed: ", opt.manualSeed)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

setManualSeed(opt)

cudnn.benchmark = True

# Number of channels in the data (R,G,B) for the Cifar10 dataset
nc=3

def setData():
    # Data Preprocessing. Normalizing the pixels and tensors
    transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = dset.CIFAR10(root=opt.dataroot, train=True,
                                            download=True, transform=transform)
    assert trainset
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batchSize,
                                              shuffle=True, num_workers=int(opt.workers))
    
    testset = dset.CIFAR10(root=opt.dataroot, train=False,
                                           download=True, transform=transform)
    assert testset
    testloader = torch.utils.data.DataLoader(testset, batch_size=opt.batchSize,
                                             shuffle=False, num_workers=int(opt.workers))

    dataloader = torch.utils.data.DataLoader("cifar10", batch_size=opt.batchSize,
                                             shuffle=True, num_workers=int(opt.workers))
    return (trainset,trainloader,testset,testloader,dataloader)

# data iterators set for Cifar10 dataset
trainset,trainloader,testset,testloader,dataloader=setData()

# Setting the first GPU for usage
device = torch.device("cuda:0")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            # batch normalizing the convolution output
            nn.BatchNorm2d(ngf * 8),
            # adding relu activation to the layer
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        return output


netG = Generator(ngpu).to(device)
# Passing the weights initialization function for the Generator

netG.apply(weights_init)
if opt.netG != '':
    loaded_model=torch.load(opt.netD)
    netG.load_state_dict(loaded_model["gen_state_dict"])
print(netG)


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        return output.view(-1, 1).squeeze(1)

# Assigning the discriminator model to the GPU
netD = Discriminator(ngpu).to(device)

# Applying weight initialization to the discriminator
netD.apply(weights_init)
if opt.netD != '':
    loaded_model=torch.load(opt.netD)
    netD.load_state_dict(loaded_model["disc_state_dict"])
print(netD)

# Setting binary cross entropy loss to the model
criterion = nn.BCELoss()

fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
# Following the two scale update rule and setting the learning rate of generator
# twice that of the discriminator
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr*2.0, betas=(opt.beta1, 0.999))

class FID():
    '''
    Code for FID Calculation taken from TA's piazza post
    '''
    def __init__(self, cache_dir='./Cache', device='cpu', transform_input=True):
        os.environ["TORCH_HOME"] = "./Cache"
        self.device=device
        self.transform_input = transform_input
        self.InceptionV3 = models.inception_v3(pretrained=True, transform_input=False, aux_logits=False).to(device=self.device)
        self.InceptionV3.eval()
    
    def build_maps(self, x):
        # Resize to Fit InceptionV3
        if list(x.shape[-2:]) != [299,299]:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                x = F.interpolate(x, size=[299,299], mode='bilinear')
        # Transform Input to InceptionV3 Standards
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        # Run Through Partial InceptionV3 Model
        with torch.no_grad():
            # N x 3 x 299 x 299
            x = self.InceptionV3.Conv2d_1a_3x3(x)
            # N x 32 x 149 x 149
            x = self.InceptionV3.Conv2d_2a_3x3(x)
            # N x 32 x 147 x 147
            x = self.InceptionV3.Conv2d_2b_3x3(x)
            # N x 64 x 147 x 147
            x = F.max_pool2d(x, kernel_size=3, stride=2)
            # N x 64 x 73 x 73
            x = self.InceptionV3.Conv2d_3b_1x1(x)
            # N x 80 x 73 x 73
            x = self.InceptionV3.Conv2d_4a_3x3(x)
            # N x 192 x 71 x 71
            x = F.max_pool2d(x, kernel_size=3, stride=2)
            # N x 192 x 35 x 35
            x = self.InceptionV3.Mixed_5b(x)
            # N x 256 x 35 x 35
            x = self.InceptionV3.Mixed_5c(x)
            # N x 288 x 35 x 35
            x = self.InceptionV3.Mixed_5d(x)
            # N x 288 x 35 x 35
            x = self.InceptionV3.Mixed_6a(x)
            # N x 768 x 17 x 17
            x = self.InceptionV3.Mixed_6b(x)
            # N x 768 x 17 x 17
            x = self.InceptionV3.Mixed_6c(x)
            # N x 768 x 17 x 17
            x = self.InceptionV3.Mixed_6d(x)
            # N x 768 x 17 x 17
            x = self.InceptionV3.Mixed_6e(x)
            # N x 768 x 17 x 17
            x = self.InceptionV3.Mixed_7a(x)
            # N x 1280 x 8 x 8
            x = self.InceptionV3.Mixed_7b(x)
            # N x 2048 x 8 x 8
            x = self.InceptionV3.Mixed_7c(x)
            # N x 2048 x 8 x 8
            # Adaptive average pooling
            x = F.adaptive_avg_pool2d(x, (1, 1))
            # N x 2048 x 1 x 1
            return x
    
    def compute_fid(self, real_images, generated_images, batch_size=64):
        # Ensure Set Sizes are the Same
        assert(real_images.shape[0] == generated_images.shape[0])
        # Build Random Sampling Orders
        real_images = real_images[np.random.permutation(real_images.shape[0])]
        generated_images = generated_images[np.random.permutation(generated_images.shape[0])]
        # Lists of Maps per Batch
        real_maps = []
        generated_maps = []
        # Build Maps
#        for s in tqdm(range(int(math.ceil(real_images.shape[0]/batch_size))), desc='Evaluation', leave=False):
        for s in range(int(math.ceil(real_images.shape[0]/batch_size))):
            sidx = np.arange(batch_size*s, min(batch_size*(s+1), real_images.shape[0]))
            real_maps.append(self.build_maps(real_images[sidx].to(device=self.device)).detach().to(device='cpu'))
#            real_maps.append(self.build_maps(real_images[sidx]).detach())
            generated_maps.append(self.build_maps(generated_images[sidx].to(device=self.device)).detach().to(device='cpu'))
#            generated_maps.append(self.build_maps(generated_images[sidx]).detach())
        # Concatenate Maps
        real_maps = np.squeeze(torch.cat(real_maps).numpy())
        generated_maps = np.squeeze(torch.cat(generated_maps).numpy())
        # Calculate FID
        # Activation Statistics
        mu_g = np.mean(generated_maps, axis=0)
        mu_x = np.mean(real_maps, axis=0)
        sigma_g = np.cov(generated_maps, rowvar=False)
        sigma_x = np.cov(real_maps, rowvar=False)
        # Sum of Squared Differences
        ssd = np.sum((mu_g - mu_x)**2)
        # Square Root of Product of Covariances
        covmean = linalg.sqrtm(sigma_g.dot(sigma_x), disp=False)[0]
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        # Final FID Computation
        return ssd + np.trace(sigma_g + sigma_x - 2*covmean)

fid_obj = FID()

#start = time.time()
writer = ()
for epoch in range(opt.niter): #for each epoch
    start = time.time()
    for i, data in enumerate(trainloader, 0): #for each batch of trainset using trainloader
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train discriminator with real
        netD.zero_grad() #clear grad data for this iteration
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device) #label 1

        output_D = netD(real_cpu) #get discriminator prediction
        errD_real = criterion(output_D, label) #apply BCE loss function on output with real label values
        errD_real.backward() #back propagate the loss in discriminator 
        D_x = output_D.mean().item() #final discriminator loss on real data

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise) #generate noise data
        label.fill_(fake_label) 
        output = netD(fake.detach()) #get discriminator prediction
        errD_fake = criterion(output, label) #apply BCE loss function
        errD_fake.backward() #back propagate the loss in discriminator
        D_G_z1 = output.mean().item() #final discriminator loss on fake data
        errD = errD_real + errD_fake #final discriminator error
        optimizerD.step() #apply Adam optimizer

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake) #prediction of discriminator
        errG = criterion(output, label) #generator loss function
        errG.backward() #back propagate in generator
        D_G_z2 = output.mean().item() 
        optimizerG.step() #use adam optimizer

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, opt.niter, i, len(trainloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                
        #end of epoch operations
        if i == len(trainloader)-1: #if last iteration in epoch
            vutils.save_image(real_cpu, #save image of real sample
                    '%s/real_samples.png' % opt.outf,
                    normalize=True)
            fake = netG(fixed_noise) #generate fake samples using generator
            vutils.save_image(fake.detach(), #save fake images generated
                    '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                    normalize=True)

            #find FID
            # Number of images to compare for calculating FID score
            comparison_size = 1000
            #fake images for FID score
            noise = torch.randn(comparison_size, nz, 1, 1, device=device)
            fake_data = netG(noise)

            #real images picked randomly from test data for FID score
            real_data = None
            rand_sampler = torch.utils.data.RandomSampler(testset, num_samples=comparison_size, replacement=True)
            test_sampler = torch.utils.data.DataLoader(testset, batch_size=comparison_size, sampler=rand_sampler)
            for i,data in enumerate(test_sampler, 0):
              real_data = data[0]
              break
            #computing FID score
            fid_val = fid_obj.compute_fid(real_data, fake_data)
            print(fid_val)

            if epoch % 10 == 0: #save generator and discriminator states once every 50 epochs
              # do checkpointing
              torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
              torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))

            writer.add_scalar('Loss/real_discriminator', errD_real.item(), epoch)
            writer.add_scalar('Loss/fake_discriminator', errD_fake.item(), epoch)
            writer.add_scalar('Loss/discriminator', errD.item(), epoch)
            writer.add_scalar('Loss/generator', errG.item(), epoch)
            writer.add_scalar('FID', fid_val, epoch)

    print('Time per epoch: ', time.time()-start)


Run the cell below to see the output for DCGAN

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from _future_ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

batchSize = 64
imageSize = 32 #image resolution
nz = 100  #noise vector size
ngf = 64  #base number of generator filters
ndf = 64 #base number of descriminator filters
niter = 1 # number of epochs #TODO: change to 100000
lr = 0.0002 #learning rate
cuda = True #for parallel execution in GPU
ngpu = 1 #number of GPUs used
outf = '.' #output file path
manualseed = 350 #random seed value
nc = 3
device = torch.device("cuda:0" if cuda else "cpu")

#generator model definition
class Generator(nn.Module):
    def _init_(self, ngpu):
        super(Generator, self)._init_()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        ) 

    def forward(self, input):
        return self.main(input)


netG = Generator(ngpu).to(device) #initialize generator

netG.load_state_dict(torch.load('./DCGAN.pth'))
noise = torch.randn(64, nz, 1, 1, device=device)#generate noise data
fake = netG(noise) #generated fake images

vutils.save_image(fake.detach(),'./image_final.png',normalize=True)

img = mpimg.imread('image_final.png')
plt.imshow(img)
plt.show()