<a href="https://colab.research.google.com/github/Kokindos/audio_denoising/blob/main/audo_denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Listen to some samples

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Read data

In [2]:
import glob

clean_wavs_path = glob.glob('/content/drive/MyDrive/speach_dataset/clean/*.wav')
noisy_wavs_path = glob.glob('/content/drive/MyDrive/speach_dataset/noisy/*.wav')

print(len(clean_wavs_path))
print(len(noisy_wavs_path))

5689
5689


In [3]:
class config:
    target_sample_rate=48000
    duration=4
    n_fft=1024
    hop_length=512
    n_mels=64
    batch_size=64
    learning_rate=1e-6
    epochs=6

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchaudio

class CustomDataset(Dataset):
    def __init__(self, clean_data_path, noisy_data_path, transform=None,
                 target_sample_rate=config.target_sample_rate, duration=config.duration):
        self.root_clean = clean_data_path
        self.root_noisy = noisy_data_path
        self.transform = transform
        self.target_sample_rate = target_sample_rate
        self.num_samples = target_sample_rate*duration

    def __len__(self):
        return len(self.root_clean)

    def __getitem__(self, index):
        audio_path_clean = self.root_clean[index]
        audio_path_noisy = self.root_noisy[index]

        signal, sr = torchaudio.load(audio_path_clean)
        signal_noisy, sr_noisy = torchaudio.load(audio_path_noisy)
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            signal = resampler(signal)
            signal_noisy = resampler(signal_noisy)

        if signal.shape[0] > 1:
            signal = torch.mean(signal, axis=0, keepdim=True)

        if signal_noisy.shape[0] > 1:
            signal_noisy = torch.mean(signal_noisy, axis=0, keepdim=True)

        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]

        if signal_noisy.shape[1] > self.num_samples:
            signal_noisy = signal_noisy[:, :self.num_samples]

        if signal.shape[1] < self.num_samples:
            num_missing_samples = self.num_samples - signal.shape[1]
            signal = F.pad(signal, (0, num_missing_samples))

        if signal_noisy.shape[1] < self.num_samples:
            num_missing_samples = self.num_samples - signal_noisy.shape[1]
            signal_noisy = F.pad(signal_noisy, (0, num_missing_samples))

        mel = self.transform(signal)
        mel_noisy = self.transform(signal_noisy)
        #print(mel.shape)
        image = mel / torch.abs(mel).max()
        return mel, mel_noisy#, signal_noisy

In [5]:
mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=config.target_sample_rate,
                                                      n_fft=config.n_fft,
                                                      hop_length=config.hop_length,
                                                      n_mels=config.n_mels)


test_clean = clean_wavs_path[:10]
test_noisy = noisy_wavs_path[:10]

num_items = len(clean_wavs_path)

training_dataset = CustomDataset(clean_wavs_path[10:(num_items-300)], noisy_wavs_path[10:(num_items-300)], mel_spectrogram)
validation_dataset = CustomDataset(clean_wavs_path[(num_items-300):], noisy_wavs_path[(num_items-300):], mel_spectrogram)

print(len(training_dataset.root_clean))
print(len(validation_dataset.root_clean))

5379
300


In [6]:
trainloader = DataLoader(training_dataset, batch_size=config.batch_size)
validloader = DataLoader(validation_dataset, batch_size=config.batch_size)



# Model

In [7]:
import torch
import torch.nn as nn


class UNet(nn.Module):
    def __init__(self, chnls_in=1, chnls_out=1):
        super(UNet, self).__init__()
        self.down_conv_layer_1 = DownConvBlock(chnls_in, 64, norm=False)
        self.down_conv_layer_2 = DownConvBlock(64, 128)
        self.down_conv_layer_3 = DownConvBlock(128, 256)
        self.down_conv_layer_4 = DownConvBlock(256, 256, dropout=0.5)
        self.down_conv_layer_5 = DownConvBlock(256, 256, dropout=0.5)
        self.down_conv_layer_6 = DownConvBlock(256, 256, dropout=0.5)

        self.up_conv_layer_1 = UpConvBlock(256, 256, kernel_size=(2,3), stride=2, padding=0, dropout=0.5)# 256+256 6 5 kernel_size=(2, 3), stride=2, padding=0
        self.up_conv_layer_2 = UpConvBlock(512, 256, kernel_size=(2,3), stride=2, padding=0, dropout=0.5) # 256+256 1 4
        self.up_conv_layer_3 = UpConvBlock(512, 256, kernel_size=(2,3), stride=2, padding=0, dropout=0.5) # 2 3
        self.up_conv_layer_4 = UpConvBlock(512, 128, dropout=0.5) # 3 2
        self.up_conv_layer_5 = UpConvBlock(256, 64) # 4 1
        self.up_conv_layer_6 = UpConvBlock(512, 128)
        self.up_conv_layer_7 = UpConvBlock(256, 64)
        self.upsample_layer = nn.Upsample(scale_factor=2)
        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv_layer_1 = nn.Conv2d(128, chnls_out, 4, padding=1)
        self.activation = nn.Tanh()

    def forward(self, x):
        #print('x', x.shape)
        enc1 = self.down_conv_layer_1(x) # [4, 64, 32, 188]
        #print('1', enc1.shape)
        enc2 = self.down_conv_layer_2(enc1) # [4, 128, 16, 94]
        #print('2', enc2.shape)
        enc3 = self.down_conv_layer_3(enc2) # [4, 256, 8, 47]
        #print('3', enc3.shape)
        enc4 = self.down_conv_layer_4(enc3) # [4, 256, 4, 23]
        #print('4', enc4.shape)
        enc5 = self.down_conv_layer_5(enc4) # [4, 256, 2, 11]
        #print('5', enc5.shape)
        enc6 = self.down_conv_layer_6(enc5) # [4, 256, 1, 5]
        #print('6', enc6.shape)

        dec1 = self.up_conv_layer_1(enc6, enc5)# enc6: 256 + enc5: 256 [4, 512, 2, 11]
        #print('d1', dec1.shape)
        dec2 = self.up_conv_layer_2(dec1, enc4)# enc4: 256 + dec1=enc5*2: [4, 512, 4, 23]
        #print('d2', dec2.shape)
        dec3 = self.up_conv_layer_3(dec2, enc3)# enc3: 256 + dec2=enc4*2: [4, 512, 8, 47]
        #print('d3', dec3.shape)
        dec4 = self.up_conv_layer_4(dec3, enc2)# enc2: 128 + dec3=enc3*2: [4, 256, 16, 94]
        #print('d4', dec4.shape)
        dec5 = self.up_conv_layer_5(dec4, enc1)# enc1: 64 + dec4=enc1*2: [4, 128, 32, 188]
        #print('d5', dec5.shape)

        final = self.upsample_layer(dec5)
        final = self.zero_pad(final)
        final = self.conv_layer_1(final)
        #print(final.shape)
        return final

class UpConvBlock(nn.Module):
    def __init__(self, ip_sz, op_sz, kernel_size=4, stride= 2, padding=1 ,dropout=0.0):
        super(UpConvBlock, self).__init__()
        self.layers = [
            nn.ConvTranspose2d(ip_sz, op_sz, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.InstanceNorm2d(op_sz),
            nn.ReLU(),
        ]
        if dropout:
            self.layers += [nn.Dropout(dropout)]
        self.layers = nn.ModuleList(self.layers) # This is needed to let use use a gpu

    def forward(self, x, enc_ip):
        x = nn.Sequential(*(self.layers))(x)
        #print('x', x.shape)
        #print('enc', enc_ip.shape)
        op = torch.cat((x, enc_ip), 1)
        return op


class DownConvBlock(nn.Module):
    def __init__(self, ip_sz, op_sz, kernel_size=4, norm=True, dropout=0.0):
        super(DownConvBlock, self).__init__()
        self.layers = [nn.Conv2d(ip_sz, op_sz, kernel_size, 2, 1)]
        if norm:
            self.layers.append(nn.InstanceNorm2d(op_sz))
        self.layers += [nn.LeakyReLU(0.2)]
        if dropout:
            self.layers += [nn.Dropout(dropout)]
        self.layers = nn.ModuleList(self.layers) # This is needed to let use use a gpu

    def forward(self, x):
        op = nn.Sequential(*(self.layers))(x)
        return op


In [8]:
from tqdm import tqdm

def train(dataloader, model, epoch, loss_fn, optimizer, device):
    model.train()
    total_loss = 0.0
    for i, (clean, noisy) in enumerate(tqdm(dataloader)):
        clean = clean.to(device)
        noisy = noisy.to(device)

        optimizer.zero_grad()
        pred = model(noisy)
        curr_loss = loss_fn(pred, clean)
        curr_loss.backward()
        optimizer.step()

        total_loss += curr_loss
        if i % 1000 == 0:
            print('[Epoch number : %d, Mini-batches: %5d] loss: %.3f' %
                  (epoch + 1, i + 1, total_loss / 200))
            total_loss = 0.0

def val(dataloader, model, epoch, loss_fn, device):
    model.eval()
    total_loss = 0.0
    print('-------------------------')
    with torch.no_grad():
        for i, (clean, noisy) in enumerate(tqdm(dataloader)):
            clean = clean.to(device)
            noisy = noisy.to(device)

            output = model(noisy)
            loss = loss_fn(output, clean)
            total_loss += loss
            if i % 100 == 0:
                print('[Valid Epoch number : %d, Mini-batches: %5d] loss: %.3f' %
                      (epoch + 1, i + 1, total_loss / 200))
                total_loss = 0.0



In [9]:
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

model = UNet()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
loss_fn = torch.nn.functional.mse_loss

model.to(device)
print('model device:', next(model.parameters()).device)

for epoch in range(config.epochs):
    train(trainloader, model, epoch, loss_fn, optimizer, device)
    val(validloader, model, epoch, loss_fn, device)

cuda:0
model device: cuda:0


  1%|          | 1/85 [01:42<2:23:39, 102.62s/it]

[Epoch number : 1, Mini-batches:     1] loss: 38.446


100%|██████████| 85/85 [2:27:07<00:00, 103.85s/it]


-------------------------


 20%|██        | 1/5 [03:47<15:11, 227.92s/it]

[Valid Epoch number : 1, Mini-batches:     1] loss: 9.133


100%|██████████| 5/5 [10:01<00:00, 120.38s/it]
  1%|          | 1/85 [00:02<03:58,  2.84s/it]

[Epoch number : 2, Mini-batches:     1] loss: 37.999


100%|██████████| 85/85 [03:35<00:00,  2.54s/it]


-------------------------


 20%|██        | 1/5 [00:02<00:11,  2.81s/it]

[Valid Epoch number : 2, Mini-batches:     1] loss: 9.058


100%|██████████| 5/5 [00:11<00:00,  2.35s/it]
  1%|          | 1/85 [00:03<04:46,  3.41s/it]

[Epoch number : 3, Mini-batches:     1] loss: 37.595


100%|██████████| 85/85 [03:27<00:00,  2.44s/it]


-------------------------


 20%|██        | 1/5 [00:03<00:12,  3.17s/it]

[Valid Epoch number : 3, Mini-batches:     1] loss: 8.993


100%|██████████| 5/5 [00:11<00:00,  2.26s/it]
  1%|          | 1/85 [00:02<04:11,  3.00s/it]

[Epoch number : 4, Mini-batches:     1] loss: 37.244


100%|██████████| 85/85 [03:21<00:00,  2.37s/it]


-------------------------


 20%|██        | 1/5 [00:02<00:09,  2.41s/it]

[Valid Epoch number : 4, Mini-batches:     1] loss: 8.939


100%|██████████| 5/5 [00:11<00:00,  2.34s/it]
  1%|          | 1/85 [00:02<03:43,  2.66s/it]

[Epoch number : 5, Mini-batches:     1] loss: 36.936


100%|██████████| 85/85 [03:23<00:00,  2.39s/it]


-------------------------


 20%|██        | 1/5 [00:02<00:09,  2.40s/it]

[Valid Epoch number : 5, Mini-batches:     1] loss: 8.892


100%|██████████| 5/5 [00:11<00:00,  2.33s/it]
  1%|          | 1/85 [00:02<03:38,  2.60s/it]

[Epoch number : 6, Mini-batches:     1] loss: 36.671


100%|██████████| 85/85 [03:21<00:00,  2.37s/it]


-------------------------


 20%|██        | 1/5 [00:03<00:13,  3.37s/it]

[Valid Epoch number : 6, Mini-batches:     1] loss: 8.853


100%|██████████| 5/5 [00:11<00:00,  2.31s/it]


In [10]:
PATH = '/content/drive/MyDrive/trained_v55555'
torch.save(model.state_dict(), PATH)

In [11]:
import IPython.display as ipd

model = UNet()
model.load_state_dict(torch.load(PATH))
model.to(device)
model.eval()
def preprocess_audio(audio_path, target_sample_rate=48000, duration=4):
    signal, sr = torchaudio.load(audio_path)
    if sr != target_sample_rate:
        resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
        signal = resampler(signal)

    if signal.shape[0] > 1:
        signal = torch.mean(signal, axis=0, keepdim=True)

    num_samples = target_sample_rate * duration
    if signal.shape[1] > num_samples:
        signal = signal[:, :num_samples]
    elif signal.shape[1] < num_samples:
        signal = F.pad(signal, (0, num_samples - signal.shape[1]))

    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=1024,
        hop_length=512,
        n_mels=64
    )
    mel = mel_transform(signal)
    mel = mel / torch.abs(mel).max()  # Нормализация
    return mel.unsqueeze(0).to(device)  # Добавляем batch-размер и переносим на устройство

def denoise_audio(audio_path):
    noisy_mel = preprocess_audio(audio_path)
    with torch.no_grad():
        clean_mel = model(noisy_mel)

    # Преобразование Mel-спектрограммы обратно в аудио
    inverse_mel = torchaudio.transforms.InverseMelScale(
        sample_rate=config.target_sample_rate,
        n_stft=config.n_fft // 2 + 1,
        n_mels=config.n_mels
    )
    griffin_lim = torchaudio.transforms.GriffinLim(
        n_fft=config.n_fft,
        hop_length=config.hop_length
    )

    audio = griffin_lim(inverse_mel(clean_mel.squeeze().cpu()))
    return audio


noisy_audio_path = '/content/drive/MyDrive/speach_dataset/noisy/p226_001.wav'
clean_audio_path = '/content/drive/MyDrive/speach_dataset/clean/p226_001.wav'

denoised_audio = denoise_audio(noisy_audio_path)


clean_sample, _ = torchaudio.load(clean_audio_path)
noisy_sample, _ = torchaudio.load(noisy_audio_path)

print("Original clean audio:")
display(ipd.Audio(clean_sample.numpy(), rate=config.target_sample_rate))

print("\nNoisy audio:")
display(ipd.Audio(noisy_sample.numpy(), rate=config.target_sample_rate))

print("\nDenoised audio:")
display(ipd.Audio(denoised_audio.numpy(), rate=config.target_sample_rate))

Original clean audio:



Noisy audio:



Denoised audio:
