In [None]:
import os

import librosa
import matplotlib.pyplot as plt
import soundfile as sf
import tensorflow as tf
from tqdm import tqdm

**Paths**

In [None]:
BASE_PATH = "../../"
MONITORING = os.path.join(BASE_PATH, "logs")
DATA = os.path.join(BASE_PATH, "data")

In [None]:
CHECKPOINT_PATH = os.path.join(MONITORING, "checkpoints")
CNN_CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, "cnn")
RNN_CHECKPOINT_PATH = os.path.join(CHECKPOINT_PATH, "rnn")

In [None]:
TENSORBOARD_LOG_DIR = os.path.join(MONITORING, "tensorboard_logs")
CNN_TENSORBOARD_LOGS = os.path.join(TENSORBOARD_LOG_DIR, "cnn")
RNN_TENSORBOARD_LOGS = os.path.join(TENSORBOARD_LOG_DIR, "rnn")

In [None]:
CSV_LOG_DIR = os.path.join(MONITORING, "csv_logs")
CNN_CSV_LOGS = os.path.join(CSV_LOG_DIR, "cnn")
RNN_CSV_LOGS = os.path.join(CSV_LOG_DIR, "rnn")

In [None]:
TUNERS = os.path.join(DATA, "tuners")
MODELS = os.path.join(DATA, "models")

**GPU/TPU Multithreading Setup**

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy
except ValueError:
    strategy = tf.distribute.get_strategy()
    print("Number of replicas:", strategy.num_replicas_in_sync)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
except ValueError:
    tpu = None
    gpus = tf.config.experimental.list_logical_devices("GPU")

In [None]:
if tpu:
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(
        tpu,
    )
    print("Running on TPU ", tpu.cluster_spec().as_dict()["worker"])
elif len(gpus) > 1:
    strategy = tf.distribute.MultiWorkerMirroredStrategy([gpu.name for gpu in gpus])
    print("Running on multiple GPUs ", [gpu.name for gpu in gpus])
elif len(gpus) == 1:
    strategy = tf.distribute.get_strategy()
    print("Running on single GPU ", gpus[0].name)
else:
    strategy = tf.distribute.get_strategy()
    print("Running on CPU")
print("Number of accelerators: ", strategy.num_replicas_in_sync)

**Hyperparameters**

In [None]:
BATCH_SIZE = 32  # Big batch size, small learning rate
HEIGHT, WIDTH = 224, 224
IMG_SIZE = (HEIGHT, WIDTH)
IMG_FORMAT = (HEIGHT, WIDTH, 3)
NOISE_DIM = 100
EPOCHS = 100
SEED = 949953915

**Load Dataset**

In [None]:
train_dataset, val_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "../../data/dataset/img/mfcc",
    validation_split=0.2,
    subset="both",
    seed=SEED,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
)

In [None]:
class_names = train_dataset.class_names
num_classes = len(class_names)

**Dataset Representation**

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

**Preprocessing**

In [None]:
autotune = tf.data.AUTOTUNE

In [None]:
train_dataset = train_dataset.cache().shuffle(1000).prefetch(buffer_size=autotune)
val_dataset = val_dataset.cache().prefetch(buffer_size=autotune)

In [None]:
normalization_layer = tf.keras.layers.Rescaling(1.0 / 255)

In [None]:
with strategy.scope():
    normalized_ds = train_dataset.map(lambda x, y: (normalization_layer(x), y))
    image_batch, labels_batch = next(iter(normalized_ds))

**Generator Model**

In [None]:
def generator_model():
    noise = tf.keras.Input(shape=(NOISE_DIM,))
    label = tf.keras.Input(shape=(1,))

    label_cast = tf.cast(label, tf.int32)
    label_one_hot = tf.one_hot(label_cast, num_classes)
    label_reshaped = tf.keras.layers.Reshape((-1, num_classes))(label_one_hot)

    noise_expanded = tf.keras.layers.Reshape((1, NOISE_DIM))(noise)

    x = tf.keras.layers.Concatenate(axis=2)([noise_expanded, label_reshaped])
    x = tf.keras.layers.Dense((8 ** 2) * 128, use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Reshape((8, 8, 128))(x)
    x = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(16, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU()(x)

    x = tf.keras.layers.Conv2DTranspose(3, (5, 5), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)

    return tf.keras.Model([noise, label], x)

**Discriminator Model**

In [None]:
def discriminator_model():
    image = tf.keras.Input(shape=IMG_FORMAT)
    label = tf.keras.Input(shape=(1,))

    label = tf.cast(label, tf.int32)
    label_one_hot = tf.one_hot(label, num_classes)
    label_one_hot = tf.keras.layers.Reshape((-1, num_classes))(label_one_hot)

    label_b = tf.keras.layers.Reshape((1, 1, num_classes))(label_one_hot)
    label_b = tf.keras.layers.Dense(HEIGHT * WIDTH * 3, use_bias=False)(label_b)
    label_b = tf.keras.layers.Reshape(IMG_FORMAT)(label_b)

    x = tf.keras.layers.Concatenate(axis=-1)([image, label_b])

    x = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Conv2D(32, (5, 5), strides=(2, 2), padding='same')(x)
    x = tf.keras.layers.LeakyReLU()(x)
    x = tf.keras.layers.Dropout(0.3)(x)

    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(1)(x)

    return tf.keras.Model([image, label], x)

**Loss Functions**

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

**Optimizers**

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

**Metrics**

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')

In [None]:
val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.BinaryAccuracy(name='val_accuracy')

**Training**

In [None]:
generator = generator_model()
discriminator = discriminator_model()

In [None]:
def train_step(images, labels):
    assert len(labels) == len(images), "Mismatch in batch size and number of labels."

    noise = tf.random.normal([len(images), NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator([noise, labels], training=True)

        real_output = discriminator([images, labels], training=True)
        fake_output = discriminator([generated_images, labels], training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    train_loss(disc_loss)
    train_accuracy(tf.ones_like(real_output), real_output)
    train_accuracy(tf.zeros_like(fake_output), fake_output)

In [None]:
def test_step(images, labels):
    assert len(labels) == len(images), "Mismatch in batch size and number of labels."

    noise = tf.random.normal([len(images), NOISE_DIM])
    generated_images = generator([noise, labels], training=False)

    real_output = discriminator([images, labels], training=False)
    fake_output = discriminator([generated_images, labels], training=False)

    t_loss = discriminator_loss(real_output, fake_output)

    val_loss(t_loss)
    val_accuracy(tf.ones_like(real_output), real_output)
    val_accuracy(tf.zeros_like(fake_output), fake_output)

In [None]:
def train_gan():
    with strategy.scope():
        train_loss.reset_states()
        train_accuracy.reset_states()
        val_loss.reset_states()
        val_accuracy.reset_states()

        for epoch in range(EPOCHS):
            for train_images, train_labels in tqdm(train_dataset, desc="Training"):
                train_step(train_images, train_labels)

            for test_images, test_labels in tqdm(val_dataset, desc="Validating"):
                test_step(test_images, test_labels)

            template = f"Epoch {epoch + 1}/{EPOCHS}, Loss: {train_loss.result()}, Accuracy: {train_accuracy.result() * 100}, Val Loss: {val_loss.result()}, Val Accuracy: {val_accuracy.result() * 100}"

In [None]:
train_gan()

In [None]:
generator.save("generator.h5")

**Generate Images**

In [None]:
generator = tf.keras.models.load_model('generator.h5')

In [None]:
noise = tf.random.normal([1, NOISE_DIM])
labels = tf.constant([[0]])

In [None]:
generated_mfcc = generator([noise, labels], training=False)
generated_mfcc = (generated_mfcc + 1) / 2.0
generated_mfcc = tf.image.resize(generated_mfcc, [3600, 2400])
generated_mfcc = tf.reduce_mean(generated_mfcc, axis=-1)
generated_mfcc = tf.squeeze(generated_mfcc, axis=0)
generated_mfcc = generated_mfcc.numpy()

In [None]:
plt.imshow(generated_mfcc)

**Convert Spectrogram Into Audio**

In [None]:
mel_spectrogram = librosa.feature.inverse.mfcc_to_mel(generated_mfcc)
stft_spectrogram = librosa.feature.inverse.mel_to_stft(mel_spectrogram)

In [None]:
audio = librosa.griffinlim(stft_spectrogram)

In [None]:
sf.write('reconstructed.wav', audio, 44100)