In [1]:
# Importing the libraries
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

In [6]:
batchSize = 10
imageSize = 64 

transform = transforms.Compose([transforms.Scale(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
dataset = dset.CIFAR10(root = 'data', download = True, transform = transform) 
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2) 

  "please use transforms.Resize instead.")


Files already downloaded and verified


In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [8]:
class G(nn.Module):

    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
            nn.Tanh()
        )

    def forward(self, input):
        output = self.main(input)
        return output

class D(nn.Module):

    def __init__(self):
        super(D, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias = False),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(64, 128, 4, 2, 1, bias = False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(128, 256, 4, 2, 1, bias = False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(256, 512, 4, 2, 1, bias = False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace = True),
            nn.Conv2d(512, 1, 4, 1, 0, bias = False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

In [9]:
netG = G()
netG.apply(weights_init)
netD = D()
netD.apply(weights_init)

criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))

In [10]:
for epoch in range(25):

    for i, data in enumerate(dataloader, 0):
        
        netD.zero_grad()
        
        real, _ = data
        input = Variable(real)
        target = Variable(torch.ones(input.size()[0]))
        output = netD(input)
        errD_real = criterion(output, target)
        
        noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
        fake = netG(noise)
        target = Variable(torch.zeros(input.size()[0]))
        output = netD(fake.detach())
        errD_fake = criterion(output, target)
        
        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0]))
        output = netD(fake)
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        if i % 100 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch, 25, i, len(dataloader), errD.data[0], errG.data[0]))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize = True)


[0/25][0/5000] Loss_D: 1.5725 Loss_G: 9.3841
[0/25][100/5000] Loss_D: 0.0032 Loss_G: 26.5622
[0/25][200/5000] Loss_D: 0.6174 Loss_G: 4.2051
[0/25][300/5000] Loss_D: 0.0734 Loss_G: 6.5412
[0/25][400/5000] Loss_D: 0.0488 Loss_G: 5.5111
[0/25][500/5000] Loss_D: 0.1495 Loss_G: 4.4456
[0/25][600/5000] Loss_D: 0.2491 Loss_G: 3.8915
[0/25][700/5000] Loss_D: 1.6698 Loss_G: 3.4403
[0/25][800/5000] Loss_D: 0.5714 Loss_G: 7.4145
[0/25][900/5000] Loss_D: 0.5977 Loss_G: 3.2059
[0/25][1000/5000] Loss_D: 0.3390 Loss_G: 5.2973
[0/25][1100/5000] Loss_D: 1.3243 Loss_G: 6.2740
[0/25][1200/5000] Loss_D: 0.6900 Loss_G: 3.1250
[0/25][1300/5000] Loss_D: 1.4859 Loss_G: 2.2865
[0/25][1400/5000] Loss_D: 2.4444 Loss_G: 1.2546
[0/25][1500/5000] Loss_D: 0.3536 Loss_G: 3.3790
[0/25][1600/5000] Loss_D: 0.7944 Loss_G: 3.2478
[0/25][1700/5000] Loss_D: 0.9626 Loss_G: 2.3153
[0/25][1800/5000] Loss_D: 0.5188 Loss_G: 4.1035
[0/25][1900/5000] Loss_D: 0.7688 Loss_G: 4.4745
[0/25][2000/5000] Loss_D: 1.8064 Loss_G: 1.5016
[0/

[3/25][2100/5000] Loss_D: 1.8931 Loss_G: 7.3209
[3/25][2200/5000] Loss_D: 0.1101 Loss_G: 4.2347
[3/25][2300/5000] Loss_D: 0.4938 Loss_G: 4.1591
[3/25][2400/5000] Loss_D: 0.1893 Loss_G: 4.0698
[3/25][2500/5000] Loss_D: 0.0197 Loss_G: 4.8704
[3/25][2600/5000] Loss_D: 0.8883 Loss_G: 5.4762
[3/25][2700/5000] Loss_D: 0.9855 Loss_G: 5.8143
[3/25][2800/5000] Loss_D: 0.1204 Loss_G: 3.5654
[3/25][2900/5000] Loss_D: 0.2924 Loss_G: 4.0579
[3/25][3000/5000] Loss_D: 0.2962 Loss_G: 2.3217
[3/25][3100/5000] Loss_D: 0.1507 Loss_G: 3.6321
[3/25][3200/5000] Loss_D: 0.0860 Loss_G: 3.9772
[3/25][3300/5000] Loss_D: 0.7060 Loss_G: 3.9178
[3/25][3400/5000] Loss_D: 0.2353 Loss_G: 3.6837
[3/25][3500/5000] Loss_D: 0.2634 Loss_G: 2.9947
[3/25][3600/5000] Loss_D: 0.2949 Loss_G: 3.0522
[3/25][3700/5000] Loss_D: 0.0464 Loss_G: 5.0162
[3/25][3800/5000] Loss_D: 1.4622 Loss_G: 6.6917
[3/25][3900/5000] Loss_D: 0.3568 Loss_G: 4.7483
[3/25][4000/5000] Loss_D: 1.1308 Loss_G: 4.2164
[3/25][4100/5000] Loss_D: 0.3692 Loss_G:

KeyboardInterrupt: 