In [1]:
import os
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.utils as vutils
%matplotlib inline

## Loading data

In [2]:
img_size = 64
batch_size=64
lr = 0.0002
beta1 = 0.5
niter= 25
outf= 'output'

dataset = datasets.CIFAR10( root = 'data',download=True,
                       transform=transforms.Compose([
                           transforms.Resize(img_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                       ]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size,
                                         shuffle=True)


Files already downloaded and verified


In [3]:
#Size of latnet vector
nz = 100
# Filter size of generator
ngf = 64
# Filter size of discriminator
ndf = 64
# Output image channels
nc = 3

# Network Initialization

In [4]:
def weights_inititialisation(m):
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

# Generator

In [5]:
class _net_generator(nn.Module):
    def __init__(self):
        super(_net_generator, self).__init__()

        self.main = nn.Sequential(
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output


net_generator = _net_generator()
net_generator.apply(weights_inititialisation)
print(net_generator)

_net_generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)


# Discriminator

In [6]:
class _net_discriminator(nn.Module):
    def __init__(self):
        super(_net_discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)


net_discriminator = _net_discriminator()
net_discriminator.apply(weights_inititialisation)
print(net_discriminator)

_net_discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)


# Defining loss functions

In [7]:
criterion = nn.BCELoss()

input = torch.FloatTensor(batch_size, 3, img_size, img_size)
noise = torch.FloatTensor(batch_size, nz, 1, 1)
fixed_noise = torch.FloatTensor(batch_size, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batch_size)
real_label = 1
fake_label = 0

In [8]:
if torch.cuda.is_available():
    net_discriminator.cuda()
    net_generator.cuda()
    criterion.cuda()
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

# Defining optimiser

In [9]:
fixed_noise = Variable(fixed_noise)

optimizer_discriminator = optim.Adam(net_discriminator.parameters(), lr, betas=(beta1, 0.95))
optimizer_generator = optim.Adam(net_generator.parameters(), lr, betas=(beta1, 0.95))

# Training 

In [None]:
for epoch in range(niter):
    for i, data in enumerate(dataloader, 0):
        # train with real
        net_discriminator.zero_grad()
        real_cpu, _ = data
        batch_size = real_cpu.size(0)
        if torch.cuda.is_available():
            real_cpu = real_cpu.cuda()
        input.resize_as_(real_cpu).copy_(real_cpu)
        label.resize_(batch_size).fill_(real_label)
        inputv = Variable(input)
        labelv = Variable(label)

        output = net_discriminator(inputv)
        err_discriminator_real = criterion(output, labelv)
        err_discriminator_real.backward()
        D_x = output.data.mean()

        noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
        noisev = Variable(noise)
        fake = net_generator(noisev)
        labelv = Variable(label.fill_(fake_label))
        output = net_discriminator(fake.detach())
        err_discriminator_fake = criterion(output, labelv)
        err_discriminator_fake.backward()
        D_G_z1 = output.data.mean()
        err_discriminator = err_discriminator_real + err_discriminator_fake
        optimizer_discriminator.step()

        net_generator.zero_grad()
        labelv = Variable(label.fill_(real_label))  # fake labels are real for generator cost
        output = net_discriminator(fake)
        err_generator = criterion(output, labelv)
        err_generator.backward()
        D_G_z2 = output.data.mean()
        optimizer_generator.step()

        print('[%d/%d][%d/%d] Loss_Discriminator: %.4f Loss_Generator: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, niter, i, len(dataloader),
                 err_discriminator.data[0], err_generator.data[0], D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(real_cpu,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = net_generator(fixed_noise)
            vutils.save_image(fake.data,
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)



[0/25][0/782] Loss_Discriminator: 1.5009 Loss_Generator: 5.5792 D(x): 0.4640 D(G(z)): 0.4075 / 0.0047
[0/25][1/782] Loss_Discriminator: 0.9016 Loss_Generator: 6.4903 D(x): 0.7805 D(G(z)): 0.4035 / 0.0020
[0/25][2/782] Loss_Discriminator: 0.8814 Loss_Generator: 6.0653 D(x): 0.7306 D(G(z)): 0.3064 / 0.0030
[0/25][3/782] Loss_Discriminator: 1.1415 Loss_Generator: 6.0359 D(x): 0.6943 D(G(z)): 0.3669 / 0.0032
[0/25][4/782] Loss_Discriminator: 1.1851 Loss_Generator: 7.0107 D(x): 0.7264 D(G(z)): 0.4358 / 0.0013
[0/25][5/782] Loss_Discriminator: 1.0506 Loss_Generator: 7.8329 D(x): 0.7478 D(G(z)): 0.4270 / 0.0006
[0/25][6/782] Loss_Discriminator: 0.9600 Loss_Generator: 7.1500 D(x): 0.6811 D(G(z)): 0.2328 / 0.0013
[0/25][7/782] Loss_Discriminator: 0.9689 Loss_Generator: 7.8842 D(x): 0.6859 D(G(z)): 0.3404 / 0.0006
[0/25][8/782] Loss_Discriminator: 0.5686 Loss_Generator: 8.9665 D(x): 0.8327 D(G(z)): 0.2671 / 0.0002
[0/25][9/782] Loss_Discriminator: 0.6097 Loss_Generator: 7.3556 D(x): 0.7349 D(G(z

[0/25][82/782] Loss_Discriminator: 0.1770 Loss_Generator: 6.1155 D(x): 0.8856 D(G(z)): 0.0194 / 0.0034
[0/25][83/782] Loss_Discriminator: 0.9941 Loss_Generator: 19.9823 D(x): 0.9596 D(G(z)): 0.5539 / 0.0000
[0/25][84/782] Loss_Discriminator: 0.9107 Loss_Generator: 19.8423 D(x): 0.6006 D(G(z)): 0.0000 / 0.0000
[0/25][85/782] Loss_Discriminator: 0.2054 Loss_Generator: 15.8387 D(x): 0.8631 D(G(z)): 0.0000 / 0.0000
[0/25][86/782] Loss_Discriminator: 0.0350 Loss_Generator: 11.2559 D(x): 0.9772 D(G(z)): 0.0001 / 0.0001
[0/25][87/782] Loss_Discriminator: 0.0350 Loss_Generator: 5.1225 D(x): 0.9916 D(G(z)): 0.0256 / 0.0112
[0/25][88/782] Loss_Discriminator: 1.4339 Loss_Generator: 15.1637 D(x): 0.9918 D(G(z)): 0.7105 / 0.0000
[0/25][89/782] Loss_Discriminator: 0.9277 Loss_Generator: 13.8700 D(x): 0.5732 D(G(z)): 0.0002 / 0.0000
[0/25][90/782] Loss_Discriminator: 0.2355 Loss_Generator: 9.7309 D(x): 0.8879 D(G(z)): 0.0015 / 0.0005
[0/25][91/782] Loss_Discriminator: 0.1773 Loss_Generator: 5.1003 D(

[0/25][161/782] Loss_Discriminator: 0.7112 Loss_Generator: 4.2318 D(x): 0.5936 D(G(z)): 0.0286 / 0.0207
[0/25][162/782] Loss_Discriminator: 0.5797 Loss_Generator: 4.9102 D(x): 0.8734 D(G(z)): 0.3239 / 0.0128
[0/25][163/782] Loss_Discriminator: 0.7331 Loss_Generator: 6.1861 D(x): 0.8023 D(G(z)): 0.3100 / 0.0050
[0/25][164/782] Loss_Discriminator: 0.9685 Loss_Generator: 2.1312 D(x): 0.5305 D(G(z)): 0.0603 / 0.1661
[0/25][165/782] Loss_Discriminator: 2.1062 Loss_Generator: 8.9309 D(x): 0.9290 D(G(z)): 0.8024 / 0.0005
[0/25][166/782] Loss_Discriminator: 1.4916 Loss_Generator: 6.7571 D(x): 0.4016 D(G(z)): 0.0095 / 0.0090
[0/25][167/782] Loss_Discriminator: 0.5314 Loss_Generator: 3.3582 D(x): 0.7148 D(G(z)): 0.0543 / 0.0930
[0/25][168/782] Loss_Discriminator: 0.9365 Loss_Generator: 5.5269 D(x): 0.9569 D(G(z)): 0.4592 / 0.0147
[0/25][169/782] Loss_Discriminator: 0.4258 Loss_Generator: 5.3197 D(x): 0.8146 D(G(z)): 0.1329 / 0.0085
[0/25][170/782] Loss_Discriminator: 0.3295 Loss_Generator: 4.105

In [None]:
mkdir output

In [None]:
ls -al output/

In [None]:
Image.open('output/real_samples.png')

In [None]:
Image.open('output/fake_samples_epoch_024.png')