<a href="https://colab.research.google.com/github/JHyunjun/DQTGAN/blob/main/WavtoImage_GAN_230815.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Code Author : Hyunjun JANG, Hyundai Motors Group
# Date : 2023.08.02
# GAN based Audio Data Quality Transformation
# Project code : DQT-GAN
# All rights reserved : github.com/JHyunjun


!pip install yt-dlp
!pip install pydub
!pip install librosa
!pip install soundfile
import os
import yt_dlp
import librosa
from pydub import AudioSegment
import numpy as np
import torch

# YouTube video URL
youtube_url = 'https://www.youtube.com/watch?v=I2ZEMjFJtzM&list=WL&index=1'

# Download YouTube video as .wav audio file
ydl_opts = {
    'format': 'bestaudio/best',
    'outtmpl': 'downloaded_audio.%(ext)s',
    'postprocessors': [{
        'key': 'FFmpegExtractAudio',
        'preferredcodec': 'wav',
        'preferredquality': '192',
    }],
}

with yt_dlp.YoutubeDL(ydl_opts) as ydl:
    ydl.download([youtube_url])

# If there is an error in running the youtube_dlp, please try to restart the runtime.

In [None]:
'''
#%% Basic settings
audio_length = 4 # seconds
audio_length_ms = audio_length * 1000
data_overlap = 50 # percent
data_overlap_ps = data_overlap / 100
sampling_rate = 8192

os.makedirs("data_folder/wav_data", exist_ok=True)
os.makedirs("data_folder/mp3_data", exist_ok=True)

wav_path = "data_folder/wav_data"
mp3_path = "data_folder/mp3_data"

# Load the audio file
base_wav = AudioSegment.from_wav("downloaded_audio.wav")
audio = base_wav.set_frame_rate(sampling_rate)

# Segment the audio file and save each segment
num_segments = int(len(audio) / (audio_length_ms * data_overlap_ps))

for i in range(1, num_segments):
    tmp_fname_wav = f"audio_wav_{i:04}.wav"
    tmp_fname_mp3 = f"audio_mp3_{i:04}.mp3"
    tmp_audio = audio[(i-1)*audio_length_ms*data_overlap_ps : (i+1)*audio_length_ms*data_overlap_ps]
    tmp_audio.export(os.path.join(wav_path, tmp_fname_wav), format="wav")
    tmp_audio.export(os.path.join(mp3_path, tmp_fname_mp3), format="mp3")

# Load the segmented audio files and compute their STFT
n_fft = 512
hop_length = 128

wav_files = os.listdir(wav_path)
mp3_files = os.listdir(mp3_path)

wav_data = []
mp3_data = []

for i, file in enumerate(wav_files):
    y, sr = librosa.load(os.path.join(wav_path, file), sr=sampling_rate)
    S1 = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    # Apply absolute to get the magnitude
    globals()[f"wav_{i:04}"] = np.abs(S1)
    wav_data.append(np.abs(S1))

for i, file in enumerate(mp3_files):
    y, sr = librosa.load(os.path.join(mp3_path, file), sr=sampling_rate)
    S2 = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    # Apply absolute to get the magnitude
    globals()[f"mp3_{i:04}"] = np.abs(S2)
    mp3_data.append(np.abs(S2))

# Convert the lists to numpy arrays
wav_data = np.array(wav_data)
mp3_data = np.array(mp3_data)

wav_data = wav_data[:, :256, :256]
mp3_data = mp3_data[:, :256, :256]

# Reshape the data if necessary
wav_data = np.expand_dims(wav_data, axis=1)  # Add channel dimension for PyTorch
mp3_data = np.expand_dims(mp3_data, axis=1)  # Add channel dimension for PyTorch

# Convert numpy arrays to PyTorch tensors
wav_data = torch.tensor(wav_data).float()
mp3_data = torch.tensor(mp3_data).float()

print("Max mp3 : ",torch.max(mp3_data))
print("Max wav : ",torch.max(wav_data))


##Z-score
# Calculate the mean and standard deviation
wav_mean = torch.mean(wav_data)
wav_std = torch.std(wav_data)
mp3_mean = torch.mean(mp3_data)
mp3_std = torch.std(mp3_data)

# Normalize data by subtracting the mean and dividing by the standard deviation (z-score normalization)
wav_data = (wav_data - wav_mean) / wav_std
mp3_data = (mp3_data - mp3_mean) / mp3_std


# Create PyTorch datasets
wav_dataset = torch.utils.data.TensorDataset(wav_data)
mp3_dataset = torch.utils.data.TensorDataset(mp3_data)

'''

In [None]:
# 기본 설정
audio_length = 4 # 초
audio_length_ms = audio_length * 1000
data_overlap = 50 # 퍼센트
data_overlap_ps = data_overlap / 100
sampling_rate = 16384

os.makedirs("data_folder/wav_data", exist_ok=True)
os.makedirs("data_folder/mp3_data", exist_ok=True)

wav_path = "data_folder/wav_data"
mp3_path = "data_folder/mp3_data"

# 오디오 파일 로드
base_wav = AudioSegment.from_wav("downloaded_audio.wav")
audio = base_wav.set_frame_rate(sampling_rate)

# 오디오 파일을 세분화하고 각 세그먼트를 저장
num_segments = int(len(audio) / (audio_length_ms * data_overlap_ps))

for i in range(1, num_segments):
    tmp_fname_wav = f"audio_wav_{i:04}.wav"
    tmp_fname_mp3 = f"audio_mp3_{i:04}.mp3"
    tmp_audio = audio[(i-1)*audio_length_ms*data_overlap_ps : (i+1)*audio_length_ms*data_overlap_ps]
    tmp_audio.export(os.path.join(wav_path, tmp_fname_wav), format="wav")
    tmp_audio.export(os.path.join(mp3_path, tmp_fname_mp3), format="mp3")

# 세분화된 오디오 파일을 로드하고 STFT 계산
n_fft = 510
hop_length = 257

wav_files = os.listdir(wav_path)
mp3_files = os.listdir(mp3_path)

wav_data = []
mp3_data = []

mag_max = []
mag_min = []

for i, file in enumerate(wav_files):
    y, sr = librosa.load(os.path.join(wav_path, file), sr=sampling_rate)
    S1 = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    log_magnitude_S1 = librosa.amplitude_to_db(np.abs(S1))

    mag_max.append(np.max(log_magnitude_S1))
    mag_min.append(np.min(log_magnitude_S1))

    wav_data.append(log_magnitude_S1)

for i, file in enumerate(mp3_files):
    y, sr = librosa.load(os.path.join(mp3_path, file), sr=sampling_rate)
    S2 = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
    log_magnitude_S2 = librosa.amplitude_to_db(np.abs(S2))

    mag_max.append(np.max(log_magnitude_S2))
    mag_min.append(np.min(log_magnitude_S2))

    mp3_data.append(log_magnitude_S2)

# 리스트를 numpy 배열로 변환
wav_data = np.array(wav_data)
mp3_data = np.array(mp3_data)
print("MP3")
print(mp3_data[0])
print(np.max(log_magnitude_S2))
print(np.min(log_magnitude_S2))
print("--------------------------------------------------------")
print("WAV")
print(wav_data[0])
print(np.max(log_magnitude_S1))
print(np.min(log_magnitude_S1))

#wav_data = wav_data[:, :256, :256]
#mp3_data = mp3_data[:, :256, :256]

wav_data = np.expand_dims(wav_data, axis=1)  # PyTorch를 위한 채널 차원 추가
mp3_data = np.expand_dims(mp3_data, axis=1)  # PyTorch를 위한 채널 차원 추가

# Min-Max 정규화를 사용하여 데이터 정규화
mag_max = np.max(mag_max)
mag_min = np.min(mag_min)

wav_data = ((wav_data - mag_min) / (mag_max - mag_min)) * 2 - 1
mp3_data = ((mp3_data - mag_min) / (mag_max - mag_min)) * 2 - 1

# numpy 배열을 PyTorch 텐서로 변환
wav_data = torch.tensor(wav_data).float()
mp3_data = torch.tensor(mp3_data).float()

# PyTorch 데이터셋 생성
wav_dataset = torch.utils.data.TensorDataset(wav_data)
mp3_dataset = torch.utils.data.TensorDataset(mp3_data)


In [None]:
import matplotlib.pyplot as plt

plt.imshow(mp3_data[0][0])
print(mp3_data[0][0].shape)
plt.show()
plt.imshow(wav_data[0][0])
print(wav_data[0][0].shape)

In [None]:
print(mag_max)
print(mag_min)

print(wav_data.shape)
print(mp3_data.shape)

print(mp3_dataset[0][0])

In [None]:
import soundfile as sf

# mp3_dataset에서 첫 번째 텐서 추출
mp3_tensor = mp3_dataset[1][0]

# 정규화 복원
mp3_tensor = (mp3_tensor + 1) * (mag_max - mag_min) / 2 + mag_min
print(mp3_tensor)

# STFT 역변환과 파일 저장
log_magnitude_S2 = mp3_tensor.squeeze().numpy()
magnitude_S2 = librosa.db_to_amplitude(log_magnitude_S2)

# 원래의 mp3 파일로부터 위상 정보 추출
original_mp3_file_path = os.path.join(mp3_path, mp3_files[0])
y_original_mp3, _ = librosa.load(original_mp3_file_path, sr=sampling_rate)
original_S2 = librosa.stft(y_original_mp3, n_fft=n_fft, hop_length=hop_length)

# 위상 정보 추출
mp3_phase = np.exp(1j * np.angle(original_S2))

# 복소-valued STFT 생성
S2_restored = magnitude_S2 * mp3_phase

# ISTFT 적용
y_restored = librosa.istft(S2_restored, hop_length=hop_length)

# 복원된 신호 저장
restored_path = os.path.join('restored_mp3.wav')
sf.write(restored_path, y_restored, sampling_rate)


In [None]:
import matplotlib.pyplot as plt
print("WAV")
print(wav_data[0,0])
print(torch.max(wav_data[0,0]))
print(torch.min(wav_data[0,0]))
print(" ")
print("MP3")
print(mp3_data[0,0])
print(torch.max(mp3_data[0,0]))
print(torch.min(mp3_data[0,0]))

plt.imshow(wav_data[0,0])
plt.show()
plt.imshow(mp3_data[0,0])


In [None]:
# Using librosa.amplitude_to_db()
import matplotlib.pyplot as plt
import librosa.display

# Choose the first .wav and .mp3 file
wav_stft = globals()["wav_0000"]
mp3_stft = globals()["mp3_0000"]

# Convert amplitude to dB
wav_stft_db = librosa.amplitude_to_db(wav_stft)
mp3_stft_db = librosa.amplitude_to_db(mp3_stft)

print(wav_stft_db.shape)
print(mp3_stft_db.shape)

plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 2)
librosa.display.specshow(wav_stft_db, sr=sampling_rate, hop_length=hop_length, x_axis='time', y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.title('Spectrogram (.wav)')
plt.subplot(1, 2, 1)
librosa.display.specshow(mp3_stft_db, sr=sampling_rate, hop_length=hop_length, x_axis='time', y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.title('Spectrogram (.mp3)')
plt.tight_layout()
plt.show()

print("mp3_data.shape : ",mp3_data.shape, mp3_stft_db)
print("wav_data.shape : ",wav_data.shape, wav_stft_db)


In [None]:
#Without liborsa.amplitude_to_db()
# Choose the first .wav and .mp3 file
wav_stft = wav_data[0,0].cpu().numpy()
mp3_stft = mp3_data[0,0].cpu().numpy()

plt.figure(figsize=(14, 5))
plt.subplot(1, 2, 2)
librosa.display.specshow(wav_stft, sr=sampling_rate, hop_length=hop_length, x_axis='time', y_axis='linear')
plt.colorbar()
plt.title('Spectrogram (.wav)')
plt.subplot(1, 2, 1)
librosa.display.specshow(mp3_stft, sr=sampling_rate, hop_length=hop_length, x_axis='time', y_axis='linear')
plt.colorbar()
plt.title('Spectrogram (.mp3)')
plt.tight_layout()
plt.show()

# Print the entire array
print(wav_data.shape)
print(wav_data[0])  # assuming the first element corresponds to "wav_0000"

# Print a specific pixel value
row = 10
col = 10
print(f"The value at row {row}, column {col} is: {wav_data[0][0][row, col]}")  # 0 for the first dimension (batch), 0 for the second dimension (channel)



In [None]:
import torch
from torch import nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.enabled = False  # cudnn 비활성화

class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 1, kernel_size=3, stride=1, padding=1),  # 256x256 크기 이미지 생성
            nn.Tanh(),
        )

    def forward(self, x):
        #print("Generator Input : ", x.shape)
        x = self.model(x)
        #print("Generator Output : ", x.shape)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 64, kernel_size=4, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(16384, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.model(x)
        return x

# Wasserstein GAN with Gradient Penalty (WGAN-GP) Loss
class WGAN_GP_Loss(nn.Module):
    def __init__(self, lambda_gp=10):
        super(WGAN_GP_Loss, self).__init__()
        self.lambda_gp = lambda_gp

    def forward(self, real_scores, fake_scores, real_images, generated_images):
        # Wasserstein GAN Loss
        generator_loss = -torch.mean(fake_scores)
        discriminator_loss = torch.mean(fake_scores) - torch.mean(real_scores)

        # Gradient Penalty
        alpha = torch.rand(real_images.size(0), 1, 1, 1).to(device)
        interpolates = alpha * real_images + (1 - alpha) * generated_images
        interpolates.requires_grad_(True)
        disc_interpolates = discriminator(interpolates)
        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                        grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.lambda_gp

        discriminator_loss += gradient_penalty

        return generator_loss, discriminator_loss

# Initialize the generator and the discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function
wgan_gp_loss = WGAN_GP_Loss()



In [None]:
# Hyperparameters
batch_size = 1
discriminator_loss = 1e-5
generator_loss = 5e-6
num_epochs = 5
gradient_penalty_constant = 1
discriminator_updates_per_generator_update = 1

# Create PyTorch data loaders
mp3_loader = torch.utils.data.DataLoader(mp3_dataset, batch_size=batch_size, shuffle=False)
wav_loader = torch.utils.data.DataLoader(wav_dataset, batch_size=batch_size, shuffle=False)

# Loss function
wgan_gp_loss = WGAN_GP_Loss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=discriminator_loss)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=generator_loss)

for epoch in range(num_epochs):
    for i, (mp3, wav) in enumerate(zip(mp3_loader, wav_loader)):
        # Move the data to the chosen device
        mp3, wav = mp3[0].to(device), wav[0].to(device)

        # Discriminator updates
        for _ in range(discriminator_updates_per_generator_update):
            # Create the labels for the real and the fake data
            real_labels = torch.ones((mp3.size(0), 1)).to(device)
            fake_labels = torch.zeros((mp3.size(0), 1)).to(device)

            # Train the discriminator with real data
            outputs_real = discriminator(wav)
            d_loss_real = -torch.mean(outputs_real)

            # Train the discriminator with fake data
            fake_images = generator(mp3)
            outputs_fake = discriminator(fake_images)
            d_loss_fake = torch.mean(outputs_fake)

            # Compute the gradient penalty
            alpha = torch.rand(mp3.size(0), 1, 1, 1).to(device)
            interpolates = alpha * wav + (1 - alpha) * fake_images
            interpolates.requires_grad_(True)
            disc_interpolates = discriminator(interpolates)
            gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                            grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                            create_graph=True, retain_graph=True, only_inputs=True)[0]
            gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
            d_loss = d_loss_real + d_loss_fake + gradient_penalty_constant * gradient_penalty

            discriminator.zero_grad()
            d_loss.backward()
            optimizer_D.step()

        # Generator updates
        fake_images = generator(mp3)
        outputs = discriminator(fake_images)
        g_loss = -torch.mean(outputs)

        generator.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}')


In [None]:
import matplotlib.pyplot as plt

# Convert the first mp3 data to numpy array
mp3_data_np = mp3_data[0, 0].cpu().numpy()

# Generate data using the generator
generated_wav = generator(mp3_data[0,0].unsqueeze(0).to(device))
generated_wav_np = generated_wav.detach().cpu().squeeze().numpy()

# Convert the first wav data to numpy array
wav_data_np = wav_data[0, 0].cpu().numpy()

# Create the figure
fig, axs = plt.subplots(3, 1, figsize=(10, 15))

# Plot the first mp3 data (spectrogram) as an image
im1 = axs[0].imshow(mp3_data_np)
axs[0].set_title('MP3 Spectrogram')
axs[0].set_xlabel('Time')
axs[0].set_ylabel('Frequency')
fig.colorbar(im1, ax=axs[0], label='Amplitude (dB)')

# Plot the generated spectrogram
im2 = axs[1].imshow(generated_wav_np)
axs[1].set_title('Generated Spectrogram')
axs[1].set_xlabel('Time')
axs[1].set_ylabel('Frequency')
fig.colorbar(im2, ax=axs[1], label='Amplitude (dB)')

# Plot the first wav data (spectrogram) as an image
im3 = axs[2].imshow(wav_data_np)
axs[2].set_title('Wav Spectrogram')
axs[2].set_xlabel('Time')
axs[2].set_ylabel('Frequency')
fig.colorbar(im3, ax=axs[2], label='Amplitude (dB)')

plt.tight_layout()
plt.show()


In [None]:
print("Mp3","Max : ",np.max(mp3_data_np), "Min : ",np.min(mp3_data_np))
print(mp3_data_np)
print(" ")
print("Wav", "Max : ",np.max(wav_data_np), "Min : ",np.min(wav_data_np))
print(wav_data_np)
print(" ")
print("Generated", "Max : ",np.max(generated_wav_np), "Min : ", np.min(generated_wav_np))
print(generated_wav_np)

In [None]:
import librosa
import numpy as np
import soundfile as sf

mp3_phase = np.exp(1j * np.angle(mp3_data_np))
# Magnitude 정보 추출
generated_magnitude = np.abs(generated_wav_np)

# De-normalization
generated_magnitude = ((generated_magnitude + 1) / 2) * (mag_max - mag_min) + mag_min

# dB 스케일을 복원
generated_magnitude = librosa.db_to_amplitude(generated_magnitude)

# 복소 스펙트로그램 생성
reconstructed_spectrogram = generated_magnitude * mp3_phase
original_spectrogram = mp3_data_np * mp3_phase

# ISTFT를 사용한 시간 영역으로의 변환
reconstructed_signal = librosa.istft(reconstructed_spectrogram, hop_length=hop_length)
original_signal = librosa.istft(original_spectrogram, hop_length=hop_length)

# 파일로 저장
sf.write('reconstructed_230812.wav', reconstructed_signal, sampling_rate)
sf.write('original_230812.wav', original_signal, sampling_rate)


In [None]:
print(mp3_phase)
print(mp3_phase.shape)