<a href="https://colab.research.google.com/github/JHyunjun/DQTGAN/blob/main/230714_DQTGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Created by Hunjun, JANG
# Recent revision date : 23.07.15
# DQT-GAN(Data Quality Transformation-Generative Adversarial Network)

!pip install pytube
!pip install pydub
!pip install librosa

%cd /content/drive/MyDrive/Colab Notebooks/GAN/DQT-GAN/Data

In [None]:
#Check the Path
! pwd

In [None]:
from pytube import YouTube
from pydub import AudioSegment
import librosa
import soundfile as sf
import numpy as np
import os
import matplotlib.pyplot as plt

# Youtube url
url = 'https://www.youtube.com/watch?v=83EzIW3MbAI'

# Download the url video
yt = YouTube(url)
stream = yt.streams.filter(only_audio=True).first()
stream.download(filename='temp.mp4')  # save it as 'temp'

# mp4 to wav
audio = AudioSegment.from_file('temp.mp4')

# Video slicing
audio = audio[2*60*1000:5*60*1000]
#audio = audio[2*60*1000:2*60*1000+4*1000]
audio.export('audio.wav', format='wav')

# Slicing the 4s, 8kHz sampling rate
y_8k, sr_8k = librosa.load('audio.wav', sr=8000)  # Load audio file at 8kHz
y_44k, sr_44k = librosa.load('audio.wav', sr=44100)
os.makedirs('slices', exist_ok=True)

# Save the first 4s video
first_slice_8k = y_8k[0:sr_8k*4]
first_slice_44k = y_44k[0:sr_44k*4]
sf.write('slices/slice_0.wav', first_slice_8k, sr_8k)
sf.write('slices/slice_0.wav', first_slice_44k, sr_44k)

# plotting the first wav sequence
plt.figure(figsize=(6, 1))
plt.plot(first_slice_8k)
plt.ylabel('Amplitude')
plt.xlabel('Sample index')
plt.title('8kHz Waveform of the first 4-second audio')
plt.show()

# plotting the first wav sequence
plt.figure(figsize=(6, 1))
plt.plot(first_slice_44k)
plt.ylabel('Amplitude')
plt.xlabel('Sample index')
plt.title('44.1kHz Waveform of the first 4-second audio')
plt.show()

# delete the temporary data
# os.remove('temp.mp4')
# os.remove('audio.wav')

In [None]:
import torch
import torch.nn as nn
import librosa

# RNN-based Model
class AudioUpsampler(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AudioUpsampler, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, input_size)

    def forward(self, x):
        _, hidden = self.rnn(x)
        output = self.linear(hidden.squeeze(0))
        return output.unsqueeze(1)

# Training data load
input_data_8k, _ = librosa.load('slices/slice_0.wav', sr=8000)
output_data_44k, _ = librosa.load('slices/slice_0.wav', sr=44100)

# Date dimension transform
input_data_8k = torch.Tensor(input_data_8k).view(1, -1, 1)
output_data_44k = torch.Tensor(output_data_44k).view(1, -1, 1)

# Set the hyperparameter
input_size = 1
hidden_size = 64
learning_rate = 0.001
num_epochs = 100

# Model and loss func
model = AudioUpsampler(input_size, hidden_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# training
for epoch in range(num_epochs):
    model.zero_grad()
    output = model(input_data_8k)
    loss = criterion(output, output_data_44k)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


In [None]:
import matplotlib.pyplot as plt

# 8kHz, 44.1kHz data for testfying
validation_input_data_8k, _ = librosa.load('slices/slice_1.wav', sr=8000)
validation_output_data_44k, _ = librosa.load('slices/slice_1.wav', sr=44100)

# Data dimension trasform
validation_input_data_8k = torch.Tensor(validation_input_data_8k).view(1, -1, 1)

# Predict
model.eval()
with torch.no_grad():
    validation_output_predicted = model(validation_input_data_8k)

# array to numpy
validation_output_predicted = validation_output_predicted.view(-1).numpy()

# Plotting
plt.figure(figsize=(14, 5))
plt.plot(validation_output_data_44k)
plt.ylabel('Amplitude')
plt.xlabel('Sample index')
plt.title('Waveform of the original 4-second audio (44.1kHz)')
plt.show()

#
plt.figure(figsize=(14, 5))
plt.plot(validation_output_predicted)
plt.ylabel('Amplitude')
plt.xlabel('Sample index')
plt.title('Waveform of the predicted 4-second audio (44.1kHz)')
plt.show()
