In [34]:
""" Creates batches of images to feed into the training network conditioned by genre, uses upsampling when creating batches to account for uneven distributions """



' Creates batches of images to feed into the training network conditioned by genre, uses upsampling when creating batches to account for uneven distributions '

In [35]:
# pip install tensorflow-gpu==1.2.1

In [36]:
import numpy as np
import imageio
import time
import random
import os
from pathlib import Path
from PIL import Image

import sys

In [37]:
# Set the dimension of images you want to be passed in to the network
DIM = 128

In [38]:
# Set your own path to images
# src_img_path  = Path('/home/ec2-user/SageMaker/genre-128')
src_img_path  = Path('/home/ec2-user/SageMaker/portrait_landscape')

In [39]:
os.listdir(src_img_path)

['1', '.ipynb_checkpoints', '0']

In [40]:
len(os.listdir(os.path.join(src_img_path,'1')))

14971

In [41]:
# This dictionary should be updated to hold the absolute number of images associated with each genre used during training
styles = {
    "portraits": 14981,
    "landscapes": 14971
}

styleNum = {
    "portraits": 0,
    "landscapes": 1
}
    
curPos = {
    "portraits": 0,
    "landscapes": 0
}

In [42]:
testNums = {}
trainNums = {}

In [43]:
# Generate test set of images made up of 1/20 of the images (per genre)
for k, v in styles.items():
    # put a twentieth of paintings in here
    nums = range(v)
    random.shuffle(list(nums))
    testNums[k] = nums[0 : v // 20]
    trainNums[k] = nums[v // 20 :]

In [44]:
trainNums

{'portraits': range(749, 14981), 'landscapes': range(748, 14971)}

In [45]:
def inf_gen(gen):
    while True:
        for (images, labels) in gen():
            yield images, labels

In [46]:
def make_generator(files, batch_size, n_classes): # add genre parameter
    if batch_size % n_classes != 0:
        raise ValueError(
            "Batch size {} must be divisible by num classes {}".format(batch_size, n_classes)
        )

    class_batch = batch_size // n_classes

    generators = []

    def get_epoch():

        while True:

            images = np.zeros((batch_size, 3, DIM, DIM), dtype="int32")
            labels = np.zeros((batch_size, n_classes))
            n = 0
            for style in styles:
#             for style in genre:
                styleLabel = styleNum[style]
                curr = curPos[style]
                for _ in range(class_batch):
                    if curr == styles[style]:
                        curr = 0
                        random.shuffle(list(files[style]))
                    img_path = str(Path(src_img_path, str(styleLabel), str(curr) + ".png"))
                    image = Image.open(img_path).convert(mode="RGB")
                    image = np.asarray(image)

                    images[n % batch_size] = image.transpose(2, 0, 1)
                    labels[n % batch_size, int(styleLabel)] = 1
                    n += 1
                    curr += 1
                curPos[style] = curr

            # randomize things but keep relationship between a conditioning vector and its associated image
            rng_state = np.random.get_state()
            np.random.shuffle(images)
            np.random.set_state(rng_state)
            np.random.shuffle(labels)
            yield (images, labels)

    return get_epoch

In [47]:
def load(batch_size):
    return (
        make_generator(trainNums, batch_size, len(styles)),
        make_generator(testNums, batch_size, len(styles)),
    )

In [48]:
import os, sys
from pathlib import Path

sys.path.append(os.getcwd())

from random import randint

import time
import functools
import math

import numpy as np
import tensorflow as tf

import tflib as lib
import tflib.ops.linear
import tflib.ops.conv2d
import tflib.ops.batchnorm
import tflib.ops.deconv2d
import tflib.save_images
# import tflib.wikiart_genre
import tflib.ops.layernorm
import tflib.plot

In [49]:
# pip install --upgrade tensorflow-gpu==1.2.1

In [50]:
# pip install --upgrade tensorflow-gpu

In [51]:
print(tf.test.gpu_device_name())
print(tf.__version__)


1.15.5


In [52]:
tf.test.is_gpu_available()

False

In [53]:
MODE = "acwgan"  # dcgan, wgan, wgan-gp, lsgan
genre = ['portraits','landscapes']
DIM = 128  # Model dimensionality
CRITIC_ITERS = 5  # How many iterations to train the critic for, increase it to 50 later
N_GPUS = 1  # Number of GPUs
BATCH_SIZE = 84  # Batch size. Must be a multiple of CLASSES and N_GPUS
ITERS = 200000  # How many iterations to train for
LAMBDA = 10  # Gradient penalty lambda hyperparameter
OUTPUT_DIM = DIM * DIM * 3  # Number of pixels in each image
CLASSES = len(genre)  # Number of classes, for genres probably 14
PREITERATIONS = 2000  # Number of preiteration training cycles to run
lib.print_model_settings(locals().copy())

Uppercase local vars:
	BATCH_SIZE: 84
	CLASSES: 2
	CRITIC_ITERS: 5
	DEVICES: ['/gpu:0']
	DIM: 128
	ITERS: 200000
	LAMBDA: 10
	MODE: acwgan
	N_GPUS: 1
	OUTPUT_DIM: 49152
	PREITERATIONS: 2000


In [54]:
# Ensure that directory exists where ground truth and plots will be saved to.
Path('generated').mkdir(parents=True, exist_ok=True)
Path('models').mkdir(parents=True, exist_ok=True)

In [55]:
def GeneratorAndDiscriminator():
    return kACGANGenerator, kACGANDiscriminator

In [56]:
DEVICES = ["/gpu:{}".format(i) for i in range(N_GPUS)]

In [57]:
def LeakyReLU(x, alpha=0.2):
    return tf.maximum(alpha * x, x)


def ReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name + ".Linear", n_in, n_out, inputs, initialization="he"
    )
    return tf.nn.relu(output)


def LeakyReLULayer(name, n_in, n_out, inputs):
    output = lib.ops.linear.Linear(
        name + ".Linear", n_in, n_out, inputs, initialization="he"
    )
    return LeakyReLU(output)

In [58]:
def Batchnorm(name, axes, inputs):

    if ("Discriminator" in name) and (MODE == "wgan-gp" or MODE == "acwgan"):
        if axes != [0, 2, 3]:
            raise Exception("Layernorm over non-standard axes is unsupported")
        return lib.ops.layernorm.Layernorm(name, [1, 2, 3], inputs)
    else:
        return lib.ops.batchnorm.Batchnorm(name, axes, inputs, fused=True)

In [59]:
def pixcnn_gated_nonlinearity(name, output_dim, a, b, c=None, d=None):
    if c is not None and d is not None:
        a = a + c
        b = b + d
    result = tf.sigmoid(a) * tf.tanh(b)
    return result

In [60]:
def SubpixelConv2D(*args, **kwargs):
    kwargs["output_dim"] = 4 * kwargs["output_dim"]
    output = lib.ops.conv2d.Conv2D(*args, **kwargs)
    output = tf.transpose(output, [0, 2, 3, 1])
    output = tf.depth_to_space(output, 2)
    output = tf.transpose(output, [0, 3, 1, 2])
    return output

In [61]:
def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, resample=None, he_init=True):
    """
    resample: None, 'down', or 'up'
    """
    if resample=='down':
        conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, stride=2)
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2, output_dim=output_dim//2, stride=2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim//2, output_dim=output_dim)
    elif resample=='up':
        conv_shortcut = SubpixelConv2D
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.deconv2d.Deconv2D, input_dim=input_dim//2, output_dim=output_dim//2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim//2, output_dim=output_dim)
    elif resample==None:
        conv_shortcut = lib.ops.conv2d.Conv2D
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim,  output_dim=input_dim//2)
        conv_1b       = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2,  output_dim=output_dim//2)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim//2, output_dim=output_dim)

    else:
        raise Exception('invalid resample value')

    if output_dim==input_dim and resample==None:
        shortcut = inputs # Identity skip-connection
    else:
        shortcut = conv_shortcut(name+'.Shortcut', input_dim=input_dim, output_dim=output_dim, filter_size=1,
                                 he_init=False, biases=True, inputs=inputs)

    output = inputs
    output = tf.nn.relu(output)
    output = conv_1(name+'.Conv1', filter_size=1, inputs=output, he_init=he_init, weightnorm=False)
    output = tf.nn.relu(output)
    output = conv_1b(name+'.Conv1B', filter_size=filter_size, inputs=output, he_init=he_init, weightnorm=False)
    output = tf.nn.relu(output)
    output = conv_2(name+'.Conv2', filter_size=1, inputs=output, he_init=he_init, weightnorm=False, biases=False)
    output = Batchnorm(name+'.BN', [0,2,3], output)

    return shortcut + (0.3*output)

In [62]:
def kACGANGenerator(n_samples, numClasses, labels, noise=None, dim=DIM, bn=True, nonlinearity=tf.nn.relu, condition=None):
    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)
    if noise is None:
        noise = tf.random_normal([n_samples, 128])

    labels = tf.cast(labels, tf.float32)        
    noise = tf.concat([noise, labels], 1)

    dim//=2
# new
#######################
    print('#######################')
    output = lib.ops.linear.Linear('Generator.Input', 128+numClasses, 16*2*2*dim*2, noise) #probs need to recalculate dimensions
    print('Generator linear output 0: ', output )
    output = tf.reshape(output, [-1, 16*dim*2, 2, 2])
    print('Generator output reshape 0: ', output )
    if bn:
        output = Batchnorm('Generator.BN0', [0,2,3], output)
    condition = lib.ops.linear.Linear('Generator.cond0', numClasses, 16*2*2*dim*2, labels,biases=False)
    print('Generator condition 0: ', condition )
    condition = tf.reshape(condition, [-1, 16*dim*2, 2, 2])
    print('Generator condition reshape 0: ', condition )
    output = pixcnn_gated_nonlinearity('Generator.nl0', 32*dim, output[:,::2], output[:,1::2], condition[:,::2], condition[:,1::2])
    print('Generator output 0 final: ', output )
    print('#######################')
#######################
#    output = lib.ops.linear.Linear('Generator.Input', 128+numClasses, 8*4*4*dim*2, noise) #probs need to recalculate dimensions
#    output = tf.reshape(output, [-1, 8*dim*2, 4, 4])
    output = lib.ops.deconv2d.Deconv2D('Generator.1', 16*dim, 8*dim*2, 5, output)
    print('Generator output 1: ', output )
    if bn:
        output = Batchnorm('Generator.BN1', [0,2,3], output)
    condition = lib.ops.linear.Linear('Generator.cond1', numClasses, 8*4*4*dim*2, labels,biases=False)
    print('Generator condition 1: ', condition )
    condition = tf.reshape(condition, [-1, 8*dim*2, 4, 4])
    print('Generator condition reshape 1: ', condition )
    output = pixcnn_gated_nonlinearity('Generator.nl1', 16*dim, output[:,::2], output[:,1::2], condition[:,::2], condition[:,1::2])
    print('Generator output 1 final: ', output )
    print('#######################')
#######################

    output = lib.ops.deconv2d.Deconv2D('Generator.2', 8*dim, 4*dim*2, 5, output)
    print('Generator output 2: ', output )
    if bn:
        output = Batchnorm('Generator.BN2', [0,2,3], output)
    condition = lib.ops.linear.Linear('Generator.cond2', numClasses, 4*8*8*dim*2, labels)
    print('Generator condition 2: ', condition )
    condition = tf.reshape(condition, [-1, 4*dim*2, 8, 8])
    print('Generator condition 2 reshape: ', condition )
    output = pixcnn_gated_nonlinearity('Generator.nl2', 4*dim,output[:,::2], output[:,1::2], condition[:,::2], condition[:,1::2])
    print('Generator output 2 final: ', output )
    print('#######################')
#######################

    output = lib.ops.deconv2d.Deconv2D('Generator.3', 4*dim, 2*dim*2, 5, output)
    print('Generator output 3: ', output )
    if bn:
        output = Batchnorm('Generator.BN3', [0,2,3], output)
    condition = lib.ops.linear.Linear('Generator.cond3', numClasses, 2*16*16*dim*2, labels)
    print('Generator condition 3: ', condition )
    condition = tf.reshape(condition, [-1, 2*dim*2, 16, 16])
    print('Generator condition 3 reshape: ', condition )
    output = pixcnn_gated_nonlinearity('Generator.nl3', 2*dim,output[:,::2], output[:,1::2], condition[:,::2], condition[:,1::2])
    print('Generator output 3 final: ', output )
    print('#######################')
#######################

    output = lib.ops.deconv2d.Deconv2D('Generator.4', 2*dim, dim*2, 5, output)
    print('Generator output 4: ', output )
    if bn:
        output = Batchnorm('Generator.BN4', [0,2,3], output)
    condition = lib.ops.linear.Linear('Generator.cond4', numClasses, 32*32*dim*2, labels)
    print('Generator condition 4: ', condition )
    condition = tf.reshape(condition, [-1, dim*2, 32, 32])
    print('Generator condition 4 reshape: ', condition )
    output = pixcnn_gated_nonlinearity('Generator.nl4', dim, output[:,::2], output[:,1::2], condition[:,::2], condition[:,1::2])
    print('Generator output 4 final: ', output )
    print('#######################')
#######################
    output = lib.ops.deconv2d.Deconv2D('Generator.5', dim, 3, 5, output)
    print('Generator output 5: ', output )
    output = tf.tanh(output)
    print('Generator output 5 final: ', output )
    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()
    print('#######################')
#######################
    return tf.reshape(output, [-1, OUTPUT_DIM]), labels

In [63]:
def kACGANDiscriminator(inputs, numClasses, dim=DIM, bn=True, nonlinearity=LeakyReLU):
#     dim//=2
    output = tf.reshape(inputs, [-1, 3, dim, dim])
    lib.ops.conv2d.set_weights_stdev(0.02)
    lib.ops.deconv2d.set_weights_stdev(0.02)
    lib.ops.linear.set_weights_stdev(0.02)
#######################   
    print('#######################')
    output = lib.ops.conv2d.Conv2D('Discriminator.1', 3, dim, 5, output, stride=2)
    print('Discriminator output 1: ', output )
    output = nonlinearity(output)
    print('Discriminator output 1 final: ', output )
    print('#######################')
#######################
    output = lib.ops.conv2d.Conv2D('Discriminator.2', dim, 2*dim, 5, output, stride=2)
    print('Discriminator output 2: ', output )
    if bn:
        output = Batchnorm('Discriminator.BN2', [0,2,3], output)
    output = nonlinearity(output)
    print('Discriminator output 2 final: ', output )
    print('#######################')
#######################
    output = lib.ops.conv2d.Conv2D('Discriminator.3', 2*dim, 4*dim, 5, output, stride=2)
    print('Discriminator output 3: ', output )
    if bn:
        output = Batchnorm('Discriminator.BN3', [0,2,3], output)
    output = nonlinearity(output)
    print('Discriminator output 3 final: ', output )
    print('#######################')
#######################
    
    output = lib.ops.conv2d.Conv2D('Discriminator.4', 4*dim, 8*dim, 5, output, stride=2)
    print('Discriminator output 4: ', output )
    if bn:
        output = Batchnorm('Discriminator.BN4', [0,2,3], output)
    output = nonlinearity(output)
    print('Discriminator output 4 final: ', output )
    print('#######################')
#######################
    finalLayer = tf.reshape(output, [-1, 4*4*8*dim])
    print('Discriminator final layer: ', finalLayer )
    sourceOutput = lib.ops.linear.Linear('Discriminator.sourceOutput', 4*4*8*dim, 1, finalLayer)
    print('Discriminator source output: ', sourceOutput )
    classOutput = lib.ops.linear.Linear('Discriminator.classOutput', 4*4*8*dim, numClasses, finalLayer)
    print('Discriminator class output: ', classOutput )
    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()
    print('#######################')
#######################
    return (tf.reshape(sourceOutput, [-1]), tf.reshape(classOutput, [-1, numClasses]))



In [64]:
def genRandomLabels(n_samples, numClasses, condition=None):
    labels = np.zeros([BATCH_SIZE, CLASSES], dtype=np.float32)
    for i in range(n_samples):
        if condition is not None:
            labelNum = condition
        else:
            labelNum = randint(0, numClasses - 1)
        labels[i, labelNum] = 1
    return labels

In [72]:
Generator, Discriminator = GeneratorAndDiscriminator()

In [73]:
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session:

    all_real_data_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, 3, DIM, DIM])
    all_real_label_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE,CLASSES])

    generated_labels_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE,CLASSES])
    sample_labels_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE,CLASSES])

    if tf.__version__.startswith('1.'):
        split_real_data_conv = tf.split(all_real_data_conv, len(DEVICES))
        split_real_label_conv = tf.split(all_real_label_conv, len(DEVICES))
        split_generated_labels_conv = tf.split(generated_labels_conv, len(DEVICES))
        split_sample_labels_conv = tf.split(sample_labels_conv, len(DEVICES))
    else:
        split_real_data_conv = tf.split(0, len(DEVICES), all_real_data_conv)
        split_real_data_label = tf.split(0, len(DEVICES), all_real_data_conv)
        split_generated_labels = tf.split(0, len(DEVICES), generated_labels_conv)
        split_sample_labels = tf.split(0, len(DEVICES), sample_labels_conv)

    gen_costs, disc_costs = [],[]

    for device_index, (device, real_data_conv, real_label_conv) in enumerate(zip(DEVICES, split_real_data_conv, split_real_label_conv)):
        with tf.device(device):

            real_data = tf.reshape(2*((tf.cast(real_data_conv, tf.float32)/255.)-.5), [BATCH_SIZE//len(DEVICES), OUTPUT_DIM])
            real_labels = tf.reshape(real_label_conv, [BATCH_SIZE//len(DEVICES), CLASSES])
            print("Real data: ", real_data)
            print("Real labels: ", real_labels)
            generated_labels = tf.reshape(split_generated_labels_conv, [BATCH_SIZE//len(DEVICES), CLASSES])
            sample_labels = tf.reshape(split_sample_labels_conv, [BATCH_SIZE//len(DEVICES), CLASSES])
            print("Generated labels: ", generated_labels)
            print("Sample labels: ", sample_labels)

            fake_data, fake_labels = Generator(BATCH_SIZE//len(DEVICES), CLASSES, generated_labels)
            print("Fake data: {}, {}".format(fake_data,fake_labels))
            #set up discrimnator results
            print('Classes: ', CLASSES)

            disc_fake,disc_fake_class = Discriminator(fake_data, CLASSES)
            print("Fake dsicriminator: {}, {}".format(disc_fake,disc_fake_class))
            disc_real,disc_real_class = Discriminator(real_data, CLASSES)
            print("Real dsicriminator: {}, {}".format(disc_real,disc_real_class))

            print("Fake shapes: fake data {}, fake disc {}, fake disc class {}".format(fake_data.shape,disc_fake.shape,disc_fake_class.shape))
            print("Real shapes: real data {}, real disc {}, real disc class {}".format(real_data.shape,disc_real.shape,disc_real_class.shape))


            prediction = tf.argmax(disc_fake_class, 1)
            print('Prediction 1: ', prediction)
            correct_answer = tf.argmax(fake_labels, 1)
            print('Correct 1: ', correct_answer)
            equality = tf.equal(prediction, correct_answer)
            print('Equality 1: ', equality)
            genAccuracy = tf.reduce_mean(tf.cast(equality, tf.float32))

            prediction = tf.argmax(disc_real_class, 1)
            print('Prediction 2: ', prediction)
            correct_answer = tf.argmax(real_labels, 1)
            print('Correct 2: ', correct_answer)
            equality = tf.equal(prediction, correct_answer)
            print('Equality 2: ', equality)
            realAccuracy = tf.reduce_mean(tf.cast(equality, tf.float32))

            gen_cost = -tf.reduce_mean(disc_fake)
            disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            gen_cost_test = -tf.reduce_mean(disc_fake)
            disc_cost_test = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

            generated_class_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_fake_class,
                                                                                              labels=fake_labels))


            real_class_cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_real_class,
                                                                                              labels=real_labels))
            gen_cost += generated_class_cost
            disc_cost += real_class_cost

            alpha = tf.random_uniform(
                shape=[BATCH_SIZE//len(DEVICES),1],
                minval=0.,
                maxval=1.
            )
            differences = fake_data - real_data
            interpolates = real_data + (alpha*differences)
            gradients = tf.gradients(Discriminator(interpolates, CLASSES)[0], [interpolates])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
            gradient_penalty = tf.reduce_mean((slopes-1.)**2)
            disc_cost += LAMBDA*gradient_penalty

            real_class_cost_gradient = real_class_cost*50 + LAMBDA*gradient_penalty


            gen_costs.append(gen_cost)
            disc_costs.append(disc_cost)

    gen_cost = tf.add_n(gen_costs) / len(DEVICES)
    disc_cost = tf.add_n(disc_costs) / len(DEVICES)

    gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost,
                                                                                             var_list=lib.params_with_name('Generator'),
                                                                                             colocate_gradients_with_ops=True)
    disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost,
                                                                                              var_list=lib.params_with_name('Discriminator.'),
                                                                                              colocate_gradients_with_ops=True)
    class_train_op =  tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(real_class_cost_gradient,
                                                                                                var_list=lib.params_with_name('Discriminator.'),
                                                                                                colocate_gradients_with_ops=True)
    # For generating samples

    fixed_noise = tf.constant(np.random.normal(size=(BATCH_SIZE, 128)).astype('float32'))
    all_fixed_noise_samples = []
    for device_index, device in enumerate(DEVICES):
        n_samples = BATCH_SIZE // len(DEVICES)
        all_fixed_noise_samples.append(Generator(n_samples, CLASSES, sample_labels,noise=fixed_noise[device_index*n_samples:(device_index+1)*n_samples])[0])
        if tf.__version__.startswith('1.'):
            all_fixed_noise_samples = tf.concat(all_fixed_noise_samples, axis=0)
        else:
            all_fixed_noise_samples = tf.concat(0, all_fixed_noise_samples)


    def generate_image(iteration):
        for i in range(CLASSES):
            curLabel= genRandomLabels(BATCH_SIZE,CLASSES,condition=i)
            samples = session.run(all_fixed_noise_samples, feed_dict={sample_labels: curLabel})
            samples = ((samples+1.)*(255.99/2)).astype('int32')
            lib.save_images.save_images(samples.reshape((BATCH_SIZE, 3, DIM, DIM)), 'generated/samples_{}_{}.png'.format(str(i), iteration))
            lib.save_images.save_images_ind(samples.reshape((BATCH_SIZE, 3, DIM, DIM)), 'generated/ind/samples_{}_{}_'.format(str(i), iteration)+'{}.png')

    def generate_good_images(iteration,thresh=.95):
        NUM_TO_MAKE = BATCH_SIZE
        TRIES = BATCH_SIZE*5
        CONF_THRESH = thresh
        for i in range(CLASSES):
            l = 0
            curLabel= genRandomLabels(BATCH_SIZE,CLASSES,condition=i)
            j = 0
            images = None
            while(j<NUM_TO_MAKE and l<TRIES):
                genr = Generator(BATCH_SIZE, CLASSES, sample_labels)[0]
                samples = session.run(genr, feed_dict={sample_labels: curLabel})
                samples = np.reshape(samples,[-1, 3, DIM, DIM])
                samples = ((samples+1.)*(255.99/2)).astype('int32')
                prediction,accuracy = session.run([disc_real_class,realAccuracy] , feed_dict = {all_real_data_conv: samples, all_real_label_conv: curLabel})
                guess = np.argmax(prediction,1)
                my_equal = np.equal(guess,np.argmax(curLabel,1))
                for s,_ in enumerate(prediction):
                    prediction[s] = prediction[s]/np.sum(prediction[s])
                    confidence = np.amax(prediction,1)
                    for k,image in enumerate(samples):
                        if guess[k] == i and confidence[k]>CONF_THRESH and j < NUM_TO_MAKE:
                            if isinstance(images, np.ndarray):
                                images = np.concatenate((images,image),0)
                            else:
                                images = image
                        j+=1
                    l += 1
                CONF_THRESH = CONF_THRESH * .9
            try:
                samples = images
                lib.save_images.save_images(samples.reshape((-1, 3, DIM, DIM)), 'generated/good_samples_{}_{}.png'.format(str(i),iteration))
                lib.save_images.save_images_ind(samples.reshape((-1, 3, DIM, DIM)), 'generated/ind/good_samples_{}_{}_'.format(str(i), iteration)+'{}.png')
            except Exception as e:
                print(e)


    # Dataset iterator
    train_gen, dev_gen = lib.wikiartGenre.load(BATCH_SIZE)

    def softmax_cross_entropy(logit, y):
        return -tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y))

    def inf_train_gen():
        while True:
            for (images,labels) in train_gen():
                yield images,labels


    _sample_labels = genRandomLabels(BATCH_SIZE, CLASSES)
    # Save a batch of ground-truth samples
    _x,_y = next(train_gen())
    _x_r = session.run(real_data, feed_dict={all_real_data_conv: _x})
    _x_r = ((_x_r+1.)*(255.99/2)).astype('int32')
    lib.save_images.save_images(_x_r.reshape((BATCH_SIZE, 3, DIM, DIM)), 'generated/samples_groundtruth.png')



    session.run(tf.initialize_all_variables(), feed_dict={generated_labels_conv: genRandomLabels(BATCH_SIZE,CLASSES)})
    gen = train_gen()

    for iterp in range(PREITERATIONS):
        _data, _labels = next(gen)
        _ , accuracy = session.run([disc_train_op, realAccuracy],feed_dict = {all_real_data_conv: _data, all_real_label_conv: _labels, generated_labels_conv: genRandomLabels(BATCH_SIZE, CLASSES)})
        if iterp % 100 == 99:
            print('pretraining accuracy: ' + str(accuracy))


    for iteration in range(ITERS):
        start_time = time.time()
        # Train generator
        if iteration > 0:
            _ = session.run(gen_train_op, feed_dict={generated_labels_conv: genRandomLabels(BATCH_SIZE,CLASSES)})
        # Train critic
        disc_iters = CRITIC_ITERS
        for i in range(disc_iters):
            _data, _labels = next(gen)
            _disc_cost, _disc_cost_test, class_cost_test, gen_class_cost, _gen_cost_test, _genAccuracy, _realAccuracy, _ = session.run([disc_cost, disc_cost_test, real_class_cost, generated_class_cost, gen_cost_test, genAccuracy, realAccuracy, disc_train_op], feed_dict={all_real_data_conv: _data, all_real_label_conv: _labels, generated_labels_conv: genRandomLabels(BATCH_SIZE,CLASSES)})

        lib.plot.plot('train disc cost', _disc_cost)
        lib.plot.plot('time', time.time() - start_time)
        lib.plot.plot('wgan train disc cost', _disc_cost_test)
        lib.plot.plot('train class cost', class_cost_test)
        lib.plot.plot('generated class cost', gen_class_cost)
        lib.plot.plot('gen cost cost', _gen_cost_test)
        lib.plot.plot('gen accuracy', _genAccuracy)
        lib.plot.plot('real accuracy', _realAccuracy)

        if (iteration % 100 == 99 and iteration<1000) or iteration % 1000 == 999 :
            t = time.time()
            dev_disc_costs = []
            images, labels = next(dev_gen())
            _dev_disc_cost, _dev_disc_cost_test, _class_cost_test, _gen_class_cost, _dev_gen_cost_test, _dev_genAccuracy, _dev_realAccuracy = session.run([disc_cost, disc_cost_test, real_class_cost, generated_class_cost, gen_cost_test, genAccuracy, realAccuracy], feed_dict={all_real_data_conv: images, all_real_label_conv: labels, generated_labels_conv: genRandomLabels(BATCH_SIZE,CLASSES)})
            dev_disc_costs.append(_dev_disc_cost)
            lib.plot.plot('dev disc cost', np.mean(dev_disc_costs))
            lib.plot.plot('wgan dev disc cost', _dev_disc_cost_test)
            lib.plot.plot('dev class cost', _class_cost_test)
            lib.plot.plot('dev generated class cost', _gen_class_cost)
            lib.plot.plot('dev gen  cost', _dev_gen_cost_test)
            lib.plot.plot('dev gen accuracy', _dev_genAccuracy)
            lib.plot.plot('dev real accuracy', _dev_realAccuracy)


        if iteration % 1000 == 999:
            generate_image(iteration)
            generate_good_images(iteration)
            #Can add generate_good_images method in here if desired

        if (iteration < 10) or (iteration % 100 == 99):
            lib.plot.flush()

        lib.plot.tick()

Real data:  Tensor("Reshape_93:0", shape=(84, 49152), dtype=float32, device=/device:GPU:0)
Real labels:  Tensor("Reshape_94:0", shape=(84, 2), dtype=int32, device=/device:GPU:0)
Generated labels:  Tensor("Reshape_95:0", shape=(84, 2), dtype=int32, device=/device:GPU:0)
Sample labels:  Tensor("Reshape_96:0", shape=(84, 2), dtype=int32, device=/device:GPU:0)


TypeError: ResnetGenerator() takes from 1 to 2 positional arguments but 3 were given

In [75]:
def nonlinearity(x):
    return tf.nn.relu(x)

def Normalize(name, inputs):
    if ('Discriminator' in name) and NORMALIZATION_D:
        return lib.ops.layernorm.Layernorm(name,[1,2,3],inputs)
    elif ('Generator' in name) and NORMALIZATION_G:
        return lib.ops.batchnorm.Batchnorm(name,[0,2,3],inputs,fused=True)

def ConvMeanPool(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
    output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=he_init, biases=biases)
    output = tf.add_n([output[:,:,::2,::2], output[:,:,1::2,::2], output[:,:,::2,1::2], output[:,:,1::2,1::2]]) / 4.
    return output

def MeanPoolConv(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
    output = inputs
    output = tf.add_n([output[:,:,::2,::2], output[:,:,1::2,::2], output[:,:,::2,1::2], output[:,:,1::2,1::2]]) / 4.
    output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases)
    return output

def ScaledUpsampleConv(name, input_dim, output_dim, filter_size, inputs, he_init=True, biases=True):
    output = inputs
    output = lib.concat([output, output, output, output], axis=1)
    output = tf.transpose(output, [0,2,3,1])
    output = tf.depth_to_space(output, 2)
    output = tf.transpose(output, [0,3,1,2])
    output = lib.ops.conv2d.Conv2D(name, input_dim, output_dim, filter_size, output, he_init=he_init, biases=biases, gain=0.5)
    return output

def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, resample=None):
    """
    resample: None, 'down', or 'up'
    """
    if resample=='down':
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=output_dim, stride=2)
        # conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, stride=2)
        # conv_2        = functools.partial(ConvMeanPool, input_dim=input_dim, output_dim=output_dim)
        conv_shortcut = MeanPoolConv
    elif resample=='up':
        conv_1        = functools.partial(ScaledUpsampleConv, input_dim=input_dim, output_dim=output_dim)
        # conv_1        = functools.partial(lib.ops.deconv2d.Deconv2D, input_dim=input_dim, output_dim=output_dim)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim)
        conv_shortcut = ScaledUpsampleConv
    elif resample==None:
        conv_shortcut = lib.ops.conv2d.Conv2D
        conv_1        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim,  output_dim=output_dim)
        conv_2        = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim)
    else:
        raise Exception('invalid resample value')

    if output_dim==input_dim and resample==None:
        shortcut = inputs # Identity skip-connection
    else:
        shortcut = inputs
        # shortcut = Normalize(name+'.NShortcut', shortcut)
        shortcut = conv_shortcut(name+'.Shortcut', input_dim=input_dim, output_dim=output_dim, filter_size=1, he_init=False, biases=True, inputs=shortcut)

    output = inputs
    output = Normalize(name+'.N1', output)
    output = nonlinearity(output)
    output = conv_1(name+'.Conv1', filter_size=filter_size, inputs=output)
    output = Normalize(name+'.N2', output)
    output = nonlinearity(output)
    output = conv_2(name+'.Conv2', filter_size=filter_size, inputs=output)
    # output = Normalize(name+'.N3', output)
    # return output
    return shortcut + output
    # return 0.7*(shortcut+output)

def ResnetGenerator(n_samples, noise=None):
    if noise is None:
        noise = tf.random_normal([n_samples, 128])

    output = lib.ops.linear.Linear('Generator.Input', 128, 4*4*DIM_G_4, noise)
    output = tf.reshape(output, [-1, DIM_G_4, 4, 4])

    # output = ResidualBlock('Generator.4_1', DIM_G_4, DIM_G_4, 3, output, resample=None)
    # output = ResidualBlock('Generator.4_2', DIM_G_4, DIM_G_4, 3, output, resample=None)
    output = ResidualBlock('Generator.4_3', DIM_G_4, DIM_G_8, 3, output, resample='up')

    # output = ResidualBlock('Generator.8_1', DIM_G_8, DIM_G_8, 3, output, resample=None)
    # output = ResidualBlock('Generator.8_2', DIM_G_8, DIM_G_8, 3, output, resample=None)
    output = ResidualBlock('Generator.8_3', DIM_G_8, DIM_G_16, 3, output, resample='up')

    # output = ResidualBlock('Generator.16_1', DIM_G_16, DIM_G_16, 3, output, resample=None)
    # output = ResidualBlock('Generator.16_2', DIM_G_16, DIM_G_16, 3, output, resample=None)
    output = ResidualBlock('Generator.16_3', DIM_G_16, DIM_G_32, 3, output, resample='up')

    # output = ResidualBlock('Generator.32_1', DIM_G_32, DIM_G_32, 3, output, resample=None)
    # output = ResidualBlock('Generator.32_2', DIM_G_32, DIM_G_32, 3, output, resample=None)
    output = ResidualBlock('Generator.32_3', DIM_G_32, DIM_G_64, 3, output, resample='up')

    output = Normalize('Generator.OutputN', output)
    output = nonlinearity(output)
    output = ScaledUpsampleConv('Generator.Output', DIM_G_64, 3, 5, output, he_init=False)
    # output = lib.ops.deconv2d.Deconv2D('Generator.Output', DIM_G_64, 3, 5, output, he_init=False)

    output = tf.tanh(output)

    return tf.reshape(output, [-1, OUTPUT_DIM])

def ResnetDiscriminator(inputs):
    output = tf.reshape(inputs, [-1, 3, 128, 128])

    output = lib.ops.conv2d.Conv2D('Discriminator.Input', 3, DIM_D_64, 5, output, he_init=True, stride=2)

    # output = ResidualBlock('Discriminator.64_1', DIM_D_64, DIM_D_64, 3, output, resample=None)
    # output = ResidualBlock('Discriminator.64_2', DIM_D_64, DIM_D_64, 3, output, resample=None)
    output = ResidualBlock('Discriminator.64_3', DIM_D_64, DIM_D_32, 3, output, resample='down')

    # output = ResidualBlock('Discriminator.32_1', DIM_D_32, DIM_D_32, 3, output, resample=None)
    # output = ResidualBlock('Discriminator.32_2', DIM_D_32, DIM_D_32, 3, output, resample=None)
    output = ResidualBlock('Discriminator.32_3', DIM_D_32, DIM_D_16, 3, output, resample='down')

    # output = ResidualBlock('Discriminator.16_1', DIM_D_16, DIM_D_16, 3, output, resample=None)
    # output = ResidualBlock('Discriminator.16_2', DIM_D_16, DIM_D_16, 3, output, resample=None)
    output = ResidualBlock('Discriminator.16_3', DIM_D_16, DIM_D_8, 3, output, resample='down')

    output = ResidualBlock('Discriminator.8_1', DIM_D_8, DIM_D_8, 3, output, resample=None)
    output = ResidualBlock('Discriminator.8_2', DIM_D_8, DIM_D_8, 3, output, resample=None)
    # output = ResidualBlock('Discriminator.8_3', DIM_D_8, DIM_D_4, 3, output, resample='down')

    # output = ResidualBlock('Discriminator.4_1', DIM_D_4, DIM_D_4, 3, output, resample=None)
    # output = ResidualBlock('Discriminator.4_2', DIM_D_4, DIM_D_4, 3, output, resample=None)

    # output = Normalize('Discriminator.OutputN', output)
    # output = output / 10.
    output = tf.reduce_mean(output, axis=[2,3])
    output = lib.ops.linear.Linear('Discriminator.Output', DIM_D_8, 1, output)

    # output = Normalize('Discriminator.OutputN', output)
    # output = nonlinearity(output)
    # output = tf.reshape(output, [-1, 4*4*DIM_D_4])
    # output = lib.ops.linear.Linear('Discriminator.Output', 4*4*DIM_D_4, 1, output)

    return tf.reshape(output, [-1])

with tf.Session() as session:

    Generator, Discriminator = GeneratorAndDiscriminator()

    iteration = tf.placeholder(tf.int32, shape=None)
    all_real_data_conv = tf.placeholder(tf.int32, shape=[BATCH_SIZE, 3, 128, 128])

    if (len(DEVICES)%2==0) and (len(DEVICES)>=2):

        fake_data_splits = []
        for device in DEVICES:
            with tf.device(device):
                fake_data_splits.append(Generator(BATCH_SIZE/len(DEVICES)))
        # fake_data = tf.concat(fake_data_splits, axis=0)
        # fake_data_splits = tf.split(fake_data, len(DEVICES))

        all_real_data = tf.reshape(2*((tf.cast(all_real_data_conv, tf.float32)/255.)-.5), [BATCH_SIZE, OUTPUT_DIM])
        all_real_data_splits = tf.split(all_real_data, len(DEVICES)/2)

        DEVICES_B = DEVICES[:len(DEVICES)/2]
        DEVICES_A = DEVICES[len(DEVICES)/2:]

        disc_costs = []
        for i, device in enumerate(DEVICES_A):
            with tf.device(device):
                real_and_fake_data = lib.concat([all_real_data_splits[i]] + [fake_data_splits[i]] + [fake_data_splits[len(DEVICES_A)+i]], axis=0)
                disc_all = Discriminator(real_and_fake_data)
                disc_real = disc_all[:BATCH_SIZE/len(DEVICES_A)]
                disc_fake = disc_all[BATCH_SIZE/len(DEVICES_A):]
                disc_costs.append(tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real))

        for i, device in enumerate(DEVICES_B):
            with tf.device(device):
                real_data = tf.identity(all_real_data_splits[i]) # transfer from gpu0
                fake_data__ = lib.concat([fake_data_splits[i], fake_data_splits[len(DEVICES_A)+i]], axis=0)
                alpha = tf.random_uniform(
                    shape=[BATCH_SIZE/len(DEVICES_A),1], 
                    minval=0.,
                    maxval=1.
                )
                differences = fake_data__ - real_data
                interpolates = real_data + (alpha*differences)
                gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
                slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
                # print "WARNING NO LIPSCHITZ PENALTY"
                gradient_penalty = 10.*tf.reduce_mean((slopes-1.)**2)
                disc_costs.append(gradient_penalty)

        disc_cost = tf.add_n(disc_costs) / len(DEVICES_A)

        if DECAY:
            decay = tf.maximum(0., 1.-(tf.cast(iteration, tf.float32)/ITERS))
        else:
            decay = 1.
        disc_train_op = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=MOMENTUM_D, beta2=0.9).minimize(disc_cost, var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)

        gen_costs = []
        for device in DEVICES:
            with tf.device(device):
                gen_costs.append(-tf.reduce_mean(Discriminator(Generator(GEN_BS_MULTIPLE*BATCH_SIZE/len(DEVICES)))))
        gen_cost = tf.add_n(gen_costs) / len(DEVICES)
        gen_train_op = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=MOMENTUM_G, beta2=0.9).minimize(gen_cost, var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)


    else:
        raise Exception()
        # split_real_data_conv = lib.split(all_real_data_conv, len(DEVICES), axis=0)

        # gen_costs, disc_costs = [],[]

        # for device_index, (device, real_data_conv) in enumerate(zip(DEVICES, split_real_data_conv)):
        #     with tf.device(device):

        #         real_data = tf.reshape(2*((tf.cast(real_data_conv, tf.float32)/255.)-.5), [BATCH_SIZE/len(DEVICES), OUTPUT_DIM])
        #         fake_data = Generator(BATCH_SIZE/len(DEVICES))

        #         disc_all = Discriminator(lib.concat([real_data, fake_data],0))
        #         disc_real = disc_all[:tf.shape(real_data)[0]]
        #         disc_fake = disc_all[tf.shape(real_data)[0]:]

        #         gen_cost = -tf.reduce_mean(Discriminator(Generator(GEN_BS_MULTIPLE*BATCH_SIZE/len(DEVICES))))
        #         disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)

        #         alpha = tf.random_uniform(
        #             shape=[BATCH_SIZE/len(DEVICES),1], 
        #             minval=0.,
        #             maxval=1.
        #         )
        #         differences = fake_data - real_data
        #         interpolates = real_data + (alpha*differences)
        #         interpolates = tf.stop_gradient(interpolates)
        #         gradients = tf.gradients(Discriminator(interpolates), [interpolates])[0]
        #         slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        #         lipschitz_penalty = 100.*tf.reduce_mean((slopes-1.)**2)
        #         disc_cost += lipschitz_penalty

        #         gen_costs.append(gen_cost)
        #         disc_costs.append(disc_cost)

        # gen_cost = tf.add_n(gen_costs) / len(DEVICES)
        # disc_cost = tf.add_n(disc_costs) / len(DEVICES)

        # if DECAY:
        #     decay = tf.maximum(0., 1.-(tf.cast(iteration, tf.float32)/ITERS))
        # else:
        #     decay = 1.
        # gen_train_op = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=MOMENTUM_G, beta2=0.9).minimize(gen_cost, var_list=lib.params_with_name('Generator'), colocate_gradients_with_ops=True)
        # disc_train_op = tf.train.AdamOptimizer(learning_rate=LR*decay, beta1=MOMENTUM_D, beta2=0.9).minimize(disc_cost, var_list=lib.params_with_name('Discriminator.'), colocate_gradients_with_ops=True)


    frame_i = [0]
    fixed_noise = tf.constant(np.random.normal(size=(64, 128)).astype('float32'))
    fixed_noise_samples = Generator(64, noise=fixed_noise)
    def generate_image(frame):
        samples = session.run(fixed_noise_samples)
        samples = ((samples+1.)*(255.99/2)).astype('int32')
        lib.save_images.save_images(samples.reshape((64, 3, 128, 128)), 'samples_{}.png'.format(frame))

    if DATASET == 'imagenet':
        train_gen = lib.imagenet.load(BATCH_SIZE)

    def inf_train_gen():
        while True:
            for images, in train_gen():
                yield images

    session.run(tf.initialize_all_variables())

    generate_image(0)

    gen = inf_train_gen()

    saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)
    # Uncomment this to restore params
    # print "WARNING RESTORING PARAMS FROM CHECKPOINT"
    # saver.restore(session, os.getcwd()+"/params.ckpt")

    for _iteration in xrange(ITERS):
        start_time = time.time()

        for i in xrange(CRITIC_ITERS):
            _data = gen.next()
            _data = _data.reshape((BATCH_SIZE,3,128,128))
            _disc_cost, _ = session.run(
                [disc_cost, disc_train_op], 
                feed_dict={all_real_data_conv: _data, iteration: _iteration}#, fake_data: fake_data_buffer[np.random.choice(BUFFER_LEN*BATCH_SIZE, BATCH_SIZE)]}
            )

        _ = session.run(
            gen_train_op,
            feed_dict={iteration: _iteration}
        )

        lib.plot.plot('cost', _disc_cost)
        lib.plot.plot('time', time.time() - start_time)

        if _iteration % 100 == 0:
            generate_image(_iteration)

        if _iteration % 1000 == 0:
            saver.save(session, 'params.ckpt')

        if _iteration % 5 == 0:
            lib.plot.flush(print_stds=True)

        lib.plot.tick()

Exception: 