In [1]:
# useful reference: https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html

In [2]:
import argparse
import os
import numpy as np
import math

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt

In [6]:
os.makedirs("images", exist_ok=True)

img_dir = '../dataset'
n_epochs = 100
batch_size = 20
lr = 0.0002
b1 = .5
b2 = .999
n_cpu = 8
latent_dim = 2
img_size = 256
channels = 3
sample_interval = 25

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()

#         self.init_size = img_size // 4
#         self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

#         self.conv_blocks = nn.Sequential(
#             nn.BatchNorm2d(128),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(128, 128, 3, stride=1, padding=1),
#             nn.BatchNorm2d(128, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(128, 64, 3, stride=1, padding=1),
#             nn.BatchNorm2d(64, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(64, channels, 3, stride=1, padding=1),
#             nn.Tanh(),
#         )

#     def forward(self, z):
#         out = self.l1(z)
#         out = out.view(out.shape[0], 128, self.init_size, self.init_size)
#         img = self.conv_blocks(out)
#         return img
    
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, channels*img_size**2),
            nn.Tanh()
        )

    def forward(self, z):
        out = self.l1(z)
        img = out.view(out.shape[0], channels, img_size, img_size)
        return img

# Loss function
mse_loss = torch.nn.MSELoss()

# Initialize generator
generator = Generator()

if cuda:
    generator.cuda()
    mse_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)

class ImageFolderWithPaths(torchvision.datasets.ImageFolder):
    def __getitem__(self, index):
        return super(ImageFolderWithPaths, self).__getitem__(index) + (self.imgs[index][0],)

transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5, ), std=(0.5, 0.5, 0.5, )),
    ])
    
train_dataset = ImageFolderWithPaths(
    root=img_dir,
    transform=transform
)

# split it up to see how we generalize
#path, dirs, files = next(os.walk(img_dir))
#file_count = len(files)
#num_train = int(0.7 * file_count)
#num_val = file_count - num_train

#train_dataset, val_dataset = torch.utils.data.random_split(dataset, [num_train, num_val])

dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=n_cpu,
    shuffle=True,
    drop_last=True
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [7]:
# ----------
#  Training
# ----------

for epoch in range(n_epochs):
    for i, data in enumerate(dataloader):
        imgs, labels, paths = data
        
        # Target value
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Use label as generator input.. not bad
        
        paths = np.array([path.split('/')[-1][:-4].split('_') for path in paths]).astype(float)
        paths[:, 0] /= 50 # -50 50
        paths[:, 1] = (paths[:, 1]-50) / 50 # 0 100
        positions = torch.from_numpy(paths).to(device)
        light_coord = Variable(positions.type(FloatTensor))

        z = light_coord #Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        # make z correspond to image labels
        gen_imgs = generator(z)

        # Loss measures generator's ability to generate valid looking images
        # paired with their labels
        g_loss = mse_loss(gen_imgs, real_imgs)
        
        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), g_loss.item())
        )
        
        if epoch % sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % epoch, nrow=5, normalize=True)

[Epoch 0/100] [Batch 0/6] [G loss: 0.200375]
[Epoch 0/100] [Batch 1/6] [G loss: 0.194708]
[Epoch 0/100] [Batch 2/6] [G loss: 0.185402]
[Epoch 0/100] [Batch 3/6] [G loss: 0.161506]
[Epoch 0/100] [Batch 4/6] [G loss: 0.154032]
[Epoch 0/100] [Batch 5/6] [G loss: 0.137365]
[Epoch 1/100] [Batch 0/6] [G loss: 0.132911]
[Epoch 1/100] [Batch 1/6] [G loss: 0.116047]
[Epoch 1/100] [Batch 2/6] [G loss: 0.108565]
[Epoch 1/100] [Batch 3/6] [G loss: 0.107540]
[Epoch 1/100] [Batch 4/6] [G loss: 0.096627]
[Epoch 1/100] [Batch 5/6] [G loss: 0.087328]
[Epoch 2/100] [Batch 0/6] [G loss: 0.079151]
[Epoch 2/100] [Batch 1/6] [G loss: 0.073386]
[Epoch 2/100] [Batch 2/6] [G loss: 0.068953]
[Epoch 2/100] [Batch 3/6] [G loss: 0.062020]
[Epoch 2/100] [Batch 4/6] [G loss: 0.054590]
[Epoch 2/100] [Batch 5/6] [G loss: 0.057345]
[Epoch 3/100] [Batch 0/6] [G loss: 0.050842]
[Epoch 3/100] [Batch 1/6] [G loss: 0.044288]
[Epoch 3/100] [Batch 2/6] [G loss: 0.041242]
[Epoch 3/100] [Batch 3/6] [G loss: 0.039246]
[Epoch 3/1

[Epoch 30/100] [Batch 0/6] [G loss: 0.002899]
[Epoch 30/100] [Batch 1/6] [G loss: 0.002872]
[Epoch 30/100] [Batch 2/6] [G loss: 0.002278]
[Epoch 30/100] [Batch 3/6] [G loss: 0.004270]
[Epoch 30/100] [Batch 4/6] [G loss: 0.003107]
[Epoch 30/100] [Batch 5/6] [G loss: 0.001778]
[Epoch 31/100] [Batch 0/6] [G loss: 0.002499]
[Epoch 31/100] [Batch 1/6] [G loss: 0.003248]
[Epoch 31/100] [Batch 2/6] [G loss: 0.002423]
[Epoch 31/100] [Batch 3/6] [G loss: 0.002154]
[Epoch 31/100] [Batch 4/6] [G loss: 0.002972]
[Epoch 31/100] [Batch 5/6] [G loss: 0.003467]
[Epoch 32/100] [Batch 0/6] [G loss: 0.003422]
[Epoch 32/100] [Batch 1/6] [G loss: 0.002878]
[Epoch 32/100] [Batch 2/6] [G loss: 0.002960]
[Epoch 32/100] [Batch 3/6] [G loss: 0.002697]
[Epoch 32/100] [Batch 4/6] [G loss: 0.002605]
[Epoch 32/100] [Batch 5/6] [G loss: 0.001757]
[Epoch 33/100] [Batch 0/6] [G loss: 0.002518]
[Epoch 33/100] [Batch 1/6] [G loss: 0.002283]
[Epoch 33/100] [Batch 2/6] [G loss: 0.002305]
[Epoch 33/100] [Batch 3/6] [G loss

[Epoch 59/100] [Batch 5/6] [G loss: 0.001452]
[Epoch 60/100] [Batch 0/6] [G loss: 0.002004]
[Epoch 60/100] [Batch 1/6] [G loss: 0.001379]
[Epoch 60/100] [Batch 2/6] [G loss: 0.001177]
[Epoch 60/100] [Batch 3/6] [G loss: 0.001630]
[Epoch 60/100] [Batch 4/6] [G loss: 0.001462]
[Epoch 60/100] [Batch 5/6] [G loss: 0.001366]
[Epoch 61/100] [Batch 0/6] [G loss: 0.000990]
[Epoch 61/100] [Batch 1/6] [G loss: 0.001573]
[Epoch 61/100] [Batch 2/6] [G loss: 0.001424]
[Epoch 61/100] [Batch 3/6] [G loss: 0.001465]
[Epoch 61/100] [Batch 4/6] [G loss: 0.001789]
[Epoch 61/100] [Batch 5/6] [G loss: 0.001559]
[Epoch 62/100] [Batch 0/6] [G loss: 0.001514]
[Epoch 62/100] [Batch 1/6] [G loss: 0.001362]
[Epoch 62/100] [Batch 2/6] [G loss: 0.001380]
[Epoch 62/100] [Batch 3/6] [G loss: 0.001517]
[Epoch 62/100] [Batch 4/6] [G loss: 0.001490]
[Epoch 62/100] [Batch 5/6] [G loss: 0.001470]
[Epoch 63/100] [Batch 0/6] [G loss: 0.001538]
[Epoch 63/100] [Batch 1/6] [G loss: 0.001426]
[Epoch 63/100] [Batch 2/6] [G loss

[Epoch 89/100] [Batch 5/6] [G loss: 0.001136]
[Epoch 90/100] [Batch 0/6] [G loss: 0.001084]
[Epoch 90/100] [Batch 1/6] [G loss: 0.001061]
[Epoch 90/100] [Batch 2/6] [G loss: 0.001135]
[Epoch 90/100] [Batch 3/6] [G loss: 0.000814]
[Epoch 90/100] [Batch 4/6] [G loss: 0.001211]
[Epoch 90/100] [Batch 5/6] [G loss: 0.001036]
[Epoch 91/100] [Batch 0/6] [G loss: 0.001128]
[Epoch 91/100] [Batch 1/6] [G loss: 0.001007]
[Epoch 91/100] [Batch 2/6] [G loss: 0.000762]
[Epoch 91/100] [Batch 3/6] [G loss: 0.000752]
[Epoch 91/100] [Batch 4/6] [G loss: 0.001493]
[Epoch 91/100] [Batch 5/6] [G loss: 0.000971]
[Epoch 92/100] [Batch 0/6] [G loss: 0.001083]
[Epoch 92/100] [Batch 1/6] [G loss: 0.001028]
[Epoch 92/100] [Batch 2/6] [G loss: 0.001136]
[Epoch 92/100] [Batch 3/6] [G loss: 0.000860]
[Epoch 92/100] [Batch 4/6] [G loss: 0.001132]
[Epoch 92/100] [Batch 5/6] [G loss: 0.001005]
[Epoch 93/100] [Batch 0/6] [G loss: 0.001424]
[Epoch 93/100] [Batch 1/6] [G loss: 0.001293]
[Epoch 93/100] [Batch 2/6] [G loss

In [9]:
# Explore latent space
from ipywidgets import interact

def sample(z1, z2):
    #z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
    z = np.zeros((1, latent_dim))
    z[0,0] = z1
    z[0,1] = z2
    z = Variable(Tensor(z))
    gen_imgs = generator(z)
    
    plt.figure(figsize=(8,8))
    if channels == 1:
        plt.imshow(gen_imgs.data[0, 0] * .5 + .5, cmap='gray')
    else:
        img = np.transpose(gen_imgs.cpu().detach().numpy()[0], (1, 2, 0)) * .5 + .5
        plt.imshow(img)
    
interact(sample, z1=(-1, 1, .1), z2=(-1, 1, .1))

interactive(children=(FloatSlider(value=0.0, description='z1', max=1.0, min=-1.0), FloatSlider(value=0.0, desc…

<function __main__.sample(z1, z2)>