In [24]:
""" Creates batches of images to feed into the training network conditioned by genre, uses upsampling when creating batches to account for uneven distributions """



' Creates batches of images to feed into the training network conditioned by genre, uses upsampling when creating batches to account for uneven distributions '

In [25]:
# pip install tensorflow-gpu==1.2.1

In [26]:
import numpy as np
import imageio
import time
import random
import os
from pathlib import Path
from PIL import Image

import sys

In [27]:
# Set the dimension of images you want to be passed in to the network
DIM = 128

In [28]:
# Set your own path to images
# src_img_path  = Path('/home/ec2-user/SageMaker/genre-128')
src_img_path  = Path('/home/ec2-user/SageMaker/portrait_landscape')

In [29]:
os.listdir(src_img_path)

['1', '.ipynb_checkpoints', '0']

In [30]:
len(os.listdir(os.path.join(src_img_path,'1')))

14971

In [31]:
# This dictionary should be updated to hold the absolute number of images associated with each genre used during training
styles = {
    "portraits": 14981,
    "landscapes": 14971
}

styleNum = {
    "portraits": 0,
    "landscapes": 1
}
    
curPos = {
    "portraits": 0,
    "landscapes": 0
}

In [32]:
testNums = {}
trainNums = {}

In [33]:
# Generate test set of images made up of 1/20 of the images (per genre)
for k, v in styles.items():
    # put a twentieth of paintings in here
    nums = range(v)
    random.shuffle(list(nums))
    testNums[k] = nums[0 : v // 20]
    trainNums[k] = nums[v // 20 :]

In [34]:
def inf_gen(gen):
    while True:
        for (images, labels) in gen():
            yield images, labels

In [35]:
def make_generator(files, batch_size, n_classes): # add genre parameter
    if batch_size % n_classes != 0:
        raise ValueError(
            "Batch size {} must be divisible by num classes {}".format(batch_size, n_classes)
        )

    class_batch = batch_size // n_classes

    generators = []

    def get_epoch():

        while True:

            images = np.zeros((batch_size, 3, DIM, DIM), dtype="int32")
            labels = np.zeros((batch_size, n_classes))
            n = 0
            for style in styles:
#             for style in genre:
                styleLabel = styleNum[style]
                curr = curPos[style]
                for _ in range(class_batch):
                    if curr == styles[style]:
                        curr = 0
                        random.shuffle(list(files[style]))
                    img_path = str(Path(src_img_path, str(styleLabel), str(curr) + ".png"))
                    image = Image.open(img_path).convert(mode="RGB")
                    image = np.asarray(image)

                    images[n % batch_size] = image.transpose(2, 0, 1)
                    labels[n % batch_size, int(styleLabel)] = 1
                    n += 1
                    curr += 1
                curPos[style] = curr

            # randomize things but keep relationship between a conditioning vector and its associated image
            rng_state = np.random.get_state()
            np.random.shuffle(images)
            np.random.set_state(rng_state)
            np.random.shuffle(labels)
            yield (images, labels)

    return get_epoch

In [36]:
def load(batch_size):
    return (
        make_generator(trainNums, batch_size, len(styles)),
        make_generator(testNums, batch_size, len(styles)),
    )

In [37]:
import os, sys
from pathlib import Path

sys.path.append(os.getcwd())

from random import randint

import time
import functools
import math

import numpy as np
import tensorflow as tf

import tflib as lib
import tflib.ops.linear
import tflib.ops.conv2d
import tflib.ops.batchnorm
import tflib.ops.deconv2d
import tflib.save_images
# import tflib.wikiart_genre
import tflib.ops.layernorm
import tflib.plot

In [38]:
# pip install --upgrade tensorflow-gpu==1.2.1

In [39]:
# pip install --upgrade tensorflow-gpu

In [40]:
print(tf.test.gpu_device_name())
print(tf.__version__)


1.15.5


In [41]:
tf.test.is_gpu_available()

False

In [42]:
MODE = "acwgan"  # dcgan, wgan, wgan-gp, lsgan
genre = ['portraits','landscapes']
DIM = 128  # Model dimensionality
CRITIC_ITERS = 5  # How many iterations to train the critic for, increase it to 50 later
N_GPUS = 1  # Number of GPUs
BATCH_SIZE = 84  # Batch size. Must be a multiple of CLASSES and N_GPUS
ITERS = 200000  # How many iterations to train for
LAMBDA = 10  # Gradient penalty lambda hyperparameter
OUTPUT_DIM = DIM * DIM * 3  # Number of pixels in each image
CLASSES = len(genre)  # Number of classes, for genres probably 14
PREITERATIONS = 2000  # Number of preiteration training cycles to run
lib.print_model_settings(locals().copy())

Uppercase local vars:
	BATCH_SIZE: 84
	CLASSES: 2
	CRITIC_ITERS: 5
	DATASET: /home/ec2-user/SageMaker/portrait_landscape
	DECAY: True
	DIM: 128
	DIM_D_16: 512
	DIM_D_32: 256
	DIM_D_4: 1024
	DIM_D_64: 128
	DIM_D_8: 1024
	DIM_G_16: 256
	DIM_G_32: 128
	DIM_G_4: 512
	DIM_G_64: 64
	DIM_G_8: 512
	GEN_BS_MULTIPLE: 1
	ITERS: 200000
	LAMBDA: 10
	LR: 0.0001
	MODE: acwgan
	MOMENTUM_D: 0.0
	MOMENTUM_G: 0.0
	NORMALIZATION_D: True
	NORMALIZATION_G: True
	N_GPUS: 1
	OUTPUT_DIM: 49152
	PREITERATIONS: 2000


In [43]:
BATCH_SIZE = 64
DATASET = '/home/ec2-user/SageMaker/portrait_landscape'

DIM_G_64  = 64
DIM_G_32  = 128
DIM_G_16  = 256
DIM_G_8   = 512
DIM_G_4   = 512

DIM_D_64  = 128
DIM_D_32  = 256
DIM_D_16  = 512
DIM_D_8   = 1024
DIM_D_4   = 1024

NORMALIZATION_G = True
NORMALIZATION_D = True

ITERS = 200000
LR = 1e-4
DECAY = True
CRITIC_ITERS = 5
MOMENTUM_G = 0.
MOMENTUM_D = 0.
GEN_BS_MULTIPLE = 1

In [44]:
# Ensure that directory exists where ground truth and plots will be saved to.
Path('generated').mkdir(parents=True, exist_ok=True)
Path('models').mkdir(parents=True, exist_ok=True)

In [45]:
def GeneratorAndDiscriminator():
    return kACGANGenerator, kACGANDiscriminator

In [46]:
DEVICES = ["/gpu:{}".format(i) for i in range(N_GPUS)]

In [47]:
def LeakyReLU(x, alpha=0.2):
    return tf.maximum(alpha * x, x)


def ReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name + ".Linear", n_in, n_out, inputs, initialization="he"
    )
    return tf.nn.relu(output)


def LeakyReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name + ".Linear", n_in, n_out, inputs, initialization="he"
    )
    return LeakyReLU(output)

In [48]:
def Batchnorm(name, axes, inputs):

    if ("Discriminator" in name) and (MODE == "wgan-gp" or MODE == "acwgan"):
        if axes != [0, 2, 3]:
            raise Exception("Layernorm over non-standard axes is unsupported")
        return lib.ops.layernorm.Layernorm(name, [1, 2, 3], inputs)
    else:
        return lib.ops.batchnorm.Batchnorm(name, axes, inputs, fused=True)

In [49]:
def pixcnn_gated_nonlinearity(name, output_dim, a, b, c=None, d=None):
    if c is not None and d is not None:
        a = a + c
        b = b + d

    result = tf.sigmoid(a) * tf.tanh(b)
    return result

In [50]:
def SubpixelConv2D(*args, **kwargs):
    kwargs["output_dim"] = 4 * kwargs["output_dim"]
    output = lib.ops.conv2d.Conv2D(*args, **kwargs)
    output = tf.transpose(output, [0, 2, 3, 1])
    output = tf.depth_to_space(output, 2)
    output = tf.transpose(output, [0, 3, 1, 2])
    return output

In [51]:
def ResidualBlock(
    name, input_dim, output_dim, filter_size, inputs, resample=None, he_init=True
):
    """
    resample: None, 'down', or 'up'
    """
    if resample == "down":
        conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, stride=2)
        conv_1 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim // 2
        )
        conv_1b = functools.partial(
            lib.ops.conv2d.Conv2D,
            input_dim=input_dim // 2,
            output_dim=output_dim // 2,
            stride=2,
        )
        conv_2 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=output_dim // 2, output_dim=output_dim
        )
    elif resample == "up":
        conv_shortcut = SubpixelConv2D
        conv_1 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim // 2
        )
        conv_1b = functools.partial(
            lib.ops.deconv2d.Deconv2D,
            input_dim=input_dim // 2,
            output_dim=output_dim // 2,
        )
        conv_2 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=output_dim // 2, output_dim=output_dim
        )
    elif resample == None:
        conv_shortcut = lib.ops.conv2d.Conv2D
        conv_1 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim // 2
        )
        conv_1b = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=input_dim // 2, output_dim=output_dim // 2
        )
        conv_2 = functools.partial(
            lib.ops.conv2d.Conv2D, input_dim=input_dim // 2, output_dim=output_dim
        )

    else:
        raise Exception("invalid resample value")

    if output_dim == input_dim and resample == None:
        shortcut = inputs  # Identity skip-connection
    else:
        shortcut = conv_shortcut(
            name + ".Shortcut",
            input_dim=input_dim,
            output_dim=output_dim,
            filter_size=1,
            he_init=False,
            biases=True,
            inputs=inputs,
        )

    output = inputs
    output = tf.nn.relu(output)
    output = conv_1(
        name + ".Conv1", filter_size=1, inputs=output, he_init=he_init, weightnorm=False
    )
    output = tf.nn.relu(output)
    output = conv_1b(
        name + ".Conv1B",
        filter_size=filter_size,
        inputs=output,
        he_init=he_init,
        weightnorm=False,
    )
    output = tf.nn.relu(output)
    output = conv_2(
        name + ".Conv2",
        filter_size=1,
        inputs=output,
        he_init=he_init,
        weightnorm=False,
        biases=False,
    )
    output = Batchnorm(name + ".BN", [0, 2, 3], output)

    return shortcut + (0.3 * output)

In [52]:
def kACGANGenerator(
    n_samples,
    numClasses,
    labels,
    noise=None,
    dim=DIM,
    bn=True,
    nonlinearity=tf.nn.relu,
    condition=None,
):
    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)
    if noise is None:
        noise = tf.random_normal([n_samples, 128])

    labels = tf.cast(labels, tf.float32)
    noise = tf.concat([noise, labels], 1)
#######################
    output = lib.ops.linear.Linear(
        "Generator.Input", 128 + numClasses, 4*4*DIM_G_4, noise
    )  # probs need to recalculate dimensions
    print('output 1: ', output)
    output = tf.reshape(output, [-1, DIM_G_4, 4, 4])
#     output = tf.reshape(output, [-1, dim , 4, 4])
    print('output 1_reshape: ', output, output.shape)
    if bn:
        output = Batchnorm("Generator.BN1", [0, 2, 3], output)
        print('output 1_bn: ', output)
    condition = lib.ops.linear.Linear(
        "Generator.cond1", numClasses, 8 * 4 * 4 * dim * 2, labels, biases=False
    )
    condition = tf.reshape(condition, [-1, 8 * dim * 2, 4, 4])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl1",
        8 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
    print('output 1_final: ', output)
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.2", DIM_G_4, DIM_G_8, 3, output)
    print('output 2: ', output)
    if bn:
        output = Batchnorm("Generator.BN2", [0, 2, 3], output)
        print('output 2_bn: ', output)
    condition = lib.ops.linear.Linear(
        "Generator.cond2", numClasses, 4 * 8 * 8 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, 4 * dim * 2, 8, 8])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl2",
        4 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
    print('output 2_final: ', output)
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.3", DIM_G_8, DIM_G_16, 3, output)
    if bn:
        output = Batchnorm("Generator.BN3", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond3", numClasses, 2 * 16 * 16 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, 2 * dim * 2, 16, 16])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl3",
        2 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.4", DIM_G_16, DIM_G_32, 3, output)
    if bn:
        output = Batchnorm("Generator.BN4", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond4", numClasses, 32 * 32 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, dim * 2, 32, 32])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl4",
        dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.5",  DIM_G_32, DIM_G_64, 3, output)
    if bn:
        output = Batchnorm("Generator.BN5", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond5", numClasses, 32 * 64 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, dim , 64, 64])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl5",
        dim//2,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.6", DIM_G_64, 3, 5, output)

    output = tf.tanh(output)

    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return tf.reshape(output, [-1, OUTPUT_DIM]), labels

In [61]:
def kACGANGenerator(
    n_samples,
    numClasses,
    labels,
    noise=None,
    dim=DIM,
    bn=True,
    nonlinearity=tf.nn.relu,
    condition=None,
):
    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)
    if noise is None:
        noise = tf.random_normal([n_samples, 128])

    labels = tf.cast(labels, tf.float32)
    noise = tf.concat([noise, labels], 1)

    output = lib.ops.linear.Linear(
        "Generator.Input", 128 + numClasses, 8 * 4 * 4 * dim * 2, noise
    )  # probs need to recalculate dimensions
    print('output 1: ', output)
    output = tf.reshape(output, [-1, 8 * dim * 2, 4, 4])
#     output = tf.reshape(output, [-1, dim , 4, 4])
    print('output 2: ', output, output.shape)
    if bn:
        output = Batchnorm("Generator.BN1", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond1", numClasses, 8 * 4 * 4 * dim * 2, labels, biases=False
    )
    condition = tf.reshape(condition, [-1, 8 * dim * 2, 4, 4])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl1",
        8 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.2", 8 * dim, 4 * dim * 2, 5, output)
    if bn:
        output = Batchnorm("Generator.BN2", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond2", numClasses, 4 * 8 * 8 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, 4 * dim * 2, 8, 8])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl2",
        4 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.3", 4 * dim, 2 * dim * 2, 5, output)
    if bn:
        output = Batchnorm("Generator.BN3", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond3", numClasses, 2 * 16 * 16 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, 2 * dim * 2, 16, 16])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl3",
        2 * dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.4", 2 * dim, dim * 2, 5, output)
    if bn:
        output = Batchnorm("Generator.BN4", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond4", numClasses, 32 * 32 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, dim * 2, 32, 32])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl4",
        dim,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.4", dim, dim , 5, output)
    if bn:
        output = Batchnorm("Generator.BN4", [0, 2, 3], output)
    condition = lib.ops.linear.Linear(
        "Generator.cond4", numClasses, 32 * 64 * dim * 2, labels
    )
    condition = tf.reshape(condition, [-1, dim , 64, 64])
    output = pixcnn_gated_nonlinearity(
        "Generator.nl4",
        dim//2,
        output[:, ::2],
        output[:, 1::2],
        condition[:, ::2],
        condition[:, 1::2],
    )
#######################
    output = lib.ops.deconv2d.Deconv2D("Generator.5", dim, 3, 5, output)

    output = tf.tanh(output)

    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return tf.reshape(output, [-1, OUTPUT_DIM]), labels

In [60]:
def kACGANDiscriminator(inputs, numClasses, dim=DIM, bn=True, nonlinearity=LeakyReLU):
    output = tf.reshape(inputs, [-1, 3, dim, dim])

    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)

    output = lib.ops.conv2d.Conv2D("Discriminator.1", 3, DIM_D_64, 5, output, stride=2)
    output = nonlinearity(output)
#######################
    output = lib.ops.conv2d.Conv2D("Discriminator.2", DIM_D_64, DIM_D_32, 3, output, stride=2)
    if bn:
        output = Batchnorm("Discriminator.BN2", [0, 2, 3], output)
    output = nonlinearity(output)
#######################
    output = lib.ops.conv2d.Conv2D(
        "Discriminator.3", DIM_D_32, DIM_D_16, 3, stride=2
    )
    if bn:
        output = Batchnorm("Discriminator.BN3", [0, 2, 3], output)
    output = nonlinearity(output)
#######################
    output = lib.ops.conv2d.Conv2D(
        "Discriminator.4", DIM_D_16, DIM_D_8, 3, output, stride=2
    )
    if bn:
        output = Batchnorm("Discriminator.BN4", [0, 2, 3], output)
    output = nonlinearity(output)
    finalLayer = tf.reshape(output, [-1, 4 * 4 * 8 * dim])
#######################
    sourceOutput = lib.ops.linear.Linear(
        "Discriminator.sourceOutput", 4 * 4 * 8 * dim, 1, finalLayer
    )
#######################
    classOutput = lib.ops.linear.Linear(
        "Discriminator.classOutput", 4 * 4 * 8 * dim, numClasses, finalLayer
    )
#######################
    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return (tf.reshape(sourceOutput, [-1]), tf.reshape(classOutput, [-1, numClasses]))

In [62]:
def kACGANDiscriminator(inputs, numClasses, dim=DIM, bn=True, nonlinearity=LeakyReLU):
    output = tf.reshape(inputs, [-1, 3, dim, dim])

    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)

    output = lib.ops.conv2d.Conv2D("Discriminator.1", 3, dim, 5, output, stride=2)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D("Discriminator.2", dim, 2 * dim, 5, output, stride=2)
    if bn:
        output = Batchnorm("Discriminator.BN2", [0, 2, 3], output)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D(
        "Discriminator.3", 2 * dim, 4 * dim, 5, output, stride=2
    )
    if bn:
        output = Batchnorm("Discriminator.BN3", [0, 2, 3], output)
    output = nonlinearity(output)

    output = lib.ops.conv2d.Conv2D(
        "Discriminator.4", 4 * dim, 8 * dim, 5, output, stride=2
    )
    if bn:
        output = Batchnorm("Discriminator.BN4", [0, 2, 3], output)
    output = nonlinearity(output)
    finalLayer = tf.reshape(output, [-1, 4 * 4 * 8 * dim])

    sourceOutput = lib.ops.linear.Linear(
        "Discriminator.sourceOutput", 4 * 4 * 8 * dim, 1, finalLayer
    )

    classOutput = lib.ops.linear.Linear(
        "Discriminator.classOutput", 4 * 4 * 8 * dim, numClasses, finalLayer
    )

    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return (tf.reshape(sourceOutput, [-1]), tf.reshape(classOutput, [-1, numClasses]))

In [63]:
def genRandomLabels(n_samples, numClasses, condition=None):
    labels = np.zeros([BATCH_SIZE, CLASSES], dtype=np.float32)
    for i in range(n_samples):
        if condition is not None:
            labelNum = condition
        else:
            labelNum = randint(0, numClasses - 1)
        labels[i, labelNum] = 1
    return labels

In [64]:
Generator, Discriminator = GeneratorAndDiscriminator()

In [65]:
DEVICES

['/gpu:0']

In [66]:
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:
    print('started ...')
    best_accuracy = 0
    
    all_real_data_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, 3, DIM, DIM])
    all_real_label_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, CLASSES])

    generated_labels_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, CLASSES])
    sample_labels_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, CLASSES])

    split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES))
    split_real_label_conv = tf.split(all_real_label_conv, len(DEVICES))
    split_generated_labels_conv = tf.split(generated_labels_conv, len(DEVICES))
    split_sample_labels_conv = tf.split(sample_labels_conv, len(DEVICES))

    gen_costs, disc_costs = [], []

    for device_index, (device, real_data_conv, real_label_conv) in enumerate(
        zip(DEVICES, split_real_data_conv, split_real_label_conv)
    ):
        with tf.device(device):

            real_data = tf.reshape(
                2 * ((tf.cast(real_data_conv, tf.float32) / 255.0) - 0.5),
                [BATCH_SIZE // len(DEVICES), OUTPUT_DIM],
            )
            print('real_data: ', real_data)
            real_labels = tf.reshape(
                real_label_conv, [BATCH_SIZE // len(DEVICES), CLASSES]
            )
            print('real_labels: ', real_labels)
            generated_labels = tf.reshape(
                split_generated_labels_conv, [BATCH_SIZE // len(DEVICES), CLASSES]
            )
            print('generated_labels: ', generated_labels)
            sample_labels = tf.reshape(
                split_sample_labels_conv, [BATCH_SIZE // len(DEVICES), CLASSES]
            )
            print('sample_labels: ', sample_labels)
            fake_data, fake_labels = Generator(
                BATCH_SIZE // len(DEVICES), CLASSES, generated_labels
            )
#             print(real_data,real_labels,generated_labels,sample_labels,fake_data, fake_labels)
            print('fake_data: ', fake_data)
            print('fake_labels: ', fake_labels)
            # set up discrimnator results

            disc_fake, disc_fake_class = Discriminator(fake_data, CLASSES)
            print('disc fake: ', disc_fake, disc_fake_class)
            disc_real, disc_real_class = Discriminator(real_data, CLASSES)
            print('disc real: ', disc_real, disc_real_class)
            prediction = tf.argmax(disc_fake_class, 1)
            print('prediction 1: ', prediction, disc_fake_class)
            correct_answer = tf.argmax(fake_labels, 1)
            print('correct 1: ', correct_answer, fake_labels)
            equality = tf.equal(prediction, correct_answer)
            print('equality 1: ', equality)
            genAccuracy = tf.reduce_mean(tf.cast(equality, tf.float32))
            print('accuracy 1: ', genAccuracy)
            
            prediction = tf.argmax(disc_real_class, 1)
            print('prediction 2: ', prediction)
            print('disc real class: ', disc_real_class)
            correct_answer = tf.argmax(real_labels, 1)
            print('correct 2: ', correct_answer, real_labels)
#             equality = tf.equal(correct_answer, prediction)
            equality = tf.equal(prediction, correct_answer)
            print('equality 2: ', equality)
            realAccuracy = tf.reduce_mean(tf.cast(equality, tf.float32))

            gen_cost = -tf.reduce_mean(disc_fake)
            disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            gen_cost_test = -tf.reduce_mean(disc_fake)
            disc_cost_test = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            generated_class_cost = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=disc_fake_class, labels=fake_labels
                )
            )

            real_class_cost = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    logits=disc_real_class, labels=real_labels
                )
            )
            gen_cost += generated_class_cost
            disc_cost += real_class_cost

            alpha = tf.random_uniform(
                shape=[BATCH_SIZE // len(DEVICES), 1], minval=0.0, maxval=1.0
            )
            differences = fake_data - real_data
            interpolates = real_data + (alpha * differences)
            gradients = tf.gradients(
                Discriminator(interpolates, CLASSES)[0], [interpolates]
            )[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = tf.reduce_mean((slopes - 1.0) ** 2)
            disc_cost += LAMBDA * gradient_penalty

            real_class_cost_gradient = real_class_cost * 50 + LAMBDA * gradient_penalty

            gen_costs.append(gen_cost)
            disc_costs.append(disc_cost)

    gen_cost = tf.add_n(gen_costs) / len(DEVICES)
    disc_cost = tf.add_n(disc_costs) / len(DEVICES)

    gen_train_op = tf.train.AdamOptimizer(
        learning_rate=1e-4, beta1=0.5, beta2=0.9
    ).minimize(
        gen_cost,
        var_list=lib.params_with_name("Generator"),
        colocate_gradients_with_ops=True,
    )
    disc_train_op = tf.train.AdamOptimizer(
        learning_rate=1e-4, beta1=0.5, beta2=0.9
    ).minimize(
        disc_cost,
        var_list=lib.params_with_name("Discriminator."),
        colocate_gradients_with_ops=True,
    )
    class_train_op = tf.train.AdamOptimizer(
        learning_rate=1e-4, beta1=0.5, beta2=0.9
    ).minimize(
        real_class_cost_gradient,
        var_list=lib.params_with_name("Discriminator."),
        colocate_gradients_with_ops=True,
    )
    # For generating samples

    fixed_noise = tf.constant(
        np.random.normal(size=(BATCH_SIZE, 128)).astype("float32")
    )
    all_fixed_noise_samples = []
    for device_index, device in enumerate(DEVICES):
        n_samples = BATCH_SIZE // len(DEVICES)
        all_fixed_noise_samples.append(
            Generator(
                n_samples,
                CLASSES,
                sample_labels,
                noise=fixed_noise[
                    device_index * n_samples : (device_index + 1) * n_samples
                ],
            )[0]
        )
        if tf.__version__.startswith("1."):
            all_fixed_noise_samples = tf.concat(all_fixed_noise_samples, axis=0)
        else:
            all_fixed_noise_samples = tf.concat(0, all_fixed_noise_samples)

    def generate_image(iteration):
        # this might be where we add the conditionality
        for i in range(CLASSES):
            curLabel = genRandomLabels(BATCH_SIZE, CLASSES, condition=i)
            samples = session.run(
                all_fixed_noise_samples, feed_dict={sample_labels: curLabel}
            )
            samples = ((samples + 1.0) * (255.99 / 2)).astype("int32")
            lib.save_images.save_images(
                samples.reshape((BATCH_SIZE, 3, DIM, DIM)),
                "generated/samples_{}_{}.png".format(str(i), iteration),
            )

    # Dataset iterator
#     train_gen, dev_gen = lib.wikiart_genre.load(BATCH_SIZE)
    train_gen, dev_gen = load(BATCH_SIZE)

    def softmax_cross_entropy(logit, y):
        return -tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y)
        )

    def inf_train_gen():
        while True:
            for (images, labels) in train_gen():
                yield images, labels

    _sample_labels = genRandomLabels(BATCH_SIZE, CLASSES)
    # Save a batch of ground-truth samples
    _x, _y = next(train_gen())
    _x_r = session.run(real_data, feed_dict={all_real_data_conv: _x})
    _x_r = ((_x_r + 1.0) * (255.99 / 2)).astype("int32")
    lib.save_images.save_images(
        _x_r.reshape((BATCH_SIZE, 3, DIM, DIM)), "generated/samples_groundtruth.png"
    )

    session.run(
        tf.initialize_all_variables(),
        feed_dict={generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES)},
    )
    gen = train_gen()

    for iterp in range(PREITERATIONS):
        _data, _labels = next(gen)
        _, accuracy = session.run(
            [disc_train_op, realAccuracy],
            feed_dict={
                all_real_data_conv: _data,
                all_real_label_conv: _labels,
                generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES),
            },
        )
        if iterp % 100 == 99:
            print("pretraining accuracy: " + str(accuracy))

    for iteration in range(ITERS):
        start_time = time.time()
        # Train generator
        if iteration > 0:
            _ = session.run(
                gen_train_op,
                feed_dict={generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES)},
            )
        # Train critic
        disc_iters = CRITIC_ITERS
        for i in range(disc_iters):
            _data, _labels = next(gen)
            (
                _disc_cost,
                _disc_cost_test,
                class_cost_test,
                gen_class_cost,
                _gen_cost_test,
                _genAccuracy,
                _realAccuracy,
                _,
            ) = session.run(
                [
                    disc_cost,
                    disc_cost_test,
                    real_class_cost,
                    generated_class_cost,
                    gen_cost_test,
                    genAccuracy,
                    realAccuracy,
                    disc_train_op,
                ],
                feed_dict={
                    all_real_data_conv: _data,
                    all_real_label_conv: _labels,
                    generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES),
                },
            )

        lib.plot.plot("train disc cost", _disc_cost)
        lib.plot.plot("time", time.time() - start_time)
        lib.plot.plot("wgan train disc cost", _disc_cost_test)
        lib.plot.plot("train class cost", class_cost_test)
        lib.plot.plot("generated class cost", gen_class_cost)
        lib.plot.plot("gen cost cost", _gen_cost_test)
        lib.plot.plot("gen accuracy", _genAccuracy)
        lib.plot.plot("real accuracy", _realAccuracy)

        if (iteration % 100 == 99 and iteration < 1000) or iteration % 1000 == 999:
            t = time.time()
            dev_disc_costs = []
            images, labels = next(dev_gen())
            (
                _dev_disc_cost,
                _dev_disc_cost_test,
                _class_cost_test,
                _gen_class_cost,
                _dev_gen_cost_test,
                _dev_genAccuracy,
                _dev_realAccuracy,
            ) = session.run(
                [
                    disc_cost,
                    disc_cost_test,
                    real_class_cost,
                    generated_class_cost,
                    gen_cost_test,
                    genAccuracy,
                    realAccuracy,
                ],
                feed_dict={
                    all_real_data_conv: images,
                    all_real_label_conv: labels,
                    generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES),
                },
            )
            dev_disc_costs.append(_dev_disc_cost)
            lib.plot.plot("dev disc cost", np.mean(dev_disc_costs))
            lib.plot.plot("wgan dev disc cost", _dev_disc_cost_test)
            lib.plot.plot("dev class cost", _class_cost_test)
            lib.plot.plot("dev generated class cost", _gen_class_cost)
            lib.plot.plot("dev gen  cost", _dev_gen_cost_test)
            lib.plot.plot("dev gen accuracy", _dev_genAccuracy)
            lib.plot.plot("dev real accuracy", _dev_realAccuracy)

        if iteration % 100 == 999:
            generate_image(iteration)
            # Can add generate_good_images method in here if desired

        if (iteration < 10) or (iteration % 100 == 99):
            lib.plot.flush()

        lib.plot.tick()

started ...
real_data:  Tensor("Reshape_6:0", shape=(64, 49152), dtype=float32, device=/device:GPU:0)
real_labels:  Tensor("Reshape_7:0", shape=(64, 2), dtype=int32, device=/device:GPU:0)
generated_labels:  Tensor("Reshape_8:0", shape=(64, 2), dtype=int32, device=/device:GPU:0)
sample_labels:  Tensor("Reshape_9:0", shape=(64, 2), dtype=int32, device=/device:GPU:0)
output 1:  Tensor("Generator.Input_1/BiasAdd:0", shape=(64, 8192), dtype=float32, device=/device:GPU:0)
output 2:  Tensor("Reshape_10:0", shape=(16, 2048, 4, 4), dtype=float32, device=/device:GPU:0) (16, 2048, 4, 4)


ValueError: Dimensions must be equal, but are 2048 and 512 for 'FusedBatchNormV3_1' (op: 'FusedBatchNormV3') with input shapes: [16,2048,4,4], [512], [512], [0], [0].