# Imports and TF setup

In [1]:
import itertools
import numpy as np
import tensorflow as tf
import keras.backend as K
import keras 
import pandas as pd

import sys
sys.path.insert(0,'..')

from scipy import signal, fftpack
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
from keras.layers import Conv1D, Conv2D, MaxPooling2D, GlobalMaxPooling2D, GlobalAveragePooling2D
from keras.layers import UpSampling2D, LeakyReLU, Lambda, Add, Multiply, Activation, Conv2DTranspose
from keras.layers import Cropping2D, ZeroPadding2D, Flatten, Subtract
from keras.backend.tensorflow_backend import set_session
from keras.utils import plot_model
from keras.optimizers import Adam
from functools import partial
from New_Layers import *
from keras.layers.merge import _Merge
from multiprocessing import Pool
import copy 
import pandas as pd
from os import listdir
from os.path import isfile, join
from os import listdir
from os.path import isfile, join
import numpy as np
import astropy
from tqdm import tqdm_notebook as tqdm 
from astropy.io.fits.card import UNDEFINED
from astropy.io import fits

from sklearn.model_selection import train_test_split

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "0"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))
%matplotlib inline


Using TensorFlow backend.


#  
# GAN Interface

In [2]:
class GAN(object):
    def __init__(self, generator, discriminator, training_scheme, 
                 generator_kwargs={}, discriminator_kwargs={}, 
                 generator_training_kwargs={}, discriminator_training_kwargs={}):
        
        assert training_scheme is not None , "No training scheme selected!"
        assert isinstance(generator, keras.models.Model), "Generator is not a model!"
        assert isinstance(discriminator, keras.models.Model), "Discriminator is not a model!"
        
        assert type(generator_kwargs) is dict , "generator kwargs are not a dictionary!"
        assert type(discriminator_kwargs) is dict , "discriminator kwargs are not a dictionary!"
        assert type(generator_training_kwargs) is dict , "discriminator training kwargs are not a dictionary!"
        assert type(discriminator_training_kwargs) is dict , "generator training kwargs are not a dictionary!"
        
        
        self._training_scheme = training_scheme

        self.__generator_model = generator
        self.__discriminator_model = discriminator

        self.__discriminator = self._training_scheme.compile_discriminator(self.__generator_model, 
                                                                           self.__discriminator_model, 
                                                                           **discriminator_kwargs)
        self.__generator = self._training_scheme.compile_generator(self.__generator_model, 
                                                                   self.__discriminator_model,
                                                                   **generator_kwargs)
        self.generator_training_kwargs = generator_training_kwargs
        self.discriminator_training_kwargs = discriminator_training_kwargs
    
    def generator_model(self): return self.__generator_model
    def discriminator_model(self): return self.__discriminator_model
    def generator(self): return self.__generator
    def discriminator(self): return self.__discriminator
    def summaries(self):
        print "\n\n\nGenerator Summary: \n"
        self.__generator_model.summary()
        plot_model(self.__generator_model, show_shapes=True, to_file='GAN_Generator_Model.png')
        
        print "\n\n\nDiscriminator Summary: \n"
        self.__discriminator_model.summary()
        plot_model(self.__discriminator_model, show_shapes=True, to_file='GAN_Discriminator_Model.png')
        
        print "\n\n\nGenerator Training Model Summary: \n"
        self.__generator.summary()
        plot_model(self.__generator, show_shapes=True, to_file='GAN_Generator_Training_Model.png')
       
        print "\n\n\nDiscriminator Training Model Summary: \n"
        self.__discriminator.summary()
        plot_model(self.__discriminator, show_shapes=True, to_file='GAN_Discriminator_Training_Model.png')
   
    def fit(self, x, y, verbose=False, shuffle=False, steps=None, epochs=None, steps_per_epoch=None, batch_size=None, 
            generator_training_multiplier=1, discriminator_training_multiplier=1, 
            generator_callbacks=[],discriminator_callbacks=[], **kwargs):
        self.verbose = verbose
        self.callbacks_generator, self.callbacks_discriminator= [], []
        self.History = keras.callbacks.History()
        self.shuffle = shuffle
        assert (steps is None and epochs is not None and steps_per_epoch is not None) or \
               (steps is not None and epochs is None and steps_per_epoch is None), "please supply either steps OR epochs and steps per epoch"
        
        assert batch_size is not None, "batch size is None, please provide batch size"
        try:
            iterator = iter(generator_callbacks)            
        except TypeError:
            assert False, "generator callbacks are not iterable!"
        
        try:
            iterator = iter(discriminator_callbacks)            
        except TypeError:
            assert False, "discriminator callbacks are not iterable!"
        
        for c in generator_callbacks:
            c.set_model(self.__generator)
            self.callbacks_generator.append(c)
        
        for c in discriminator_callbacks:
            c.set_model(self.__discriminator)
            self.callbacks_discriminator.append(c)
        
        
        for callback in self.callbacks_generator + self.callbacks_discriminator + [self.History]:
            callback.on_train_begin()
        
        
        if steps is not None:
            for i in tqdm(xrange(steps)):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(x=x, y=y, step=i, 
                                                                                       shuffle=self.shuffle,
                                                                                       batch_size=batch_size, 
                                                                                       generator_training_multiplier=generator_training_multiplier,
                                                                                       discriminator_training_multiplier=discriminator_training_multiplier)
                    loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(temp_loss_discriminator)}
                    loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(temp_loss_generator)}
                    losses = loss_discriminator.copy().update(loss_generator)
                    self.History.on_epoch_end(i,losses) # populate history
        elif steps_per_epoch is not None and epochs is not None:
            for k in xrange(epochs):
                for callback in self.callbacks_generator + self.callbacks_discriminator:
                    callback.on_epoch_begin(k)
                loss_discriminator, loss_generator = None, None
                for i in xrange(steps_per_epoch):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(x=x, y=y, step=i, epoch=k, 
                                                                                       shuffle=self.shuffle,
                                                                                       steps_per_epoch=steps_per_epoch, 
                                                                                       batch_size=batch_size,
                                                                                       generator_training_multiplier=generator_training_multiplier,
                                                                                       discriminator_training_multiplier=discriminator_training_multiplier)

                    if loss_discriminator is None: loss_discriminator = temp_loss_discriminator
                    elif hasattr(loss_discriminator, '__iter__'): loss_discriminator = [x+y for x,y in zip(loss_discriminator, temp_loss_discriminator)]
                    else: loss_discriminator += temp_loss_discriminator
                        
                    if loss_generator is None: loss_generator = temp_loss_generator
                    elif hasattr(loss_generator, '__iter__'): loss_generator = [x+y for x,y in zip(loss_generator, temp_loss_generator)]
                    else: loss_generator += temp_loss_generator
                
                loss_discriminator = [item * 1.0/steps_per_epoch for item in loss_discriminator]
                loss_generator = [item * 1.0/steps_per_epoch for item in loss_generator]
                                    
                for callback in self.callbacks_generator:
                    callback.on_epoch_end(k, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})

                for callback in self.callbacks_discriminator:
                    callback.on_epoch_end(k, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

                loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(loss_discriminator)}
                loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(loss_generator)}
                losses = loss_discriminator.copy().update(loss_generator)
                self.History.on_epoch_end(k,losses) # populate history
                    
        for callback in self.callbacks_generator + self.callbacks_discriminator:
            callback.on_train_end()

        return self.History


    def __fit_on_batch(self, x, y, step, epoch=None, steps_per_epoch=None, generator_training_multiplier=1, 
                       discriminator_training_multiplier=1, batch_size=None, shuffle=False, **kwargs):
        for callback in self.callbacks_generator + self.callbacks_discriminator:
                if epoch is not None:
                    callback.on_batch_begin(step, logs={'batch':step, 'size':batch_size})
                else:
                    callback.on_epoch_begin(step)
                    
        steps = step if epoch is None else step + (epoch * steps_per_epoch)
        curr_x, curr_y = x, y

        # Discriminator Training
        loss = None
        for j in range(discriminator_training_multiplier): 
            curr = (steps*discriminator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % curr_x.shape[0] for i in range(curr,nex)]
            if shuffle: idxs = np.random.randint(0, curr_x.shape[0], size=batch_size)
            train_x, train_y = curr_x[idxs], curr_y[idxs]
            temp_loss = self._training_scheme.train_discriminator(self.__discriminator, train_x, train_y, batch_size, **self.discriminator_training_kwargs)
            if loss is None: loss = temp_loss
            elif hasattr(loss, '__iter__'): loss = [x+y for x,y in zip(loss, temp_loss)]
            else: loss += temp_loss
             
        loss_discriminator = [item * 1.0/discriminator_training_multiplier for item in loss]
        if self.verbose: print 'Discriminator Loss:', loss_discriminator
        
        # Generator Training
        loss = None
        for j in range(generator_training_multiplier): 
            curr = (steps*generator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % curr_x.shape[0] for i in range(curr,nex)]
            if shuffle: idxs = np.random.randint(0, curr_x.shape[0], size=batch_size)
            train_x, train_y = curr_x[idxs], curr_y[idxs]
            temp_loss = self._training_scheme.train_generator(self.__generator, train_x, train_y, batch_size, **self.generator_training_kwargs)
            if loss is None: loss = temp_loss
            elif hasattr(loss, '__iter__'): loss = [x+y for x,y in zip(loss, temp_loss)]
            else: loss += temp_loss
            
        loss_generator = [item * 1.0/generator_training_multiplier for item in loss]
        if self.verbose: print 'Generator Loss:', loss_generator

        for callback in self.callbacks_generator: # callbacks for generator
            if epoch is not None:
                callback.on_batch_end(step, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})
            else:
                callback.on_epoch_end(step, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})
            
        for callback in self.callbacks_discriminator: # callbacks for discriminator
            if epoch is not None:
                callback.on_batch_end(step, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})
            else:
                callback.on_epoch_end(step, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

        return loss_discriminator, loss_generator


#  
 
# IWGAN TrainingScheme


In [3]:
class Base_TrainingScheme(object):
    @staticmethod
    def compile_discriminator(generator, discriminator, **kwargs):
            raise NotImplementedError("Please Implement this method")

    @staticmethod
    def compile_generator(generator, discriminator, **kwargs):
            raise NotImplementedError("Please Implement this method")

    @staticmethod
    def train_discriminator(discriminator, x, y, batch_size, **kwargs):
        raise NotImplementedError("Please Implement this method")
    
    @staticmethod
    def train_generator(generator, x, y, batch_size, **kwargs):
        raise NotImplementedError("Please Implement this method")
    
class IWGAN_TrainingScheme(Base_TrainingScheme):

    @staticmethod
    def wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    @staticmethod
    def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
        gradients = K.gradients(y_pred, averaged_samples)[0]
        gradients_sqr = K.square(gradients)
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)

    @staticmethod
    class RandomWeightedAverage(_Merge):
        def __init__(self, batch_size,**kwargs):
            super(IWGAN_TrainingScheme.RandomWeightedAverage, self).__init__(**kwargs)
            self.batch_size = batch_size
            
        def _merge_function(self, inputs):
            weights = K.random_uniform((self.batch_size, 1, 1, 1))
            return (weights * inputs[0]) + ((1 - weights) * inputs[1])

    @staticmethod
    def compile_discriminator(generator, discriminator, optimizer, batch_size, **kwargs):
        gp_weight = 10
        for layer in generator.layers: layer.trainable = False
        generator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        in_gen = generator(inp)
        in_real = Input(tuple(discriminator.layers[0].input_shape[1:]))
        discriminator_output_from_generator = discriminator(in_gen)
        discriminator_output_from_real_samples = discriminator(in_real)
        averaged_samples = IWGAN_TrainingScheme.RandomWeightedAverage(batch_size)([in_real, in_gen])
        averaged_samples_out = discriminator(averaged_samples)
        partial_gp_loss = partial(IWGAN_TrainingScheme.gradient_penalty_loss, averaged_samples=averaged_samples,
                                  gradient_penalty_weight=gp_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'
        in_gen.trainable = False

        # ----- Discriminator -----
        discriminator_model = Model(inputs=[in_real, inp], outputs=[discriminator_output_from_real_samples,
                                                                    discriminator_output_from_generator,
                                                                    averaged_samples_out])
        discriminator_model.layers[1].trainable = False
        discriminator_model.compile(optimizer=optimizer, loss=[IWGAN_TrainingScheme.wasserstein_loss,
                                                               IWGAN_TrainingScheme.wasserstein_loss, partial_gp_loss])
        # ----- Discriminator -----

        for layer in generator.layers: layer.trainable = True
        generator.trainable = True
        return discriminator_model

    @staticmethod
    def compile_generator(generator, discriminator, gen_loss, dis_loss, optimizer, 
                          gen_metrics, dis_metrics, gen_dis_loss_ratio, **kwargs):

        
        for layer in discriminator.layers: layer.trainable = False
        discriminator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        gen_out = generator(inp)

        # ----- Generator -----
        model = Model(inp, [gen_out,discriminator(generator(inp))])
        model.layers[2].trainable = False
        model.compile(loss={'discriminator': dis_loss, 'generator': gen_loss},
                      optimizer= optimizer, metrics={'discriminator':dis_metrics, 'generator':gen_metrics},
                      loss_weights={'discriminator': 1-gen_dis_loss_ratio, 'generator':gen_dis_loss_ratio})
        # ----- Generator -----

        for layer in discriminator.layers: layer.trainable = True
        discriminator.trainable = True
        return model

    @staticmethod
    def train_discriminator(discriminator, x, y, batch_size, **kwargs):
        # Discriminator Training
        loss = discriminator.train_on_batch([y, x],  # inp_real, x
                                            [np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_real_samples
                                            -np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_generator
                                            np.zeros((batch_size, 1), dtype=np.float32)])  # averaged_samples_out

        return loss

    @staticmethod
    def train_generator(generator, x, y, batch_size, **kwargs):
        # Generator Training
        loss = generator.train_on_batch(x, [y, np.ones((batch_size, 1), dtype=np.float32)])            #x, [y , 1]
                                 

        return loss


In [None]:
model = GAN(generator=Generator(x_train.shape[1:]), 
            discriminator=Discriminator(x_train.shape[1:]), 
            training_scheme=IWGAN_TrainingScheme,
            generator_kwargs=Generator_kwargs(), 
            discriminator_kwargs=Discriminator_kwargs(batch_size=32), 
            generator_training_kwargs={}, 
            discriminator_training_kwargs={})
print 'loaded GAN'
#model.summaries()

#classifier = compile_classifier(Classifier(), model.generator_model(), 'binary_crossentropy', Adam(), ['accuracy'], None, 'classifier', reverse_freeze=True)
#print 'loaded Classifier'
#class_train = classifier_training(classifier, 1, 32)
log_callback_train = log_results_multi(100, 'best_model_train_resnet_se.txt', network_name='FCN-SE', 
                                       x=x_train, y=y_train, save_path='best_model_train_resnet_se.h5', image_name='train')
log_callback_test = log_results_multi(100, 'best_model_test_resnet_se.txt', network_name='FCN-SE', 
                                      x=x_test, y=y_test, save_path='best_model_test_resnet_se.h5', image_name='test')
def sch(epoch):
    if epoch < 750:
        return 1e-3
    
    if epoch < 1500:
        if epoch == 750: print("changed lr to 5e-5")
        return 5e-5 
    
    if epoch == 1500: print("changed lr to 1e-5")
    return 1e-5
import missinglink
missinglink_callback_gen = missinglink.KerasCallback()
missinglink_callback_gen.set_properties(display_name='generator test')
missinglink_callback_dis = missinglink.KerasCallback()    
missinglink_callback_dis.set_properties(display_name='discriminator test')
schedule = keras.callbacks.LearningRateScheduler(sch)
model.fit(x_train, y_train, steps=20001, batch_size=32, shuffle=True,
          generator_callbacks=[log_callback_train, log_callback_test, schedule, missinglink_callback_gen],
          discriminator_callbacks=[schedule, missinglink_callback_dis])
