### followed : https://www.tensorflow.org/tutorials/generative/dcgan

In [4]:
import models
import tensorflow as tf 
from tensorflow.keras import layers, Model, Input

import os
import time
import numpy as np

In [5]:
BUFFER_SIZE = 60000
BATCH_SIZE = 64

## Data preprocessing

In [6]:
# TODO
def serialize_example(signal, label):
    # Convert to tf.train.Example
    feature = {
        'signal': tf.train.Feature(float_list=tf.train.FloatList(value=signal.flatten())),
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

if ("tfrecords" not in os.listdir("./")): # no tf records
    os.mkdir("./tfrecords") 

    
    for idx, subject in enumerate(os.listdir("./data")): # iterate over subjects
        print(f"{idx} of 91 ({subject})")
        with tf.io.TFRecordWriter(f'./tfrecords/pos_{subject}.tfrecord') as pos_writer: # add tf_record for subject
            with tf.io.TFRecordWriter(f'./tfrecords/neg_{subject}.tfrecord') as neg_writer: # add tf_record for subject

                with open("./data/"+subject) as data_file:
                    for event in data_file.readlines():
                            event_data = np.array(event.split(","),dtype=float)
                            signal = np.array(event_data[3:])
                            label= int(event[0])
                            serialized = serialize_example(signal, label)
                            if label == 0:
                                neg_writer.write(serialized)
                            else:
                                pos_writer.write(serialized)


                        
                    
def parse_example(example_proto):
    # Define the structure of the data (must match how it was written!)
    feature_description = {
        'signal': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
        'label': tf.io.FixedLenFeature([], tf.int64),
    }
    
    # Parse the serialized data into a dictionary of tensors
    parsed = tf.io.parse_single_example(example_proto, feature_description)

    # Return a tuple (signal, label)
    return parsed['signal'], parsed['label']

raw_dataset = tf.data.TFRecordDataset(tf.io.gfile.glob('./tfrecords/pos_*.tfrecord'))

# Parse before shuffling or batching
parsed_dataset = raw_dataset.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)

pos_train_dataset = parsed_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

raw_dataset = tf.data.TFRecordDataset(tf.io.gfile.glob('./tfrecords/pos_*.tfrecord'))

# Parse before shuffling or batching
parsed_dataset = raw_dataset.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)

neg_train_dataset = parsed_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)


2025-05-09 13:11:01.320431: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-09 13:11:01.907944: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15134 MB memory:  -> device: 0, name: Quadro P5000, pci bus id: 0000:65:00.0, compute capability: 6.1


## Model Initilization

In [7]:


generator = models.build_generator()
critic = models.build_critic()




### Loss and optimizers

In [8]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def critic_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output) 
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) 
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output) 



generator_optimizer = tf.keras.optimizers.Adam(1e-4)
critic_optimizer = tf.keras.optimizers.Adam(1e-4)

### Train utils

In [15]:
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16



@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = tf.reshape(images, [-1, 10, 640,1])
    print(images.shape)
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = critic(images, training=True)

      fake_output = critic(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = critic_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, critic.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    critic_optimizer.apply_gradients(zip(gradients_of_discriminator, critic.trainable_variables))


## Training loop

In [10]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=critic_optimizer,
                                 generator=generator,
                                 discriminator=critic)


def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch[0]) # only need signal for gan traing

    # Produce images for the GIF as you go
   
    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    with open("gan_log.txt",'a') as f:
      f.write('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))


  


In [16]:
train(pos_train_dataset, EPOCHS)

(64, 10, 640, 1)


  return dispatch_target(*args, **kwargs)


KeyboardInterrupt: 