In [1]:
# useful reference: https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html

In [9]:
import argparse
import os
import numpy as np
import math

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt

In [10]:
os.makedirs("images", exist_ok=True)

img_dir = '../dataset'
n_epochs = 100
batch_size = 20
lr = 0.0002
b1 = .5
b2 = .999
n_cpu = 8
latent_dim = 2
img_size = 256
channels = 3
sample_interval = 25

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    def __getitem__(self, index):
        return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)

transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5, ), std=(0.5, 0.5, 0.5, )),
    ])
    
train_dataset = ImageFolderWithPaths(
    root=img_dir,
    transform=transform
)

dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=n_cpu,
    shuffle=True,
    drop_last=True
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [11]:
# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, (imgs, _, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )
        
        if epoch % sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % epoch, nrow=5, normalize=True)

[Epoch 0/200] [Batch 0/6] [D loss: 0.693209] [G loss: 0.691682]
[Epoch 0/200] [Batch 1/6] [D loss: 0.692419] [G loss: 0.692087]
[Epoch 0/200] [Batch 2/6] [D loss: 0.691384] [G loss: 0.692457]
[Epoch 0/200] [Batch 3/6] [D loss: 0.689352] [G loss: 0.692803]
[Epoch 0/200] [Batch 4/6] [D loss: 0.687504] [G loss: 0.693199]
[Epoch 0/200] [Batch 5/6] [D loss: 0.682552] [G loss: 0.693178]
[Epoch 1/200] [Batch 0/6] [D loss: 0.674631] [G loss: 0.691690]
[Epoch 1/200] [Batch 1/6] [D loss: 0.657683] [G loss: 0.684936]
[Epoch 1/200] [Batch 2/6] [D loss: 0.642764] [G loss: 0.660109]
[Epoch 1/200] [Batch 3/6] [D loss: 0.658566] [G loss: 0.599663]
[Epoch 1/200] [Batch 4/6] [D loss: 0.691423] [G loss: 0.473432]
[Epoch 1/200] [Batch 5/6] [D loss: 0.776700] [G loss: 0.442369]
[Epoch 2/200] [Batch 0/6] [D loss: 0.743901] [G loss: 0.517997]
[Epoch 2/200] [Batch 1/6] [D loss: 0.706106] [G loss: 0.552059]
[Epoch 2/200] [Batch 2/6] [D loss: 0.697030] [G loss: 0.630843]
[Epoch 2/200] [Batch 3/6] [D loss: 0.703

[Epoch 21/200] [Batch 1/6] [D loss: 0.621979] [G loss: 0.814500]
[Epoch 21/200] [Batch 2/6] [D loss: 0.728016] [G loss: 0.689179]
[Epoch 21/200] [Batch 3/6] [D loss: 0.615353] [G loss: 0.666093]
[Epoch 21/200] [Batch 4/6] [D loss: 0.685234] [G loss: 0.555764]
[Epoch 21/200] [Batch 5/6] [D loss: 0.705958] [G loss: 0.619468]
[Epoch 22/200] [Batch 0/6] [D loss: 0.711321] [G loss: 0.598957]
[Epoch 22/200] [Batch 1/6] [D loss: 0.680507] [G loss: 0.654149]
[Epoch 22/200] [Batch 2/6] [D loss: 0.719054] [G loss: 0.609707]
[Epoch 22/200] [Batch 3/6] [D loss: 0.725845] [G loss: 0.677982]
[Epoch 22/200] [Batch 4/6] [D loss: 0.740699] [G loss: 0.690626]
[Epoch 22/200] [Batch 5/6] [D loss: 0.748055] [G loss: 0.638531]
[Epoch 23/200] [Batch 0/6] [D loss: 0.758780] [G loss: 0.700709]
[Epoch 23/200] [Batch 1/6] [D loss: 0.689735] [G loss: 0.750688]
[Epoch 23/200] [Batch 2/6] [D loss: 0.720902] [G loss: 0.720737]
[Epoch 23/200] [Batch 3/6] [D loss: 0.716493] [G loss: 0.674857]
[Epoch 23/200] [Batch 4/6

[Epoch 42/200] [Batch 2/6] [D loss: 0.690417] [G loss: 0.707697]
[Epoch 42/200] [Batch 3/6] [D loss: 0.694563] [G loss: 0.715728]
[Epoch 42/200] [Batch 4/6] [D loss: 0.703983] [G loss: 0.737373]
[Epoch 42/200] [Batch 5/6] [D loss: 0.682843] [G loss: 0.697803]
[Epoch 43/200] [Batch 0/6] [D loss: 0.686903] [G loss: 0.710279]
[Epoch 43/200] [Batch 1/6] [D loss: 0.695777] [G loss: 0.702707]
[Epoch 43/200] [Batch 2/6] [D loss: 0.702898] [G loss: 0.692184]
[Epoch 43/200] [Batch 3/6] [D loss: 0.710948] [G loss: 0.706478]
[Epoch 43/200] [Batch 4/6] [D loss: 0.715878] [G loss: 0.704295]
[Epoch 43/200] [Batch 5/6] [D loss: 0.699112] [G loss: 0.734653]
[Epoch 44/200] [Batch 0/6] [D loss: 0.705099] [G loss: 0.689840]
[Epoch 44/200] [Batch 1/6] [D loss: 0.695533] [G loss: 0.716887]
[Epoch 44/200] [Batch 2/6] [D loss: 0.695812] [G loss: 0.694710]
[Epoch 44/200] [Batch 3/6] [D loss: 0.697530] [G loss: 0.685388]
[Epoch 44/200] [Batch 4/6] [D loss: 0.699271] [G loss: 0.695562]
[Epoch 44/200] [Batch 5/6

[Epoch 63/200] [Batch 3/6] [D loss: 0.699474] [G loss: 0.715899]
[Epoch 63/200] [Batch 4/6] [D loss: 0.697720] [G loss: 0.718283]
[Epoch 63/200] [Batch 5/6] [D loss: 0.692012] [G loss: 0.711821]
[Epoch 64/200] [Batch 0/6] [D loss: 0.689568] [G loss: 0.711664]
[Epoch 64/200] [Batch 1/6] [D loss: 0.686847] [G loss: 0.703362]
[Epoch 64/200] [Batch 2/6] [D loss: 0.688535] [G loss: 0.698752]
[Epoch 64/200] [Batch 3/6] [D loss: 0.692824] [G loss: 0.696378]
[Epoch 64/200] [Batch 4/6] [D loss: 0.687072] [G loss: 0.701597]
[Epoch 64/200] [Batch 5/6] [D loss: 0.689062] [G loss: 0.684824]
[Epoch 65/200] [Batch 0/6] [D loss: 0.696504] [G loss: 0.689146]
[Epoch 65/200] [Batch 1/6] [D loss: 0.691750] [G loss: 0.692474]
[Epoch 65/200] [Batch 2/6] [D loss: 0.694195] [G loss: 0.693440]
[Epoch 65/200] [Batch 3/6] [D loss: 0.694324] [G loss: 0.683501]
[Epoch 65/200] [Batch 4/6] [D loss: 0.696495] [G loss: 0.690264]
[Epoch 65/200] [Batch 5/6] [D loss: 0.700037] [G loss: 0.682920]
[Epoch 66/200] [Batch 0/6

[Epoch 84/200] [Batch 4/6] [D loss: 0.692273] [G loss: 0.689792]
[Epoch 84/200] [Batch 5/6] [D loss: 0.696289] [G loss: 0.693777]
[Epoch 85/200] [Batch 0/6] [D loss: 0.691800] [G loss: 0.695257]
[Epoch 85/200] [Batch 1/6] [D loss: 0.689656] [G loss: 0.693422]
[Epoch 85/200] [Batch 2/6] [D loss: 0.694554] [G loss: 0.688672]
[Epoch 85/200] [Batch 3/6] [D loss: 0.692365] [G loss: 0.692883]
[Epoch 85/200] [Batch 4/6] [D loss: 0.690809] [G loss: 0.692525]
[Epoch 85/200] [Batch 5/6] [D loss: 0.690203] [G loss: 0.682778]
[Epoch 86/200] [Batch 0/6] [D loss: 0.692091] [G loss: 0.695071]
[Epoch 86/200] [Batch 1/6] [D loss: 0.697698] [G loss: 0.690735]
[Epoch 86/200] [Batch 2/6] [D loss: 0.694023] [G loss: 0.692892]
[Epoch 86/200] [Batch 3/6] [D loss: 0.694207] [G loss: 0.686378]
[Epoch 86/200] [Batch 4/6] [D loss: 0.693411] [G loss: 0.683194]
[Epoch 86/200] [Batch 5/6] [D loss: 0.685853] [G loss: 0.695348]
[Epoch 87/200] [Batch 0/6] [D loss: 0.690628] [G loss: 0.687362]
[Epoch 87/200] [Batch 1/6

[Epoch 105/200] [Batch 4/6] [D loss: 0.700055] [G loss: 0.690871]
[Epoch 105/200] [Batch 5/6] [D loss: 0.695119] [G loss: 0.692131]
[Epoch 106/200] [Batch 0/6] [D loss: 0.699956] [G loss: 0.690910]
[Epoch 106/200] [Batch 1/6] [D loss: 0.697867] [G loss: 0.690899]
[Epoch 106/200] [Batch 2/6] [D loss: 0.696175] [G loss: 0.691235]
[Epoch 106/200] [Batch 3/6] [D loss: 0.695573] [G loss: 0.693589]
[Epoch 106/200] [Batch 4/6] [D loss: 0.700749] [G loss: 0.695839]
[Epoch 106/200] [Batch 5/6] [D loss: 0.692638] [G loss: 0.692340]
[Epoch 107/200] [Batch 0/6] [D loss: 0.694503] [G loss: 0.685727]
[Epoch 107/200] [Batch 1/6] [D loss: 0.694282] [G loss: 0.694322]
[Epoch 107/200] [Batch 2/6] [D loss: 0.694913] [G loss: 0.690547]
[Epoch 107/200] [Batch 3/6] [D loss: 0.693541] [G loss: 0.688225]
[Epoch 107/200] [Batch 4/6] [D loss: 0.690244] [G loss: 0.683562]
[Epoch 107/200] [Batch 5/6] [D loss: 0.688061] [G loss: 0.691769]
[Epoch 108/200] [Batch 0/6] [D loss: 0.690544] [G loss: 0.689416]
[Epoch 108

[Epoch 126/200] [Batch 3/6] [D loss: 0.688419] [G loss: 0.687791]
[Epoch 126/200] [Batch 4/6] [D loss: 0.688320] [G loss: 0.692570]
[Epoch 126/200] [Batch 5/6] [D loss: 0.697530] [G loss: 0.698441]
[Epoch 127/200] [Batch 0/6] [D loss: 0.685710] [G loss: 0.704315]
[Epoch 127/200] [Batch 1/6] [D loss: 0.688974] [G loss: 0.712564]
[Epoch 127/200] [Batch 2/6] [D loss: 0.690864] [G loss: 0.691143]
[Epoch 127/200] [Batch 3/6] [D loss: 0.681373] [G loss: 0.707967]
[Epoch 127/200] [Batch 4/6] [D loss: 0.685514] [G loss: 0.703789]
[Epoch 127/200] [Batch 5/6] [D loss: 0.692171] [G loss: 0.688739]
[Epoch 128/200] [Batch 0/6] [D loss: 0.684465] [G loss: 0.703926]
[Epoch 128/200] [Batch 1/6] [D loss: 0.693987] [G loss: 0.692098]
[Epoch 128/200] [Batch 2/6] [D loss: 0.685724] [G loss: 0.680393]
[Epoch 128/200] [Batch 3/6] [D loss: 0.702328] [G loss: 0.689911]
[Epoch 128/200] [Batch 4/6] [D loss: 0.692703] [G loss: 0.673430]
[Epoch 128/200] [Batch 5/6] [D loss: 0.692041] [G loss: 0.692898]
[Epoch 129

[Epoch 147/200] [Batch 2/6] [D loss: 0.699765] [G loss: 0.690438]
[Epoch 147/200] [Batch 3/6] [D loss: 0.693886] [G loss: 0.696560]
[Epoch 147/200] [Batch 4/6] [D loss: 0.706342] [G loss: 0.688298]
[Epoch 147/200] [Batch 5/6] [D loss: 0.698401] [G loss: 0.703612]
[Epoch 148/200] [Batch 0/6] [D loss: 0.698344] [G loss: 0.700962]
[Epoch 148/200] [Batch 1/6] [D loss: 0.693815] [G loss: 0.695183]
[Epoch 148/200] [Batch 2/6] [D loss: 0.695590] [G loss: 0.715663]
[Epoch 148/200] [Batch 3/6] [D loss: 0.690979] [G loss: 0.708139]
[Epoch 148/200] [Batch 4/6] [D loss: 0.698185] [G loss: 0.693158]
[Epoch 148/200] [Batch 5/6] [D loss: 0.696037] [G loss: 0.701235]
[Epoch 149/200] [Batch 0/6] [D loss: 0.692879] [G loss: 0.701669]
[Epoch 149/200] [Batch 1/6] [D loss: 0.694200] [G loss: 0.689097]
[Epoch 149/200] [Batch 2/6] [D loss: 0.692941] [G loss: 0.708581]
[Epoch 149/200] [Batch 3/6] [D loss: 0.694392] [G loss: 0.718811]
[Epoch 149/200] [Batch 4/6] [D loss: 0.696906] [G loss: 0.714216]
[Epoch 149

[Epoch 168/200] [Batch 1/6] [D loss: 0.652242] [G loss: 1.038153]
[Epoch 168/200] [Batch 2/6] [D loss: 0.571082] [G loss: 1.392763]
[Epoch 168/200] [Batch 3/6] [D loss: 0.549716] [G loss: 1.271406]
[Epoch 168/200] [Batch 4/6] [D loss: 0.459686] [G loss: 1.214958]
[Epoch 168/200] [Batch 5/6] [D loss: 0.414729] [G loss: 1.084726]
[Epoch 169/200] [Batch 0/6] [D loss: 0.433346] [G loss: 0.776501]
[Epoch 169/200] [Batch 1/6] [D loss: 0.626656] [G loss: 0.458999]
[Epoch 169/200] [Batch 2/6] [D loss: 0.611728] [G loss: 0.563367]
[Epoch 169/200] [Batch 3/6] [D loss: 0.542424] [G loss: 0.737893]
[Epoch 169/200] [Batch 4/6] [D loss: 0.614206] [G loss: 0.962968]
[Epoch 169/200] [Batch 5/6] [D loss: 0.771790] [G loss: 0.662549]
[Epoch 170/200] [Batch 0/6] [D loss: 1.049061] [G loss: 0.503731]
[Epoch 170/200] [Batch 1/6] [D loss: 0.828050] [G loss: 0.773536]
[Epoch 170/200] [Batch 2/6] [D loss: 0.971503] [G loss: 0.694622]
[Epoch 170/200] [Batch 3/6] [D loss: 1.199655] [G loss: 0.438737]
[Epoch 170

[Epoch 189/200] [Batch 0/6] [D loss: 0.349656] [G loss: 1.450049]
[Epoch 189/200] [Batch 1/6] [D loss: 0.563122] [G loss: 0.866496]
[Epoch 189/200] [Batch 2/6] [D loss: 0.744121] [G loss: 0.695884]
[Epoch 189/200] [Batch 3/6] [D loss: 0.635363] [G loss: 2.308929]
[Epoch 189/200] [Batch 4/6] [D loss: 0.774921] [G loss: 0.484427]
[Epoch 189/200] [Batch 5/6] [D loss: 0.732974] [G loss: 1.083977]
[Epoch 190/200] [Batch 0/6] [D loss: 0.564276] [G loss: 1.148520]
[Epoch 190/200] [Batch 1/6] [D loss: 0.748715] [G loss: 0.274506]
[Epoch 190/200] [Batch 2/6] [D loss: 0.754760] [G loss: 1.681927]
[Epoch 190/200] [Batch 3/6] [D loss: 0.834144] [G loss: 1.618726]
[Epoch 190/200] [Batch 4/6] [D loss: 0.686828] [G loss: 0.450847]
[Epoch 190/200] [Batch 5/6] [D loss: 0.647730] [G loss: 0.525481]
[Epoch 191/200] [Batch 0/6] [D loss: 0.781273] [G loss: 0.953063]
[Epoch 191/200] [Batch 1/6] [D loss: 0.635535] [G loss: 0.782223]
[Epoch 191/200] [Batch 2/6] [D loss: 0.895526] [G loss: 0.629325]
[Epoch 191

In [12]:
# Explore latent space
from ipywidgets import interact

def sample(z1, z2):
    #z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
    z = np.zeros((1, latent_dim))
    z[0,0] = z1
    z[0,1] = z2
    z = Variable(Tensor(z))
    gen_imgs = generator(z)
    
    if channels == 1:
        plt.imshow(gen_imgs.data[0, 0] * .5 + .5, cmap='gray')
    else:
        img = np.transpose(gen_imgs.cpu().detach().numpy()[0], (1, 2, 0)) * .5 + .5
        plt.imshow(img)
    
interact(sample, z1=(-1, 1, .1), z2=(-1, 1, .1))

interactive(children=(FloatSlider(value=0.0, description='z1', max=1.0, min=-1.0), FloatSlider(value=0.0, desc…

<function __main__.sample(z1, z2)>