# Imports and TF setup

In [None]:
import itertools
import numpy as np
import tensorflow as tf
import keras.backend as K
import keras 
import pandas as pd

import sys
sys.path.insert(0,'..')

from scipy import signal, fftpack
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
from keras.layers import Conv1D, Conv2D, MaxPooling2D, GlobalMaxPooling2D, GlobalAveragePooling2D
from keras.layers import UpSampling2D, LeakyReLU, Lambda, Add, Multiply, Activation, Conv2DTranspose
from keras.layers import Cropping2D, ZeroPadding2D, Flatten, Subtract
from keras.backend.tensorflow_backend import set_session
from keras.utils import plot_model
from keras.optimizers import Adam
from functools import partial
from New_Layers import *
from keras.layers.merge import _Merge
from multiprocessing import Pool
import copy 
import pandas as pd
from os import listdir
from os.path import isfile, join
from os import listdir
from os.path import isfile, join
import numpy as np
import astropy
from tqdm import tqdm_notebook as tqdm 
from astropy.io.fits.card import UNDEFINED
from astropy.io import fits

from sklearn.model_selection import train_test_split

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "0"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))
%matplotlib inline


#  
# GAN Interface

In [None]:
class GAN(object):
    def __init__(self, generator, discriminator, training_scheme, 
                 generator_kwargs={}, discriminator_kwargs={}, 
                 generator_training_kwargs={}, discriminator_training_kwargs={}):
        
        assert training_scheme is not None , "No training scheme selected!"
        assert isinstance(generator, keras.models.Model), "Generator is not a model!"
        assert isinstance(discriminator, keras.models.Model), "Discriminator is not a model!"
        
        assert type(generator_kwargs) is dict , "generator kwargs are not a dictionary!"
        assert type(discriminator_kwargs) is dict , "discriminator kwargs are not a dictionary!"
        assert type(generator_training_kwargs) is dict , "discriminator training kwargs are not a dictionary!"
        assert type(discriminator_training_kwargs) is dict , "generator training kwargs are not a dictionary!"
        
        
        self._training_scheme = training_scheme

        self.__generator_model = generator
        self.__discriminator_model = discriminator

        self.__discriminator = self._training_scheme.compile_discriminator(self.__generator_model, 
                                                                           self.__discriminator_model, 
                                                                           **discriminator_kwargs)
        self.__generator = self._training_scheme.compile_generator(self.__generator_model, 
                                                                   self.__discriminator_model,
                                                                   **generator_kwargs)
        self.generator_training_kwargs = generator_training_kwargs
        self.discriminator_training_kwargs = discriminator_training_kwargs
    
    def generator_model(self): return self.__generator_model
    def discriminator_model(self): return self.__discriminator_model
    def generator(self): return self.__generator
    def discriminator(self): return self.__discriminator
    def summaries(self):
        print "\n\n\nGenerator Summary: \n"
        self.__generator_model.summary()
        plot_model(self.__generator_model, show_shapes=True, to_file='GAN_Generator_Model.png')
        
        print "\n\n\nDiscriminator Summary: \n"
        self.__discriminator_model.summary()
        plot_model(self.__discriminator_model, show_shapes=True, to_file='GAN_Discriminator_Model.png')
        
        print "\n\n\nGenerator Training Model Summary: \n"
        self.__generator.summary()
        plot_model(self.__generator, show_shapes=True, to_file='GAN_Generator_Training_Model.png')
       
        print "\n\n\nDiscriminator Training Model Summary: \n"
        self.__discriminator.summary()
        plot_model(self.__discriminator, show_shapes=True, to_file='GAN_Discriminator_Training_Model.png')
   
    def fit(self, x, y, verbose=False, shuffle=False, steps=None, epochs=None, steps_per_epoch=None, batch_size=None, 
            generator_training_multiplier=1, discriminator_training_multiplier=1, 
            generator_callbacks=[],discriminator_callbacks=[], **kwargs):
        self.verbose = verbose
        self.callbacks_generator, self.callbacks_discriminator= [], []
        self.History = keras.callbacks.History()
        self.shuffle = shuffle
        assert (steps is None and epochs is not None and steps_per_epoch is not None) or \
               (steps is not None and epochs is None and steps_per_epoch is None), "please supply either steps OR epochs and steps per epoch"
        
        assert batch_size is not None, "batch size is None, please provide batch size"
        try:
            iterator = iter(generator_callbacks)            
        except TypeError:
            assert False, "generator callbacks are not iterable!"
        
        try:
            iterator = iter(discriminator_callbacks)            
        except TypeError:
            assert False, "discriminator callbacks are not iterable!"
        
        for c in generator_callbacks:
            c.set_model(self.__generator)
            self.callbacks_generator.append(c)
        
        for c in discriminator_callbacks:
            c.set_model(self.__discriminator)
            self.callbacks_discriminator.append(c)
        
        
        for callback in self.callbacks_generator + self.callbacks_discriminator + [self.History]:
            callback.on_train_begin()
        
        
        if steps is not None:
            for i in tqdm(xrange(steps)):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(x=x, y=y, step=i, 
                                                                                       shuffle=self.shuffle,
                                                                                       batch_size=batch_size, 
                                                                                       generator_training_multiplier=generator_training_multiplier,
                                                                                       discriminator_training_multiplier=discriminator_training_multiplier)
                    loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(temp_loss_discriminator)}
                    loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(temp_loss_generator)}
                    losses = loss_discriminator.copy().update(loss_generator)
                    self.History.on_epoch_end(i,losses) # populate history
        elif steps_per_epoch is not None and epochs is not None:
            for k in xrange(epochs):
                for callback in self.callbacks_generator + self.callbacks_discriminator:
                    callback.on_epoch_begin(k)
                loss_discriminator, loss_generator = None, None
                for i in xrange(steps_per_epoch):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(x=x, y=y, step=i, epoch=k, 
                                                                                       shuffle=self.shuffle,
                                                                                       steps_per_epoch=steps_per_epoch, 
                                                                                       batch_size=batch_size,
                                                                                       generator_training_multiplier=generator_training_multiplier,
                                                                                       discriminator_training_multiplier=discriminator_training_multiplier)

                    if loss_discriminator is None: loss_discriminator = temp_loss_discriminator
                    elif hasattr(loss_discriminator, '__iter__'): loss_discriminator = [x+y for x,y in zip(loss_discriminator, temp_loss_discriminator)]
                    else: loss_discriminator += temp_loss_discriminator
                        
                    if loss_generator is None: loss_generator = temp_loss_generator
                    elif hasattr(loss_generator, '__iter__'): loss_generator = [x+y for x,y in zip(loss_generator, temp_loss_generator)]
                    else: loss_generator += temp_loss_generator
                
                loss_discriminator = [item * 1.0/steps_per_epoch for item in loss_discriminator]
                loss_generator = [item * 1.0/steps_per_epoch for item in loss_generator]
                                    
                for callback in self.callbacks_generator:
                    callback.on_epoch_end(k, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})

                for callback in self.callbacks_discriminator:
                    callback.on_epoch_end(k, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

                loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(loss_discriminator)}
                loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(loss_generator)}
                losses = loss_discriminator.copy().update(loss_generator)
                self.History.on_epoch_end(k,losses) # populate history
                    
        for callback in self.callbacks_generator + self.callbacks_discriminator:
            callback.on_train_end()

        return self.History


    def __fit_on_batch(self, x, y, step, epoch=None, steps_per_epoch=None, generator_training_multiplier=1, 
                       discriminator_training_multiplier=1, batch_size=None, shuffle=False, **kwargs):
        for callback in self.callbacks_generator + self.callbacks_discriminator:
                if epoch is not None:
                    callback.on_batch_begin(step, logs={'batch':step, 'size':batch_size})
                else:
                    callback.on_epoch_begin(step)
                    
        steps = step if epoch is None else step + (epoch * steps_per_epoch)
        curr_x, curr_y = x, y

        # Discriminator Training
        loss = None
        for j in range(discriminator_training_multiplier): 
            curr = (steps*discriminator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % curr_x.shape[0] for i in range(curr,nex)]
            if shuffle: idxs = np.random.randint(0, curr_x.shape[0], size=batch_size)
            train_x, train_y = curr_x[idxs], curr_y[idxs]
            temp_loss = self._training_scheme.train_discriminator(self.__discriminator, train_x, train_y, batch_size, **self.discriminator_training_kwargs)
            if loss is None: loss = temp_loss
            elif hasattr(loss, '__iter__'): loss = [x+y for x,y in zip(loss, temp_loss)]
            else: loss += temp_loss
             
        loss_discriminator = [item * 1.0/discriminator_training_multiplier for item in loss]
        if self.verbose: print 'Discriminator Loss:', loss_discriminator
        
        # Generator Training
        loss = None
        for j in range(generator_training_multiplier): 
            curr = (steps*generator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % curr_x.shape[0] for i in range(curr,nex)]
            if shuffle: idxs = np.random.randint(0, curr_x.shape[0], size=batch_size)
            train_x, train_y = curr_x[idxs], curr_y[idxs]
            temp_loss = self._training_scheme.train_generator(self.__generator, train_x, train_y, batch_size, **self.generator_training_kwargs)
            if loss is None: loss = temp_loss
            elif hasattr(loss, '__iter__'): loss = [x+y for x,y in zip(loss, temp_loss)]
            else: loss += temp_loss
            
        loss_generator = [item * 1.0/generator_training_multiplier for item in loss]
        if self.verbose: print 'Generator Loss:', loss_generator

        for callback in self.callbacks_generator: # callbacks for generator
            if epoch is not None:
                callback.on_batch_end(step, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})
            else:
                callback.on_epoch_end(step, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})
            
        for callback in self.callbacks_discriminator: # callbacks for discriminator
            if epoch is not None:
                callback.on_batch_end(step, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})
            else:
                callback.on_epoch_end(step, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

        return loss_discriminator, loss_generator


#  
 
# IWGAN TrainingScheme


In [None]:
class Base_TrainingScheme(object):
    @staticmethod
    def compile_discriminator(generator, discriminator, **kwargs):
            raise NotImplementedError("Please Implement this method")

    @staticmethod
    def compile_generator(generator, discriminator, **kwargs):
            raise NotImplementedError("Please Implement this method")

    @staticmethod
    def train_discriminator(discriminator, x, y, batch_size, **kwargs):
        raise NotImplementedError("Please Implement this method")
    
    @staticmethod
    def train_generator(generator, x, y, batch_size, **kwargs):
        raise NotImplementedError("Please Implement this method")
    
class IWGAN_TrainingScheme(Base_TrainingScheme):

    @staticmethod
    def wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    @staticmethod
    def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
        gradients = K.gradients(y_pred, averaged_samples)[0]
        gradients_sqr = K.square(gradients)
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)

    @staticmethod
    class RandomWeightedAverage(_Merge):
        def __init__(self, batch_size,**kwargs):
            super(IWGAN_TrainingScheme.RandomWeightedAverage, self).__init__(**kwargs)
            self.batch_size = batch_size
            
        def _merge_function(self, inputs):
            weights = K.random_uniform((self.batch_size, 1, 1, 1))
            return (weights * inputs[0]) + ((1 - weights) * inputs[1])

    @staticmethod
    def compile_discriminator(generator, discriminator, optimizer, batch_size, **kwargs):
        gp_weight = 10
        for layer in generator.layers: layer.trainable = False
        generator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        in_gen = generator(inp)
        in_real = Input(tuple(discriminator.layers[0].input_shape[1:]))
        discriminator_output_from_generator = discriminator(in_gen)
        discriminator_output_from_real_samples = discriminator(in_real)
        averaged_samples = IWGAN_TrainingScheme.RandomWeightedAverage(batch_size)([in_real, in_gen])
        averaged_samples_out = discriminator(averaged_samples)
        partial_gp_loss = partial(IWGAN_TrainingScheme.gradient_penalty_loss, averaged_samples=averaged_samples,
                                  gradient_penalty_weight=gp_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'
        in_gen.trainable = False

        # ----- Discriminator -----
        discriminator_model = Model(inputs=[in_real, inp], outputs=[discriminator_output_from_real_samples,
                                                                    discriminator_output_from_generator,
                                                                    averaged_samples_out])
        discriminator_model.layers[1].trainable = False
        discriminator_model.compile(optimizer=optimizer, loss=[IWGAN_TrainingScheme.wasserstein_loss,
                                                               IWGAN_TrainingScheme.wasserstein_loss, partial_gp_loss])
        # ----- Discriminator -----

        for layer in generator.layers: layer.trainable = True
        generator.trainable = True
        return discriminator_model

    @staticmethod
    def compile_generator(generator, discriminator, gen_loss, dis_loss, optimizer, 
                          gen_metrics, dis_metrics, gen_dis_loss_ratio, **kwargs):

        
        for layer in discriminator.layers: layer.trainable = False
        discriminator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        gen_out = generator(inp)

        # ----- Generator -----
        model = Model(inp, [gen_out,discriminator(generator(inp))])
        model.layers[2].trainable = False
        model.compile(loss={'discriminator': dis_loss, 'generator': gen_loss},
                      optimizer= optimizer, metrics={'discriminator':dis_metrics, 'generator':gen_metrics},
                      loss_weights={'discriminator': 1-gen_dis_loss_ratio, 'generator':gen_dis_loss_ratio})
        # ----- Generator -----

        for layer in discriminator.layers: layer.trainable = True
        discriminator.trainable = True
        return model

    @staticmethod
    def train_discriminator(discriminator, x, y, batch_size, **kwargs):
        # Discriminator Training
        loss = discriminator.train_on_batch([y, x],  # inp_real, x
                                            [np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_real_samples
                                            -np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_generator
                                            np.zeros((batch_size, 1), dtype=np.float32)])  # averaged_samples_out

        return loss

    @staticmethod
    def train_generator(generator, x, y, batch_size, **kwargs):
        # Generator Training
        loss = generator.train_on_batch(x, [y, np.ones((batch_size, 1), dtype=np.float32)])            #x, [y , 1]
                                 

        return loss


#  
 
# Callbacks and Helper Functions

In [None]:
class classifier_training(keras.callbacks.Callback):
    def __init__(self, classifier_model, training_ratio, batch_size):
        self.classifier_model = classifier_model
        self.train_ratio = training_ratio
        self.x, _ ,self.y = load_simulation_data()
        self.batch_size = batch_size
    
    def on_epoch_end(self, epoch, logs=None):
        loss = None
        for i in xrange(self.train_ratio):
            idxs = np.random.randint(0, self.x.shape[0], size=self.batch_size)
            loss = self.classifier_model.train_on_batch(self.x[idxs], self.y[idxs])
        #print 'class loss: ', loss

class log_results(keras.callbacks.Callback):
    def __init__(self, wgan_model=None, classifier_model=None, logging_frequency=0, log_path=''):
        self.logging_frequency = logging_frequency
        self.log_path = log_path
        self.model = wgan_model
        self.classifier_model = classifier_model
        self.last_best_classifier, self.last_best_generator, self.last_best_combined= 0, 0, 0
    def on_epoch_end(self, epoch, logs=None):
        if self.logging_frequency != 0 and epoch % self.logging_frequency == 0:
            time = np.linspace(0, 28.625, 20610)
            plt.figure(1, figsize=(10,10))
            plt.subplot(211)
            plt.scatter(time,  self.y[:1][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            final_res = self.model.generator().predict_on_batch(self.x[:1])
            plt.subplot(212)
            plt.scatter(time, final_res[0][0, 0, :20610, 0], s=0.5)
            plt.title("Neural Net Output,"+ str(final_res[1][0,0]))
            plt.tight_layout()
            plt.savefig('img_'+str(epoch)+'_test.png')
            plt.show()
            self.get_score(epoch, self.model != None, self.classifier_model != None)
    
    
    def get_score(self, i_in, gen_test = False, class_test = False):
        def find_nearest(array, value):
            array = np.asarray(array)
            idx = (np.abs(array - value)).argmin()
            return idx
        
        at_2perc = 1
        if gen_test:
            periods = np.load('../Data/total_params_sim_test_true_3.npy')[:,1]
            transits_ref = np.load('../Data/total_transits_sim_test_true_3.npy')
            x_test = np.expand_dims(np.load('../Data/total_x_sim_test_true_3.npy'),axis=1)
            x_test = np.pad(x_test, ((0,0), (0,0) ,(0, 30+ 96), (0,0)), 'constant', constant_values=(0, 0))
            print 'X loaded'
            transits = self.model.generator().predict(x_test, verbose=1)[0][:, 0, :20610, :]
            print "Finished predicting data"
            transits_ref = transits_ref[:,:,0]
            scaler = MinMaxScaler(feature_range=(0, 1))
            scaler.fit(transits_ref)
            transits_ref = scaler.transform(transits_ref)
            print("Finished loading data")
            periods = np.power(10, periods)
            period_pred = []
            model_preds = []
            np.warnings.filterwarnings('ignore')
            for i in tqdm(range(10000)): model_preds.append([transits[i, :, 0], periods[i], 1000, transits_ref[i, :]])
            model_preds = np.asarray(imap_unordered_bar(process_transit, model_preds, 5))
            auc_p, percentages, epsilon_range = p_epsilon_chart(model_preds[:, 0], model_preds[:, 1])
            at_1perc = percentages[np.argmin(np.abs(epsilon_range - 0.01))]
            at_2perc = percentages[np.argmin(np.abs(epsilon_range - 0.02))]
            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(epsilon_range, percentages, label='Keras (area = {:.5f})'.format(auc_p))
            plt.xlabel('epsilon')
            plt.ylabel('period detection rate')
            plt.title('Period ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.1)
            plt.savefig('img_PAUC_'+str(i_in)+'_test.png')
            plt.show()

            if(at_2perc > self.last_best_generator): 
                self.model.generator().save('best_generator_test.h5')
                self.classifier_model.save('best_classifier_test_generator_based.h5')
                self.last_best_generator = at_2perc

            if not class_test:
                open(self.log_path,'a').write( 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc))
                print 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc)

        if class_test:
            y_test = np.load('../Data/total_params_sim_test_3.npy')[:,1] > 0
            x_test = np.expand_dims(np.load('../Data/total_x_sim_test_3.npy'),axis=1)
            x_test = np.pad(x_test, ((0,0), (0,0) ,(0, 30+ 96), (0,0)), 'constant', constant_values=(0, 0))
            y_pred = self.classifier_model.predict(x_test, verbose=1)[:,0]
            fpr, tpr, thresholds_keras = roc_curve(y_test, y_pred)
            roc_auc = roc_auc_score(y_test, y_pred)

            if(tpr[find_nearest(fpr, 0.01)] > self.last_best_classifier): 
                self.model.generator().save('best_generator_test_classifier_based.h5')
                self.classifier_model.save('best_classifier_test.h5')
                self.last_best_classifier = tpr[find_nearest(fpr, 0.01)]
            
            if(tpr[find_nearest(fpr, 0.01)] * at_2perc > self.last_best_combined):
                self.model.generator().save('best_generator_combined_test.h5')
                self.classifier_model.save('best_classifier_combined_test.h5')
                self.last_best_combined = tpr[find_nearest(fpr, 0.01)] * at_2perc              
            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(fpr, tpr, label='Keras (area = {:.5f})'.format(roc_auc))
            plt.xlabel('False positive rate')
            plt.ylabel('True positive rate')
            plt.title('Classifier ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.02)
            plt.savefig('img_ClassAUC_'+str(i_in)+'_test.png')
            plt.show()

            if not gen_test:
                open(self.log_path,'a').write( 'i: %d, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc))
                print 'i: %d, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc)
                
        if class_test and gen_test:
            open(self.log_path,'a').write( 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc))
            print 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc)

class RocAucMetricCallback(keras.callbacks.Callback):
    def __init__(self, predict_batch_size=80, include_on_batch=False):
        super(RocAucMetricCallback, self).__init__()
        self.predict_batch_size = predict_batch_size
        self.include_on_batch = include_on_batch

    def on_train_begin(self, logs={}):
        if not ('roc_auc_val' in self.params['metrics']):
            self.params['metrics'].append('roc_auc_val')


    def on_epoch_end(self, epoch, logs={}):
        logs['roc_auc_val'] = float('-inf')
        if (self.validation_data):
            y_pred = self.model.predict(self.validation_data[0])[:, 0]
            y_test = self.validation_data[1]
            fpr, tpr, thresholds_keras = roc_curve(y_test, y_pred)
            logs['roc_auc_val'] = roc_auc_score(y_test, y_pred)
            print(logs['roc_auc_val'])

            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(fpr, tpr, label='Keras (area = {:.5f})'.format(logs['roc_auc_val']))
            plt.xlabel('False positive rate')
            plt.ylabel('True positive rate')
            plt.title('ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.02)
            plt.show()
            def find_nearest(array, value):
                array = np.asarray(array)
                idx = (np.abs(array - value)).argmin()
                return idx
            print 'Percentage at 0.001: %.5f, Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], logs['roc_auc_val'])

class log_results_multi(keras.callbacks.Callback):
    def __init__(self, logging_frequency=0, log_path='', x=None, y=None, 
                 myslice=slice(0,1), network_name='', save_path='best_generator_test_sectors.h5', image_name=''):
        self.LOGDIR = 'TrainingLogs'
        self.network_name = network_name
        if not os.path.exists(self.LOGDIR): os.makedirs(self.LOGDIR)
        if not os.path.exists(os.path.join(self.LOGDIR, self.network_name)): os.makedirs(os.path.join(self.LOGDIR, self.network_name))
        self.logging_frequency = logging_frequency
        self.log_path = log_path
        self.model = None
        self.save_path = save_path
        self.x_test, self.y_test = x, y 
        self.last_best_generator= 0
        self.myslice = myslice
        self.image_name = image_name
        
    def on_epoch_end(self, epoch, logs=None):
        if self.logging_frequency != 0 and epoch % self.logging_frequency == 0:
            time = np.linspace(0, 28.625, 20610)
            plt.figure(1, figsize=(15,10))
            plt.subplot(311)
            plt.scatter(time,  self.x_test[self.myslice][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            plt.subplot(312)
            plt.scatter(time,  self.y_test[self.myslice][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            final_res = self.model.predict_on_batch(self.x_test[self.myslice])
            plt.subplot(313)
            plt.scatter(time, final_res[0][0, 0, :20610, 0], s=0.5)
            plt.title("Neural Net Output,"+ str(final_res[1][0,0]))
            plt.tight_layout()
            plt.savefig(os.path.join(self.LOGDIR, self.network_name, 'img_' + str(epoch)+ '_' + self.image_name +'_test.png'))
            #plt.show()
            plt.clf()
            plt.close()
            print "epoch: %d current pic dice coeff: %f" %(epoch,dice_coef(self.y_test[self.myslice][0, 0, :20610, 0], final_res[0][0, 0, :20610, 0]))
            self.get_score(epoch, self.model != None)
    
    
    def get_score(self, i_in, gen_test = False):
        def find_nearest(array, value):
            array = np.asarray(array)
            idx = (np.abs(array - value)).argmin()
            return idx
        
        if gen_test:
            transits = self.model.predict(self.x_test, verbose=1)[0][:,0,:20610,0]
            transits_ref = self.y_test[:,0,:20610,0]
            print(transits.shape, transits_ref.shape)
            np.warnings.filterwarnings('ignore')
            dice_coeff = dice_coef(transits_ref, transits)
            
            if(dice_coeff > self.last_best_generator): 
                self.model.save(os.path.join(self.LOGDIR, self.network_name, self.save_path))
                self.last_best_generator = dice_coeff 
            open(os.path.join(self.LOGDIR, self.network_name, self.log_path),'a').write('epoch: %d ,dice coef: %f\n' % (i_in, dice_coeff))
            print('epoch: %d ,dice coef: %f\n' % (i_in, dice_coeff))

            
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def dice_coef(y_true,y_pred):
    y_pred = y_pred > 0.01
    oyt = np.sum((2.0 * y_true * y_pred)) / np.sum(y_true+y_pred)
    print(oyt, np.sum((y_true * y_pred)), np.sum(y_true), np.sum(y_pred))
    return oyt#sum((2 * y_true * y_pred))/ sum(y_true+y_pred)

def p_epsilon_chart(p_test, p_pred):
    percentages = []
    auc_p = 0
    epsilon_range = np.linspace(0, 1, 10000)
    for epsilon in epsilon_range:
        current_correct = p_pred[np.abs(1 - (p_pred / p_test)) < epsilon]
        percentages.append(float(current_correct.shape[0]) / float(p_pred.shape[0]))
        auc_p += float(percentages[-1]) / 10000
    return auc_p, percentages, epsilon_range


#  
 
# General Helper Functions

In [None]:
def intersection_loss(y_true, y_pred):
    return 10 * absolute_true_error(y_true, y_pred) + 3 * intersection_true_error(y_true, y_pred) + absolute_false_error(y_true, y_pred) + intersection_false_error(y_true, y_pred) - dice_coef1(y_true, y_pred)  # + masked_mse(y_true, y_pred)/10

def load_simulation_data(header='', only_true=False, cross_validation=False):
    x = np.load('../Data/total_x_sim_train_3.npy')
    x = np.expand_dims(x,axis=1)
    x = np.pad(x, ((0, 0), (0, 0), (0, 126), (0, 0)), 'constant', constant_values=(0.0, 0.0))
    print 'X loaded'
    y = np.load('../Data/total_transits_sim_train_3.npy')
    y = y[:, :, 0]
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(y)
    y = scaler.transform(y)
    y = np.expand_dims(y, axis=2)
    ya = y[:, :, 0]
    ya[ya != 0] = 1
    y = np.expand_dims(ya, axis=2)
    y = np.expand_dims(y, axis=1)
    y = np.pad(y, ((0, 0), (0, 0), (0, 30 + 96), (0, 0)), 'constant', constant_values=(0, 0))
    print y.shape
    print 'Y loaded'
    print(x.shape, y.shape)
    has_transit = np.load('../Data/total_params_sim_train_3.npy')[:,1] != 0
    print np.sum(has_transit), has_transit.shape
    print 'Params loaded'
    print("Finished Loading Data")
    if only_true:
        x = x[has_transit]
        y = y[has_transit]
    if cross_validation:
        return train_test_split(x, y, test_size=0.2, random_state=0)
    return x, y, has_transit

    

#  
 
# Network Blocks

In [None]:
def residual_block_v2(inx, filters, kernel, activation, pooling):
    x = BatchNormalization()(inx) # full pre activation
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = se_block(x, filters, 2) # Squeeze and Excitation Network
    
    x_k = add([x, inx]) # Resnet
    
    x = BatchNormalization()(x_k)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), strides=(1,pooling), padding='same')(x)
    #x = MaxPooling2D(pool_size=(1, pooling), padding='same')(x) # convolutional Pool
    return x, x_k # UNET

def residual_block_up_v2(inx, x_k, filters, kernel, activation, pooling):
    x = BatchNormalization()(inx) # full pre activation
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = se_block(x, filters, 2) # Squeeze and Excitation Network
    x = add([x, x_k, inx]) # UNET with Resnet
    
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2DTranspose(filters, (1, kernel), strides=(1,pooling), padding='same')(x)
    #x = UpSampling2D(size=(1, pooling))(x) # convolutional Pool
    return x

def residual_block(inx, filters, kernel, activation, pooling):
    x = Conv2D(filters, (1, kernel), padding='same')(inx)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = add([x, inx])
    x_k = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, pooling), strides=(1, pooling))(x_k)
    return x, x_k

def residual_block_up(inx, x_k, filters, kernel, activation, pooling):
    x = Conv2D(filters, (1, kernel), padding='same')(inx)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = add([x, x_k, inx])
    x = Activation(activation=activation)(x)
    x = UpSampling2D((1, pooling))(x)
    return x

def se_block(inx, ch, ratio=16):
    x = GlobalAveragePooling2D()(inx)
    x = Dense(ch//ratio, activation='relu')(x)
    x = Dense(ch, activation='sigmoid')(x)
    return Multiply()([inx, x])

#  
 
# Generator Functions

In [None]:
def Generator(input_shape):
    gen_inputs = Input(shape=input_shape)
    x = Conv2D(32, (1, 5), padding='same')(gen_inputs)
    x, x_k_1 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_2 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_3 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_4 = residual_block_v2(x, 32, 5, 'relu', 2)
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, (1, 5), padding='same')(x)
    
    
    x, x_k_5 = residual_block_v2(x, 64, 5, 'relu', 2)
    x, x_k_6 = residual_block_v2(x, 64, 5, 'relu', 2)
    x, x_k_7 = residual_block_v2(x, 64, 5, 'relu', 2)
    y, x_k_8 = residual_block_v2(x, 64, 5, 'relu', 2)
    
    x = BatchNormalization()(y)
    x = Activation('relu')(x)
    x = Conv2D(64, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    decoder = x
    x = Conv2D(64, (1, 5), padding='same')(decoder)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m1 = x

    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m2 = x 

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Dropout(0.25)(x)

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = concatenate([m2,x])
    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = concatenate([m1,x])
    x = Conv2D(64, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = residual_block_up_v2(x, y, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_8, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_7, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_6, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_5, 64, 5, 'relu', 2)
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(32, (1, 5), padding='same')(x)
    
    x = residual_block_up_v2(x, x_k_4, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_3, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_2, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_1, 32, 5, 'relu', 1)
    
    _, x = residual_block_v2(x, 32, 5, 'relu', 1)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    gen_outputs = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='gen_output')(x)
    return Model(gen_inputs, gen_outputs, name='generator')
    
def Generator_2(input_shape):
    gen_inputs = Input(shape=input_shape)
    x = Conv2D(32, (1, 5), padding='same')(gen_inputs)
    x, x_k_1 = residual_block(x, 32, 5, 'relu', 2)
    x, x_k_2 = residual_block(x, 32, 5, 'relu', 2)
    x, x_k_3 = residual_block(x, 32, 5, 'relu', 2)
    x, x_k_4 = residual_block(x, 32, 5, 'relu', 2)
    x = Conv2D(64, (1, 5), padding='same')(x)
    x, x_k_5 = residual_block(x, 64, 5, 'relu', 2)
    x, x_k_6 = residual_block(x, 64, 5, 'relu', 2)
    x, x_k_7 = residual_block(x, 64, 5, 'relu', 2)
    y, x_k_8 = residual_block(x, 64, 5, 'relu', 2)
    x = Conv2D(64, (1, 5), padding='same')(y)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    encoded = x

    decoder = encoded
    x = Conv2D(64, (1, 5), padding='same')(decoder)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m1 = x

    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m2 = x 

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Dropout(0.25)(x)

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = concatenate([m2,x])
    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = concatenate([m1,x])
    x = Conv2D(64, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = residual_block_up(x, y, 64, 5, 'relu', 2)
    x = residual_block_up(x, x_k_8, 64, 5, 'relu', 2)
    x = residual_block_up(x, x_k_7, 64, 5, 'relu', 2)
    x = residual_block_up(x, x_k_6, 64, 5, 'relu', 2)
    x = residual_block_up(x, x_k_5, 64, 5, 'relu', 2)
    x = Conv2D(32, (1, 5), padding='same')(x)
    x = residual_block_up(x, x_k_4, 32, 5, 'relu', 2)
    x = residual_block_up(x, x_k_3, 32, 5, 'relu', 2)
    x = residual_block_up(x, x_k_2, 32, 5, 'relu', 2)
    x = residual_block_up(x, x_k_1, 32, 5, 'relu', 1)
    _, x = residual_block(x, 32, 5, 'relu', 1)
    gen_outputs = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='gen_output')(x)
    return Model(gen_inputs, gen_outputs, name='generator')


#  
 
# Discriminator Functions

In [None]:
def Discriminator(input_shape):
    dis_inputs = Input(shape=input_shape)
    x = Conv2D(64, (1, 5), strides=(1,2), padding='same')(dis_inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(256, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = GlobalMaxPooling2D()(x)
    x = Dropout(0.25)(x)
    x = Dense(128, activation = 'relu')(x)
    x = Dropout(0.25)(x)
    x = Dense(128, activation = 'relu')(x)
    x = Dropout(0.25)(x)
    dis_outputs = Dense(1, kernel_initializer='he_normal', name='dis_output')(x)
    return Model(dis_inputs, dis_outputs, name='discriminator')



#  
 
# Classifier Functions

In [None]:

def Classifier(input_shape):
    class_in = Input(shape=input_shape)
    class_inputs = []
    class_lengths = [32, 32, 32, 32, 64, 64, 64, 64]
    for i in xrange(8):
        class_inputs.append(Input(shape=(K.int_shape(gen_inputs)[1], K.int_shape(gen_inputs)[2]//(2**i), class_lengths[i])))
    
    z = Conv2D(32, (1, 5), padding='same')(class_in)
    z = Add()([z, class_inputs[0]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[1]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[2]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[3]])

    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2D(64, (1, 5), padding='same')(z)

    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[4]])

    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[5]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[6]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[7]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2D(128, (1, 5), kernel_initializer='he_normal', padding='same')(z)
    z = BatchNormalization()(z)
    z = LeakyReLU(0.2)(z)
    z = Dropout(0.25)(z)

    z = Conv2D(256, (1, 5), kernel_initializer='he_normal', padding='same')(z)
    z = BatchNormalization()(z)
    z = LeakyReLU(0.2)(z)
    z = Dropout(0.25)(z)


    z = GlobalMaxPooling2D()(z)
    z = Dense(256, activation='relu')(z)
    z = Dropout(0.4)(z)
    z = Dense(256, activation='relu')(z)
    z = Dropout(0.4)(z)

    class_out = Dense(1, activation='sigmoid')(z)
    return Model(class_inputs, class_out, name='Classifier')

def compile_classifier_stack(classifier, generator, loss, optimizer, metrics, loss_weights, name, reverse_freeze=False):
    for layer in generator.layers: layer.trainable = False
    generator.trainable = False
    inp = Input(tuple(generator.layers[0].input_shape[1:]))
    gen_out = [layer.output for layer in generator.layers if isinstance(layer, keras.layers.Add)][:8]
    gen_multi_out = Model(generator.inputs, gen_out)
    classifier_out = classifier(gen_multi_out(inp)+[inp])
    model = Model(inp, classifier_out)
    if reverse_freeze: model.layers[1].trainable = False
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics, loss_weights=loss_weights)
    model.summary()
    plot_model(model, show_shapes=True, to_file='DCGAN_model_classifier.png')
    for layer in generator.layers: layer.trainable = True
    generator.trainable = True
    return model

#  
 
# Training Example

In [None]:
!rm -rf ./TrainingLogs/FCN-SE/img_*.png

def Discriminator_kwargs(batch_size):
    Discriminator_training_kwargs = {"optimizer":Adam(1e-3, beta_1=0.5, beta_2=0.9),
                                     "batch_size":batch_size} 
    return Discriminator_training_kwargs


def Generator_kwargs():
    Generator_training_kwargs = {"gen_loss":intersection_loss, 
                                 "dis_loss":IWGAN_TrainingScheme.wasserstein_loss,
                                 "optimizer":Adam(1e-3, beta_1=0.5, beta_2=0.9), 
                                 "dis_metrics":['accuracy'],
                                 "gen_metrics":[dice_coef1, masked_mse], 
                                 "gen_dis_loss_ratio":0.85} 
    return Generator_training_kwargs

x_train, x_test, y_train, y_test = load_simulation_data(only_true=True, cross_validation=True)


In [None]:
#!rm -rf ./TrainingLogs/FCN-SE/img_*.png
#!rm -rf ./TrainingLogs/FCN/img_*.png
model = GAN(generator=Generator(x_train.shape[1:]), 
            discriminator=Discriminator(x_train.shape[1:]), 
            training_scheme=IWGAN_TrainingScheme,
            generator_kwargs=Generator_kwargs(), 
            discriminator_kwargs=Discriminator_kwargs(batch_size=32), 
            generator_training_kwargs={}, 
            discriminator_training_kwargs={})
print 'loaded GAN'
#model.summaries()

#classifier = compile_classifier(Classifier(), model.generator_model(), 'binary_crossentropy', Adam(), ['accuracy'], None, 'classifier', reverse_freeze=True)
#print 'loaded Classifier'
#class_train = classifier_training(classifier, 1, 32)
log_callback_train = log_results_multi(100, 'best_model_train_resnet_se.txt', network_name='FCN-SE', 
                                       x=x_train, y=y_train, save_path='best_model_train_resnet_se.h5', image_name='train')
log_callback_test = log_results_multi(100, 'best_model_test_resnet_se.txt', network_name='FCN-SE', 
                                      x=x_test, y=y_test, save_path='best_model_test_resnet_se.h5', image_name='test')
def sch(epoch):
    if epoch < 750:
        return 1e-3
    
    if epoch < 1500:
        if epoch == 750: print("changed lr to 5e-5")
        return 5e-5 
    
    if epoch == 1500: print("changed lr to 1e-5")
    return 1e-5
import missinglink
missinglink_callback_gen = missinglink.KerasCallback()
missinglink_callback_gen.set_properties(display_name='generator test')
missinglink_callback_dis = missinglink.KerasCallback()    
missinglink_callback_dis.set_properties(display_name='discriminator test')
schedule = keras.callbacks.LearningRateScheduler(sch)
model.fit(x_train, y_train, steps=20001, batch_size=32, shuffle=True,
          generator_callbacks=[log_callback_train, log_callback_test, schedule, missinglink_callback_gen],
          discriminator_callbacks=[schedule, missinglink_callback_dis])


[16A