<a href="https://colab.research.google.com/github/JHyunjun/DQTGAN/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Created by Hunjun, JANG
# Recent revision date : 23.07.15
# DQT-GAN(Data Quality Transformation-Generative Adversarial Network)

!pip install pytube
!pip install pydub
!pip install librosa

%cd /content/drive/MyDrive/Colab Notebooks/GAN/DQT-GAN/Data

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
#Check the Path
! pwd

In [None]:
import os
import numpy as np
import librosa
from pytube import YouTube
import tensorflow as tf
from pydub import AudioSegment
from tensorflow.keras import layers

input_sampling_rate = 400
output_sampling_rate = 1600
clip_duration = 4  # Clip duration in seconds

# Define the SRGAN model
def srgan_model():
    # Generator Model
    generator_input = layers.Input(shape=(input_sampling_rate * clip_duration, 1))
    x = layers.Conv1D(256, kernel_size=5, padding='same')(generator_input)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(128, kernel_size=5, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(64, kernel_size=5, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(32, kernel_size=5, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.UpSampling1D(int(output_sampling_rate /input_sampling_rate))(x)
    generator_output = layers.Conv1D(1, kernel_size=5, padding='same')(x)
    generator_model = tf.keras.Model(generator_input, generator_output, name='Generator')

    # Discriminator Model
    discriminator_input = layers.Input(shape=(output_sampling_rate * clip_duration, 1))
    x = layers.Conv1D(32, kernel_size=3, strides=2, padding='same')(discriminator_input)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(64, kernel_size=3, strides=2, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(128, kernel_size=3, strides=2, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Conv1D(256, kernel_size=3, strides=2, padding='same')(x)
    x = layers.LeakyReLU()(x)
    x = layers.Flatten()(x)
    discriminator_output = layers.Dense(1)(x)
    discriminator_model = tf.keras.Model(discriminator_input, discriminator_output, name='Discriminator')

    return generator_model, discriminator_model

# Define the loss and compile the models
def compile_srgan(generator, discriminator):
    generator_optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-7)
    discriminator_optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-7)
    return generator_optimizer, discriminator_optimizer

# Download and preprocess the audio
def download_and_preprocess(link):
    # Download audio
    youtube = YouTube(link)
    video = youtube.streams.filter(only_audio=True).first()
    video.download(filename='audio.mp4')

    audio = AudioSegment.from_file('audio.mp4')
    audio.export('audio.wav', format='wav')

    # Load and resample audio
    audio, sr = librosa.load('audio.wav', sr=None, offset=2 * 60, duration = 5 * 60)
    audio_8k = librosa.resample(audio, orig_sr=sr, target_sr=input_sampling_rate)
    audio_44k = librosa.resample(audio, orig_sr=sr, target_sr=output_sampling_rate)

    # Calculate the total number of clips
    total_clips = len(audio_8k) - input_sampling_rate * clip_duration + 1

    # Slice into n-second clips with overlap
    audio_8k_clips = np.array([audio_8k[i:i + input_sampling_rate * clip_duration] for i in range(0, total_clips, input_sampling_rate)])
    audio_44k_clips = np.array([audio_44k[i:i + output_sampling_rate * clip_duration] for i in range(0, total_clips, input_sampling_rate)])

    return audio_8k_clips, audio_44k_clips


# Define a training step
def train_step(generator, discriminator, generator_optimizer, discriminator_optimizer, input_audio, target_audio):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_audio = generator(np.expand_dims(input_audio, axis=0), training=True)
        real_output = discriminator(np.expand_dims(target_audio, axis=0), training=True)
        fake_output = discriminator(generated_audio, training=True)

        gen_loss = -tf.reduce_mean(fake_output)
        disc_loss = tf.reduce_mean(real_output) - tf.reduce_mean(fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Define the training loop
def train_srgan(link, epochs):
    # Download and preprocess audio
    audio_8k, audio_44k = download_and_preprocess(link)

    # Build and compile the SRGAN model
    generator, discriminator = srgan_model()
    generator_optimizer, discriminator_optimizer = compile_srgan(generator, discriminator)

    for epoch in range(epochs):
        for i in range(len(audio_8k)):
            gen_loss, disc_loss = train_step(generator, discriminator, generator_optimizer, discriminator_optimizer, audio_8k[i], audio_44k[i])
            print(f'Epoch {epoch+1}/{epochs}, Clip {i+1}/{len(audio_8k)}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}')

# Finally, call the training function with the YouTube link and number of epochs
train_srgan('https://www.youtube.com/watch?v=83EzIW3MbAI', 5)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define the validation function for a single sample
def validate_sample(generator, input_audio, target_audio, clip_duration):
    generated_audio = generator.predict(np.expand_dims(input_audio, axis=0))
    generated_audio = generated_audio.squeeze()

    # Time axes
    time_input = np.arange(input_audio.shape[0]) / input_sampling_rate
    time_generated = np.arange(generated_audio.shape[0]) / output_sampling_rate * (clip_duration)
    time_target = np.arange(target_audio.shape[0]) / output_sampling_rate * (clip_duration)

    plt.figure(figsize=(20, 8))

    plt.subplot(3, 1, 1)
    plt.plot(input_audio, label='Input')
    plt.title('Input 0.4kHz')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.xlim(0, input_sampling_rate * clip_duration)

    plt.subplot(3, 1, 2)
    plt.plot(generated_audio, label='Generated 1.6kHz')
    plt.title('Generated Audio (1.6kHz)')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    #plt.xlim(0, output_sampling_rate * clip_duration)
    plt.xlim(0, 20)
    plt.legend()

    plt.subplot(3, 1, 3)
    plt.plot(target_audio, label='Target')
    plt.title('Original 1.6kHz(Target)')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    #plt.xlim(0, output_sampling_rate * clip_duration)
    plt.xlim(0, 20)
    plt.legend()

    plt.tight_layout()
    plt.show()

# Download and preprocess audio
audio_8k_clips, audio_44k_clips = download_and_preprocess('https://www.youtube.com/watch?v=83EzIW3MbAI')

# Finally, call the validation function with the generator, input audio, and target audio
validate_sample(generator, audio_8k_clips[0], audio_44k_clips[0], clip_duration)
