In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np
import os
import librosa
import librosa.display
import matplotlib.pyplot as plt
from pydub import AudioSegment
import soundfile as sf

In [None]:
# Check that TensorFlow can see the GPU
physical_devices = tf.config.list_physical_devices('GPU')
print(physical_devices)

# Define the model

In [None]:
## KERAS MODELS
latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        layers.Dense(32 * 16 * 128),
        layers.Reshape((32, 16, 128)),
        
        layers.BatchNormalization(),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.ReLU(),
        
        layers.BatchNormalization(),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.ReLU(),
        
        layers.BatchNormalization(),
        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
        layers.ReLU(),
        
        layers.Conv2D(1, kernel_size=5, padding="same", activation="tanh"),
    ],
    name="generator",
)
generator.summary()

In [None]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(256, 128, 1)),
        
        layers.BatchNormalization(),
        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        
        layers.BatchNormalization(),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        
        layers.BatchNormalization(),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
)
discriminator.summary()

# Create the Dataset

In [None]:
# Load dataset from directory with keras
mega_tensor =  np.load("data_magnitudes.npy")

train_ds = tf.data.Dataset.from_tensor_slices(mega_tensor)
dataset = train_ds.batch(4)


# Training
https://towardsdatascience.com/generative-adversarial-network-gan-for-dummies-a-step-by-step-tutorial-fdefff170391

1. Select a number of real images from the training set.
2. Generate a number of fake images. This is done by sampling random noise vectors and creating images from them using the generator
3. Train the discriminator for one or more epochs using both fake and real images. This will update on the discrimators weights by labeling all the real images as 1 and the fake images as 0.
4. Generate another number of fake images
5. Train the full GAN model for one or more epochs using only fake images. This will update only the generator's weights by labeling all fake images as 1. 

**SOURCE**: Link above

In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

In [None]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=256):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images.numpy()
        for i in range(self.num_img):
            
            sample = np.reshape(generated_images[i], (256, 128))
            
            # Save the numpy array
            np.save(f"output-arrays/zz-epoch_{epoch+1}_sample_{i}.npy", sample)
            
            # Save a spectrogram
            fig, ax = plt.subplots()
            img = librosa.display.specshow(librosa.amplitude_to_db(sample), x_axis='time', y_axis='cqt_note', ax=ax)
            ax.set_title('Constant-Q power spectrum')
            fig.colorbar(img, ax=ax, format="%+2.0f dB")
            plt.savefig(f"output-specs/zz-epoch_{epoch+1}_sample_{i}.png")
            plt.close()


In [None]:
epochs = 300   # In practice, use ~100 epochs

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=128)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss_fn=keras.losses.BinaryCrossentropy(),
)

gan.fit(
    dataset, epochs=epochs, callbacks=[GANMonitor(num_img=2, latent_dim=128)]
)

In [None]:
gan.fit(
    dataset, epochs=200, callbacks=[GANMonitor(num_img=2, latent_dim=128)]
)

In [None]:
gan.fit(
    dataset, epochs=200, callbacks=[GANMonitor(num_img=2, latent_dim=128)]
)