### followed : https://www.tensorflow.org/tutorials/generative/dcgan

In [None]:
import models
import tensorflow as tf 
import os
import time
import numpy as np

In [None]:
BUFFER_SIZE = 60000
BATCH_SIZE = 64

## Data preprocessing

In [None]:
# TODO
def serialize_example(signal, label):
    # Convert to tf.train.Example
    feature = {
        'signal': tf.train.Feature(float_list=tf.train.FloatList(value=signal.flatten())),
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

if ("tfrecords" not in os.listdir("./")): # no tf records 

    
    for subject in os.listdir("./data"): # iterate over subjects
        with tf.io.TFRecordWriter(f'{subject}.tfrecord') as writer: # add tf_record for subject
            with open("./data/"+subject) as data_file:
                for events in data_file.readlines():
                        event_data = np.array(events.split(","),dtype=float)
                        signal = np.array(event_data[3:])
                        label= np.array(event_data[0])
                        serialized = serialize_example(signal, label)
                        writer.write(serialize_example)


                        
                    


# building train dataset 
dataset = tf.data.Dataset.list_files('data/tfrecords/*.tfrecord') # make data set 
train_dataset = tf.data.Dataset.from_tensor_slices(dataset  ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

## Model Initilization

In [None]:
generator = models.build_generator()
critic = models.build_critic()



### Loss and optimizers

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def critic_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output) 
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output) 



generator_optimizer = tf.keras.optimizers.Adam(1e-4)
critic_optimizer = tf.keras.optimizers.Adam(1e-4)

### Train utils

In [None]:
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16



@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = critic(images, training=True)
      fake_output = critic(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = critic_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, critic.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    critic_optimizer.apply_gradients(zip(gradients_of_discriminator, critic.trainable_variables))


## Training loop

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=critic_optimizer,
                                 generator=generator,
                                 discriminator=critic)


def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
   
    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    with open("gan_log.txt",'a') as f:
      f.write('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))


  
