In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from tensorflow.keras import layers
import time
import matplotlib.pyplot as plt
import os


BATCH_SIZE = 32
noise_dim = 100
EPOCHS = 2000
num_examples_to_generate = 16
CHECKPOINT_DIR = './training_checkpoints'


datagen = ImageDataGenerator(rescale=1.0/255.0)

dataset = datagen.flow_from_directory('/content/drive/MyDrive/horse-or-human',
                                      target_size=(64, 64),
                                      batch_size=BATCH_SIZE,
                                      class_mode=None,
                                      shuffle=True)


def build_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(8*8*256, use_bias=False, input_shape=(noise_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Reshape((8, 8, 256)))
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model

generator = build_generator()

def build_discriminator():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[64, 64, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model

discriminator = build_discriminator()

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        print(gen_loss)
        print(disc_loss)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow((predictions[i] * 0.5 + 0.5).numpy())
        plt.axis('off')
    plt.savefig(f'image_at_epoch_{epoch:04d}.png')
    plt.close(fig)

seed = tf.random.normal([num_examples_to_generate, noise_dim])

images = []
for _ in range(len(dataset)):
    batch = next(dataset)
    images.append(batch)

images = np.concatenate(images, axis=0)
images = images.astype('float32')
images = (images - 0.5) / 0.5
train_dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(60000).batch(BATCH_SIZE)

checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(generator=generator,
                                 discriminator=discriminator,
                                 generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer)

def train(dataset, epochs, start_epoch=0):
    for epoch in range(start_epoch, epochs):
        start = time.time()
        for image_batch in dataset:
            train_step(image_batch)
        generate_and_save_images(generator, epoch + 1, seed)
        print(f'Time for epoch {epoch + 1} is {time.time() - start} sec')
        checkpoint.save(file_prefix=checkpoint_prefix)
    generate_and_save_images(generator, epochs, seed)

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)
    start_epoch = 0
else:
    start_epoch = 0

train(train_dataset, EPOCHS, start_epoch)

Found 1027 images belonging to 2 classes.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Tensor("binary_crossentropy/truediv:0", shape=(), dtype=float32)
Tensor("add:0", shape=(), dtype=float32)
Tensor("binary_crossentropy/truediv:0", shape=(), dtype=float32)
Tensor("add:0", shape=(), dtype=float32)
Tensor("binary_crossentropy/truediv:0", shape=(), dtype=float32)
Tensor("add:0", shape=(), dtype=float32)
Time for epoch 1 is 12.08403992652893 sec
Time for epoch 2 is 1.2366952896118164 sec
Time for epoch 3 is 1.238903284072876 sec
Time for epoch 4 is 1.2317898273468018 sec
Time for epoch 5 is 1.2240736484527588 sec
Time for epoch 6 is 1.6654422283172607 sec
Time for epoch 7 is 1.4312033653259277 sec
Time for epoch 8 is 1.2235941886901855 sec
Time for epoch 9 is 1.234706163406372 sec
Time for epoch 10 is 1.2278430461883545 sec
Time for epoch 11 is 1.5418109893798828 sec
Time for epoch 12 is 1.2356171607971191 sec
Time for epoch 13 is 1.2382020950317383 sec
Time for epoch 14 is 1.2508325576782227 sec
Time for epoch 15 is 1.3808927536010742 sec
Time for epoch 16 is 1.42547178268