In [2]:
import numpy as np
import pandas as pd
import tensorflow as tf

from matplotlib import pyplot as plt

from tensorflow import keras
from keras import layers
from keras.models import Sequential, Model
from keras.initializers import RandomNormal
from keras.optimizers import Adam

from tensorflow.keras.initializers import RandomNormal
import tensorflow.keras.backend as K

from wasserstein_loss import wasserstein_generator_loss, wasserstein_discriminator_loss
from image_wgan_gp import imageWGANGP
from GANmonitor import GANMonitor

In [3]:
import warnings
warnings.filterwarnings('ignore')

# Building the generator and discriminator

In [5]:
def upsample_block(x, filters, size, strides, upsize, apply_dropout=True):
    initializer = tf.random_normal_initializer(0., 0.02)

    x = layers.UpSampling2D(upsize)(x)
    x = layers.Conv2D(
        filters, size, strides=strides, padding="same", use_bias=False
    )(x)
    x = layers.BatchNormalization()(x)
    if apply_dropout:
        x = layers.Dropout(0.3)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

def convolution_block(x, filters, size, strides, apply_batchnorm=False, apply_layernorm=True, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    x = layers.Conv2D(filters, size, strides=strides, padding='same',
                                    kernel_initializer=initializer, use_bias=False)(x)
    if apply_batchnorm:
        x = layers.BatchNormalization()(x)
    if apply_layernorm:
        x = layers.LayerNormalization()(x)
    if apply_dropout:
        x = layers.Dropout(0.3)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

In [6]:
def build_generator(latent_dim=128, num_classes=10, img_height=32, img_width=32, img_channels=3):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    in_channels = latent_dim+num_classes
    inputs = layers.Input((in_channels,))
    
    x = inputs
    x = layers.Dense(4*4*in_channels)(x)
    x = layers.Reshape((4,4,in_channels))(x)
    
    x = upsample_block(x, filters = 256, size = 3, strides = 1, upsize = (2,2))
    x = upsample_block(x, filters = 128, size = 3, strides = 1, upsize = (2,2))
    x = upsample_block(x, filters = 64, size = 3, strides = 1, upsize = (2,2))
    
    last = tf.keras.layers.Conv2D(img_channels, 7,
                                 padding='same',
                                 kernel_initializer=initializer,
                                 activation='tanh')
        
    x = last(x)
    
    return Model(inputs=inputs, outputs = x, name="generator")

def build_discriminator(num_classes=10, img_height=32, img_width=32, img_channels=3):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = layers.Input((img_height,img_width,img_channels))
    label = layers.Input((img_height,img_width,num_classes))
    
    x = layers.concatenate([inp, label])
    
    x = convolution_block(x, filters=64, size = 5, strides = 2)
    x = convolution_block(x, filters=128, size = 5, strides = 2)
    x = convolution_block(x, filters=256, size = 5, strides = 2)
    x = convolution_block(x, filters=512, size = 5, strides = 2)
    
    flatten = layers.Flatten()
    last = layers.Dense(1)
    
    x = flatten(x)
    x = layers.Dropout(0.2)(x)
    x = last(x)
    
    return Model(inputs=[inp,label], outputs=x, name="discriminator")

# Reading and preparing the data

In [5]:
latent_dim = 128
num_classes = 10
batch_size = 32
img_height = 32
img_width = 32
img_channels = 3

In [6]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

x_train = (x_train.astype("float32") - 127.5) / 127.5
x_train = np.reshape(x_train, (-1, img_height, img_width, img_channels))
y_train = keras.utils.to_categorical(y_train, num_classes)

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

print(f"Shape of training images: {x_train.shape}")
print(f"Shape of training labels: {y_train.shape}")

Shape of training images: (50000, 32, 32, 3)
Shape of training labels: (50000, 10)


# Building and training the GAN model

In [7]:
generator = build_generator()
generator.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 138)]             0         
                                                                 
 dense (Dense)               (None, 2208)              306912    
                                                                 
 reshape (Reshape)           (None, 4, 4, 138)         0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 8, 8, 138)        0         
 )                                                               
                                                                 
 conv2d (Conv2D)             (None, 8, 8, 256)         317952    
                                                                 
 batch_normalization (BatchN  (None, 8, 8, 256)        1024      
 ormalization)                                           

In [7]:
discriminator = build_discriminator()
discriminator.summary()

Model: "discriminator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 32, 32, 10)  0           []                               
                                ]                                                                 
                                                                                                  
 concatenate (Concatenate)      (None, 32, 32, 13)   0           ['input_1[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                      

In [9]:
cifar10gan = imageWGANGP(generator,discriminator,latent_dim = latent_dim, num_classes = num_classes,
                        img_height = img_height, img_width = img_width, img_channels = img_channels)
cifar10gan.compile(
    discriminator_optimizer=Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),
    generator_optimizer=Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),
    discriminator_loss = wasserstein_discriminator_loss,
    generator_loss = wasserstein_generator_loss
)

In [10]:
cbk = GANMonitor(num_img = num_classes,latent_dim = latent_dim,num_classes = num_classes,
                img_height = img_height, img_width = img_width, img_channels = img_channels, name = "cifarep2")
checkpoint_path = "checkpoints_cifar/cifarep2_{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True)

In [11]:
epochs = 50
# comment out if testing on already trained weights, takes lots of time
cifar10gan.fit(dataset,epochs=epochs,batch_size=batch_size,callbacks=[cbk,cp_callback])

In [12]:
cifar10gan.load_weights("checkpoints_cifar/cifarep2_0050.ckpt")

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x1273b1c2bf0>

In [13]:
# create a demo cifar-10 image
random_latent_vectors = tf.random.normal(shape=(100, latent_dim))
arr = np.zeros((100,num_classes))
for i in range(10):
    for j in range(10):
        arr[i*10+j][i] = 1
random_latent_vectors = tf.concat([random_latent_vectors,arr], axis = 1)
generated_images = cifar10gan.generator(random_latent_vectors)
generated_images = (generated_images * 127.5) + 127.5

img = []
for i in range(10):
    img.append([])
    for j in range(10):
        numpy_img = generated_images[i*10+j].numpy()
        img[i].append(numpy_img)
img = np.array(img)
img = np.hstack(np.hstack(img))
print(img.shape)
img = img.reshape((img_height*10, img_width*10, img_channels))
img = keras.preprocessing.image.array_to_img(img)
img.save(f"cifar10_demo.png")

(320, 320, 3)


## Generating CIFAR images

In [14]:
img_list = []
label_list = []
for label in range(10):
    for _ in range(100):
        random_latent_vectors = tf.random.normal(shape=(50, latent_dim))
        labels = np.full((50), label)
        labels = keras.utils.to_categorical(labels,num_classes)
        random_latent_vectors = tf.concat([random_latent_vectors,labels], axis = 1)
        generated_imgs = cifar10gan.generator(random_latent_vectors)
        img_list.append(generated_imgs)
        label_list.append(labels)
generated_images = tf.reshape(tf.stack(img_list), [50000, 32, 32, 3])
generated_labels = tf.reshape(tf.stack(label_list), [50000, 10])

In [15]:
generated_labels = np.argmax(generated_labels, axis = -1)

## Classification test

In [16]:
def make_classifier():
    classifier = Sequential([
      layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
      layers.MaxPooling2D((2, 2)),
      layers.Conv2D(32, (3, 3), activation='relu'),
      layers.MaxPooling2D((2, 2)),
      layers.Conv2D(64, (3, 3), activation='relu'),
      layers.Flatten(),
      layers.Dense(64, activation='relu'),
      layers.Dense(10)
    ])
    classifier.compile(
        optimizer=Adam(0.001),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=["accuracy"],
    )
    return classifier

In [17]:
classifier1 = make_classifier()
classifier2 = make_classifier()

In [18]:
y_train = np.argmax(y_train, axis=-1)

In [19]:
# gotta normalize x_test
x_test_norm = (x_test.astype("float32") - 127.5) / 127.5

In [25]:
# training on real data
classifier1.fit(x_train, y_train, epochs = 10, batch_size = batch_size, validation_data=(x_test_norm, y_test))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x12937e6e1d0>

In [26]:
# training on generated data
classifier2.fit(generated_images, generated_labels, epochs = 10, batch_size = batch_size, validation_data=(x_test_norm, y_test))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x12938937910>