In [2]:
# Tensorflow
import tensorflow as tf
# Image data and layers pacakges
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D, MaxPool2D, Rescaling
from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom
# Losses and Optimizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

# Base model class to subclass for training step
from tensorflow.keras.models import Model

# Load images from directory
from tensorflow.keras.utils import image_dataset_from_directory
# For callbacks and checkpoints
from tensorflow.keras.callbacks import Callback
from tensorflow.train import Checkpoint
# Extra
import numpy as np
import os

In [3]:
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [4]:
os.chdir("C:\\GANs\\zelda_gan\\code")
resize_path = "..\\resized_images"

In [5]:
batch_size = 32
img_height = 40
img_width = 30

In [6]:
def load_images(in_dir, img_height, img_width, batch_size, save_dir):
    norm = lambda x: x.astype("float32")/255
    dg = ImageDataGenerator(preprocessing_function = img_norm, zoom_range = 0.25, 
                                  horizontal_flip = True, rotation_range = 0.05)
    x_train = dg.flow_from_directory(in_dir,
                                            target_size = (img_height, img_width),
                                            batch_size = batch_size,
                                            shuffle = True,
                                            save_to_dir = save_dir,
                                            classes = None,
                                            class_mode = None,
                                            subset = "training")
    return x_train

In [7]:

img_norm = tf.keras.layers.Rescaling(1./255)
train_ds = image_dataset_from_directory(resize_path, 
                                       labels = None, 
                                       color_mode = 'grayscale', 
                                       image_size = (img_height , img_width),
                                       batch_size = batch_size, 
                                       shuffle = True)
train_ds = train_ds.map(img_norm)
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)

Found 2340 files belonging to 1 classes.


In [8]:
# Generative NN
def build_gen_nn():
    # Initialize Model
    model = Sequential()

    # Takes in random values and reshape it to 7x7x64
    model.add(Dense(20*15*128, input_dim = 64))
    model.add(LeakyReLU(.2))
    model.add(Reshape((20, 15, 128)))

    # Effectively doubles the size of previous layer
    model.add(UpSampling2D())
    model.add(Conv2D(128, 5, padding = "same"))
    model.add(LeakyReLU(.2))

    # Second Upsampling block
    model.add(UpSampling2D())
    model.add(Conv2D(256, 5, padding = "same"))
    model.add(LeakyReLU(.2))
    
        
    # Down Sample
    model.add(MaxPool2D())
    model.add(Conv2D(256, 5, padding = "same"))
    model.add(LeakyReLU(.2))
    
    # Normal Convolutional Block
    model.add(Conv2D(256, 4, padding = "same"))
    model.add(LeakyReLU(.2))
    
    # Normal Convolutional Block
    model.add(Conv2D(512, 4, padding = "same"))
    model.add(LeakyReLU(.2))
    
    # Second Convolutional Block
    model.add(Conv2D(512, 4, padding = "same"))
    model.add(LeakyReLU(.2))


    # Conv layer to get to one channel
    model.add(Conv2D(1, 4, padding = 'same', activation = "sigmoid"))
    
    return model

In [9]:
def build_disc_nn():
    model = Sequential()

    model.add(Conv2D(32, 5, input_shape = (40,30,1)))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))

    model.add(Conv2D(64, 5))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))

    model.add(Conv2D(128, 5))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))

    model.add(Conv2D(256, 5))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))
    
    model.add(Conv2D(512, 5))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))

    model.add(Conv2D(512, 5))
    model.add(LeakyReLU(.2))
    model.add(Dropout(.2))

    # Flatten then pass to a dense layer
    model.add(Flatten())
    model.add(Dropout(.4))
    model.add(Dense(1, activation = "sigmoid"))

    return model

In [10]:
class ZelGAN(Model):
    def __init__(self, generator, discriminator, *args, **kwargs):
        # pass args and kwards to base class
        super().__init__(*args, **kwargs)
        
        # Create attributes for two models
        self.generator = generator
        self.discriminator = discriminator
        self.transformation = Sequential([
  RandomFlip("horizontal_and_vertical"),
  RandomRotation(0.2),
  RandomZoom(height_factor=(-0.1, 0.1), width_factor = (-0.1, 0.1))
])

        pass
    
    def compile(self, gen_opt, disc_opt, gen_loss, disc_loss, *args, **kwargs):
        # Compile with base class
        super().compile(*args, **kwargs)

        # Create attributes for losses and optimizers
        self.gen_opt = gen_opt
        self.disc_opt = disc_opt
        self.gen_loss = gen_loss
        self.disc_loss = disc_loss
    
    def train_step(self, batch):
        # Get the data
        real_images = self.transformation(batch)
        gen_images = self.generator(tf.random.normal((32, 64, 1), 0, 2), training = False)
        # Train the discriminator
        with tf.GradientTape() as disc_tape:
            # Pass the real and generated images to the discriminator
            yhat_real = self.discriminator(real_images, training = True)
            yhat_gen = self.discriminator(gen_images, training = True)
            yhat_all = tf.concat([yhat_real, yhat_gen], axis = 0)

            # Creat labels for real and generated images
            y_all = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_gen)], axis = 0)

            # Add some noise to the TRUE outputs
            noise_real = .1 * tf.random.normal(tf.shape(yhat_real), 0, 2)
            noise_gen = -.1 * tf.random.normal(tf.shape(yhat_gen), 0, 2)
            y_all += tf.concat([noise_real, noise_gen], axis = 0)

            # Calculate loss
            total_disc_loss = self.disc_loss(y_all, yhat_all)

        # Apply backprop
        disc_grad = disc_tape.gradient(total_disc_loss, self.discriminator.trainable_variables)
        self.disc_opt.apply_gradients(zip(disc_grad, self.discriminator.trainable_variables))

        with tf.GradientTape() as gen_tape:
            # Generate some new images
            gen_images = self.generator(tf.random.normal((64, 64, 1), 0, 2), training = True)
            
            # Create the predicted labels
            predicted_labels = self.discriminator(gen_images, training = False)
            
            # Calculate Loss
            total_gen_loss = self.gen_loss(tf.zeros_like(predicted_labels), predicted_labels)
            
        # Apply Backprop
        gen_grad = gen_tape.gradient(total_gen_loss, self.generator.trainable_variables)
        self.gen_opt.apply_gradients(zip(gen_grad, self.generator.trainable_variables))
        
        

        return {"disc_loss" : total_disc_loss, "gen_loss" : total_gen_loss}

In [11]:
# Initialize Networks
generator = build_gen_nn()
discriminator = build_disc_nn()

# Set up Losses and Optimizers
gen_opt = Adam(learning_rate = .000005, beta_1=0.8, clipvalue=1.0)
disc_opt = Adam(learning_rate = .0000005, beta_1=0.8, clipvalue=1.0)
gen_loss = BinaryCrossentropy()
disc_loss = BinaryCrossentropy()


In [12]:
Zelda = ZelGAN(generator, discriminator)
Zelda.compile(gen_opt, disc_opt, gen_loss, disc_loss)

In [13]:
class ModelMonitor(Callback):
    def __init__(self, num_img = 1, latent_dim = 64):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs = None):
        random_latent_vectors = tf.random.normal((self.num_img, self.latent_dim, 1), 0, 2)
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = array_to_img(generated_images[i])
            img.save(os.path.join("..\\progress_images", f"generated_img_{epoch}_{i}.png"))
        


In [14]:
Zelda.generator.load_weights('..//saves//generator_norm.h5')
Zelda.discriminator.load_weights('..//saves//discriminator_norm.h5')

In [None]:
with tf.device('/GPU:1'):
    hist = Zelda.fit(train_ds, epochs = 512, callbacks = [ModelMonitor()])

Epoch 1/512
Epoch 2/512
Epoch 3/512
Epoch 4/512
Epoch 5/512
Epoch 6/512
Epoch 7/512
Epoch 8/512
Epoch 9/512
Epoch 10/512
Epoch 11/512
Epoch 12/512
Epoch 13/512
Epoch 14/512
Epoch 15/512
Epoch 16/512
Epoch 17/512
Epoch 18/512
Epoch 19/512
Epoch 20/512
Epoch 21/512
Epoch 22/512
Epoch 23/512
Epoch 24/512
Epoch 25/512
Epoch 26/512
Epoch 27/512
Epoch 28/512
Epoch 29/512
Epoch 30/512
Epoch 31/512
Epoch 32/512
Epoch 33/512
Epoch 34/512
Epoch 35/512
Epoch 36/512
Epoch 37/512
Epoch 38/512
Epoch 39/512
Epoch 40/512
Epoch 41/512
Epoch 42/512
Epoch 43/512
Epoch 44/512
Epoch 45/512
Epoch 46/512
Epoch 47/512
Epoch 48/512
Epoch 49/512
Epoch 50/512
Epoch 51/512
Epoch 52/512
Epoch 53/512
Epoch 54/512
Epoch 55/512
Epoch 56/512
Epoch 57/512
Epoch 58/512
Epoch 59/512
Epoch 60/512
Epoch 61/512
Epoch 62/512
Epoch 63/512
Epoch 64/512
Epoch 65/512
Epoch 66/512
Epoch 67/512
Epoch 68/512
Epoch 69/512
Epoch 70/512
Epoch 71/512
Epoch 72/512
Epoch 73/512
Epoch 74/512
Epoch 75/512
Epoch 76/512
Epoch 77/512
Epoch 78

Epoch 80/512
Epoch 81/512
Epoch 82/512
Epoch 83/512
Epoch 84/512
Epoch 85/512
Epoch 86/512
Epoch 87/512
Epoch 88/512
Epoch 89/512
Epoch 90/512
Epoch 91/512
Epoch 92/512
Epoch 93/512
Epoch 94/512
Epoch 95/512
Epoch 96/512
Epoch 97/512
Epoch 98/512
Epoch 99/512
Epoch 100/512
Epoch 101/512
Epoch 102/512
Epoch 103/512
Epoch 104/512
Epoch 105/512
Epoch 106/512
Epoch 107/512
Epoch 108/512
Epoch 109/512
Epoch 110/512
Epoch 111/512
Epoch 112/512
Epoch 113/512
Epoch 114/512
Epoch 115/512
Epoch 116/512
Epoch 117/512
Epoch 118/512
Epoch 119/512
Epoch 120/512
Epoch 121/512
Epoch 122/512
Epoch 123/512
Epoch 124/512
Epoch 125/512
Epoch 126/512
Epoch 127/512
Epoch 128/512
Epoch 129/512
Epoch 130/512
Epoch 131/512
Epoch 132/512
Epoch 133/512
Epoch 134/512
Epoch 135/512
Epoch 136/512
Epoch 137/512
Epoch 138/512
Epoch 139/512
Epoch 140/512
Epoch 141/512
Epoch 142/512
Epoch 143/512
Epoch 144/512
Epoch 145/512
Epoch 146/512
Epoch 147/512
Epoch 148/512
Epoch 149/512
Epoch 150/512
Epoch 151/512
Epoch 152/51

Epoch 157/512
Epoch 158/512
Epoch 159/512
Epoch 160/512
Epoch 161/512
Epoch 162/512
Epoch 163/512
Epoch 164/512
Epoch 165/512
Epoch 166/512
Epoch 167/512
Epoch 168/512
Epoch 169/512
Epoch 170/512
Epoch 171/512
Epoch 172/512
Epoch 173/512
Epoch 174/512
Epoch 175/512
Epoch 176/512
Epoch 177/512
Epoch 178/512
Epoch 179/512
Epoch 180/512
Epoch 181/512
Epoch 182/512
Epoch 183/512
Epoch 184/512
Epoch 185/512
Epoch 186/512
Epoch 187/512
Epoch 188/512
Epoch 189/512
Epoch 190/512
Epoch 191/512
Epoch 192/512
Epoch 193/512
Epoch 194/512
Epoch 195/512
Epoch 196/512
Epoch 197/512
Epoch 198/512
Epoch 199/512
Epoch 200/512
Epoch 201/512
Epoch 202/512
Epoch 203/512
Epoch 204/512
Epoch 205/512
Epoch 206/512
Epoch 207/512
Epoch 208/512
Epoch 209/512
Epoch 210/512
Epoch 211/512
Epoch 212/512
Epoch 213/512
Epoch 214/512
Epoch 215/512
Epoch 216/512
Epoch 217/512
Epoch 218/512
Epoch 219/512
Epoch 220/512
Epoch 221/512
Epoch 222/512
Epoch 223/512
Epoch 224/512
Epoch 225/512
Epoch 226/512
Epoch 227/512
Epoch 

Epoch 234/512
Epoch 235/512
Epoch 236/512
Epoch 237/512
Epoch 238/512
Epoch 239/512
Epoch 240/512
Epoch 241/512
Epoch 242/512
Epoch 243/512
Epoch 244/512
Epoch 245/512
Epoch 246/512
Epoch 247/512
Epoch 248/512
Epoch 249/512
Epoch 250/512
Epoch 251/512
Epoch 252/512
Epoch 253/512
Epoch 254/512
Epoch 255/512
Epoch 256/512
Epoch 257/512
Epoch 258/512
Epoch 259/512
Epoch 260/512
Epoch 261/512
Epoch 262/512
Epoch 263/512
Epoch 264/512
Epoch 265/512
Epoch 266/512
Epoch 267/512
Epoch 268/512
Epoch 269/512
Epoch 270/512
Epoch 271/512
Epoch 272/512
Epoch 273/512
Epoch 274/512
Epoch 275/512
Epoch 276/512
Epoch 277/512
Epoch 278/512
Epoch 279/512
Epoch 280/512
Epoch 281/512
Epoch 282/512
Epoch 283/512
Epoch 284/512
Epoch 285/512
Epoch 286/512
Epoch 287/512
Epoch 288/512
Epoch 289/512
Epoch 290/512
Epoch 291/512
Epoch 292/512
Epoch 293/512
Epoch 294/512
Epoch 295/512
Epoch 296/512
Epoch 297/512
Epoch 298/512
Epoch 299/512
Epoch 300/512
Epoch 301/512
Epoch 302/512
Epoch 303/512
Epoch 304/512
Epoch 

Epoch 311/512
Epoch 312/512
Epoch 313/512
Epoch 314/512
Epoch 315/512
Epoch 316/512
Epoch 317/512
Epoch 318/512
Epoch 319/512
Epoch 320/512
Epoch 321/512
Epoch 322/512
Epoch 323/512
Epoch 324/512
Epoch 325/512
Epoch 326/512
Epoch 327/512
Epoch 328/512
Epoch 329/512
Epoch 330/512
Epoch 331/512
Epoch 332/512
Epoch 333/512
Epoch 334/512
Epoch 335/512
Epoch 336/512
Epoch 337/512
Epoch 338/512
Epoch 339/512
Epoch 340/512
Epoch 341/512
Epoch 342/512
Epoch 343/512
Epoch 344/512
Epoch 345/512
Epoch 346/512
Epoch 347/512
Epoch 348/512
Epoch 349/512
Epoch 350/512
Epoch 351/512
Epoch 352/512
Epoch 353/512
Epoch 354/512
Epoch 355/512
Epoch 356/512
Epoch 357/512
Epoch 358/512
Epoch 359/512
Epoch 360/512
Epoch 361/512
Epoch 362/512
Epoch 363/512
Epoch 364/512
Epoch 365/512
Epoch 366/512
Epoch 367/512
Epoch 368/512
Epoch 369/512
Epoch 370/512
Epoch 371/512
Epoch 372/512
Epoch 373/512
Epoch 374/512
Epoch 375/512
Epoch 376/512
Epoch 377/512
Epoch 378/512
Epoch 379/512
Epoch 380/512
Epoch 381/512
Epoch 

Epoch 388/512
Epoch 389/512
Epoch 390/512
Epoch 391/512
Epoch 392/512
Epoch 393/512
Epoch 394/512
Epoch 395/512
Epoch 396/512
Epoch 397/512
Epoch 398/512
Epoch 399/512
Epoch 400/512
 2/74 [..............................] - ETA: 1:19 - disc_loss: 0.7309 - gen_loss: 0.8395

In [None]:
Zelda.generator.save('..//saves//generator_norm.h5')
Zelda.discriminator.save('..//saves//discriminator_norm.h5')

In [None]:
#random_latent_vectors = tf.random.normal((50, 64, 1))
#generated_images = Zelda.generator(random_latent_vectors)
#generated_images *= 255
#generated_images.numpy()
#for i in range(50):
#    img = array_to_img(generated_images[i])
#    img.save(os.path.join("..\\progress_images", f"generated_img_{i}.png"))