# Dataset

In [1]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
import PIL
from PIL import Image
from pathlib import Path
import torchvision.transforms as transforms

In [2]:
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

# Model

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

# WGAN

In [6]:
import argparse
import os
import numpy as np
import math
import itertools
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

# from models import *
# from datasets import *
import torchvision.models as models
from torch.utils.data import Dataset
from torchvision import datasets

import torch.nn as nn
import torch.nn.functional as F
import torch

In [8]:
os.makedirs("images", exist_ok=True)

In [7]:
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=64, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")

_StoreAction(option_strings=['--sample_interval'], dest='sample_interval', nargs=None, const=None, default=400, type=<class 'int'>, choices=None, help='interval betwen image samples', metavar=None)

In [8]:
opt = parser.parse_args("")
print(opt)

Namespace(batch_size=64, channels=3, clip_value=0.01, img_size=64, latent_dim=100, lr=5e-05, n_cpu=8, n_critic=5, n_epochs=200, sample_interval=400)


In [9]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

In [10]:
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()


In [11]:
if cuda:
    generator.cuda()
    discriminator.cuda()

In [12]:
# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [15]:
data_dir = Path('/home/ec2-user/SageMaker/wikiart_post')

In [16]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(opt.img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(opt.img_size),
        transforms.CenterCrop(opt.img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
# image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'test']}
image_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms[ 'train'])
# # Create training and validation dataloaders
# dataloader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batch_size, shuffle=True, num_workers=4) for x in ['train', 'test']}
dataloader = torch.utils.data.DataLoader(image_datasets, batch_size=opt.batch_size, shuffle=True, num_workers=4)



Initializing Datasets and Dataloaders...


In [17]:
# ----------
#  Training
# ----------

batches_done = 0
for epoch in range(opt.n_epochs):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

[Epoch 0/200] [Batch 0/215] [D loss: -0.058177] [G loss: -0.009735]
[Epoch 0/200] [Batch 5/215] [D loss: -1.924978] [G loss: -0.014037]
[Epoch 0/200] [Batch 10/215] [D loss: -5.047487] [G loss: -0.103665]
[Epoch 0/200] [Batch 15/215] [D loss: -15.890151] [G loss: -0.474473]
[Epoch 0/200] [Batch 20/215] [D loss: -25.173355] [G loss: -1.211193]
[Epoch 0/200] [Batch 25/215] [D loss: -38.283615] [G loss: -2.284422]
[Epoch 0/200] [Batch 30/215] [D loss: -35.740757] [G loss: -3.572651]
[Epoch 0/200] [Batch 35/215] [D loss: -34.968380] [G loss: -5.168399]
[Epoch 0/200] [Batch 40/215] [D loss: -56.246059] [G loss: -7.444133]
[Epoch 0/200] [Batch 45/215] [D loss: -53.103363] [G loss: -9.719498]
[Epoch 0/200] [Batch 50/215] [D loss: -46.204273] [G loss: -12.900633]
[Epoch 0/200] [Batch 55/215] [D loss: -62.801132] [G loss: -15.771315]
[Epoch 0/200] [Batch 60/215] [D loss: -56.549217] [G loss: -18.519272]
[Epoch 0/200] [Batch 65/215] [D loss: -94.944603] [G loss: -23.149567]
[Epoch 0/200] [Batch 

KeyboardInterrupt: 