In [1]:
import os, time, math, glob
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

In [2]:
# Constants definition
IMAGE_SIZE = (64, 64)
N_CHANNELS = 3
BATCH_SIZE = 4
NUM_EXAMPLES = 1
NORM_FACTOR = 255. / 2
EPOCHS = 20
LATENT_DIMS = 100

In [3]:
# Load dataset
dataset_path = "dataset"

train_ds = keras.utils.image_dataset_from_directory(
    directory=dataset_path, 
    image_size=IMAGE_SIZE,
    seed=101,
    shuffle=True,
    batch_size=BATCH_SIZE
)

train_ds

Found 2872 files belonging to 1 classes.


<BatchDataset element_spec=(TensorSpec(shape=(None, 64, 64, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>

In [5]:
# Visualize sample images
# plt.figure(figsize=(8, 8), dpi=100)
# for (x, y) in train_ds.take(1):
#     for i in range(6):
#         plt.subplot(3, 3, i + 1)
#         plt.imshow(x[i].numpy().astype('uint8'))
#         plt.title(f'label: {y[i].numpy()} \n shape: {x[i].numpy().shape}')
#         plt.axis('off')

In [6]:
# Preprocess images
def preprocess(data, label):
    data = tf.image.resize(data, IMAGE_SIZE)
    return ((data - NORM_FACTOR) / NORM_FACTOR, label)

def deprocess(data):
    return data * NORM_FACTOR + NORM_FACTOR

train_ds = train_ds.map(preprocess).prefetch(1).cache()

In [7]:
# Visualize preprocessed images
# plt.figure(figsize=(6, 6), dpi=100)

# for (x, y) in train_ds.take(1):
#     for i in range(0, 6):
#         img = x[i]
#         label = y[i]
#         plt.subplot(3, 3, i + 1)
#         plt.imshow(deprocess(img.numpy()).astype(np.uint8))
#         plt.title(f'label: {label}')
#         plt.axis('off')

In [8]:
# GAN model
class GAN():
    def __init__(self, latent_dims, input_shape, batch_size, training_size=None):
        self.latent_dims = latent_dims
        self.image_shape = input_shape
        self.batch_size = batch_size
        self.training_size = training_size
        self.checkpoint_dir = 'training-checkpoints/abstract-art'
        self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt")
        
        self.cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
        
        self.generator_optimizer = keras.optimizers.Adam(3e-4)
        self.discriminator_optimizer = keras.optimizers.Adam(3e-4)
        
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
    
    def build_generator(self):
        model = keras.Sequential([
            keras.layers.Dense(16 * 16 * 1024, input_shape=(self.latent_dims, )),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Reshape((16, 16, 1024)),
            keras.layers.Conv2DTranspose(1024, (4, 4), strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(512, (4, 4), strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(512, (4, 4), strides=1, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(256, (4, 4), strides=1, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(3, (4, 4), strides=1, padding='same', activation='tanh')
        ])
        return model
    
    def build_discriminator(self):
        model = keras.Sequential([
            keras.layers.Conv2D(256, (4, 4), strides=1, padding='same', input_shape=self.image_shape),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.5),
            keras.layers.Conv2D(512, (4, 4), strides=1, padding='same'),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),
            keras.layers.Conv2D(512, (4, 4), strides=2, padding='same'),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),
            keras.layers.Conv2D(1024, (4, 4), strides=2, padding='same'),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),
            keras.layers.Flatten(),
            keras.layers.Dense(1)
        ])
        return model
    
    def generator_loss(self, fakes):
        return self.cross_entropy(tf.ones_like(fakes), fakes)
    
    def discriminator_loss(self, reals, fakes):
        real_loss = self.cross_entropy(tf.ones_like(reals), reals)
        fake_loss = self.cross_entropy(tf.zeros_like(fakes), fakes)
        return real_loss + fake_loss
    
    @tf.function
    def train_step(self, images):
        noise = tf.random.normal(shape=(self.batch_size, self.latent_dims))
        
        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
            generated_images = self.generator(noise, training=True)
            
            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)
            
            generator_loss = self.generator_loss(fakes=fake_output)
            discriminator_loss = self.discriminator_loss(reals=real_output, fakes=fake_output)
            
        generator_gradients = generator_tape.gradient(generator_loss, self.generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        
        self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
    
    def get_checkpoint_callback(self):
        checkpoint = tf.train.Checkpoint(
            generator_optimizer=self.generator_optimizer,
            discriminator_optimizer=self.discriminator_optimizer,
            generator=self.generator,
            discriminator=self.discriminator
        )
        return checkpoint
    
    def train(self, dataset, epochs, seed, load_from_checkpoint=False):
        start_training = time.time()
        for epoch in range(1, epochs + 1):
            start = time.time()
            for batch_idx, (image_batch, _) in enumerate(dataset):
                self.train_step(image_batch)
                self.generate_and_save(epoch, seed)
                if self.training_size is not None:
                    print(f'Batch: {batch_idx+1}/{math.ceil(self.training_size/self.batch_size)}')
                else:
                    print(f'Batch: {batch_idx+1}')
            end = time.time()
            print(f'Epoch: {epoch}/{epochs} | Time: {math.ceil(end - start)} seconds')
        end_training = time.time()  
        print(f'Training duration: {math.ceil(end_training - start_training)} seconds')
        
        self.generate_and_save(epochs, seed)
        return True
    
    def generate_and_save(self, epoch, test):
        generated_images = self.generator(test, training=False)
        plt.figure(figsize=(4, 4), dpi=200)
        for i in range(generated_images.shape[0]):
            plt.imshow(deprocess(generated_images[i].numpy()).astype(np.uint8))
            plt.axis('off')
        plt.savefig(f'generated-images/abstract-art/image_at_epoch__{epoch}.png')
        plt.close()

In [9]:
# Create a directory for generated images
if not os.path.isdir('generated-images/abstract-art'):
    os.mkdir('generated-images/abstract-art')

In [10]:
# Random starting seed
seed = tf.random.normal([NUM_EXAMPLES, LATENT_DIMS])

In [11]:
# Initialize a GAN model
gan = GAN(
    latent_dims=LATENT_DIMS, 
    input_shape=IMAGE_SIZE + (N_CHANNELS, ), 
    batch_size=BATCH_SIZE
)

In [12]:
gan.generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 262144)            26476544  
                                                                 
 batch_normalization (BatchN  (None, 262144)           1048576   
 ormalization)                                                   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 262144)            0         
                                                                 
 reshape (Reshape)           (None, 16, 16, 1024)      0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 32, 32, 1024)     16778240  
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 32, 32, 1024)     4

In [None]:
keras.utils.plot_model(model=gan.generator, to_file='abstract-art-gan-generator.png', show_shapes=True)

In [13]:
gan.discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 64, 64, 256)       12544     
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 64, 64, 256)       0         
                                                                 
 dropout (Dropout)           (None, 64, 64, 256)       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 64, 64, 512)       2097664   
                                                                 
 leaky_re_lu_6 (LeakyReLU)   (None, 64, 64, 512)       0         
                                                                 
 dropout_1 (Dropout)         (None, 64, 64, 512)       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 32, 32, 512)      

In [None]:
keras.utils.plot_model(model=gan.discriminator, to_file='abstract-art-gan-discriminator.png', show_shapes=True)

In [14]:
# Train gan
gan.train(dataset=train_ds, epochs=EPOCHS, seed=seed)

Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Batch: 11
Batch: 12
Batch: 13
Batch: 14
Batch: 15
Batch: 16
Batch: 17
Batch: 18
Batch: 19
Batch: 20
Batch: 21
Batch: 22
Batch: 23
Batch: 24
Batch: 25


KeyboardInterrupt: 

<Figure size 800x800 with 0 Axes>