In [40]:
import itertools
import numpy as np
import tensorflow as tf
import keras.backend as K
import keras 
import pandas as pd

import sys
sys.path.insert(0,'..')

from scipy import signal, fftpack
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
from keras.layers import Conv1D, Conv2D, MaxPooling2D, GlobalMaxPooling2D, UpSampling2D, LeakyReLU, Lambda, Add, Activation
from keras.backend.tensorflow_backend import set_session
from keras.utils import plot_model
from keras.optimizers import Adam
from functools import partial
from New_Layers import *
from keras.layers.merge import _Merge
from multiprocessing import Pool
import copy 
import pandas as pd
from os import listdir
from os.path import isfile, join
from os import listdir
from os.path import isfile, join
import numpy as np
import astropy
from tqdm import tqdm_notebook as tqdm 
from astropy.io.fits.card import UNDEFINED
from astropy.io import fits

from sklearn.model_selection import train_test_split

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.visible_device_list = "1"
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))


%matplotlib inline

class IWGAN_TrainingScheme(object):

    @staticmethod
    def wasserstein_loss(y_true, y_pred):
        return K.mean(y_true * y_pred)

    @staticmethod
    def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
        gradients = K.gradients(y_pred, averaged_samples)[0]
        gradients_sqr = K.square(gradients)
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
        return K.mean(gradient_penalty)

    @staticmethod
    class RandomWeightedAverage(_Merge):
        def _merge_function(self, inputs):
            weights = K.random_uniform((32, 1, 1, 1))
            return (weights * inputs[0]) + ((1 - weights) * inputs[1])

    @staticmethod
    def compile_discriminator(generator, discriminator, **kwargs):
        optimizer = kwargs['optimizer']
        gp_weight = 10
        for layer in generator.layers: layer.trainable = False
        generator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        in_gen = generator(inp)
        in_real = Input(tuple(discriminator.layers[0].input_shape[1:]))
        discriminator_output_from_generator = discriminator(in_gen)
        discriminator_output_from_real_samples = discriminator(in_real)
        averaged_samples = IWGAN_TrainingScheme.RandomWeightedAverage()([in_real, in_gen])
        averaged_samples_out = discriminator(averaged_samples)
        partial_gp_loss = partial(IWGAN_TrainingScheme.gradient_penalty_loss, averaged_samples=averaged_samples,
                                  gradient_penalty_weight=gp_weight)
        partial_gp_loss.__name__ = 'gradient_penalty'
        in_gen.trainable = False

        # ----- Discriminator -----
        discriminator_model = Model(inputs=[in_real, inp], outputs=[discriminator_output_from_real_samples,
                                                                    discriminator_output_from_generator,
                                                                    averaged_samples_out])
        discriminator_model.layers[1].trainable = False
        discriminator_model.compile(optimizer=optimizer, loss=[IWGAN_TrainingScheme.wasserstein_loss,
                                                               IWGAN_TrainingScheme.wasserstein_loss, partial_gp_loss])
        # ----- Discriminator -----

        for layer in generator.layers: layer.trainable = True
        generator.trainable = True
        plot_model(discriminator_model, show_shapes=True, to_file='DCGAN_model_dis.png')
        return discriminator_model

    @staticmethod
    def compile_generator(generator, discriminator, **kwargs):
        gen_loss = kwargs['gen_loss']
        dis_loss = kwargs['dis_loss']
        optimizer = kwargs['optimizer']
        gen_metrics = kwargs['gen_metrics']
        dis_metrics = kwargs['dis_metrics']
        gen_dis_loss_ratio = kwargs['gen_dis_loss_ratio']
        
        
        
        for layer in discriminator.layers: layer.trainable = False
        discriminator.trainable = False
        inp = Input(tuple(generator.layers[0].input_shape[1:]))
        gen_out = generator(inp)

        # ----- Generator -----
        model = Model(inp, [gen_out,discriminator(generator(inp))])
        model.layers[2].trainable = False
        model.compile(loss={'discriminator': kwargs['dis_loss'], 'generator': kwargs['gen_loss']},
                      optimizer= kwargs['optimizer'], metrics={'discriminator':kwargs['dis_metrics'], 'generator':kwargs['gen_metrics']},
                      loss_weights={'discriminator': 1-kwargs['gen_dis_loss_ratio'], 'generator':kwargs['gen_dis_loss_ratio']})
        # ----- Generator -----

        for layer in discriminator.layers: layer.trainable = True
        discriminator.trainable = True
        return model

    @staticmethod
    def train_discriminator(discriminator, x, y, batch_size):
        # Discriminator Training
        loss = discriminator.train_on_batch([x, y],  # inp_real, x
                                            [np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_real_samples
                                            -np.ones((batch_size, 1), dtype=np.float32),  # discriminator_output_from_generator
                                            np.zeros((batch_size, 1), dtype=np.float32)])  # averaged_samples_out

        return loss

    @staticmethod
    def train_generator(generator, x, y, batch_size):

        # Generator Training
        loss = generator.train_on_batch(x, [y, np.ones((len(x_true), 1), dtype=np.float32)])            #x, [y , 1]
                                 

        return loss


class GAN(object):
    def __init__(self, generator, discriminator, training_scheme, generator_kwargs={}, discriminator_kwargs={}, generator_training_kwargs={}, discriminator_training_kwargs={}):
        
        assert not training_scheme is None , "No training scheme selected!"
        assert type(generator) is callable , "No generator function supplied!"
        assert type(discriminator) is callable , "No discriminator function supplied!"
        
        assert type(generator_kwargs) is dict , "generator kwargs are not a dictionary!"
        assert type(discriminator_kwargs) is dict , "discriminator kwargs are not a dictionary!"
        assert type(generator_training_kwargs) is dict , "discriminator training kwargs are not a dictionary!"
        assert type(discriminator_training_kwargs) is dict , "generator training kwargs are not a dictionary!"
        
        
        self._training_scheme = training_scheme
        # TODO: ----- Unrelated, Move out -----
        self.__generator_model = generator(**generator_args) #Model(gen_inputs, gen_outputs, name='generator')
        self.__discriminator_model = discriminator(**discriminator_args)#Model(dis_inputs, dis_outputs, name='discriminator')
        # TODO: ----- Unrelated, Move out -----

        self.__discriminator = self._training_scheme.compile_discriminator(self.__generator, self.__discriminator, **discriminator_training_kwargs)
        self.__generator = self._training_scheme.compile_generator(self.__generator, self.__discriminator, **generator_training_kwargs)
    
    def generator_model(self): return self.__generator_model
    def discriminator_model(self): return self.__discriminator_model
    def generator(self): return self.__generator
    def discriminator(self): return self.__discriminator
    
    def fit(self, x, y, verbose=True, steps=None, epochs=None, steps_per_epoch=None, batch_size=None, 
            generator_training_multiplier=1, discriminator_training_multiplier=1, 
            generator_callbacks=None,discriminator_callbacks=None):
        self.x = x
        self.y = y
        self.verbose = verbose
        self.callbacks_generator, self.callbacks_discriminator= [], []
        self.History = keras.callbacks.History()
        
        assert steps is None and not epochs is None and not steps_per_epoch is None, "please supply either steps OR epochs and steps per epoch"
        assert not steps is None and epochs is None and steps_per_epoch is None, "please supply either steps OR epochs and steps per epoch"
        
        try:
            iterator = iter(generator_callbacks)            
        except TypeError:
            assert False, "generator callbacks are not iterable!"
        
        try:
            iterator = iter(discriminator_callbacks)            
        except TypeError:
            assert False, "discriminator callbacks are not iterable!"
        
        for c in generator_callbacks:
            c.set_model(self.__generator)
            self.callbacks_generator.append(c)
        
        for c in discriminator_callbacks:
            c.set_model(self.__discriminator)
            self.callbacks_discriminator.append(c)
        
        
        for callback in self.callbacks_generator + self.callbacks_discriminator + [self.History]:
            callback.on_train_begin()
        
        
        if not steps is None:
            for i in xrange(steps):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(i, generator_training_multiplier,
                                                                                       discriminator_training_multiplier, batch_size)
                    loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(temp_loss_discriminator)}
                    loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(temp_loss_generator)}
                    self.History.on_epoch_end(i,loss_generator+loss_discriminator) # populate history
        elif not steps_per_epoch is None and not epochs is None:
            for k in xrange(epochs):
                for callback in self.callbacks_generator + self.callbacks_discriminator:
                    callback.on_epoch_begin(k)
                loss_discriminator, loss_generator = None, None
                for i in xrange(steps_per_epoch):
                    temp_loss_discriminator, temp_loss_generator = self.__fit_on_batch(i, k, steps_per_epoch,
                                                                                       generator_training_multiplier,
                                                                                       discriminator_training_multiplier, 
                                                                                       batch_size)
                    loss_discriminator = temp_loss_discriminator if loss_discriminator is None else loss_discriminator + temp_loss_discriminator
                    loss_generator = temp_loss_generator if loss_generator is None else loss_generator + temp_loss_generator
                
                loss_discriminator = [item * 1.0/steps_per_epoch for item in loss_discriminator]
                loss_generator = [item * 1.0/steps_per_epoch for item in loss_generator]
                                    
                for callback in self.callbacks_generator:
                    callback.on_epoch_end(k, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})

                for callback in self.callbacks_discriminator:
                    callback.on_epoch_end(k, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

                loss_discriminator = {('discriminator_'+self.__discriminator.metrics_names[i]):item for i,item in enumerate(loss_discriminator)}
                loss_generator = {('generator_'+self.__generator.metrics_names[i]):item for i,item in enumerate(loss_generator)}
                self.History.on_epoch_end(k,loss_generator+loss_discriminator) # populate history
                    
        for callback in self.callbacks_generator + self.callbacks_discriminator:
            callback.on_train_end()

        return self.History


    def __fit_on_batch(self, step, epoch=None, steps_per_epoch=None, generator_training_multiplier=1, discriminator_training_multiplier=1, batch_size=None):
        for callback in self.callbacks_generator + self.callbacks_discriminator:
                    callback.on_train_batch_begin(step, logs={'batch':step, 'size':batch_size})
                
        steps = step if epoch is None else step + (epoch * steps_per_epoch)

        # Generator Training
        loss = None
        for j in xrange(generator_training_multiplier): 
            curr = (steps*generator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % self.x.shape[0] for i in range(curr,nex)]
            temp_loss = self._training_scheme.train_generator(self.__generator, self.x[idxs], self.y[idxs], batch_size)
            loss = temp_loss if loss is None else temp_loss + loss
            
        loss_generator = [item * 1.0/generator_training_multiplier for item in loss]
        if self.verbose: print 'Generator Loss:', loss_generator

        # Discriminator Training
        loss = None
        for j in xrange(discriminator_training_multiplier): 
            curr = (steps*discriminator_training_multiplier + j) * batch_size
            nex = curr + batch_size
            idxs = [i % self.x.shape[0] for i in range(curr,nex)]
            temp_loss = self._training_scheme.train_discriminator(self.__discriminator, self.x[idxs], self.y[idxs], batch_size)
            loss = temp_loss if loss is None else temp_loss + loss

        loss_discriminator = [item * 1.0/discriminator_training_multiplier for item in loss]
        if self.verbose: print 'Discriminator Loss:', loss_discriminator
        
        for callback in self.callbacks_generator: # callbacks for generator
            callback.on_train_batch_end(step, logs={self.__generator.metrics_names[i]:item for i,item in enumerate(loss_generator)})
            
        for callback in self.callbacks_discriminator: # callbacks for discriminator
            callback.on_train_batch_end(step, logs={self.__discriminator.metrics_names[i]:item for i,item in enumerate(loss_discriminator)})

        return loss_discriminator, loss_generator
    
class classifier_training(keras.callbacks.Callback):
    def __init__(self, classifier_model, training_ratio, batch_size):
        self.classifier_model = classifier_model
        self.train_ratio = training_ratio
        self.x, _ ,self.y = load_simulation_data()
        self.batch_size = batch_size
    
    def on_epoch_end(self, epoch, logs=None):
        loss = None
        for i in xrange(self.train_ratio):
            idxs = np.random.randint(0, self.x.shape[0], size=self.batch_size)
            loss = self.classifier_model.train_on_batch(self.x[idxs], self.y[idxs])
        #print 'class loss: ', loss

class log_results(keras.callbacks.Callback):
    def __init__(self, wgan_model=None, classifier_model=None, logging_frequency=0, log_path=''):
        self.logging_frequency = logging_frequency
        self.log_path = log_path
        self.model = wgan_model
        self.classifier_model = classifier_model
        self.x, self.y, _ = load_simulation_data() 
        self.last_best_classifier, self.last_best_generator, self.last_best_combined= 0, 0, 0
    def on_epoch_end(self, epoch, logs=None):
        if self.logging_frequency != 0 and epoch % self.logging_frequency == 0:
            time = np.linspace(0, 28.625, 20610)
            plt.figure(1, figsize=(10,10))
            plt.subplot(211)
            plt.scatter(time,  self.y[:1][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            final_res = self.model.gen_dis().predict_on_batch(self.x[:1])
            plt.subplot(212)
            plt.scatter(time, final_res[0][0, 0, :20610, 0], s=0.5)
            plt.title("Neural Net Output,"+ str(final_res[1][0,0]))
            plt.tight_layout()
            plt.savefig('img_'+str(epoch)+'_test.png')
            plt.show()
            self.get_score(epoch, self.model != None, self.classifier_model != None)
    
    
    def get_score(self, i_in, gen_test = False, class_test = False):
        def find_nearest(array, value):
            array = np.asarray(array)
            idx = (np.abs(array - value)).argmin()
            return idx
        
        at_2perc = 1
        if gen_test:
            periods = np.load('../Data/total_params_sim_test_true_3.npy')[:,1]
            transits_ref = np.load('../Data/total_transits_sim_test_true_3.npy')
            x_test = np.expand_dims(np.load('../Data/total_x_sim_test_true_3.npy'),axis=1)
            x_test = np.pad(x_test, ((0,0), (0,0) ,(0, 30+ 96), (0,0)), 'constant', constant_values=(0, 0))
            print 'X loaded'
            transits = self.model.gen_dis().predict(x_test, verbose=1)[0][:, 0, :20610, :]
            print "Finished predicting data"
            transits_ref = transits_ref[:,:,0]
            scaler = MinMaxScaler(feature_range=(0, 1))
            scaler.fit(transits_ref)
            transits_ref = scaler.transform(transits_ref)
            print("Finished loading data")
            periods = np.power(10, periods)
            period_pred = []
            model_preds = []
            np.warnings.filterwarnings('ignore')
            for i in tqdm(range(10000)): model_preds.append([transits[i, :, 0], periods[i], 1000, transits_ref[i, :]])
            model_preds = np.asarray(imap_unordered_bar(process_transit, model_preds, 5))
            auc_p, percentages, epsilon_range = p_epsilon_chart(model_preds[:, 0], model_preds[:, 1])
            at_1perc = percentages[np.argmin(np.abs(epsilon_range - 0.01))]
            at_2perc = percentages[np.argmin(np.abs(epsilon_range - 0.02))]
            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(epsilon_range, percentages, label='Keras (area = {:.5f})'.format(auc_p))
            plt.xlabel('epsilon')
            plt.ylabel('period detection rate')
            plt.title('Period ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.1)
            plt.savefig('img_PAUC_'+str(i_in)+'_test.png')
            plt.show()

            if(at_2perc > self.last_best_generator): 
                self.model.gen_dis().save('best_generator_test.h5')
                self.classifier_model.save('best_classifier_test_generator_based.h5')
                self.last_best_generator = at_2perc

            if not class_test:
                open(self.log_path,'a').write( 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc))
                print 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc)

        if class_test:
            y_test = np.load('../Data/total_params_sim_test_3.npy')[:,1] > 0
            x_test = np.expand_dims(np.load('../Data/total_x_sim_test_3.npy'),axis=1)
            x_test = np.pad(x_test, ((0,0), (0,0) ,(0, 30+ 96), (0,0)), 'constant', constant_values=(0, 0))
            y_pred = self.classifier_model.predict(x_test, verbose=1)[:,0]
            fpr, tpr, thresholds_keras = roc_curve(y_test, y_pred)
            roc_auc = roc_auc_score(y_test, y_pred)

            if(tpr[find_nearest(fpr, 0.01)] > self.last_best_classifier): 
                self.model.gen_dis().save('best_generator_test_classifier_based.h5')
                self.classifier_model.save('best_classifier_test.h5')
                self.last_best_classifier = tpr[find_nearest(fpr, 0.01)]
            
            if(tpr[find_nearest(fpr, 0.01)] * at_2perc > self.last_best_combined):
                self.model.gen_dis().save('best_generator_combined_test.h5')
                self.classifier_model.save('best_classifier_combined_test.h5')
                self.last_best_combined = tpr[find_nearest(fpr, 0.01)] * at_2perc              
            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(fpr, tpr, label='Keras (area = {:.5f})'.format(roc_auc))
            plt.xlabel('False positive rate')
            plt.ylabel('True positive rate')
            plt.title('Classifier ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.02)
            plt.savefig('img_ClassAUC_'+str(i_in)+'_test.png')
            plt.show()

            if not gen_test:
                open(self.log_path,'a').write( 'i: %d, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc))
                print 'i: %d, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc)
                
        if class_test and gen_test:
            open(self.log_path,'a').write( 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc))
            print 'i: %d, width: 1000, PAUC: %.5f, Percentage at 0.01: %.5f, Percentage at 0.02: %.5f, Class Percentage at 0.001: %.5f, Class Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (i_in,auc_p, at_1perc, at_2perc, tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], roc_auc)

# ------------------------------ Anything below this line is unrelated to the GAN class ------------------------------


# ---------------------------------------- Networks ----------------------------------------

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

def dice_coef(y_true,y_pred):
    y_pred = y_pred > 0.01
    oyt = np.sum((2.0 * y_true * y_pred)) / np.sum(y_true+y_pred)
    print(oyt, np.sum((y_true * y_pred)), np.sum(y_true), np.sum(y_pred))
    return oyt#sum((2 * y_true * y_pred))/ sum(y_true+y_pred)

def intersection_loss(y_true, y_pred):
    return 10 * absolute_true_error(y_true, y_pred) + 3 * intersection_true_error(y_true, y_pred) + absolute_false_error(y_true, y_pred) + intersection_false_error(y_true, y_pred) - dice_coef1(y_true, y_pred)  # + masked_mse(y_true, y_pred)/10
    #return - dice_coef1(y_true, y_pred)



def compile_classifier(classifier, generator, loss, optimizer, metrics, loss_weights, name, reverse_freeze=False):
    for layer in generator.layers: layer.trainable = False
    generator.trainable = False
    inp = Input(tuple(generator.layers[0].input_shape[1:]))
    gen_out = [layer.output for layer in generator.layers[0::8]][1:]
    gen_multi_out = Model(generator.inputs, gen_out)
    classifier_out = classifier(gen_multi_out(inp)+[inp])
    model = Model(inp, classifier_out)
    if reverse_freeze: model.layers[1].trainable = False
    model.compile(loss=loss, optimizer=optimizer, metrics=metrics, loss_weights=loss_weights)
    model.summary()
    plot_model(model, show_shapes=True, to_file='DCGAN_model_classifier.png')
    for layer in generator.layers: layer.trainable = True
    generator.trainable = True
    return model


# ---------------------------------------- Utils ----------------------------------------

class RocAucMetricCallback(keras.callbacks.Callback):
    def __init__(self, predict_batch_size=80, include_on_batch=False):
        super(RocAucMetricCallback, self).__init__()
        self.predict_batch_size = predict_batch_size
        self.include_on_batch = include_on_batch

    def on_train_begin(self, logs={}):
        if not ('roc_auc_val' in self.params['metrics']):
            self.params['metrics'].append('roc_auc_val')


    def on_epoch_end(self, epoch, logs={}):
        logs['roc_auc_val'] = float('-inf')
        if (self.validation_data):
            y_pred = self.model.predict(self.validation_data[0])[:, 0]
            y_test = self.validation_data[1]
            fpr, tpr, thresholds_keras = roc_curve(y_test, y_pred)
            logs['roc_auc_val'] = roc_auc_score(y_test, y_pred)
            print(logs['roc_auc_val'])

            plt.plot([0, 1], [0, 1], 'k--')
            plt.plot(fpr, tpr, label='Keras (area = {:.5f})'.format(logs['roc_auc_val']))
            plt.xlabel('False positive rate')
            plt.ylabel('True positive rate')
            plt.title('ROC curve')
            plt.legend(loc='best')
            plt.ylim(0, 1)
            plt.xlim(0, 0.02)
            plt.show()
            def find_nearest(array, value):
                array = np.asarray(array)
                idx = (np.abs(array - value)).argmin()
                return idx
            print 'Percentage at 0.001: %.5f, Percentage at 0.01: %.5f, Class AUC: %.5f\n' % (tpr[find_nearest(fpr, 0.001)], tpr[find_nearest(fpr, 0.01)], logs['roc_auc_val'])

# ---------------------------------------- Utils ----------------------------------------


# --------------------------------------------- STRuDL ---------------------------------------------

def load_simulation_data(header='', batch_size=128, only_true=False, cross_validation=False):
    x = np.load('../Data/total_x_sim_train_3.npy')
    x = np.expand_dims(x,axis=1)
    x = np.pad(x, ((0, 0), (0, 0), (0, 126), (0, 0)), 'constant', constant_values=(0.0, 0.0))
    print 'X loaded'
    y = np.load('../Data/total_transits_sim_train_3.npy')
    y = y[:, :, 0]
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(y)
    y = scaler.transform(y)
    y = np.expand_dims(y, axis=2)
    ya = y[:, :, 0]
    ya[ya != 0] = 1
    y = np.expand_dims(ya, axis=2)
    y = np.expand_dims(y, axis=1)
    y = np.pad(y, ((0, 0), (0, 0), (0, 30 + 96), (0, 0)), 'constant', constant_values=(0, 0))
    print y.shape
    print 'Y loaded'
    print(x.shape, y.shape)
    has_transit = np.load('../Data/total_params_sim_train_3.npy')[:,1] != 0
    print np.sum(has_transit), has_transit.shape
    print 'Params loaded'
    print("Finished Loading Data")
    if only_true:
        x = x[has_transit]
        y = y[has_transit]
    if cross_validation:
        return train_test_split(x, y, test_size=0.2, random_state=0)
    return x, y, has_transit

def p_epsilon_chart(p_test, p_pred):
    percentages = []
    auc_p = 0
    epsilon_range = np.linspace(0, 1, 10000)
    for epsilon in epsilon_range:
        current_correct = p_pred[np.abs(1 - (p_pred / p_test)) < epsilon]
        percentages.append(float(current_correct.shape[0]) / float(p_pred.shape[0]))
        auc_p += float(percentages[-1]) / 10000
    return auc_p, percentages, epsilon_range

# --------------------------------------------- STRuDL ---------------------------------------------

import pandas as pd

def load_data_now(header='', batch_size=128):
    x = np.expand_dims(np.load('../Data/total_x_sim_3.npy'), axis=1)
    x = np.pad(x, ((0, 0), (0, 0), (0, 30 + 96), (0, 0)), 'constant', constant_values=(0, 0))
    print 'X loaded'

    params = np.load('../Data/total_params_sim_3.npy')
    print 'Params loaded'
    return x, 0, params




In [2]:
def residual_block_v2(inx, filters, kernel, activation, pooling):
    x = BatchNormalization()(inx) # full pre activation
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = se_block(x, filters, 4) # Squeeze and Excitation Network
    
    x_k = add([x, inx]) # Resnet
    
    x = BatchNormalization()(x_k)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), strides=(1, pooling))(x) # convolutional Pool
    return x, x_k # UNET

def residual_block_up_v2(inx, x_k, filters, kernel, activation, pooling):
    x = BatchNormalization()(inx) # full pre activation
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = se_block(x, filters, 4) # Squeeze and Excitation Network
    x = add([x, x_k, inx]) # UNET with Resnet
    x = Conv2DTranspose(filters, (1, kernel), strides=(1, pooling), padding='same')(x) # deconvolutional network
    return x

def residual_block(inx, filters, kernel, activation, pooling):
    x = Conv2D(filters, (1, kernel), padding='same')(inx)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = add([x, inx])
    x_k = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, pooling), strides=(1, pooling))(x_k)
    return x, x_k

def residual_block_up(inx, x_k, filters, kernel, activation, pooling):
    x = Conv2D(filters, (1, kernel), padding='same')(inx)
    x = BatchNormalization()(x)
    x = Activation(activation=activation)(x)
    x = Conv2D(filters, (1, kernel), padding='same')(x)
    x = BatchNormalization()(x)
    x = add([x, x_k, inx])
    x = Activation(activation=activation)(x)
    x = UpSampling2D((1, pooling))(x)
    return x

def se_block(inx, ch, ratio=16):
    x = GlobalAveragePooling2D()(inx)
    x = Dense(ch//ratio, activation='relu')(x)
    x = Dense(ch, activation='sigmoid')(x)
    return multiply()([inx, x])


In [None]:
def Generator(input_shape):
    gen_inputs = Input(shape=input_shape)
    x = Conv2D(32, (1, 5), padding='same')(gen_inputs)
    x, x_k_1 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_2 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_3 = residual_block_v2(x, 32, 5, 'relu', 2)
    x, x_k_4 = residual_block_v2(x, 32, 5, 'relu', 2)
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, (1, 5), padding='same')(x)
    
    
    x, x_k_5 = residual_block_v2(x, 64, 5, 'relu', 2)
    x, x_k_6 = residual_block_v2(x, 64, 5, 'relu', 2)
    x, x_k_7 = residual_block_v2(x, 64, 5, 'relu', 2)
    y, x_k_8 = residual_block_v2(x, 64, 5, 'relu', 2)
    
    x = BatchNormalization()(y)
    x = Activation('relu')(x)
    x = Conv2D(64, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    
    decoder = x
    x = Conv2D(64, (1, 5), padding='same')(decoder)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m1 = x

    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    m2 = x 

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Dropout(0.25)(x)

    x = Conv2D(256, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = concatenate([m2,x])
    x = Conv2D(128, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = concatenate([m1,x])
    x = Conv2D(64, (1, 5), padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = residual_block_up_v2(x, y, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_8, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_7, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_6, 64, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_5, 64, 5, 'relu', 2)
    
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(32, (1, 5), padding='same')(x)
    
    x = residual_block_up_v2(x, x_k_4, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_3, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_2, 32, 5, 'relu', 2)
    x = residual_block_up_v2(x, x_k_1, 32, 5, 'relu', 1)
    
    _, x = residual_block_v2(x, 32, 5, 'relu', 1)
    
    gen_outputs = Conv2D(1, (1, 1), activation='sigmoid', padding='same', name='gen_output')(x)


In [None]:
def Discriminator(input_shape):
    dis_inputs = Input(shape=input_shape)
    x = Conv2D(64, (1, 5), strides=(1,2), padding='same')(dis_inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(64, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(256, (1, 5), kernel_initializer='he_normal', strides=(1,2), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = GlobalMaxPooling2D()(x)
    x = Dropout(0.25)(x)
    x = Dense(128, activation = 'relu')(x)
    x = Dropout(0.25)(x)
    x = Dense(128, activation = 'relu')(x)
    x = Dropout(0.25)(x)
    dis_outputs = Dense(1, kernel_initializer='he_normal', name='dis_output')(x)


In [None]:
def Classifier(input_shape):
    class_in = Input(shape=input_shape)
    class_inputs = []
    class_lengths = [32, 32, 32, 32, 64, 64, 64, 64]
    for i in xrange(8):
        class_inputs.append(Input(shape=(K.int_shape(gen_inputs)[1], K.int_shape(gen_inputs)[2]//(2**i), class_lengths[i])))
    
    z = Conv2D(32, (1, 5), padding='same')(class_in)
    z = Add()([z, class_inputs[0]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[1]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[2]])

    z, _ = residual_block_v2(z, 32, 5, 'relu', 2)
    z = Add()([z, class_inputs[3]])

    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2D(64, (1, 5), padding='same')(z)

    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[4]])

    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[5]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[6]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z = Add()([z, class_inputs[7]])
    
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    z, _ = residual_block_v2(z, 64, 5, 'relu', 2)
    
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2D(128, (1, 5), kernel_initializer='he_normal', padding='same')(z)
    z = BatchNormalization()(z)
    z = LeakyReLU(0.2)(z)
    z = Dropout(0.25)(z)

    z = Conv2D(256, (1, 5), kernel_initializer='he_normal', padding='same')(z)
    z = BatchNormalization()(z)
    z = LeakyReLU(0.2)(z)
    z = Dropout(0.25)(z)


    z = GlobalMaxPooling2D()(z)
    z = Dense(256, activation='relu')(z)
    z = Dropout(0.4)(z)
    z = Dense(256, activation='relu')(z)
    z = Dropout(0.4)(z)

    class_out = Dense(1, activation='sigmoid')(z)


In [44]:
class log_results_multi(keras.callbacks.Callback):
    def __init__(self, wgan_model=None, logging_frequency=0, log_path='', x_test=None, y_test=None, myslice=slice(0,1), network_name='', save_path='best_generator_test_sectors.h5'):
        self.LOGDIR = 'TrainingLogs'
        self.network_name = network_name
        if not os.path.exists(self.LOGDIR): os.makedirs(self.LOGDIR)
        if not os.path.exists(os.path.join(self.LOGDIR, self.network_name)): os.makedirs(os.path.join(self.LOGDIR, self.network_name))
        self.logging_frequency = logging_frequency
        self.log_path = log_path
        self.model = wgan_model
        self.save_path = save_path
        self.x, self.y, _ = load_simulation_data() 
        self.x_test, self.y_test = x_test, y_test 
        self.last_best_generator, self.last_best_combined= 0, 0
        self.myslice = myslice
        
    def on_epoch_end(self, epoch, logs=None):
        if self.logging_frequency != 0 and epoch % self.logging_frequency == 0:
            time = np.linspace(0, 28.625, 20610)
            plt.figure(1, figsize=(15,10))
            plt.subplot(311)
            plt.scatter(time,  self.x_test[self.myslice][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            plt.subplot(312)
            plt.scatter(time,  self.y_test[self.myslice][0, 0, :20610, 0], s=0.5)
            plt.title("Simulation Output, " + str(epoch))

            final_res = self.model.gen_dis().predict_on_batch(self.x_test[self.myslice])
            plt.subplot(313)
            plt.scatter(time, final_res[0][0, 0, :20610, 0], s=0.5)
            plt.title("Neural Net Output,"+ str(final_res[1][0,0]))
            plt.tight_layout()
            plt.savefig(os.path.join(self.LOGDIR, self.network_name, 'img_' + str(epoch) + '_test.png'))
            #plt.show()
            plt.clf()
            plt.close()
            print "epoch: %d current pic dice coeff: %f" %(epoch,dice_coef(self.y_test[self.myslice][0, 0, :20610, 0], final_res[0][0, 0, :20610, 0]))
            self.get_score(epoch, self.model != None)
    
    
    def get_score(self, i_in, gen_test = False):
        def find_nearest(array, value):
            array = np.asarray(array)
            idx = (np.abs(array - value)).argmin()
            return idx
        
        if gen_test:
            transits = self.model.gen_dis().predict(self.x_test, verbose=1)[0][:,0,:20610,0]
            transits_ref = self.y_test[:,0,:20610,0]
            print(transits.shape, transits_ref.shape)
            np.warnings.filterwarnings('ignore')
            dice_coeff = dice_coef(transits_ref, transits)
            
            if(dice_coeff > self.last_best_generator): 
                self.model.gen_dis().save(os.path.join(self.LOGDIR, self.network_name, self.save_path))
                self.last_best_generator = dice_coeff 
            open(os.path.join(self.LOGDIR, self.network_name, self.log_path),'a').write('epoch: %d ,dice coef: %f\n' % (i_in, dice_coeff))
            print('epoch: %d ,dice coef: %f\n' % (i_in, dice_coeff))

#print 'loaded GAN'
#model = GAN(gen_inputs, gen_outputs, Adam(1e-5, beta_1=0.5, beta_2=0.9), intersection_loss, dis_inputs, dis_outputs, class_in, class_out, dis_metrics=['accuracy'], gen_dis_loss_ratio = 0.75, gen_metrics=[dice_coef1, masked_mse])
#model.gen_dis().load_weights('./best_generator_test.h5')

In [48]:

class SpaceGanFCN(object):
    def __init__(self, gen_weights=None, class_weights=None, save_gen_weights=None, save_class_weights=None, lr=1e-4):
        model = GAN(gen_inputs, gen_outputs, Adam(lr, beta_1=0.5, beta_2=0.9), intersection_loss, dis_inputs, dis_outputs, class_in, class_out, dis_metrics=['accuracy'], gen_dis_loss_ratio = 0.75, gen_metrics=[dice_coef1, masked_mse])    
        if gen_weights: model.gen_dis().load_weights(gen_weights)        
        print 'loaded GAN'
        self.model = model
        self.gen_dis = model.gen_dis
        self.classifier = Model(class_inputs + [class_in], class_out, 'classifier')
        self.classifier_model = compile_classifier(self.classifier, model.encoder(), 'binary_crossentropy', Adam(), ['accuracy'], None, 'classifier', reverse_freeze=True)
        if class_weights: self.classifier_model.load_weights(class_weights)        
            
    def train(self, x_train, y_train, x_test, y_test, log_files=['sector_training.txt','sector_training.txt'], epochs=100001):
        log_callback_train = log_results_multi(self.model, 100, log_files[0], network_name='FCN', x_test=x_train, y_test=y_train, save_path=log_files[0][:-4]+'.h5')
        log_callback_test = log_results_multi(self.model, 100, log_files[1], network_name='FCN', x_test=x_test, y_test=y_test, save_path=log_files[1][:-4]+'.h5')
        self.model.fit(x_train, y_train, gen_overtraining_multiplier=1, verbose=False, epochs=epochs, batch_size=32, callbacks=[log_callback_train, log_callback_test])    
    
    def generate_on_batch(self, batch):
        return self.model.gen_dis().predict_on_batch(batch)

    def generate(self, batch):
        return self.model.gen_dis().predict(batch)
        
    def classify(self, batch):
        return self.classifier_model.predict(batch)

    def plot_prediction(self, x, y):
        time = np.linspace(0, 28.625, 20610)
        plt.figure(1, figsize=(10,10))
        plt.subplot(311)
        plt.scatter(time,  x[0, 0, :20610, 0], s=0.5)
        plt.title("Simulation Output, " + str(-1))

        plt.subplot(312)
        plt.scatter(time,  y[0, 0, :20610, 0], s=0.5)
        plt.title("Simulation Output, " + str(-1))

        final_res = self.model.gen_dis().predict_on_batch(x)
        plt.subplot(313)
        plt.scatter(time, final_res[0][0, 0, :20610, 0], s=0.5)
        plt.title("Neural Net Output,"+ str(final_res[1][0,0]))
        plt.tight_layout()
        plt.show()

    

# ---------- Tests ----------

# x_train, y_train, x_test, y_test = load_nasa_tces()
#x, y, has_transits = load_simulation_data(only_true=True)

# model = SpaceGanFCN(gen_weights='./best_generator_test.h5')
# model.train(x_train, y_train, x_test, y_test, log_file='sector_training.txt')
# final_res = model.generate_on_batch(x_test[myslice])

# model = SpaceGanFCN(gen_weights='./best_generator_test_sectors.h5', class_weights='best_classifier_test_generator_based.h5')
# model.train(x_train, y_train, x_test, y_test, log_file='sector_training.txt')

# ---------- Tests ----------

