# Speech Enhancement

## 1. 데이터 로드

In [17]:
import os
import librosa
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 데이터셋 경로
mix_path = "/data/cjchun/ejkim/SpeechEnhancement(25.1)/mix"
clean_path = "/data/cjchun/ejkim/SpeechEnhancement(25.1)/clean"

# 데이터셋 분석 함수
def analyze_dataset(dataset_path):
    file_lengths = []
    amplitudes = []

    print(f"Analyzing dataset: {dataset_path}")
    for file_name in tqdm(os.listdir(dataset_path)):
        file_path = os.path.join(dataset_path, file_name)
        if file_path.endswith(".wav"):
            # Load audio
            audio, sr = librosa.load(file_path, sr=None)
            file_lengths.append(len(audio) / sr)  # Length in seconds
            amplitudes.append(audio)

    # Length statistics
    min_length = np.min(file_lengths)
    max_length = np.max(file_lengths)
    avg_length = np.mean(file_lengths)

    # Amplitude statistics
    all_amplitudes = np.concatenate(amplitudes)
    min_amplitude = np.min(all_amplitudes)
    max_amplitude = np.max(all_amplitudes)
    avg_amplitude = np.mean(all_amplitudes)

    return {
        "num_files": len(file_lengths),
        "min_length": min_length,
        "max_length": max_length,
        "avg_length": avg_length,
        "min_amplitude": min_amplitude,
        "max_amplitude": max_amplitude,
        "avg_amplitude": avg_amplitude,
    }

In [18]:
# 데이터셋 분석
mix_stats = analyze_dataset(mix_path)
clean_stats = analyze_dataset(clean_path)

# 분석 결과 출력
def print_stats(stats, name):
    print(f"\n{name} Dataset Statistics:")
    print(f"- Number of files: {stats['num_files']}")
    print(f"- Length (min/max/avg): {stats['min_length']:.2f}s / {stats['max_length']:.2f}s / {stats['avg_length']:.2f}s")
    print(f"- Amplitude (min/max/avg): {stats['min_amplitude']:.4f} / {stats['max_amplitude']:.4f} / {stats['avg_amplitude']:.4f}")

print_stats(mix_stats, "Mix")
print_stats(clean_stats, "Clean")


## 데이터셋 인덱스 매칭

In [None]:
import os

# 데이터셋 경로
mix_dataset_path = "/data/cjchun/ejkim/SpeechEnhancement(25.1)/mix"
clean_dataset_path = "/data/cjchun/ejkim/SpeechEnhancement(25.1)/clean"

# 파일 리스트 로드
mix_files = sorted(os.listdir(mix_dataset_path))
clean_files = sorted(os.listdir(clean_dataset_path))

# 매칭된 파일 생성
matched_mix_files = []
matched_clean_files = []

# 매칭 기준: 파일 이름으로 매칭
for clean_file in clean_files:
    if clean_file in mix_files:
        matched_mix_files.append(os.path.join(mix_dataset_path, clean_file))
        matched_clean_files.append(os.path.join(clean_dataset_path, clean_file))

# 매칭 결과 확인
print(f"매칭된 mix 파일 수: {len(matched_mix_files)}")
print(f"매칭된 clean 파일 수: {len(matched_clean_files)}")

# 파일 수가 동일한지 확인
assert len(matched_mix_files) == len(matched_clean_files), "매칭된 mix와 clean 파일 수가 다릅니다!"
print("매칭된 mix와 clean 파일 수가 동일합니다.")


## STFT

In [None]:
import torchaudio
import numpy as np
import matplotlib.pyplot as plt

# STFT 파라미터
n_fft = 1024  # FFT window size
hop_length = 256  # Hop length
win_length = 1024  # Window length
window = torch.hann_window(win_length).to(device)

# Magnitude와 Phase를 저장할 리스트
mix_magnitude_list = []
mix_phase_list = []
clean_magnitude_list = []
clean_phase_list = []

# STFT 수행 함수
def compute_stft(file_path, n_fft, hop_length, win_length, window):
    waveform, sr = torchaudio.load(file_path)
    waveform = waveform.to(device)

    # STFT 수행
    stft_output = torch.stft(
        waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        return_complex=True
    )
    magnitude = torch.abs(stft_output)
    phase = torch.angle(stft_output)

    return magnitude, phase

In [None]:
# 매칭된 mix와 clean 데이터셋에 대해 STFT 수행
for mix_file, clean_file in zip(matched_mix_files, matched_clean_files):
    mix_magnitude, mix_phase = compute_stft(mix_file, n_fft, hop_length, win_length, window)
    clean_magnitude, clean_phase = compute_stft(clean_file, n_fft, hop_length, win_length, window)

    mix_magnitude_list.append(mix_magnitude)
    mix_phase_list.append(mix_phase)
    clean_magnitude_list.append(clean_magnitude)
    clean_phase_list.append(clean_phase)

# Magnitude와 Phase 시각화 (첫 번째 샘플)
def plot_spectrogram(magnitude, title):
    magnitude_db = 20 * torch.log10(magnitude + 1e-6).cpu().numpy()
    plt.figure(figsize=(10, 4))
    plt.imshow(magnitude_db[0], origin='lower', aspect='auto', cmap='viridis')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Frequency")
    plt.show()

# mix 데이터의 Magnitude 시각화
plot_spectrogram(mix_magnitude_list[0], "Mix Magnitude Spectrogram")

# clean 데이터의 Magnitude 시각화
plot_spectrogram(clean_magnitude_list[0], "Clean Magnitude Spectrogram")

## Mel Spectrogram

In [None]:
import torchaudio.transforms as T
import numpy as np
import matplotlib.pyplot as plt

# Mel 변환 파라미터
n_fft = 1024
hop_length = 256
win_length = 1024
n_mels = 128

# MelSpectrogram 변환기 정의
mel_transform = T.MelSpectrogram(
    sample_rate=16000,  # 샘플 레이트 (고정)
    n_fft=n_fft,
    hop_length=hop_length,
    win_length=win_length,
    n_mels=n_mels
).to(device)

# dB 변환기 정의
db_transform = T.AmplitudeToDB(stype="power").to(device)

# Mel 변환 및 dB 변환 수행
mix_mel_list = []
clean_mel_list = []

print("Converting matched datasets to Mel-spectrograms...")
for mix, clean in zip(matched_mix_list, matched_clean_list):
    mix = mix.to(device)
    clean = clean.to(device)

    # Mel 변환
    mix_mel = mel_transform(mix)
    clean_mel = mel_transform(clean)

    # dB 변환 (시각화용)
    mix_mel_db = db_transform(mix_mel)
    clean_mel_db = db_transform(clean_mel)

    # 리스트에 추가
    mix_mel_list.append(mix_mel_db.cpu())
    clean_mel_list.append(clean_mel_db.cpu())

print("Mel-spectrogram 변환 완료.")

# 리스트를 텐서로 변환
mix_mel_list = torch.stack(mix_mel_list)
clean_mel_list = torch.stack(clean_mel_list)

# 샘플 데이터 시각화 (격자 형태)
def plot_mel_spectrograms(mel_list, title, num_samples=5):
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
    for i in range(num_samples):
        mel = mel_list[i].squeeze().numpy()
        axes[i].imshow(mel, origin="lower", aspect="auto", cmap="viridis")
        axes[i].set_title(f"{title} {i+1}")
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()

# Mix 및 Clean Mel-spectrogram 시각화
plot_mel_spectrograms(mix_mel_list, "Mix Mel", num_samples=5)
plot_mel_spectrograms(clean_mel_list, "Clean Mel", num_samples=5)


## UNet 모델

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

# UNet 모델 정의
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


# 하이퍼파라미터 설정
batch_size = 16
learning_rate = 1e-3
num_epochs = 10

# 데이터 준비
mix_mel_tensor = mix_mel_list.unsqueeze(1)  # (Batch, Channels, Time, Mel)
clean_mel_tensor = clean_mel_list.unsqueeze(1)  # (Batch, Channels, Time, Mel)

dataset = TensorDataset(mix_mel_tensor, clean_mel_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 모델, 손실 함수, 옵티마이저 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 모델 가중치 초기화
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

model.apply(init_weights)

# 학습 루프
train_losses = []
print("Training UNet model...")
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0

    for mix_batch, clean_batch in dataloader:
        mix_batch = mix_batch.to(device)
        clean_batch = clean_batch.to(device)

        # 모델 예측
        output = model(mix_batch)

        # 손실 계산
        loss = criterion(output, clean_batch)
        epoch_loss += loss.item()

        # 역전파 및 최적화
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 에포크 평균 손실 계산 및 저장
    avg_epoch_loss = epoch_loss / len(dataloader)
    train_losses.append(avg_epoch_loss)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}")

# 학습 손실 시각화
plt.plot(train_losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("UNet Training Loss")
plt.legend()
plt.show()

print("UNet 학습 완료.")


## IRM MASK & Wiener Filter

In [None]:
import numpy as np
import torch.nn.functional as F

# IRM 마스크 생성 함수
def compute_irm(clean_magnitude, mix_magnitude, eps=1e-8):
    """Ideal Ratio Mask 계산"""
    return clean_magnitude / (mix_magnitude + eps)

# Wiener 필터 적용 함수
def apply_wiener_filter(mix_stft, enhanced_stft, eps=1e-8):
    """Wiener 필터 적용"""
    power_mix = np.abs(mix_stft) ** 2
    power_enhanced = np.abs(enhanced_stft) ** 2
    gain = power_enhanced / (power_mix + eps)
    return gain * mix_stft

# UNet 출력으로 IRM 마스크 적용
print("Applying IRM and Wiener filter...")

# Mel-spectrogram 데이터에서 magnitude 추출
with torch.no_grad():
    model.eval()
    mix_batch = mix_mel_list.unsqueeze(1).to(device)
    enhanced_mel = model(mix_batch).squeeze(1)  # UNet 출력

# IRM 마스크 생성
mix_magnitude = mix_mel_list.cpu().numpy()
clean_magnitude = clean_mel_list.cpu().numpy()
irm_mask = compute_irm(clean_magnitude, mix_magnitude)

In [None]:
# IRM 마스크 적용
enhanced_with_irm = irm_mask * mix_magnitude

# Wiener 필터 적용
mix_stft, mix_phase = mix_stft_list, mix_phase_list
enhanced_stft = torch.stft(
    torch.tensor(enhanced_with_irm).to(device),
    n_fft=512,
    hop_length=256,
    win_length=512,
    return_complex=True,
).cpu().numpy()
enhanced_wiener = apply_wiener_filter(mix_stft, enhanced_stft)

# Wiener 필터 결과를 Mel-spectrogram으로 변환
enhanced_wiener_mel = torch.log1p(torch.tensor(np.abs(enhanced_wiener))).numpy()

# 시각화: 원본, IRM 적용, Wiener 필터 적용 결과
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Mix Mel")
plt.imshow(mix_mel_list[0], aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 2)
plt.title("IRM Applied Mel")
plt.imshow(enhanced_with_irm[0], aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 3)
plt.title("Wiener Filter Applied Mel")
plt.imshow(enhanced_wiener_mel[0], aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.tight_layout()
plt.show()

print("IRM과 Wiener 필터 적용 완료.")


## 회귀 학습

In [None]:
# 12번째 샘플 데이터 준비
sample_index = 12

# Mix 샘플: Wiener 필터 적용 후 Mel-spectrogram
mix_sample = torch.tensor(enhanced_wiener_mel[sample_index]).unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, Time, Mel)
clean_sample = clean_mel_list[sample_index].unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, Time, Mel)

# UNet 모델 준비
model = UNet().to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# 학습 루프 설정
num_epochs = 50
losses = []

print("Starting regression training...")

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # UNet 모델 출력
    output = model(mix_sample)
    
    # 손실 계산
    loss = criterion(output, clean_sample)
    loss.backward()
    optimizer.step()
    
    # 손실 기록
    losses.append(loss.item())
    
    # 10 epoch마다 손실 출력
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.6f}")

# 학습 완료 후 손실 시각화
plt.plot(losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Regression Training Loss")
plt.legend()
plt.show()

In [None]:
# 학습된 모델의 출력 Mel-spectrogram 확인
model.eval()
with torch.no_grad():
    predicted_clean = model(mix_sample).squeeze(0).squeeze(0).cpu().numpy()

# 원본 clean과 비교하여 시각화
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Clean Mel")
plt.imshow(clean_sample.squeeze(0).squeeze(0).cpu().numpy(), aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 2)
plt.title("Input Noisy Mel (Wiener Filter)")
plt.imshow(mix_sample.squeeze(0).squeeze(0).cpu().numpy(), aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 3)
plt.title("Predicted Clean Mel")
plt.imshow(predicted_clean, aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.tight_layout()
plt.show()

print("Regression training completed.")

## MSE 확인

In [None]:
# 학습 루프 재구성
num_epochs = 50
losses = []

print("Starting regression training with detailed loss outputs...")

for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # UNet 모델 출력
    output = model(mix_sample)
    
    # 손실 계산
    loss = criterion(output, clean_sample)
    loss.backward()
    optimizer.step()
    
    # 손실 기록
    losses.append(loss.item())
    
    # 10 epoch마다 손실 출력
    if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.6f}")

# 학습 완료 후 전체 손실 시각화
plt.plot(losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Regression Training Loss (Detailed Output)")
plt.legend()
plt.show()

# 학습된 모델의 출력 Mel-spectrogram 확인
model.eval()
with torch.no_grad():
    predicted_clean = model(mix_sample).squeeze(0).squeeze(0).cpu().numpy()

# 원본 clean과 비교하여 시각화
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title("Original Clean Mel")
plt.imshow(clean_sample.squeeze(0).squeeze(0).cpu().numpy(), aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 2)
plt.title("Input Noisy Mel (Wiener Filter)")
plt.imshow(mix_sample.squeeze(0).squeeze(0).cpu().numpy(), aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.subplot(1, 3, 3)
plt.title("Predicted Clean Mel")
plt.imshow(predicted_clean, aspect="auto", origin="lower", cmap="viridis")
plt.colorbar()

plt.tight_layout()
plt.show()

print("Loss output and spectrogram visualization completed.")


## Hifigan 복원

In [None]:
from speechbrain.inference.vocoders import HIFIGAN

# HiFi-GAN 모델 경로 설정
local_model_path = '/data/hjhan/lab/Audio MNIST/hifigan'

# HiFi-GAN 모델 로드
hifi_gan = HIFIGAN.from_hparams(
    source=local_model_path,
    savedir=local_model_path
)

# 혼합된 멜 스펙트로그램을 HiFi-GAN으로 파형으로 복원
reconstructed_waveform = hifi_gan.decode_batch(mel_spectrogram)


## 결과

In [None]:
import torch
import librosa.display
import matplotlib.pyplot as plt
from IPython.display import Audio

# 12번째 샘플 선택
sample_idx = 11  # 12번째 샘플 (0-based indexing)

# 재구성된 파형을 12번째 샘플에 대해 가져옴
reconstructed_waveform_sample = reconstructed_waveform[sample_idx].squeeze().cpu().detach().numpy()

# 오디오 출력 (재구성된 파형)
Audio(reconstructed_waveform_sample, rate=16000)


In [None]:
from pystoi import stoi

# 12번째 샘플의 clean waveform을 불러옴
clean_waveform_sample = clean_waveforms[sample_idx].squeeze().cpu().detach().numpy()

# STOI 계산
stoi_value = stoi(clean_waveform_sample, reconstructed_waveform_sample, 16000)
print(f"STOI (Speech Transmission Index) for 12th sample: {stoi_value}")


In [None]:
import librosa
import numpy as np
from skimage.metrics import structural_similarity as ssim

# Mel-spectrogram 계산 (시각적인 품질 비교를 위한 준비)
def compute_mel_spectrogram(waveform, sr=16000, n_mels=128, hop_length=512, win_length=1024):
    mel_spectrogram = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, hop_length=hop_length, win_length=win_length)
    return mel_spectrogram

# clean 및 reconstructed waveform을 Mel-spectrogram으로 변환
clean_mel = compute_mel_spectrogram(clean_waveform_sample)
reconstructed_mel = compute_mel_spectrogram(reconstructed_waveform_sample)

# SSIM 계산
ssim_value = ssim(clean_mel, reconstructed_mel)
print(f"SSIM (Structural Similarity Index) for 12th sample: {ssim_value}")


In [None]:
# HiFi-GAN으로 복원된 파형을 다시 출력 (12번째 샘플)
reconstructed_waveform_12th_sample = hifi_gan.decode_batch(reconstructed_mel)

# 재구성된 파형 시청
Audio(reconstructed_waveform_12th_sample.squeeze().cpu().detach().numpy(), rate=16000)
