In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel 
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable

In [2]:
batchsize = 64
imagesize = 64

In [None]:
transform = transforms.Compose([transforms.Scale(imagesize),transforms.ToTensor(),transforms.Normalize\
                               (mean=(0.485,0.456,0.406),std=(0.229,0.224,0.225) )] ) 

In [4]:
dataset = dset.CIFAR10(root='./data',download=True, transform=transform)

Files already downloaded and verified


In [5]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size= batchsize, shuffle= True, num_workers=4)


In [6]:
cuda0 = torch.device('cuda:0')
torch.set_default_dtype = cuda0

In [7]:
import numpy

In [8]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0,0.02).cuda()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0,0.02).cuda()
        m.bias.data.fill_(0).cuda()

In [9]:
class G(nn.Module):
    
    def __init__(self):
        super(G, self).__init__()
        self.main = nn.Sequential(
        nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
        nn.BatchNorm2d(512),
        nn.ReLU(True),
        nn.ConvTranspose2d(512,256,4,2,1, bias = False),
        nn.BatchNorm2d(256),
        nn.ReLU(True),
        nn.ConvTranspose2d(256,128,4,2,1, bias = False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128,32,4,2,1, bias = False),
        nn.BatchNorm2d(32),
        nn.ReLU(True),
        nn.ConvTranspose2d(32,3,4,2,1, bias = False),
        nn.Tanh()
        )
    def forward(self, input):
        return(self.main(input.to(cuda0)))

In [10]:
netG = G().to(cuda0)
netG.apply(weights_init)

G(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

In [11]:
class D(nn.Module):
    
    def __init__(self):
        super(D,self).__init__()
        self.main = nn.Sequential(
        nn.Conv2d(3, 64, 4, 2 , 1, bias= False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(64, 128, 4, 2, 1, bias= False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(128, 256, 4, 2, 1, bias=True),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(256, 512, 4, 2, 1, bias=True),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(512, 1, 4, 1, 0, bias=True),
        nn.Sigmoid()
        )
    def forward(self,input):
        output = self.main(input.to(cuda0))
        output = output.view(-1).to(cuda0)
        return output

In [12]:
netD = D().to(cuda0)
netD.apply(weights_init)

D(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

In [13]:
criterion= nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(),lr=0.0002, betas= (0.5,0.999))
optimizerG = optim.Adam(netG.parameters(),lr=0.0002, betas= (0.5,0.999))

In [15]:
for epoch in range(50):
    for i, data in enumerate(dataloader, 0):
        netD.zero_grad()
        real, _ = data 
        input = Variable(real)
        target = Variable(torch.ones(input.size()[0],device=cuda0))
        output = netD(input).cuda()
        
        errD_real = criterion(output, target)
        
        noise = Variable(torch.randn(input.size()[0], 100 , 1, 1, device=cuda0))
        fake= netG(noise).cuda()
        target = Variable(torch.zeros(input.size()[0],device=cuda0))
        
        output = netD(fake.detach())
        errD_fake = criterion(output, target)
        
        errD = errD_fake+ errD_real
        errD.backward()
        optimizerD.step()
        
        netG.zero_grad()
        target = Variable(torch.ones(input.size()[0],device=cuda0))
        output = netD(fake)
        
        errG = criterion(output, target)
        errG.backward()
        optimizerG.step()
        
        
        if i % 100 == 0:
            print((epoch, 25,i, len(dataloader), errD.item(), errG.item()))
            vutils.save_image(real, '%s/real_samples.png' % "./results", normalize=True)
            fake = netG(noise)
            vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results",epoch), normalize=True)
            
            
            
        
        

(0, 25, 0, 782, 0.04658062383532524, 6.782224655151367)
(0, 25, 100, 782, 0.21423789858818054, 3.8183646202087402)
(0, 25, 200, 782, 0.9487955570220947, 8.667064666748047)
(0, 25, 300, 782, 0.3661361038684845, 1.8020191192626953)
(0, 25, 400, 782, 0.46146145462989807, 3.708737850189209)
(0, 25, 500, 782, 0.7442079186439514, 4.517812252044678)
(0, 25, 600, 782, 0.006271081510931253, 5.5968217849731445)
(0, 25, 700, 782, 0.014883400872349739, 5.5108160972595215)
(1, 25, 0, 782, 0.38002851605415344, 9.019732475280762)
(1, 25, 100, 782, 0.21582162380218506, 4.184281826019287)
(1, 25, 200, 782, 0.643979012966156, 8.367136001586914)
(1, 25, 300, 782, 0.23294028639793396, 3.063279151916504)
(1, 25, 400, 782, 0.018846724182367325, 5.372779846191406)
(1, 25, 500, 782, 0.015436206944286823, 6.301811218261719)
(1, 25, 600, 782, 0.3203301429748535, 3.83286714553833)
(1, 25, 700, 782, 0.6672513484954834, 6.247830390930176)
(2, 25, 0, 782, 2.3371238708496094, 12.453275680541992)
(2, 25, 100, 782, 0.

(17, 25, 500, 782, 0.0001673525694059208, 9.14365291595459)
(17, 25, 600, 782, 0.001131592900492251, 7.166050910949707)
(17, 25, 700, 782, 0.0009347595041617751, 7.2957940101623535)
(18, 25, 0, 782, 0.0008367101545445621, 7.369678974151611)
(18, 25, 100, 782, 0.0003123607602901757, 8.457815170288086)
(18, 25, 200, 782, 0.0006827160832472146, 10.981554985046387)
(18, 25, 300, 782, 0.0003807727189268917, 9.261443138122559)
(18, 25, 400, 782, 0.00043745210859924555, 9.389287948608398)
(18, 25, 500, 782, 0.0003142166242469102, 8.811992645263672)
(18, 25, 600, 782, 0.0007097726338542998, 7.913066864013672)
(18, 25, 700, 782, 0.00015682335651945323, 9.02233600616455)
(19, 25, 0, 782, 3.8352345654857345e-06, 18.204914093017578)
(19, 25, 100, 782, 5.162989691598341e-05, 10.370983123779297)
(19, 25, 200, 782, 0.000557219609618187, 7.7516984939575195)
(19, 25, 300, 782, 7.952935266075656e-05, 9.732998847961426)
(19, 25, 400, 782, 0.00022744137095287442, 8.633962631225586)
(19, 25, 500, 782, 0.00

(34, 25, 400, 782, 9.313225746154785e-10, 18.06257438659668)
(34, 25, 500, 782, 0.0, 19.238300323486328)
(34, 25, 600, 782, 4.6566128730773926e-08, 17.296585083007812)
(34, 25, 700, 782, 0.0, 18.378963470458984)
(35, 25, 0, 782, 3.3527612686157227e-08, 18.0323486328125)
(35, 25, 100, 782, 0.0, 19.421281814575195)
(35, 25, 200, 782, 1.434236907016384e-07, 16.54471206665039)
(35, 25, 300, 782, 5.122274160385132e-08, 17.501033782958984)
(35, 25, 400, 782, 1.2759119272232056e-07, 16.907411575317383)
(35, 25, 500, 782, 5.029141902923584e-08, 17.622032165527344)
(35, 25, 600, 782, 3.4458935260772705e-08, 18.18533706665039)
(35, 25, 700, 782, 4.7497451305389404e-08, 17.472164154052734)
(36, 25, 0, 782, 2.2351741790771484e-08, 18.48625946044922)
(36, 25, 100, 782, 1.862645149230957e-09, 19.130769729614258)
(36, 25, 200, 782, 6.51925802230835e-09, 18.802082061767578)
(36, 25, 300, 782, 3.725290298461914e-09, 18.934736251831055)
(36, 25, 400, 782, 1.1175872671742582e-08, 18.967788696289062)
(36,