In [1]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dst
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from __future__ import print_function

In [2]:
batch_size = 64
image_size = 64

In [3]:
transform = transforms.Compose([transforms.Resize(image_size) ,
                                transforms.ToTensor() ,
                                transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 , 0.5 , 0.5))
                               ])

In [4]:
dataset = dst.CIFAR10(root = 'data' , download = True , transform = transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 47213478.17it/s]


Extracting data/cifar-10-python.tar.gz to data


In [5]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True , num_workers=2)

In [6]:
device = 'cuda'

In [7]:
def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1 :
        m.weight.data.normal_(0.0 , 0.02)
    elif classname.find('Batchnorm') != -1:
        m.weight.data.normal_(1.0 , 0.02)
        m.bias.data.fill_(0)

In [10]:
class G(nn.Module):

    def __init__(self):
        super(G , self).__init__()


        self.main = nn.Sequential(

        nn.ConvTranspose2d(100 , 512 , 4 , 1 , 0 , bias = True),
        nn.BatchNorm2d(512),
        nn.ReLU(True),

        nn.ConvTranspose2d(512 , 256 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(256),
        nn.ReLU(True),

        nn.ConvTranspose2d(256 , 128 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(128),
        nn.ReLU(True),

        nn.ConvTranspose2d(128 , 64 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(64),
        nn.ReLU(True),

        nn.ConvTranspose2d(64 , 3 , 4 , 2 , 1 , bias = True),
        nn.Tanh()
        ).to(device)


    def forward(self , input):
        output = self.main(input)
        return output

In [11]:
netG = G()
netG.apply(weights_init)

G(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (13): Tanh()
  )
)

In [12]:
class D(nn.Module):

    def __init__(self):
        super(D , self).__init__()

        self.main = nn.Sequential(

        # Conv2d(input , output , kernelsize , stride , padding , bias)
        nn.Conv2d(3 , 64 , 4 , 2 , 1, bias = True),

        nn.LeakyReLU(0.2 , inplace = True),

        nn.Conv2d(64 , 128 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2 , True),

        nn.Conv2d(128 , 256 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2 , True),

        nn.Conv2d(256 , 512 , 4 , 2 , 1 , bias = True),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2 , True),

        nn.Conv2d(512 , 1 , 4 , 1 , 0 , bias = True),
        nn.Sigmoid()
        ).to(device)

    def forward(self ,input):
        output = self.main(input)
        return output.view(-1)

In [13]:
netD = D()
netD.apply(weights_init)

D(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)

In [14]:
criterion = nn.BCELoss()
# Adam(parameters , learning rate , coefficients)
optimizerG = optim.Adam(netG.parameters() , lr = 0.0002 , betas = (0.5 , 0.999))
optimizerD = optim.Adam(netD.parameters() , lr = 0.0002 , betas = (0.5 , 0.999))

In [18]:
# Training loop
for epoch in range(15):
    for i, data in enumerate(dataloader, 0):
        real, _ = data
        real = real.to(device)

        # Discriminator update
        netD.zero_grad()
        target_real = torch.ones(real.size(0)).to(device)
        output_real = netD(real)
        errD_real = criterion(output_real, target_real)

        noise = torch.randn(real.size(0), 100, 1, 1).to(device)
        fake = netG(noise)
        target_fake = torch.zeros(real.size(0)).to(device)
        output_fake = netD(fake.detach())
        errD_fake = criterion(output_fake, target_fake)

        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        # Generator update
        netG.zero_grad()
        target_gen = torch.ones(real.size(0)).to(device)
        output_gen = netD(fake)
        errG = criterion(output_gen, target_gen)

        errG.backward()
        optimizerG.step()

        # Print loss and save images
        print(f"[{epoch+1}/{15}][{i+1}/{len(dataloader)}] Loss_D: {errD.item():} Loss_G: {errG.item():}")

        if i % 100 == 0:
            vutils.save_image(real , '%s/real_sample.png' % 'results' , normalize = True)
            fake = netG(noise)
            vutils.save_image(fake.data , '%s/fake_sample_epoch_%03d.png' % ('results' , epoch) , normalize = True)

[1/15][1/782] Loss_D: 0.7805328965187073 Loss_G: 24.40935516357422
[1/15][2/782] Loss_D: 0.7407083511352539 Loss_G: 23.446062088012695
[1/15][3/782] Loss_D: 0.7836392521858215 Loss_G: 19.230785369873047
[1/15][4/782] Loss_D: 0.2704241871833801 Loss_G: 13.050256729125977
[1/15][5/782] Loss_D: 0.056919604539871216 Loss_G: 5.088726997375488
[1/15][6/782] Loss_D: 1.9166302680969238 Loss_G: 19.534854888916016
[1/15][7/782] Loss_D: 0.698429524898529 Loss_G: 22.621641159057617
[1/15][8/782] Loss_D: 0.8734854459762573 Loss_G: 21.47257423400879
[1/15][9/782] Loss_D: 0.4727477729320526 Loss_G: 18.290748596191406
[1/15][10/782] Loss_D: 0.40392911434173584 Loss_G: 13.009635925292969
[1/15][11/782] Loss_D: 0.5447190403938293 Loss_G: 6.149919033050537
[1/15][12/782] Loss_D: 1.0382089614868164 Loss_G: 14.36469841003418
[1/15][13/782] Loss_D: 0.290138840675354 Loss_G: 16.266094207763672
[1/15][14/782] Loss_D: 0.1614851951599121 Loss_G: 14.473166465759277
[1/15][15/782] Loss_D: 0.46044573187828064 Loss

  return F.conv_transpose2d(


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[9/15][475/782] Loss_D: 0.12721958756446838 Loss_G: 4.806495666503906
[9/15][476/782] Loss_D: 0.10292631387710571 Loss_G: 4.4180192947387695
[9/15][477/782] Loss_D: 0.12146349251270294 Loss_G: 4.7516326904296875
[9/15][478/782] Loss_D: 0.21737852692604065 Loss_G: 6.027863502502441
[9/15][479/782] Loss_D: 0.09813395887613297 Loss_G: 6.063605308532715
[9/15][480/782] Loss_D: 0.10064670443534851 Loss_G: 4.009065628051758
[9/15][481/782] Loss_D: 0.14226984977722168 Loss_G: 4.387025833129883
[9/15][482/782] Loss_D: 0.09943057596683502 Loss_G: 5.123204231262207
[9/15][483/782] Loss_D: 0.11160195618867874 Loss_G: 4.867332935333252
[9/15][484/782] Loss_D: 0.1971191167831421 Loss_G: 3.646179676055908
[9/15][485/782] Loss_D: 0.5066353678703308 Loss_G: 11.580012321472168
[9/15][486/782] Loss_D: 2.9500463008880615 Loss_G: 2.6742281913757324
[9/15][487/782] Loss_D: 0.9669626951217651 Loss_G: 0.07966593652963638
[9/15][488/782] Loss_D: