<a href="https://colab.research.google.com/github/JHyunjun/DQTGAN/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Created by Hunjun, JANG
# Recent revision date : 23.07.15
# DQT-GAN(Data Quality Transformation-Generative Adversarial Network)

!pip install pytube
!pip install pydub
!pip install librosa

%cd /content/drive/MyDrive/Colab Notebooks/GAN/DQT-GAN/Data

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
#Check the Path
! pwd

In [None]:
# Import libraries
import os
import numpy as np
import librosa
from pytube import YouTube
import tensorflow as tf
from pydub import AudioSegment
from tensorflow.keras import layers

input_sampling_rate = 8000
output_sampling_rate = 32000

# Define the SRGAN model
def srgan_model():
    # Generator Model
    generator_input = layers.Input(shape=(input_sampling_rate, 1))
    x = layers.Conv1D(64, kernel_size=3, padding='same')(generator_input)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(64, kernel_size=3, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.UpSampling1D(4)(x) # 8000*4
    generator_output = layers.Conv1D(1, kernel_size=3, padding='same')(x)
    generator_model = tf.keras.Model(generator_input, generator_output, name='Generator')

    # Discriminator Model
    discriminator_input = layers.Input(shape=(output_sampling_rate, 1))
    x = layers.Conv1D(64, kernel_size=3, strides=2, padding='same')(discriminator_input)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(128, kernel_size=3, strides=2, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Flatten()(x)
    discriminator_output = layers.Dense(1, activation='sigmoid')(x)
    discriminator_model = tf.keras.Model(discriminator_input, discriminator_output, name='Discriminator')

    return generator_model, discriminator_model

# Define the loss and compile the models
def compile_srgan(generator, discriminator):
    bce_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True, label_smoothing=0.1)
    mse_loss_fn = tf.keras.losses.MeanSquaredError()
    generator_optimizer = tf.keras.optimizers.Adam()
    discriminator_optimizer = tf.keras.optimizers.Adam()
    return bce_loss_fn, mse_loss_fn, generator_optimizer, discriminator_optimizer

# Download and preprocess the audio
def download_and_preprocess(link):
    # Download audio
    youtube = YouTube(link)
    video = youtube.streams.filter(only_audio=True).first()
    video.download(filename='audio.mp4')

    audio = AudioSegment.from_file('audio.mp4')
    audio.export('audio.wav', format='wav')

    # Load and resample audio
    audio, sr = librosa.load('audio.wav', sr=None, offset=2*60, duration=3*60)
    audio_8k = librosa.resample(audio, orig_sr=sr, target_sr=input_sampling_rate)
    audio_44k = librosa.resample(audio, orig_sr=sr, target_sr=output_sampling_rate)

    # Slice into 1-second clips
    audio_8k_clips = np.array([audio_8k[i:i+input_sampling_rate] for i in range(0, len(audio_8k), input_sampling_rate)])
    audio_44k_clips = np.array([audio_44k[i:i+output_sampling_rate] for i in range(0, len(audio_44k), output_sampling_rate)])

    return audio_8k_clips, audio_44k_clips


# Define a training step
def train_step(generator, discriminator, bce_loss_fn, mse_loss_fn, generator_optimizer, discriminator_optimizer, input_audio, target_audio):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_audio = generator(np.expand_dims(input_audio, axis=0), training=True)
        real_output = discriminator(np.expand_dims(target_audio, axis=0), training=True)
        fake_output = discriminator(generated_audio, training=True)

        gen_loss = mse_loss_fn(np.expand_dims(target_audio, axis=0), generated_audio) + bce_loss_fn(tf.ones_like(fake_output), fake_output)
        disc_loss = bce_loss_fn(tf.ones_like(real_output), real_output) + bce_loss_fn(tf.zeros_like(fake_output), fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Define the training loop
def train_srgan(link, epochs):
    # Download and preprocess audio
    audio_8k, audio_44k = download_and_preprocess(link)

    # Build and compile the SRGAN model
    generator, discriminator = srgan_model()
    bce_loss_fn, mse_loss_fn, generator_optimizer, discriminator_optimizer = compile_srgan(generator, discriminator)

    for i in range(len(audio_8k)):
        gen_loss, disc_loss = train_step(generator, discriminator, bce_loss_fn, mse_loss_fn, generator_optimizer, discriminator_optimizer, audio_8k[i], audio_44k[i])
        print(f'Epoch {epochs+1}/{epochs}, Clip {i+1}/{len(audio_8k)}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}')

# Finally, call the training function with the YouTube link and number of epochs
train_srgan('https://www.youtube.com/watch?v=83EzIW3MbAI', 10)
