In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import PIL
from timeit import default_timer as timer

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

# Constants and helper methods

<b>Make sure that you run all the cells in this section!</b>

In [5]:
DATA_DIR = '../learning/datasets/sketchy/256x256'
PHOTOS_AUG = 'tx_000100000000'
PHOTOS_DIR = os.path.join(DATA_DIR, 'photo', PHOTOS_AUG)
SKETCHES_AUG = 'tx_000000000010'
SKETCHES_DIR = os.path.join(DATA_DIR, 'sketch', SKETCHES_AUG)
INFO_DIR = '../learning/datasets/info-06-04/info'

We create some helpful helper methods to convert between class labels, ImageNet IDs, and class ids.

In [6]:
def generate_imagenet_id_class_map(photos_dir):
    """Returns a dictionary that maps ImageNet ID to class label, and vice versa."""
    d = {}
    for folder in os.listdir(photos_dir):
        file = os.listdir(os.path.join(photos_dir, folder))[0]
        imagenet_id = file.split('_')[0]
        d[imagenet_id] = folder
        d[folder] = imagenet_id
    return d

imagenet_id_class_map = generate_imagenet_id_class_map(PHOTOS_DIR)

In [7]:
def generate_labels_id_map(photos_dir):
    """Returns a list of class labels"""
    d = {}
    for i, label in enumerate(sorted(os.listdir(photos_dir))):
        d[i] = label
        d[label] = i
    return d

labels_id_map = generate_labels_id_map(PHOTOS_DIR)

Since we need to filter out some sketches, we'll create a special dataset for them.

In [8]:
class SketchesDataset(Dataset):
    """A custom Dataset class for sketches. """
    
    def __init__(self, sketches_dir, info_dir, transform=None, remove_error=True, remove_ambiguous=False, 
                 remove_pose=False, remove_context=False):
        """
        Initialize the sketches dataset.
        
        Args:
            sketches_dir (str): directory of sketches, divided by class
            info_dir (str): directory with additional information about the sketches
            remove_error (bool): set to True to remove sketches classified as erroneous
            remove_ambiguous (bool): set to True to remove sketches classified as ambiguous
            remove_pose (bool): set to True to remove sketches drawn from a wrong pose/perspective
            remove_context (bool): set to True to remove sketches with extraneous details
        """
        self.sketches_dir = sketches_dir
        self.info_dir = info_dir
        self.transform = transform
        self.invalid = [line for line in open(os.path.join(info_dir, 'invalid-error.txt'), 'r')]
        self.stats = pd.read_csv(os.path.join(info_dir, 'stats.csv'))
        if remove_error:
            self.stats = self.stats.loc[self.stats['Error?'] == 0]
        if remove_ambiguous:
            self.stats = self.stats.loc[self.stats['Ambiguous?'] == 0]
        if remove_pose:
            self.stats = self.stats.loc[self.stats['WrongPose?'] == 0]
        if remove_context:
            self.stats = self.stats.loc[self.stats['Context?'] == 0]
    
    def __len__(self):
        return len(self.stats)
    
    def __getitem__(self, idx):
        row = self.stats.iloc[idx]
        class_folder = row['Category'].replace(' ', '_')
        sketch_file = f"{row['ImageNetID']}-{row['SketchID']}.png"
        sketch_path = os.path.join(self.sketches_dir, class_folder, sketch_file)
        with open(sketch_path, 'rb') as f:
            image = PIL.Image.open(f).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, labels_id_map[class_folder]

We want to invert the colors of the sketches, so the background is black and the sketch lines are white. We'll write a custom `Transform` for this.

In [9]:
class InvertTransform:
    """A transform that takes a Tensor with values in [0, 1], and inverts those values."""
    
    def __call__(self, sample):
        return 1 - sample

# Pre-processing images

You don't need to run the cells in this section if you have dataset_stats.npz or you change `SketchesDataset`.

## Helper methods

We have two methods below which calculate the mean and standard deviation of a dataset, which allow us to normalize that dataset later.

In [37]:
def dataset_mean(dataset, batch_size=100):
    """Computes the mean of the dataset."""
    dl = DataLoader(dataset, batch_size=batch_size)
    mean = 0.
    for batch, _ in dl:
        batch_samples = batch.size(0)  # batch size (the last batch can have smaller size!)
        batch = batch.view(batch_samples, batch.size(1), -1)
        mean += batch.mean(2).sum(0)
    mean /= len(dl.dataset)
    return mean

def dataset_std(dataset, mean, batch_size=100):
    """Computes the standard deviation of the dataset."""
    c, h, w = dataset[0][0].size()
    dl = DataLoader(dataset, batch_size=batch_size)
    var = 0.
    for batch, _ in dl:
        batch_samples = batch.size(0)  # batch size (the last batch can have smaller size!)
        batch = batch.view(batch_samples, c, -1)
        var += ((batch - mean.unsqueeze(1))**2).sum([0, 2])
    std = torch.sqrt(var / (len(dl.dataset) * h * w))
    return std

def dataset_scaling(dataset, batch_size=100):
    """
    Computes how much to scale dataset to be in range [-1, 1] after mean subtraction
    Assumes dataset passed in has already been through mean subtraction
    """
    dl = DataLoader(dataset, batch_size=batch_size)
    min_value = 1
    max_value = -1
    for batch, _ in dl:
        batch_samples = batch.size(0)  # batch size (the last batch can have smaller size!)
        batch = batch.view(batch_samples, batch.size(1), -1)
        max_value = max(max_value, torch.max(batch).item())
        min_value = min(min_value, torch.min(batch).item())
    scaling = max(abs(min_value), abs(max_value))
    return scaling

## Photos dataset

For photos, we can simply use PyTorch's `ImageFolder`.

First, we calculate the mean and standard deviation of the photos dataset, so we can normalize it.

In [42]:
photos_dataset = ImageFolder(root=PHOTOS_DIR, transform=transforms.ToTensor())
start = timer()
photos_mean = dataset_mean(photos_dataset)
print(f"time for mean: {timer() - start}")
print(photos_mean)

time for mean: 26.669545827026013
tensor([0.4714, 0.4475, 0.3958])


In [43]:
start = timer()
photos_std = dataset_std(photos_dataset, photos_mean)
print(f"time for std: {timer() - start}")
print(photos_std)

time for std: 31.831889882974792
tensor([0.2679, 0.2565, 0.2746])


Now, we want to find how much to scale the dataset after mean subtraction to make it in the range [-1, 1]

In [44]:
photos_dataset = ImageFolder(root=PHOTOS_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=photos_mean, std=np.array([1, 1, 1]))
]))
start = timer()
photos_scaling = dataset_scaling(photos_dataset)
print(f"time for scaling: {timer() - start}")
print(photos_scaling)

time for scaling: 30.865477380983066
0.6042009592056274


## Sketches dataset

We have already created the cells for the sketches dataset in [Constants and helper methods](#Constants-and-helper-methods).

We'll also need to find the mean and std of the sketches dataset to normalize, just like the photos dataset.

In [46]:
sketches_dataset = SketchesDataset(SKETCHES_DIR, INFO_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    InvertTransform()
]))
start = timer()
sketches_mean = dataset_mean(sketches_dataset)
print(f"time for mean: {timer() - start}")
print(sketches_mean)

time for mean: 122.28621749696322
tensor([0.0388, 0.0388, 0.0388])


In [47]:
start = timer()
sketches_std = dataset_std(sketches_dataset, sketches_mean)
print(f"time for std: {timer() - start}")
print(sketches_std)

time for std: 146.3518540650257
tensor([0.1892, 0.1892, 0.1892])


Now, we want to find how much to scale the dataset after mean subtraction to make it in the range [-1, 1]

In [48]:
sketches_dataset = SketchesDataset(SKETCHES_DIR, INFO_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    InvertTransform(),
    transforms.Normalize(mean=sketches_mean, std=np.array([1, 1, 1]))
]))
start = timer()
sketches_scaling = dataset_scaling(sketches_dataset)
print(f"time for scaling: {timer() - start}")
print(sketches_scaling)

time for scaling: 139.5827461790177
0.9611556529998779


We can save the means and standard deviations of the datasets into a file, so we don't have to run this code again.

In [50]:
np.savez("dataset_stats", photos_mean=photos_mean.numpy(), photos_std=photos_std.numpy(), photos_scaling=photos_scaling,
        sketches_mean=sketches_mean.numpy(), sketches_std=sketches_std.numpy(), sketches_scaling=sketches_scaling)

# Image Loading

Run the code below, assuming that you have already run the code from the section [Pre-processing images](#Pre-processing-images) or have the dataset_stats.npz file.

In [10]:
npzfile = np.load("dataset_stats.npz")
photos_mean = npzfile['photos_mean']
photos_std = npzfile['photos_std']
photos_scaling = npzfile['photos_scaling']
sketches_mean = npzfile['sketches_mean']
sketches_std = npzfile['sketches_std']
sketches_scaling = npzfile['sketches_scaling']
print(photos_mean, photos_std, photos_scaling)
print(sketches_mean, sketches_std, sketches_scaling)

[0.47139016 0.44750962 0.395799  ] [0.2678542 0.2564591 0.274611 ] 0.6042009592056274
[0.03884432 0.03884432 0.03884432] [0.18919694 0.18919694 0.18919694] 0.9611556529998779


In [19]:
photos_dataset = ImageFolder(root=PHOTOS_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=photos_mean, std=np.array([photos_scaling, photos_scaling, photos_scaling]))
]))
photos_dl = DataLoader(photos_dataset, batch_size=50, shuffle=True, num_workers=4)

In [20]:
sketches_dataset = SketchesDataset(SKETCHES_DIR, INFO_DIR, transform=transforms.Compose([
    transforms.ToTensor(),
    InvertTransform(),
    transforms.Normalize(mean=sketches_mean, std=np.array([sketches_scaling, sketches_scaling, sketches_scaling]))
]))
sketches_dl = DataLoader(sketches_dataset, batch_size=50, shuffle=True, num_workers=4)

In [1]:
from generator import Generator
from discriminator import Discriminator
from dataset import load_sketchygan_dataset

In [5]:
ds, dl = load_sketchygan_dataset(8)
d = Discriminator(125, 3)
image = ds[0][0]
image = image[:, :64, :64]
image = image.view(1, 3, 64, 64)
d(image)

Sequential(
  (0): MRU(
    (conv_mi): ModuleList(
      (0): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Sigmoid()
    )
    (conv_ni): ModuleList(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Sigmoid()
    )
    (conv_zi): ModuleList(
      (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (conv_xi): ModuleList(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
  )
  (1): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
)


TypeError: forward() takes 2 positional arguments but 3 were given

In [3]:
from discriminator import Discriminator

In [4]:
Discriminator(125, 6)

Discriminator(
  (encoder): Encoder(
    (image_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (layer1): ModuleList(
      (0): MRU(
        (conv_mi): ModuleList(
          (0): Conv2d(9, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Sigmoid()
        )
        (conv_ni): ModuleList(
          (0): Conv2d(9, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Sigmoid()
        )
        (conv_zi): ModuleList(
          (0): Conv2d(9, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01)
        )
        (conv_xi): ModuleList(
          (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01)
        )
      )
      (1): Conv2d(64, 64, kernel_s