# Dataset

In [128]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
import PIL
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms

In [129]:
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Model

In [130]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

In [131]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


In [132]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# WGAN

In [133]:
import argparse
import os
import numpy as np
import math
import itertools
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

# from models import *
# from datasets import *
import torchvision.models as models
from torch.utils.data import Dataset
from torchvision import datasets

import torch.nn as nn
import torch.nn.functional as F
import torch

In [134]:
os.makedirs("images", exist_ok=True)

In [145]:
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=1024, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate") # was 0.00005
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=64, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=10, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=100, help="interval betwen image samples")

_StoreAction(option_strings=['--sample_interval'], dest='sample_interval', nargs=None, const=None, default=100, type=<class 'int'>, choices=None, help='interval betwen image samples', metavar=None)

In [146]:
opt = parser.parse_args("")
print(opt)

Namespace(batch_size=1024, channels=3, clip_value=0.01, img_size=64, latent_dim=100, lr=0.0002, n_cpu=8, n_critic=10, n_epochs=2000, sample_interval=100)


In [147]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

In [148]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()


In [149]:
if cuda:
    generator.cuda()
    discriminator.cuda()

In [150]:
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [151]:
data_dir = Path('/home/ec2-user/SageMaker/genre-224')

In [152]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(opt.img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'test']}
image_datasets = datasets.ImageFolder(os.path.join(data_dir), data_transforms)
# # Create training and validation dataloaders
# dataloader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batch_size, shuffle=True, num_workers=4) for x in ['train', 'test']}
dataloader = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)



Initializing Datasets and Dataloaders...


In [153]:
image_datasets

Dataset ImageFolder
    Number of datapoints: 14981
    Root location: /home/ec2-user/SageMaker/genre-224
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(64, 64), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=PIL.Image.BILINEAR)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [154]:
# path = Path('/home/ec2-user/SageMaker/portrait_landscape/1')
corrupted = []
for filename in os.listdir(path):
    if filename.endswith('.jpg'):
        try:
            img = Image.open(os.path.join(path,filename)) # open the image file
            img.verify() # verify that it is, in fact an image
        except (IOError, SyntaxError) as e:
#             pass
            print('Bad file:', filename) # print out the names of corrupt files
            corrupted.append(filename)

In [155]:
for images in corrupted:
    os.remove(os.path.join(path,images))

In [None]:
# ----------
#  Training
# ----------

batches_done = 0
for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        def __getitem__(self, idx):
            try:
                img, label = load_img(idx)
            except:
                return None
            return [img, label]
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

[Epoch 0/2000] [Batch 0/15] [D loss: 0.013199] [G loss: 0.011938]
[Epoch 0/2000] [Batch 10/15] [D loss: -140.359360] [G loss: -9.024948]
[Epoch 1/2000] [Batch 0/15] [D loss: -175.444611] [G loss: -36.094719]
[Epoch 1/2000] [Batch 10/15] [D loss: -186.070480] [G loss: -88.891174]
[Epoch 2/2000] [Batch 0/15] [D loss: -118.922623] [G loss: -155.759979]
[Epoch 2/2000] [Batch 10/15] [D loss: -45.034836] [G loss: -207.456299]
[Epoch 3/2000] [Batch 0/15] [D loss: 16.427322] [G loss: -191.760590]
[Epoch 3/2000] [Batch 10/15] [D loss: -32.049347] [G loss: -72.178139]
[Epoch 4/2000] [Batch 0/15] [D loss: -48.712421] [G loss: -28.688244]
[Epoch 4/2000] [Batch 10/15] [D loss: -89.197983] [G loss: 28.255409]
[Epoch 5/2000] [Batch 0/15] [D loss: -93.053131] [G loss: 48.222977]
[Epoch 5/2000] [Batch 10/15] [D loss: -114.225204] [G loss: 65.388237]
[Epoch 6/2000] [Batch 0/15] [D loss: -112.659637] [G loss: 63.903999]
[Epoch 6/2000] [Batch 10/15] [D loss: -115.170555] [G loss: 61.986237]
[Epoch 7/2000]