Modern Art: ARTGAN

In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import tensorflow.keras as keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Reshape, Conv2D, Conv2DTranspose, Flatten, Dropout
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing import image

In [2]:
latent_dim = 32
height = 32
width = 32
channels = 3

In [3]:
# generator
generator_input = Input(shape=(latent_dim,))

x = Dense(128 * 16 * 16)(generator_input)
x = LeakyReLU()(x)
x = Reshape((16, 16, 128))(x)

x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = LeakyReLU()(x)

x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)
x = Conv2D(256, 5, padding='same')(x)
x = LeakyReLU()(x)

x = Conv2D(channels, 7, activation='tanh', padding='same')(x)

generator = Model(inputs=generator_input, outputs=x)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [4]:
# discriminator
disc_input = Input(shape=(height, width, channels))
x = Conv2D(128, 3)(disc_input)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Conv2D(128, 4, strides=2)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x)

discriminator = Model(inputs=disc_input, outputs=x)

disc_opt = keras.optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=disc_opt, loss='binary_crossentropy')

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [5]:
# adversarial
discriminator.trainable = False

gan_input = Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))

gan = Model(inputs=gan_input, outputs=gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

In [6]:
# Load the dataset
(X_train, y_train), (_, _) = cifar10.load_data()
X_train = X_train[y_train.flatten() == 6]  
X_train = X_train.astype('float32') / 255

In [None]:
iterations = 100
batch_size = 20
save_dir = './Star_date'

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)

start = 0
for step in range(iterations):
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
    
    generated_images = generator.predict(random_latent_vectors)
    stop = start + batch_size
    real_images = X_train[start: stop]
    combined_images = np.concatenate([generated_images, real_images])
    labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
    labels += 0.05 * np.random.random(labels.shape)
    
    # train discriminator
    d_loss = discriminator.train_on_batch(combined_images, labels)
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
    
    # train generator
    misleading_targets = np.zeros((batch_size, 1))  
    a_loss = gan.train_on_batch(random_latent_vectors,
                               misleading_targets)  
    start += batch_size
    if start > len(X_train) - batch_size:
        start = 0
        
    if step % 100 == 0:
        gan.save_weights('gan.1')
        
        print('discriminator loss:', d_loss)
        print('adversarial loss:', a_loss)
        
        img = image.array_to_img(generated_images[0] * 255, scale=False)
        img.save(os.path.join(save_dir, 'generated_image' + str(step) + '.png'))
        
        img = image.array_to_img(real_images[0] * 255, scale=False)
        img.save(os.path.join(save_dir, 'real_image' + str(step) + '.png'))

discriminator loss: 0.70526505
adversarial loss: 0.7005836


In [None]:
figsize=(16, 6)
plt.imshow(plt.imread('./Star_date/generated_image0.png'))

plt.show()
plt.imshow(plt.imread('./Star_date/real_image0.png'))

In [None]:
img = image.array_to_img(generated_images[1] * 255, scale=False)
img.save(os.path.join(save_dir, 'generated_image' + str(step) + '.png'))
        
img = image.array_to_img(real_images[1] * 255, scale=False)
img.save(os.path.join(save_dir, 'real_image' + str(step) + '.png'))

In [None]:
figsize=(16, 6)
plt.imshow(plt.imread('./Star_date/generated_image99.png'))

plt.show()
plt.imshow(plt.imread('./Star_date/real_image99.png'))