In [1]:
import tensorflow as tf

from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Reshape, LeakyReLU, Dropout, UpSampling2D, BatchNormalization

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

import os
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback

In [2]:
vgg = VGG19(include_top=False, input_shape=(224, 224, 3))
vgg.trainable = False

In [3]:
def build_generator():
    model = Sequential()
    # Beginnings of a generated image
    model.add(Dense(7*7*1024, input_dim=100))
    model.add(LeakyReLU(0.2))
    model.add(Reshape((7, 7,1024)))
    
    # 1 
    model.add(UpSampling2D())
    model.add(Conv2D(1024, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    
    # 2
    model.add(Conv2D(1024, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    # model.add(Conv2D(1024, 4, padding='same'))
    # model.add(LeakyReLU(0.2))
    
    # 3
    model.add(UpSampling2D())
    model.add(Conv2D(512, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    # model.add(Conv2D(512, 4, padding='same'))
    # model.add(LeakyReLU(0.2))

    # 4
    model.add(Conv2D(512, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    
    
    # 5
    model.add(UpSampling2D())
    model.add(Conv2D(256, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    # model.add(Conv2D(256, 4, padding='same'))
    # model.add(LeakyReLU(0.2))

    #6
    model.add(Conv2D(256, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    # model.add(Conv2D(256, 4, padding='same'))
    # model.add(LeakyReLU(0.2))

    #7
    model.add(UpSampling2D())
    model.add(Conv2D(128, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))

    #8
    model.add(Conv2D(128, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(0.2))
    # model.add(Conv2D(128, 4, padding='same'))
    # model.add(LeakyReLU(0.2))

    #9
    model.add(UpSampling2D())
    model.add(Conv2D(64, 4, padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    # model.add(Conv2D(64, 4, padding='same'))
    # model.add(LeakyReLU(0.2))
    
    # Conv layer to get to one channel
    model.add(Conv2D(1, 3, padding='same', activation='sigmoid'))
    
    return model

In [4]:
generator = build_generator()

In [5]:
generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 50176)             5067776   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 50176)             0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 1024)        0         
                                                                 
 up_sampling2d (UpSampling2  (None, 14, 14, 1024)      0         
 D)                                                              
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 1024)      16778240  
                                                                 
 batch_normalization (Batch  (None, 14, 14, 1024)      4096      
 Normalization)                                         

In [6]:
def build_discriminator():
    model = tf.keras.Sequential([
        vgg,
        layers.Flatten(),
        layers.Dense(1024),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

In [7]:
discriminator = build_discriminator()

In [8]:
discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vgg19 (Functional)          (None, 7, 7, 512)         20024384  
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 dense_1 (Dense)             (None, 1024)              25691136  
                                                                 
 leaky_re_lu_8 (LeakyReLU)   (None, 1024)              0         
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense_2 (Dense)             (None, 1)                 1025      
                                                                 
Total params: 45716545 (174.39 MB)
Trainable params: 2

In [9]:
class FashionGAN(Model): 
    def __init__(self, generator, discriminator, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.generator = generator 
        self.discriminator = discriminator 
        
    def compile(self, g_opt, d_opt, g_loss, d_loss, *args, **kwargs): 
        super().compile(*args, **kwargs)
        
        self.g_opt = g_opt
        self.d_opt = d_opt
        self.g_loss = g_loss
        self.d_loss = d_loss 

    def train_step(self, batch):
        # Get the data 
        real_images = batch
        fake_images = self.generator(tf.random.normal((224, 224, 1)), training=False)
        with tf.GradientTape() as d_tape: 
            yhat_real = self.discriminator(real_images, training=True) 
            yhat_fake = self.discriminator(fake_images, training=True)
            yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0)
            y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)
            noise_real = 0.15*tf.random.uniform(tf.shape(yhat_real))
            noise_fake = -0.15*tf.random.uniform(tf.shape(yhat_fake))
            y_realfake += tf.concat([noise_real, noise_fake], axis=0)
            
            # Calculate loss - BINARYCROSS 
            total_d_loss = self.d_loss(y_realfake, yhat_realfake)
            
        dgrad = d_tape.gradient(total_d_loss, self.discriminator.trainable_variables) 
        self.d_opt.apply_gradients(zip(dgrad, self.discriminator.trainable_variables))
        
        # Train the generator 
        with tf.GradientTape() as g_tape: 
            # Generate some new images
            gen_images = self.generator(tf.random.normal((224,224,1)), training=True)
            predicted_labels = self.discriminator(gen_images, training=False)
            total_g_loss = self.g_loss(tf.zeros_like(predicted_labels), predicted_labels) 
            
        # Apply backprop
        ggrad = g_tape.gradient(total_g_loss, self.generator.trainable_variables)
        self.g_opt.apply_gradients(zip(ggrad, self.generator.trainable_variables))
        
        return {"d_loss":total_d_loss, "g_loss":total_g_loss}

In [10]:
g_opt = Adam(learning_rate=0.0001) 
d_opt = Adam(learning_rate=0.00001) 
g_loss = BinaryCrossentropy()
d_loss = BinaryCrossentropy()



In [11]:
fashgan = FashionGAN(generator, discriminator)

In [12]:
fashgan.compile(g_opt, d_opt, g_loss, d_loss)

In [13]:
class ModelMonitor(Callback):
    def __init__(self, num_img=3, latent_dim=100):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.uniform((self.num_img, self.latent_dim,1))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = array_to_img(generated_images[i])
            img.save(os.path.join('images', f'generated_img_{epoch}_{i}.png'))

In [None]:
hist = fashgan.fit(data, epochs=2000, callbacks=[ModelMonitor()])