GAN MODEL : Generates detector level images for Halo photon/(Prompt Photon). A generator and a discriminator is trained for the purpose

In [None]:
import tensorflow as tf

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
import time
from tensorflow.keras.models import load_model

from IPython import display

Creating directories to save the model weights and assets 

In [None]:

pdir = "./Halo"
def save_checkpoint(model,epoch):
    epo = "halo_gen_model_at_epoch"+str(epoch)
    sdir = os.path.join(pdir,epo)
    tf.keras.models.save_model(model, sdir)


Loading the data set and some pre-processing 

In [None]:
data_set = np.loadtxt("halo_data.csv", delimiter = ",")
x_train = data_set[:,0:81]
X_Net = x_train.reshape(x_train.shape[0],9,9,1).astype('float32')/127.5 - 1
BUFFER_SIZE = data_set.shape[0]
BATCH_SIZE = 150
# data_set.shape
train_dataset = tf.data.Dataset.from_tensor_slices(X_Net).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Defining the generator model

In [None]:
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(9*9*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((9, 9, 256)))
    assert model.output_shape == (None, 9, 9, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (3, 3), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 9, 9, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (3, 3), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 9, 9, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 9, 9, 1)

    return model

Creating an instance of Generator using the helper function

In [None]:
generator = make_generator_model()
# generator.compile(generator_optimizer, cross_entropy)
# noise = tf.random.normal([1, 100])
# generated_image = generator(noise, training=False)


# plt.imshow(generated_image[0, :, :, 0], cmap='gray')

# print(type(generated_image))

Defining the Discriminator function

In [None]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (3, 3), strides=(1, 1), padding='same',
                                     input_shape=[9, 9, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (3,3), strides=(1, 1), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Creating a discriminator instance

In [None]:
discriminator = make_discriminator_model()
# decision = discriminator(generated_image)
# print (decision)

In [None]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

The discriminator's job is to tell a generated image from a real one. We define two trvial labels 1 and 0 to internally represent real and fake images.
We then use a binary cross entropy loss function. Recall that a binary cross entropy takes in two parameters: y = the target, and y' = the internal prediction for the target. 
the real_loss and the fake_loss are defined as the binary cross entropy with the target (ones_like/zero_like produce a tensor of the same rank as the given tensor) and the real/generated image

In [None]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

The generator needs to improve its efficiency if it is to cheat the discriminator. Its loss is calculated against the fake output in order to improve its generated images

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
#Defining the optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
#Defining checkpoint and checkpoint directory  

checkpoint_dir = pdir+'/training_checkpoints_for_prompt'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


In [None]:
EPOCHS = 500
noise_dim = 100
num_examples_to_generate = 16

# We will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
def train(dataset, epochs):  
    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            train_step(image_batch)

           # Produce images for the GIF as we go
#         display.clear_output(wait=True)
        generate_and_save_images(generator,
                             epoch + 1,seed)

          # Save the model every 15 epochs
        if (epoch + 1) % 1 == 0:
            save_checkpoint(generator,epoch+1)
            checkpoint.save(file_prefix = checkpoint_prefix)

        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    # Generate after the final epoch
#     display.clear_output(wait=True)
    generate_and_save_images(generator,
                           epochs,
                           seed)

In [None]:
if not os.path.isdir(pdir+"/images"):
    qdir = "images"
    os.makedirs(os.path.join(pdir,qdir))

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)
 
    fig = plt.figure(figsize=(4,4))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
    imagefile = os.path.join(pdir,"images","image_at_epoch_{:04d}.png".format(epoch))
    plt.savefig(imagefile)


In [None]:
# %%time
train(train_dataset, EPOCHS)


In [None]:
# model1 = tf.keras.models.load_model("Prompt/prompt_gen_model_at_epoch6/")
# noise = tf.random.normal([1, 100])
# generated_image = model1(noise, training=False)


# plt.imshow(generated_image[0, :, :, 0], cmap='gray')
 