In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from configGAN import *
cfg = flying_objects_config()
import os

import tensorflow as tf
from tensorflow import keras
from utilsGAN import *
from sklearn.metrics import confusion_matrix
# import seaborn as sns
from datetime import datetime
import imageio
from skimage import img_as_ubyte

import pprint
# import the necessary packages
from keras.models import Sequential
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv3D, Conv2D, Conv1D, Convolution2D, Deconvolution2D, Cropping2D, UpSampling2D
from keras.layers import Input, Conv2DTranspose, ConvLSTM2D, TimeDistributed, GlobalMaxPooling2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers import Concatenate, concatenate, Reshape
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.models import Model
from keras.callbacks import TensorBoard
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.layers import Input, merge
from keras.regularizers import l2
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Reshape, core, Dropout, LeakyReLU
import keras.backend as kb
from tensorflow.python.keras.engine import compile_utils
import io

In [3]:

def limit_gpu():
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            print(e)
limit_gpu()

if cfg.GPU >=0:
    print("creating network model using gpu " + str(cfg.GPU))
    os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.GPU)
elif cfg.GPU >=-1:
    print("creating network model using cpu ")  
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
    os.environ["CUDA_VISIBLE_DEVICES"] = ""


creating network model using gpu 0


In [4]:
show_statistics(cfg.training_data_dir, fineGrained=False, title=" Training Data Statistics ")
show_statistics(cfg.validation_data_dir, fineGrained=False, title=" Validation Data Statistics ")
show_statistics(cfg.testing_data_dir, fineGrained=False, title=" Testing Data Statistics ")


######################################################################
##################### Training Data Statistics #####################
######################################################################
total image number 	 10817
total class number 	 3
class square 	 3488 images
class circular 	 3626 images
class triangle 	 3703 images
######################################################################

######################################################################
##################### Validation Data Statistics #####################
######################################################################
total image number 	 2241
total class number 	 3
class triangle 	 745 images
class square 	 783 images
class circular 	 713 images
######################################################################

######################################################################
##################### Testing Data Statistics #####################
##########################

In [5]:
batch_size = 64
image_shape = (cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, cfg.IMAGE_CHANNEL)

# Prepare dataset

# Prepare plots and logger

In [6]:
from improvedUtils import *

# Model Architecture

In [7]:
class architecture:
    __name__='default_model_v0'
    __changes__="Default model given from teacher. Changed Discriminator optimizer to SGD. Disable Jitter"
    
    __normalization__='[0,1]'
    __jitter__= False
    
    generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    
    @staticmethod
    def discriminator():
        def main():
            last_img = Input(shape=image_shape)
            first_img = Input(shape=image_shape)

            # Concatenate image and conditioning image by channels to produce input
            combined_imgs = Concatenate(axis=-1)([last_img, first_img])

            d1 = Conv2D(32, (3, 3), strides=2, padding='same')(combined_imgs) 
            d1 = Activation('relu')(d1) 
            d2 = Conv2D(64, (3, 3), strides=2, padding='same')(d1)
            d2 = Activation('relu')(d2) 
            d3 = Conv2D(128, (3, 3), strides=2, padding='same')(d2)
            d3 = Activation('relu')(d3) 

            validity = Conv2D(1, (3, 3), strides=2, padding='same')(d3)

            model = Model([last_img, first_img], validity)
            return model
        return main()
    
    @staticmethod
    def generator():

        def main():

            inputs = Input(shape=image_shape)

            down1 = Conv2D(32, (3, 3),padding='same')(inputs)
            down1 = Activation('relu')(down1) 
            down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)

            down2 = Conv2D(64, (3, 3), padding='same')(down1_pool)
            down2 = Activation('relu')(down2) 


            up1 = UpSampling2D((2, 2))(down2)
            up1 = concatenate([down1, up1], axis=3)
            up1 = Conv2D(256, (3, 3), padding='same')(up1) 
            up1 = Activation('relu')(up1) 


            up2 = Conv2D(256, (3, 3), padding='same')(up1) 
            up2 = Activation('relu')(up2) 

            nbr_img_channels = image_shape[2]
            outputs = Conv2D(nbr_img_channels, (1, 1), activation='sigmoid')(up2)

            model = Model(inputs=inputs, outputs=outputs, name='Generator')
            return model

        return main()
    
    def loss():  # Decided by https://arxiv.org/abs/1611.07004
        loss_object = tf.keras.losses.MeanSquaredError()
        LAMBDA = 100
        def generator_loss(disc_generated_output, gen_output, target): # https://arxiv.org/abs/1611.07004
            gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

            # mean absolute error
            l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

            total_gen_loss = gan_loss + (LAMBDA * l1_loss)

            return total_gen_loss, gan_loss, l1_loss

        def discriminator_loss(disc_real_output, disc_generated_output):
            real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

            generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

            total_disc_loss = real_loss + generated_loss

            return total_disc_loss

        return {'g_loss_fn':generator_loss, 'd_loss_fn':discriminator_loss}
    
    def __model__():
        return [architecture.generator, architecture.discriminator]

In [8]:
train_batch_generator, valid_batch_generator, test_batch_generator, nbr_train_data,nbr_valid_data, nbr_test_data = preprocess(image_shape, normalize_type=architecture.__normalization__, jitter=architecture.__jitter__)

train_x (30, 128, 128, 3) float32 0.0 1.0
train_y (30, 128, 128, 3) float32 0.0 1.0
{'BATCH_SIZE': 30,
 'DATA_AUGMENTATION': True,
 'DEBUG_MODE': True,
 'DROPOUT_PROB': 0.5,
 'GPU': 0,
 'IMAGE_CHANNEL': 3,
 'IMAGE_HEIGHT': 128,
 'IMAGE_WIDTH': 128,
 'LEARNING_RATE': 0.01,
 'LR_DECAY_FACTOR': 0.1,
 'NUM_EPOCHS': 200,
 'PRINT_EVERY': 50,
 'SAVE_EVERY': 1,
 'SEQUENCE_LENGTH': 10,
 'testing_data_dir': '../data/FlyingObjectDataset_10K/testing',
 'training_data_dir': '../data/FlyingObjectDataset_10K/training',
 'validation_data_dir': '../data/FlyingObjectDataset_10K/validation'}


# Model

In [9]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        #self.epoch = 0
    
    def special_compile(self, 
                d_optimizer=None, 
                g_optimizer=None,
                d_loss=None,
                g_loss=None,               
                loss_fn=None,
                metrics=None,
                loss_weights=None,
                weighted_metrics=None,
                run_eagerly=None,
                steps_per_execution=None,
              **kwargs):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss = d_loss
        self.g_loss = g_loss
        
        super().compile(metrics=metrics)
    
    def compile(self, **kwargs):
        raise NotImplementedError("Please use special_compile()")

    @tf.function
    def train_step(self, data): 
        #self.epoch += 1
        input_image, target = data # TODO: Must check if this iterates or take same image each run
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            
            # Generate images
            gen_output = self.generator(input_image, training=True)
            
            # Train discriminator
            disc_real_output = self.discriminator([input_image, target], training=True)
            disc_generated_output = self.discriminator([input_image, gen_output], training=True)
            
            # Training
            gen_total_loss, gen_gan_loss, gen_l1_loss = self.g_loss(disc_generated_output, gen_output, target)
            disc_loss = self.d_loss(disc_real_output, disc_generated_output)
            
            # Set weights
            generator_gradients = gen_tape.gradient(gen_total_loss,
                                              self.generator.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss,
                                                   self.discriminator.trainable_variables)
            # Update weights
            self.g_optimizer.apply_gradients(zip(generator_gradients,
                                              self.generator.trainable_variables))
            self.d_optimizer.apply_gradients(zip(discriminator_gradients,
                                                  self.discriminator.trainable_variables))
            
        self.compiled_metrics.update_state(target, gen_output)
        
        met = {
                'gen_total_loss':gen_total_loss,
                'gen_gan_loss':gen_gan_loss,
                'gen_l1_loss':gen_l1_loss,
                'disc_loss':disc_loss, 
                
        }
        met.update({m.name: m.result() for m in self.metrics})
        return met
    

    def test_step(self, data):
        real_images, last_images = data
        valid, fake_last_frame = self(real_images, training=False)

        self.compiled_metrics.update_state(real_images, fake_last_frame)
            
        return {m.name: m.result() for m in self.metrics}
    
    def call(self, first_frame, training=False):
        fake_last_frame = self.generator(first_frame, training)
        validate_frame = self.discriminator([fake_last_frame, first_frame], training)
        
        return [validate_frame, fake_last_frame]

In [None]:
steps_per_epoch = (nbr_train_data // cfg.BATCH_SIZE) 
validation_steps=(nbr_valid_data//cfg.BATCH_SIZE)
log_dir = logger(architecture().__name__)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, update_freq='epoch')

generator, discriminator =  architecture.__model__()
loss = architecture.loss()
model_to_json(discriminator(), log_dir + "/discriminator.json")
model_to_json(generator(), log_dir + "/generator.json")

gan = GAN(discriminator=discriminator(), generator=generator())

gan.special_compile(
    d_optimizer=architecture.discriminator_optimizer,
    g_optimizer=architecture.generator_optimizer,
    d_loss=loss['d_loss_fn'],
    g_loss=loss['g_loss_fn'],
    metrics=['accuracy', SSIM_loss]
)
gan.fit(
    x=train_batch_generator, 
    epochs=cfg.NUM_EPOCHS, 
    verbose=1, 
    batch_size=cfg.BATCH_SIZE,
    steps_per_epoch=steps_per_epoch, #
    validation_data=valid_batch_generator,
    validation_steps=validation_steps, 
    callbacks=[GANMonitor(num_img=3, validation_data=valid_batch_generator,log_dir=log_dir), tensorboard_callback],
    
) 


with open(log_dir+"/gan_finished", 'a') as f:
    f.write(architecture.__changes__)

Epoch 1/200
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/200
 53/360 [===>..........................] - ETA: 1:01 - gen_total_loss: 2.1899 - gen_gan_loss: 0.2817 - gen_l1_loss: 0.0191 - disc_loss: 0.4725 - accuracy: 0.3927 - SSIM_loss: 0.0458