<a href="https://colab.research.google.com/github/1995subhankar1995/AE-VAE-GAN-implementation/blob/master/GAN_torch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision

workers = 2
batchSize = 64
#parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
imageSize = 64
#parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
nz = 100
#parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
ngf = 64
#parser.add_argument('--ngf', type=int, default=64)
ndf = 64
#parser.add_argument('--ndf', type=int, default=64)
niter = 25
#parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
lr = 0.0002
#parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
beta1 = 0.5


ngpu = 1
netG = ''
netD = ''
outf = '.'
classes = 'bedroom'

nc = 3
cudnn.benchmark = True
if torch.cuda.is_available() and not torch.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
  

dataset = torchvision.datasets.CIFAR10(root='./data', download=True,
                           transform=transforms.Compose([
                               transforms.Resize(imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
 
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize,
                                         shuffle=True, num_workers=int(workers))

device = torch.device("cuda:0" if torch.cuda else "cpu")
ngpu = int(ngpu)
nz = int(nz)
ngf = int(ngf)
ndf = int(ndf)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output


netG = Generator(ngpu).to(device)
netG.apply(weights_init)

print(netG)


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)


netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)

print(netD)

criterion = nn.BCELoss()

fixed_noise = torch.randn(batchSize, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, niter, i, len(dataloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)
    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[18/25][475/782] Loss_D: 1.0489 Loss_G: 1.4157 D(x): 0.4759 D(G(z)): 0.1166 / 0.3228
[18/25][476/782] Loss_D: 0.5893 Loss_G: 2.7349 D(x): 0.9005 D(G(z)): 0.3545 / 0.0828
[18/25][477/782] Loss_D: 0.5528 Loss_G: 3.0373 D(x): 0.8127 D(G(z)): 0.2509 / 0.0610
[18/25][478/782] Loss_D: 0.7517 Loss_G: 1.5513 D(x): 0.5730 D(G(z)): 0.0824 / 0.2670
[18/25][479/782] Loss_D: 0.9986 Loss_G: 3.6612 D(x): 0.9251 D(G(z)): 0.5570 / 0.0341
[18/25][480/782] Loss_D: 0.6806 Loss_G: 2.1776 D(x): 0.5856 D(G(z)): 0.0713 / 0.1512
[18/25][481/782] Loss_D: 0.4569 Loss_G: 2.4586 D(x): 0.8663 D(G(z)): 0.2424 / 0.1130
[18/25][482/782] Loss_D: 0.6262 Loss_G: 3.5153 D(x): 0.8600 D(G(z)): 0.3256 / 0.0386
[18/25][483/782] Loss_D: 0.7050 Loss_G: 1.6997 D(x): 0.5923 D(G(z)): 0.0885 / 0.2213
[18/25][484/782] Loss_D: 0.4843 Loss_G: 2.7115 D(x): 0.9116 D(G(z)): 0.2919 / 0.0838
[18/25][485/782] Loss_D: 0.4736 Loss_G: 2.8640 D(x): 0.8145 D(G(z)): 0.2097 / 0.0738
