In [1]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.autograd import Variable
from torchvision import transforms, utils

In [2]:
'''Vanilla Generative Adversarial Network'''

# disciminator network
class D(nn.Module):
    def __init__(self):
        super(D, self).__init__()
        self.ndf = 32
        self.main = nn.Sequential(
            nn.Conv2d(3, self.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf, self.ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(self.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.main(x)

# generator network 
class G(nn.Module):
    def __init__(self, latent):
        super(G, self).__init__()
        self.ngf = 32
        self.latent = latent
        self.main = nn.Sequential(
            nn.ConvTranspose2d(self.latent, self.ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.ngf * 8),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 4),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf * 2),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.ngf),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(self.ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [3]:
# custom pytorch dataset
class PokeDataset(Dataset):
    def __init__(self, root_dir):
        self.root = root_dir
        self.tform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(os.listdir(self.root))

    def __getitem__(self, idx):
        file = os.path.dirname("")
        working_dir = os.path.join(file, self.root)
        imname = str(idx+1) + '.png'
        impath = os.path.join(working_dir, imname)
        tmp_imp = Image.open(impath)
        # remove alpha channel
        img_rgb = Image.new('RGB', tmp_imp.size, color=(255, 255, 255))
        img_rgb.paste(tmp_imp, mask=tmp_imp.split()[3])
        img_64rgb = img_rgb.resize((64,64))
        return self.tform(img_64rgb)

In [4]:
epochs = 1000
lr = 0.0001
torch.manual_seed(12345)
batch_size = 64
use_cuda = torch.cuda.is_available()
latent_size = 100

In [5]:
dataset = PokeDataset('./data/pokemon')
dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2)
discriminator = D()
generator = G(latent_size)

In [6]:
print(dataset[0])


(0 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(1 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...             ⋱             ...          
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000

(2 ,.,.) = 
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
  1.0000  1.0000  1.0000  ...   1.0000  1.0000  1.0000
           ...            

In [7]:
# loss(o, t) = - 1/n \sum_i (t[i] log(o[i]) + (1 - t[i]) log(1 - o[i]))
loss = nn.BCELoss(size_average=True)

if use_cuda:
    print('CUDA device found and active')
    discriminator.cuda()
    generator.cuda()
    loss.cuda()

# optimizers
optimD = optim.Adam(discriminator.parameters(), lr, betas=(0.5, 0.999))
optimG = optim.Adam(generator.parameters(), lr, betas=(0.5, 0.999))

CUDA device found and active


    Found GPU0 GeForce GTX 1060 with Max-Q Design which requires CUDA_VERSION >= 8000 for
     optimal performance and fast startup time, but your PyTorch was compiled
     with CUDA_VERSION 7050. Please install the correct PyTorch binary
     using instructions from http://pytorch.org
    


In [8]:
print_every = 100
im_samples_every = 100

test_noise = torch.Tensor(batch_size, latent_size, 1, 1).normal_(0, 1)
if use_cuda:
    test_noise = test_noise.cuda()

test_noiseV = Variable(test_noise)

for i in range(epochs):
    for j, data in enumerate(dataloader):
        latent = torch.Tensor(data.size(0), latent_size, 1, 1)
        label = torch.Tensor(data.size(0), 1, 1, 1)

        if use_cuda:
            latent = latent.cuda()
            label = label.cuda()
            data = data.cuda()

        # train discriminator        
        # train on real
        # input an image, 0|1 if fake|real        
        optimD.zero_grad()
        real_label = Variable(label.fill_(1), requires_grad=False)
        real_im = Variable(data, requires_grad=False)

        out = discriminator(real_im)
        loss_real = loss(out, real_label)
        loss_real.backward()

        # train D on fake
        noise = Variable(latent.normal_(0, 1), requires_grad=False)
        fake_label = Variable(label.fill_(0), requires_grad=False)

        fake = generator(noise)
        out = discriminator(fake.detach())
        loss_fake = loss(out, fake_label)
        loss_fake.backward()
        optimD.step()

        # train generator
        fake_real_label = Variable(label.fill_(1), requires_grad=False)       
        optimG.zero_grad()
        out = discriminator(fake)
        loss_gen = loss(out, fake_real_label)
        loss_gen.backward()
        optimG.step()
        
        if i % print_every == 0 and j % 12 == 0:
            print('epoch [{}]/[{}]    lossD {:.5f}    lossG {:.5f}'.format(
                    i, epochs, (loss_real.cpu().data[0] + loss_fake.cpu().data[0]), 
                    loss_gen.cpu().data[0]))

        if i % im_samples_every == 0:
            out = generator(test_noiseV).cpu().data
            utils.save_image(out, './data/fake/fake'+str(i//im_samples_every)+'.jpg', normalize=True)
            torch.save(discriminator, 'poke_dis.pkl')
            torch.save(generator, 'poke_gen.pkl')

epoch [0]/[1000]    lossD 1.32882    lossG 0.89517


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


epoch [100]/[1000]    lossD 0.76900    lossG 5.91544
epoch [200]/[1000]    lossD 0.28005    lossG 2.99047
epoch [300]/[1000]    lossD 4.48831    lossG 11.28469
epoch [400]/[1000]    lossD 0.07096    lossG 3.68977
epoch [500]/[1000]    lossD 0.06381    lossG 4.38413
epoch [600]/[1000]    lossD 0.03484    lossG 4.45693
epoch [700]/[1000]    lossD 0.04137    lossG 4.47763
epoch [800]/[1000]    lossD 0.03775    lossG 4.56968
epoch [900]/[1000]    lossD 0.19890    lossG 14.35035


In [9]:
torch.save(discriminator, 'poke_dis.pkl')
torch.save(generator, 'poke_gen.pkl')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
