In [None]:
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], '../modules'))
from data_manipulation import *
tf.reset_default_graph()


import numpy as np

class Data_Creator_GAN(object):

    def __init__(self,
                 num_flatnesses,
                 bl_data = None,
                 bl_dict = None,
                 gains = None):
        
        self._num = num_flatnesses
                    
        self._bl_data = bl_data
        self._bl_data_c = None
        
        self._bl_dict = bl_dict
        
        self._gains = gains
        self._gains_c = None
        
        self._epoch_batch = []

    def gen_data(self):
        
        self._thread = Thread(target = self._gen_data, args=())
        self._thread.start()

    def get_data(self, timeout = 10):

        if len(self._epoch_batch) == 0:
            self._thread.join(timeout)
            
        return self._epoch_batch.pop(0)

    def _gen_data(self):

        angle_tx  = lambda x: (np.asarray(x) + np.pi) / (2. * np.pi)
        angle_itx = lambda x: np.asarray(x) * 2. * np.pi - np.pi

        if self._bl_data_c == None:
            self._bl_data_c = {key : self._bl_data[key].conjugate() for key in self._bl_data.keys()}

        if self._gains_c == None:
            self._gains_c = {key : self._gains[key].conjugate() for key in self._gains.keys()}


        def _flatness(seps):
            """Create a flatness from a given pair of seperations, their data & their gains."""

            a, b = seps[0][0], seps[0][1]
            c, d = seps[1][0], seps[1][1]


            return self._bl_data[seps[0]]   * self._gains_c[(a,'x')] * self._gains[(b,'x')] * \
                   self._bl_data_c[seps[1]] * self._gains[(c,'x')]   * self._gains_c[(d,'x')]

        inputs = []
        for _ in range(self._num):

            unique_baseline = random.sample(self._bl_dict.keys(), 1)[0]
            two_seps = [random.sample(self._bl_dict[unique_baseline], 2)][0]

            inputs.append(_flatness(two_seps))
            

        inputs = np.angle(np.array(inputs).reshape(-1,1024))
        

        self._epoch_batch.append(angle_tx(inputs))


def model_opt(d_loss, g_loss):
    """
    Get optimization operations
    """
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
    g_vars = [var for var in t_vars if var.name.startswith('generator')]

    # Optimize
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 
        d_train_opt = tf.train.AdamOptimizer().minimize(d_loss, var_list=d_vars)
        g_train_opt = tf.train.AdamOptimizer().minimize(g_loss, var_list=g_vars)

    return d_train_opt, g_train_opt

def model_inputs():
    """
    Create the model inputs
    """
    inputs_real = tf.placeholder(tf.float32, shape=(None, 60, 1024, 1), name='input_real') 
    inputs_z = tf.placeholder(tf.float32, (None, 4096), name='input_z')
    
    return inputs_real, inputs_z

def discriminator(inputs_real, reuse=False, is_training = True):
    """
    Create the discriminator network
    """

    with tf.variable_scope('discriminator', reuse = reuse):
        with tf.variable_scope('conv_1'):
            conv_1 = tf.layers.conv2d(inputs_real, 2, (3, 3), (2, 2), 'SAME')
            conv_1 = tf.layers.batch_normalization(conv_1, training = is_training)
            conv_1 = tf.nn.leaky_relu(conv_1)

        with tf.variable_scope('conv_2'):
            conv_2 = tf.layers.conv2d(conv_1, 2, (5, 5), (2, 2), 'SAME')
            conv_2 = tf.layers.batch_normalization(conv_2, training = is_training)
            conv_2 = tf.nn.leaky_relu(conv_2)
            # width = 512, 2 filters

        with tf.variable_scope('conv_3'):
            conv_3 = tf.layers.conv2d(conv_2, 4, (3, 3), (2, 2), 'SAME')
            conv_3 = tf.layers.batch_normalization(conv_3, training = is_training)
            conv_3 = tf.nn.leaky_relu(conv_3) 
            # width = 512, 4 filters

        with tf.variable_scope('conv_4'):
            conv_4 = tf.layers.conv2d(conv_3, 4, (5, 5), (2, 2), 'SAME')
            conv_4 = tf.layers.batch_normalization(conv_4, training = is_training)
            conv_4 = tf.nn.leaky_relu(conv_4) 
            # width = 256, 4 filters

        with tf.variable_scope('output'):
            logits = tf.layers.dense(tf.layers.flatten(conv_4), 1)
            disc_out = tf.nn.softmax(logits)
    return disc_out, logits


#inputs_z = tf.placeholder(tf.float32, (1, 2048), name='input_z')
def generator(z, is_training = True):
    """
    Create the generator network
    """
    
    with tf.variable_scope('generator', reuse = False if is_training == True else True):
        with tf.variable_scope('deconv_1'):
            deconv_1 = tf.layers.dense(z, 2048)
            deconv_1 = tf.reshape(deconv_1, (-1, 1, 1, 2048))
            deconv_1 = tf.layers.conv2d_transpose(deconv_1, 1792, (5, 5), (1, 3), 'SAME')
            deconv_1 = tf.layers.batch_normalization(deconv_1, training = is_training)
            deconv_1 = tf.nn.leaky_relu(deconv_1) 


        with tf.variable_scope('deconv_2'):

            deconv_2 = tf.layers.conv2d_transpose(deconv_1, 1536, (3, 3), (1, 2), 'SAME')
            deconv_2 = tf.layers.batch_normalization(deconv_2, training = is_training)
            deconv_2 = tf.nn.leaky_relu(deconv_2) 

        with tf.variable_scope('deconv_3'):

            deconv_3 = tf.layers.conv2d_transpose(deconv_2, 1280, (5, 5), (1, 2), 'SAME')
            deconv_3 = tf.layers.batch_normalization(deconv_3, training = is_training)
            deconv_3 = tf.nn.leaky_relu(deconv_3) 

        with tf.variable_scope('deconv_4'):

            deconv_4 = tf.layers.conv2d_transpose(deconv_3, 1024, (3, 3), (1, 5), 'SAME')
            deconv_4 = tf.layers.batch_normalization(deconv_4, training = is_training)
            deconv_4 = tf.nn.leaky_relu(deconv_4) 

        with tf.variable_scope('out'):
            gen_out = tf.reshape(deconv_4, (-1, 60, 1024, 1))
    
    return gen_out



def model_loss(input_real, input_z):
    """
    Get the loss for the discriminator and generator
    """
    
    label_smoothing = 0.9
    
    g_model = generator(input_z)
    d_model_real, d_logits_real = discriminator(input_real)
    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)
    
    d_loss_real = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                labels=tf.ones_like(d_model_real) * label_smoothing))
    d_loss_fake = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                labels=tf.zeros_like(d_model_fake)))
    
    d_loss = d_loss_real + d_loss_fake
                                                  
    g_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                labels=tf.ones_like(d_model_fake) * label_smoothing))
    
    
    return d_loss, g_loss

def train(num_epochs, batch_size, bl_data, bl_dict, gains):
    """
    Train the GAN
    """
    input_real, input_z = model_inputs()
    d_loss, g_loss  = model_loss(input_real, input_z)
    d_opt, g_opt = model_opt(d_loss, g_loss)
    
    steps = 0
    
    train_batcher = Data_Creator_GAN(1,
                                     bl_data,
                                     bl_dict,
                                     gains)
    train_batcher.gen_data()

    
    with tf.Session() as sess:
        
        sess.run(tf.global_variables_initializer())
        
        for epoch in range(num_epochs):
            print('epoch {}'.format(epoch))
            
            batches_real = train_batcher.get_data(); train_batcher.gen_data()
            num_entries = batches_real.shape[0]
            
            for j in range(int(num_entries/batch_size)):
                
                batch_real = batches_real[j*batch_size:(j + 1)*batch_size].reshape(-1,60,1024,1)
                
                steps += 1
            
                batch_z = np.random.uniform(0, 1, size=(batch_size, 4096))
                
            _ = sess.run(d_opt, feed_dict={input_real: batch_real, input_z: batch_z})
            _ = sess.run(g_opt, feed_dict={input_real: batch_real, input_z: batch_z})

            # At the end of every 10 epochs, get the losses and print them out
            train_loss_d = d_loss.eval({input_z: batch_z, input_real: batch_real})
            train_loss_g = g_loss.eval({input_z: batch_z})

            print("Epoch {}...".format(epoch),
                  "Discriminator Loss: {:.4f}...".format(train_loss_d),
                  "Generator Loss: {:.4f}".format(train_loss_g))

        samples = sess.run(generator(input_z, False),
                           feed_dict={input_z: batch_z})
    return samples, batch_real

test = train(10,60,training_baselines_data,training_redundant_baselines_dict, gains)