# Set up the model

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers

In [None]:
# Check that TensorFlow can see the GPU
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)


In [None]:
# Build the model containing the generator followed by the discriminator. 
# This model was originally designed for Fashion MINST


codings_size = 100

generator = keras.models.Sequential([
    keras.layers.Dense(25 * 75 * 128, input_shape=[codings_size]),
    keras.layers.Reshape([25, 75, 128]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="SAME",
                                 activation="selu"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(1, kernel_size=3, strides=2, padding="SAME",
                                 activation="tanh"),
])

discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2),
                        input_shape=[100, 300, 1]),
    keras.layers.Dropout(0.4),
    keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2)),
    keras.layers.Dropout(0.4),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation="sigmoid")
])

gan = keras.models.Sequential([generator, discriminator])

In [None]:
## KERAS MODELS
latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(16 * 32 * 128),
        layers.Reshape((16, 32, 128)),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, kernel_size=5, padding="same", activation="sigmoid"),
    ],
    name="generator",
)
generator.summary()

In [None]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(128, 256, 1)),
        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()

In [None]:
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss='binary_crossentropy', optimizer="rmsprop")
gan.summary()

# Load in the data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Used to unzip a data set loaded to the drive
import zipfile

with zipfile.ZipFile("/content/drive/MyDrive/MotifGAN Capstone/spec_solo_dist_noise_1.zip") as z:
  z.extractall()

In [None]:
os.mkdir("spec")
os.rename("negative", "spec/negative")
os.rename("positive", "spec/positive")

In [None]:
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

pos_img_name = random.choice(os.listdir('./spec/positive'))
neg_img_name = random.choice(os.listdir('./spec/negative'))

print(pos_img_name)
pos_img = mpimg.imread(f'./spec/positive/{pos_img_name}')
plt.figure()
plt.imshow(pos_img)
print(neg_img_name)
neg_img = mpimg.imread(f'./spec/negative/{neg_img_name}')
plt.figure()
plt.imshow(neg_img)

# Create the Dataset

In [None]:
import os
# Load dataset from directory with keras

# sometimes a directory called .ipynb_checkpoints is present, remove it if so
if os.path.isdir('./spec/.ipynb_checkpoints'):
  os.removedirs("./spec/.ipynb_checkpoints")

train_dir = './spec/'
BATCH_SIZE = 16

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    color_mode="grayscale",
    label_mode=None,
    batch_size=BATCH_SIZE, 
    image_size=(128,256),
    seed=123
)

dataset = train_ds.map(lambda x: x / 255.0)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Straight from the microsoft tutorial 
# shows the spectrograms with their labels
plt.figure(figsize=(50, 50))
for images in train_ds.take(1):
    for i in range(len(images)):
        ax = plt.subplot(3, 5, i + 1)
        plt.imshow(images[i].numpy().astype(np.float64))
        plt.title(f"img {i}")
        plt.axis("off")

# Training
https://towardsdatascience.com/generative-adversarial-network-gan-for-dummies-a-step-by-step-tutorial-fdefff170391

1. Select a number of real images from the training set.
2. Generate a number of fake images. This is done by sampling random noise vectors and creating images from them using the generator
3. Train the discriminator for one or more epochs using both fake and real images. This will update on the discrimators weights by labeling all the real images as 1 and the fake images as 0.
4. Generate another number of fake images
5. Train the full GAN model for one or more epochs using only fake images. This will update only the generator's weights by labeling all fake images as 1. 

**SOURCE**: Link above

In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

In [None]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = keras.preprocessing.image.array_to_img(generated_images[i])
            img.save("output/generated_img_%03d_%d.png" % (epoch, i))

In [None]:
epochs = 1000   # In practice, use ~100 epochs

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=128)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

gan.fit(
    dataset, epochs=epochs, callbacks=[GANMonitor(num_img=1, latent_dim=128)]
)