In [None]:
import tensorflow as tf
from tensorflow import keras
import glob
#import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
from IPython import display
from keras import backend #added to implement custom loss function
from keras.constraints import Constraint



In [None]:
#https://keras.io/examples/generative/dcgan_overriding_train_step/

img_height = img_width = 128

train_images = tf.keras.utils.image_dataset_from_directory(
  "./datasets/real_vs_fake/real-vs-fake/train/", # need to use real AND fake so discriminator is also trained
  label_mode=None,
  #seed=123,
  batch_size=32,
  image_size=(img_height, img_width))
train_images = train_images.map(lambda x: x / 255.0) #normalize images


for x in train_images:
    plt.axis("off")
    plt.imshow((x.numpy() * 255).astype("int32")[0])
    break
    
#for x in train_images:
#    x = tf.image.resize(
#        x,
#        (64, 64),
#        #method=ResizeMethod.BILINEAR,
#        preserve_aspect_ratio=False,
#        antialias=False,
#        name=None
#    )



In [None]:
#https://www.tensorflow.org/tutorials/keras/save_and_load

#checkpoint_path = "training_checkpoints/cp.ckpt"
#checkpoint_dir = os.path.dirname(checkpoint_path)

#cp_callback = tf.keras.callbacks.ModelCheckpoint(
#    filepath=checkpoint_path, 
#    verbose=1, 
#    save_weights_only=True,
#    save_freq='epoch') # length of one epoch


In [None]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(128, 128, 3)),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()





latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(16 * 16 * 256),
        layers.Reshape((16, 16, 256)),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
    ],
    name="generator",
)
generator.summary()




#generator = keras.Sequential(
#    [
#        keras.Input(shape=(latent_dim,)),
#        layers.Dense(8 * 8 * 128),
#        layers.Reshape((8, 8, 128)),
#        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
#        layers.LeakyReLU(alpha=0.2),
#        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
#        layers.LeakyReLU(alpha=0.2),
#        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
#        layers.LeakyReLU(alpha=0.2),
#        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
#    ],
#    name="generator",
#)
#generator.summary()


class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }
        checkpoint = tf.train.Checkpoint(generator_optimizer=self.g_optimizer,
            discriminator_optimizer=self.d_optimizer,
            generator=self.generator,
            discriminator=self.discriminator)


    
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

    
#saves images
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = keras.preprocessing.image.array_to_img(generated_images[i])
            img.save("generated_img_%03d_%d.png" % (epoch, i))



            
            
epochs = 200

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)

# gan.save_weights(checkpoint_path.format(epoch=0)) #restore point thing


gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0002), #try impro
    loss_fn=keras.losses.BinaryCrossentropy(),
)

# restore mose recent save
#print(os.listdir(checkpoint_dir))
#latest = tf.train.latest_checkpoint(checkpoint_dir)
#print(latest) # print out latest checkpoint to double check it loaded proper
#gan.load_weights(latest)


gan.fit(
    train_images, 
    epochs=epochs, 
    callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)] # added cp_callback
)