# GAN for Cochlear HC Generation
### Cole Krudwig

In [None]:
!mkdir data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
import os
import numpy as np
import time

# Constants for generating 256x256 images
# Need to change to your needs
num_examples_to_generate = 1
data_dir = 'data'
IMG_HEIGHT, IMG_WIDTH = 256, 256
BATCH_SIZE = 32
noise_dim = 200
EPOCHS = 50000
save_path = '/content/drive/My Drive/Colab Generated Images 2/'

os.makedirs(save_path, exist_ok=True)

def generate_and_save_images(model, epoch, test_input):
    prediction = model(test_input, training=False)[0]
    plt.figure(figsize=(4, 4))
    plt.imshow((prediction * 127.5 + 127.5).numpy().astype(np.uint8))
    plt.axis('off')
    plt.savefig(f'{save_path}/image_at_epoch_{epoch:04d}.png')
    plt.close()

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    label_mode=None,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    color_mode='rgb',
    shuffle=True,
)

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset=-1)
train_dataset = train_dataset.map(lambda x: normalization_layer(x))

# Generator model for 256x256 images
def make_generator_model():
    model = tf.keras.Sequential([
        layers.Dense(64*64*128, use_bias=False, input_shape=(noise_dim,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((64, 64, 128)),
        layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')
    ])
    return model

# Discriminator model for 256x256 images
def make_discriminator_model():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[256, 256, 3]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

generator = make_generator_model()
discriminator = make_discriminator_model()

# Loss and optimizer
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Loss functions
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# Training step
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Training function
def train(dataset, epochs, test_input):
    with open(f'{save_path}/training_metrics.txt', 'w') as f:
        f.write("Epoch,Generator Loss,Discriminator Loss\n")

    for epoch in range(epochs):
        start = time.time()
        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)

        # Save images and losses every 100 epochs
        if (epoch + 1) % 100 == 0:
            display.clear_output(wait=True)
            generate_and_save_images(generator, epoch + 1, test_input)
            with open(f'{save_path}/training_metrics.txt', 'a') as f:
                f.write(f"{epoch+1},{gen_loss.numpy()},{disc_loss.numpy()}\n")
            print(f'Time for epoch {epoch + 1} is {time.time() - start} sec')

    display.clear_output(wait=True)
    generate_and_save_images(generator, epochs, test_input)  # Final image generation

seed = tf.random.normal([num_examples_to_generate, noise_dim])

# Start training
train(train_dataset, EPOCHS, seed)


Time for epoch 50000 is 0.5242900848388672 sec
