In [13]:
import os
import torch
from io import BytesIO
from torch import nn, optim
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from torchvision.transforms import functional as trans_fn
from PIL import Image
from utils import Logger

# Reset Notebook

In [22]:
gen = None
disc = None
generator = None
discriminator = None
logger = None

# Variables

In [5]:
ROOT_PATH = '/Users/alex/Projects/datasets/FLICKR/'
CATEGORIES = ['SEARCH_MOUNTAIN/']

IMAGE_SIZE = 200

# Dataset

In [6]:
def resize_and_convert(img, size, resample, quality=100):
    img = trans_fn.resize(img, size, resample)
    img = trans_fn.center_crop(img, size)
    img = img.convert('RGB')
    buffer = BytesIO()
    img.save(buffer, format='jpeg', quality=quality)
    val = buffer.getvalue()
    
    return img

class Category():
    def __init__(self, root_path, category):
        self.root_path = root_path
        self.category = category
        
    def __len__(self):
        path = self.root_path + self.category
        return len([name for name in os.listdir(path) if os.path.isfile(path + name)])
    
    def getdir(self):
        return self.root_path + self.category
    
    def getitem(self, idx):
        images = sorted(os.listdir(self.root_path + self.category))
        img_path = self.root_path + self.category + images[idx]
        return img_path
    

class PlacesDataset(Dataset):
    def __init__(self, root_path, categories, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.categories = []
        
        for c in categories:
            self.categories.append(Category(root_path=root_path, category=c))

    def __len__(self):
        count = 0
        for c in self.categories:
            count += len(c)
            
        return count

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        mod_idx = idx
        category = None
        for c in self.categories:
            category = c
            if mod_idx > len(c) - 1:
                mod_idx -= len(c)
            else:
                break
        
        img_path = category.getitem(mod_idx)
        loc_img = Image.open(img_path)
        img = resize_and_convert(loc_img, IMAGE_SIZE, Image.LANCZOS)

        if self.transform is not None:
            img = self.transform(img)

        return img

In [7]:
compose = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])
    ])

# dataset from tutorial
def mnist_data():
    out_dir = './dataset'
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

# Load data
data = PlacesDataset(root_path=ROOT_PATH, categories=CATEGORIES, transform=compose)

# Create loader with data, so that we can iterate over it
data_loader = DataLoader(data, batch_size=100)

# Num batches
num_batches = len(data_loader)

data.__getitem__(idx=576)

tensor([[[-0.3804, -0.3098, -0.2549,  ..., -1.0000, -1.0000, -1.0000],
         [-0.4510, -0.4118, -0.3569,  ..., -1.0000, -1.0000, -1.0000],
         [-0.5608, -0.5373, -0.4824,  ..., -1.0000, -1.0000, -1.0000],
         ...,
         [-0.4745, -0.5765, -0.8902,  ..., -0.0745, -0.1059, -0.0667],
         [-0.5686, -0.7569, -0.8510,  ..., -0.0902, -0.0980, -0.0510],
         [-0.5843, -0.8353, -0.9216,  ..., -0.0902, -0.0745, -0.0510]],

        [[ 0.0431,  0.0745,  0.0980,  ..., -0.2000, -0.2000, -0.2000],
         [ 0.0039,  0.0275,  0.0510,  ..., -0.1843, -0.1843, -0.1922],
         [-0.0353, -0.0196, -0.0039,  ..., -0.1529, -0.1529, -0.1608],
         ...,
         [-0.6000, -0.6549, -0.9137,  ..., -0.1608, -0.1608, -0.1686],
         [-0.6549, -0.7725, -0.8824,  ..., -0.1843, -0.1843, -0.1529],
         [-0.6549, -0.8588, -0.9451,  ..., -0.1922, -0.1765, -0.1373]],

        [[ 0.4353,  0.4510,  0.4667,  ...,  0.3333,  0.3333,  0.3333],
         [ 0.4196,  0.4275,  0.4353,  ...,  0

# Discriminator

In [8]:
class DiscriminatorNet(torch.nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = IMAGE_SIZE * IMAGE_SIZE * 3
        n_out = 1
        
        self.hidden0 = nn.Sequential( 
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            torch.nn.Linear(256, n_out),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
discriminator = DiscriminatorNet()

# Generator

In [9]:
class GeneratorNet(torch.nn.Module):
    """
    A three hidden-layer generative neural network
    """
    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100
        n_out = IMAGE_SIZE * IMAGE_SIZE * 3
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU  (0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
generator = GeneratorNet()

# Helpers

In [10]:
# Discriminator Helpers

def images_to_vectors(images):
    return images.view(images.size(0), IMAGE_SIZE * IMAGE_SIZE * 3)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 3, IMAGE_SIZE, IMAGE_SIZE)

# Generator Helpers

def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()

# Discriminator Loss
def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data

def train_discriminator(optimizer, real_data, fake_data):
    N = real_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    
    # 1.1 Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error and backpropagate
    error_real = loss(prediction_real, ones_target(N) )
    error_real.backward()

    # 1.2 Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error and backpropagate
    error_fake = loss(prediction_fake, zeros_target(N))
    error_fake.backward()
    
    # 1.3 Update weights with gradients
    optimizer.step()
    
    # Return error and predictions for real and fake inputs
    return error_real + error_fake, prediction_real, prediction_fake

def train_generator(optimizer, fake_data):
    N = fake_data.size(0)
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, ones_target(N))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

# Testing
num_test_samples = 16
test_noise = noise(num_test_samples)

# Train

In [19]:
# Create logger instance
logger = Logger(model_name='VGAN', data_name='FLICKR_MOUNTAIN')
# Total number of epochs to train
num_epochs = 400
for epoch in range(279, num_epochs):
    for n_batch, real_batch in enumerate(data_loader):
        N = real_batch.size(0)
        # 1. Train Discriminator
        real_data = Variable(images_to_vectors(real_batch))
        # Generate fake data and detach 
        # (so gradients are not calculated for generator)
        fake_data = generator(noise(N)).detach()
        # Train D
        d_error, d_pred_real, d_pred_fake = \
              train_discriminator(d_optimizer, real_data, fake_data)

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(N))
        # Train G
        g_error = train_generator(g_optimizer, fake_data)
        # Log batch error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        # Display Progress every few batches
        if (n_batch) % 100 == 0: 
            test_images = vectors_to_images(generator(test_noise))
            test_images = test_images.data
            logger.log_images(
                test_images, num_test_samples, 
                epoch, n_batch, num_batches
            );
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

KeyboardInterrupt: 

# Run Model on Random Latent Vector

In [20]:
gen_noise = noise(24)

fake_data = generator(gen_noise)
test_images = vectors_to_images(fake_data)
test_images = test_images.data
logger.log_images(
                test_images, 1, 
                1, 1, 1
            );

TypeError: 'NoneType' object is not callable

# Save Model

In [242]:
# torch.save(generator.state_dict(), "./models/flickr-mountain-generator-200.pt")
torch.save(discriminator.state_dict(), "./models/flickr-mountain-discriminator-200.pt")

# Load Saved Model

In [21]:
gen = GeneratorNet()
gen.load_state_dict(torch.load("./models/flickr-mountain-generator-200.pt"))
gen.eval()

disc = DiscriminatorNet()
disc.load_state_dict(torch.load("./models/flickr-mountain-discriminator-200.pt"))
disc.eval()

# Replace current generator and discriminator
# generator = gen
# discriminator = disc