In [None]:
import torch.nn as nn
import torch
from torch.nn.modules import conv, Linear
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.utils.data
import torch.backends.cudnn as cudnn
import random
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

class Generator(nn.Module):
    def __init__(self, noise_size, input_dim, generator_dim):
        super(Generator, self).__init__()
        self.generator_dim = generator_dim
        self.linear = nn.Linear(noise_size, 3*3*generator_dim*8)
        self.generator = nn.Sequential(
            nn.BatchNorm2d(generator_dim*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_dim*8, generator_dim*4, 5, 2, 1, bias=True),
            nn.BatchNorm2d(generator_dim*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_dim*4, generator_dim*2, 5, 2, 1, bias=True),
            nn.BatchNorm2d(generator_dim*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_dim*2, generator_dim, 5, 2, 1, bias=True),
            nn.BatchNorm2d(generator_dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(generator_dim, input_dim, 5, 2, 1, bias=True),
            nn.Tanh()
        )

    def forward(self, input):
        out = self.linear(input)
        #print('g_out shape is ',out.size())
        out = out.view(-1, self.generator_dim*8, 3, 3)
        #print('g_out shape is ',out.size())
        out = self.generator(out)
        #print('g_out shape is ',out.size())
        return out
    
class Discriminator(nn.Module):
    def __init__(self, input_dim, disc_dim):
        super(Discriminator, self).__init__()
        self.snlinear = nn.utils.spectral_norm(nn.Linear(disc_dim*8*3*3, 1024))
        self.discriminator = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(input_dim, disc_dim, 5, 2, 1, bias=True), eps=1e-12),
            nn.BatchNorm2d(disc_dim),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(disc_dim, disc_dim*2, 5, 2, 1, bias=True), eps=1e-12),
            nn.BatchNorm2d(disc_dim*2),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(disc_dim*2, disc_dim*4, 5, 2, 1, bias=True), eps=1e-12),
            nn.BatchNorm2d(disc_dim*4),
            nn.LeakyReLU(0.1, inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(disc_dim*4, disc_dim*8, 5, 2, 1, bias=True), eps=1e-12),
            nn.BatchNorm2d(disc_dim*8),
            nn.LeakyReLU(0.1, inplace=True),
        )

    def forward(self, input):
        #print('input shape is ',input.size())
        out = self.discriminator(input)
        #print('out shape is ',out.size())
        out = out.view(out.size(0), -1)
        #output = self.snlinear(output)
        #print('out shape is ',out.size())
        out = self.snlinear(out)
        return out#output.view(-1, 1).squeeze(1)    

In [None]:
dataset = datasets.ImageFolder(root='../input/train',
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.CenterCrop(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])
                                      )

dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True,
                                         shuffle=True, num_workers=int(2))

manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

torch.cuda.manual_seed_all(manualSeed)
#torch.cuda.set_device(0)

cudnn.benchmark = True

In [None]:
def weight_filler(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
generator = Generator(noise_size=100, input_dim=3, generator_dim=64)
discriminator = Discriminator(input_dim=3, disc_dim=64)
generator.apply(weight_filler)
discriminator.apply(weight_filler)
batch_size = 64
inputv = torch.FloatTensor(batch_size, 3, 64, 64)
noise = torch.FloatTensor(batch_size, 100, 1, 1)
fixed_noise = torch.FloatTensor(batch_size, 100, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batch_size)
real_label = 1
fake_label = 0

fixed_noise = Variable(fixed_noise)
criterion = nn.BCEWithLogitsLoss(reduction='none')

generator.cuda()
discriminator.cuda()
criterion.cuda()
inputv, label = inputv.cuda(), label.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0, 0.9))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0, 0.9))

In [None]:
for epoch in range(200):
    for i, data in enumerate(dataloader, 0):
        step = epoch * len(dataloader) + i
        discriminator.zero_grad()
        real_cpu, _ = data
        batch_size = real_cpu.size(0)

        label = torch.ones([64, 1024])
        inputv = Variable(real_cpu.to('cuda'))
        labelv = Variable(label.to('cuda'))
        output = discriminator(inputv)

        errD_real = criterion(output, labelv)
        errD_real = errD_real.sum() / errD_real.shape[0]
        errD_real.backward()

        D_x = output.data.mean()

        noise = torch.FloatTensor(batch_size, 100)
        noise = noise.normal_(-1, 1)

        noisev = Variable(noise)
        fake = generator(noisev.to('cuda'))

        labelv =  Variable(torch.zeros([64, 1024]))
        output = discriminator(fake.detach())

        errD_fake = criterion(output.to('cuda'), labelv.to('cuda'))
        errD_fake = errD_fake.sum() / errD_fake.shape[0]
        errD_fake.backward()
        D_G_z1 = output.data.mean()
        errD = errD_real + errD_fake

        optimizer_discriminator.step()

        generator.zero_grad()
        label = torch.ones([64, 1024])
        labelv = Variable(label.to('cuda'))
        output = discriminator(fake)
        errG = criterion(output.to('cuda'), labelv)
        errG = errG.sum() / errG.shape[0]
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizer_generator.step()
        if i % 20 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, 200, i, len(dataloader),
                     errD.data, errG.data, D_x, D_G_z1, D_G_z2))
        if epoch > 150:    
            if i % 160 == 0:
                vutils.save_image(real_cpu,
                        'real_samples.png',
                        normalize=True)
                fake = generator(noisev.to('cuda'))
                vutils.save_image(fake.data,
                        'fake_samples_epoch_%03d.png' % (epoch),
                        normalize=True)
    
    vutils.save_image(real_cpu,
                'real_samples.png',
                normalize=True)
    fake = generator(noisev.to('cuda'))
    vutils.save_image(fake.data,
                'fake_samples_epoch_%03d.png' % (epoch),
                normalize=True)            
    plt.figure()
    plt.figure(figsize=(20,20))
    plt.subplot(1,2,1)
    plt.imshow(Image.open('real_samples.png'))
    plt.subplot(1,2,2)
    plt.imshow(Image.open('fake_samples_epoch_%03d.png' % (epoch)))
    plt.show()
    # do checkpointing
torch.save(generator.state_dict(), 'netG_epoch_%d.pth' % (epoch))
torch.save(discriminator.state_dict(), 'netD_epoch_%d.pth' % (epoch))
