In [3]:
import sys
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import librosa as libr
from IPython.display import Audio

In [4]:
import numpy as np
import pandas as pd
import os
import random
import librosa as libr
import librosa.display as disp
from IPython.display import Audio
from scipy.io import wavfile

In [16]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchaudio.transforms import Resample
from torchsummary import summary
from torch.utils.data import SubsetRandomSampler,Subset,DataLoader
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
from torchmetrics.audio import pesq as PESQ

In [6]:
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

In [10]:
config = {
    'sample_rate':48000,
    'max_duration':4,
    'batch_size':16,
    'learning_rate': 0.0001,
    'lr_decay':(0.9, 0.999),
    'epochs':200,
}

In [7]:
train_noisy_data_path = "./noisy_trainset_28spk_wav/"
train_clean_data_path = "./clean_trainset_28spk_wav/"
test_noisy_data_path = "./noisy_testset_wav/"
test_clean_data_path = "./clean_testset_wav/"

In [8]:
class AudioDataset(Dataset):
    def __init__(self,noisy_path,clean_path, transform=None, sample_rate=None,max_duration=None):

        #get file paths
        noisy_all_items = os.listdir(noisy_path)
        noisy_files = [item for item in noisy_all_items if os.path.isfile(os.path.join(noisy_path, item)) and item.lower().endswith('.wav')]
        noisy_file_paths = [os.path.join(noisy_path, file_name) for file_name in noisy_files]
        clean_file_paths = [os.path.join(clean_path, file_name) for file_name in noisy_files]

        #initialize variables
        self.noisy_data = noisy_file_paths
        self.clean_data = clean_file_paths
        self.transform = transform
        self.sample_rate = sample_rate
        self.max_duration = max_duration
        self.num_samples = sample_rate*max_duration

    def __len__(self):
        return len(self.noisy_data)

    def __getitem__(self, idx):
        noisy_waveform, sr = torchaudio.load(self.noisy_data[idx])
        if sr != config['sample_rate']:
            resampler = torchaudio.transforms.Resample(sr, config['sample_rate'])
            noisy_waveform = resampler(noisy_waveform)
        noisy_waveform = torch.tensor(noisy_waveform.numpy().reshape(-1))
        if noisy_waveform.shape[0] < self.num_samples:
            num_missing_samples = self.num_samples - noisy_waveform.shape[0]
            noisy_waveform = F.pad(noisy_waveform, (0, num_missing_samples))
        noisy_waveform = torch.tensor(noisy_waveform[:self.num_samples])

        clean_waveform, sr = torchaudio.load(self.clean_data[idx])
        if sr != config['sample_rate']:
            resampler = torchaudio.transforms.Resample(sr, config['sample_rate'])
            clean_waveform = resampler(clean_waveform)
        clean_waveform = torch.tensor(clean_waveform.numpy().reshape(-1))
        if clean_waveform.shape[0] < self.num_samples:
            num_missing_samples = self.num_samples - clean_waveform.shape[0]
            clean_waveform = F.pad(torch.tensor(clean_waveform), (0, num_missing_samples))
        clean_waveform = torch.tensor(clean_waveform[:self.num_samples])  
              
        return noisy_waveform, clean_waveform

In [12]:
def split_dataset(dataset, num=16):
    dataset_indices = list(range(len(dataset)))
    random.shuffle(dataset_indices)
    train_indices = dataset_indices[num:]
    val_indices = dataset_indices[:num]
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    return train_dataset,val_dataset

In [13]:
train_dataset = AudioDataset(noisy_path=train_noisy_data_path,
                       clean_path=train_clean_data_path,
                       sample_rate=config['sample_rate'],
                       max_duration=config['max_duration'],
                       )
test_dataset = AudioDataset(noisy_path=test_noisy_data_path,
                       clean_path=test_clean_data_path,
                       sample_rate=config['sample_rate'],
                       max_duration=config['max_duration'],
                       )

train_dataset, val_dataset = split_dataset(train_dataset,16)

train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

print(f"Dataloaders: {train_dataloader, test_dataloader}")
print(f"Length of train dataloader: {len(train_dataloader)} batches of {config['batch_size']}")
print(f"Length of val dataloader: {len(val_dataloader)} batches of {config['batch_size']}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {config['batch_size']}")

Dataloaders: (<torch.utils.data.dataloader.DataLoader object at 0x7fb4a7c18310>, <torch.utils.data.dataloader.DataLoader object at 0x7fb4a7ff6fd0>)
Length of train dataloader: 723 batches of 16
Length of val dataloader: 1 batches of 16
Length of test dataloader: 52 batches of 16


In [14]:
def helper(noise, clean):
    print(noise.shape, clean.shape)

In [66]:

def denoise_audio(signal,clean_signal ,sample_rate=config['sample_rate'] ,t1=0.0,t2=1.0,save=0):
    fft = np.fft.fft(signal)
    magnitude = np.abs(fft)
    frequency =np.linspace(0,sample_rate,len(magnitude))
    psd = fft * np.conj(fft) /  len(signal)
    
    idx = [(psd >= t1) &  (psd <= t2)]
    psd_filtered = psd * idx
    filtered_fft = idx * fft

    filtered_signal = np.real(np.fft.ifft(filtered_fft))
    filtered_signal = filtered_signal.squeeze(0).astype(np.float32)
    filtered_signal = torch.Tensor(filtered_signal)

    filt_magnitude = np.abs(filtered_signal)
    filt_frequency =np.linspace(0,sample_rate,len(filt_magnitude))
    
    if save:
        filtered_sample_path = 'denoised.wav'
        wavfile.write(filtered_sample_path, sample_rate, filtered_signal)
    try:
        ssnr = ScaleInvariantSignalNoiseRatio()
        snr_clean = ssnr(signal,clean_signal).item() 
        snr_denoised = ssnr(filtered_signal, signal).item()

        #compute_snr(clean_signal,signal)
        pesq = PESQ.PerceptualEvaluationSpeechQuality(fs=16000, mode='wb')
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(16000)
            signal = resampler(signal)
            clean_signal = resampler(clean_signal)
            filtered_signal = resampler(filtered_signal)

        pesq_clean = pesq(clean_signal,signal)#compute_snr(clean_signal,filtered_signal)
        pesq_denoised = pesq(clean_signal, filtered_signal)

        criterion = nn.MSELoss()
        clean_loss = criterion(signal.float(),clean_signal .float())
        denoise_loss = criterion(filtered_signal.float(),clean_signal .float())
        
        return (snr_clean,snr_denoised), (pesq_clean,pesq_denoised), (clean_loss,denoise_loss)
    except Exception as e:
        pass

In [67]:
ssnr_scores_clean,ssnr_scores_denoise = [],[]
pesq_scores_clean, pesq_scores_denoise = [],[]
loss_scores_clean, loss_scores_denoise = [],[]
for i, data in enumerate(test_dataloader, 0):
      # get the inputs; data is a list of [inputs, labels]
      X, y = data
      try:
            ssnr_res, pesq_res, loss_res = denoise_audio(X,y)
            ssnr_scores_clean.append(ssnr_res[0])
            ssnr_scores_denoise.append(ssnr_res[1])
            pesq_scores_clean.append(pesq_res[0])
            pesq_scores_denoise.append(pesq_res[1])
            loss_scores_clean.append(loss_res[0])
            loss_scores_denoise.append(loss_res[1])
      except Exception as e:
        pass


In [68]:
np.mean(ssnr_scores_clean),np.mean(ssnr_scores_denoise)

(8.610924019533044, -25.502140606150906)

In [69]:
np.mean(pesq_scores_clean),np.mean(pesq_scores_denoise)

(1.832648, 1.0372517)

In [75]:
np.mean(loss_scores_clean),np.mean(loss_scores_denoise)

(0.0007543463, 0.002589878)

In [71]:
class BaseUNetModel(nn.Module):
    def __init__(self):
        super(BaseUNetModel,self).__init__()
        
        self.down_conv_layer_1 = nn.Sequential(
            nn.Conv2d(1, 64 , kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )        
 
        self.down_conv_layer_2 = nn.Sequential(
            nn.Conv2d(64, 128 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            
        ) 
    
        self.down_conv_layer_3 = nn.Sequential(
            nn.Conv2d(128, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            
        ) 
        
        self.down_conv_layer_4 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        )

        self.down_conv_layer_5 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        ) 

        self.down_conv_layer_6 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        ) 
        
        self.up_conv_layer_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )
        
        self.up_conv_layer_2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )
        
        self.up_conv_layer_3 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )

        self.up_conv_layer_4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size=(4), stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )

        self.up_conv_layer_5 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=(4), stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(), 
        )
                
        self.upsample_layer = nn.Upsample(scale_factor=2)
        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv_layer_final = nn.Conv2d(128, 1, kernel_size=4, padding=1)

    def forward(self,x):
        # print("x", x.shape)
        enc1 = self.down_conv_layer_1(x)
        # print("enc1", enc1.shape)
        enc2 = self.down_conv_layer_2(enc1) 
        # print("enc2", enc2.shape)
        enc3 = self.down_conv_layer_3(enc2) 
        # print("enc3", enc3.shape)
        enc4 = self.down_conv_layer_4(enc3)
        # print("enc4", enc4.shape)
        enc5 = self.down_conv_layer_5(enc4)
        # print("enc5", enc5.shape)
        enc6 = self.down_conv_layer_6(enc5)
        # print("enc6", enc6.shape)

        dec1 = self.up_conv_layer_1(enc6)
        # print("dec1", dec1.shape)
        dec15 = torch.cat((dec1, enc5), 1)
        # print("dec15", dec15.shape)

        dec2 = self.up_conv_layer_2(dec15)
        # print("dec2", dec2.shape)
        dec24 = torch.cat((dec2, enc4), 1)
        # print("dec24", dec24.shape)

        dec3 = self.up_conv_layer_3(dec24)
        # print("dec3", dec3.shape)
        dec33 = torch.cat((dec3, enc3), 1)
        # print("dec33", dec33.shape)

        dec4 = self.up_conv_layer_4(dec33)
        # print("dec4", dec4.shape)
        dec42 = torch.cat((dec4, enc2), 1)
        # print("dec42", dec42.shape)

        dec5 = self.up_conv_layer_5(dec42)
        # print("dec5", dec5.shape)
        dec51 = torch.cat((dec5, enc1), 1)
        # print("dec51", dec51.shape)

        final = self.upsample_layer(dec51)
        # print("up", final.shape)
        final = self.zero_pad(final)
        # print("zero", final.shape)
        final = self.conv_layer_final(final)
        # print("final", final.shape)
        return final

In [82]:
config = {
    'sample_rate':48000,
    'max_duration':4,
    'n_fft':1024,
    'hop_length':512,
    'n_mels':64,
    'batch_size':128,
    'learning_rate': 5e-5,
    'epochs':200,
    'val_split':0.9,
}

In [83]:

test_dataset = AudioDataset(noisy_path=test_noisy_data_path,
                       clean_path=test_clean_data_path,
                       sample_rate=config['sample_rate'],
                       max_duration=config['max_duration'],
                       )

test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

In [86]:
model1 = BaseUNetModel()
model1.load_state_dict(torch.load('./model'))
model1 = model1.cuda()  

In [88]:
model1.eval()
criterion = nn.MSELoss()
test_loss,test_ssnr,test_pesq = 0.0,0.0,0.0
with torch.no_grad():
  for i, data in enumerate(test_dataloader, 0):
    X, y = data
    X = X.unsqueeze(1)
    X, y = Variable(X.cuda()), Variable(y.cuda())
    preds = model1(X)
    preds = preds.squeeze(1).cuda()

    loss = criterion(preds.float().squeeze(1), y.float())
    # ssnr_score = get_ssnr(preds, clean)
    # pesq_score = get_pesq(preds, clean)

    test_loss += loss.item()
    # test_ssnr += ssnr_score
    # test_pesq += pesq_score

print(f'Test: Loss: {test_loss:.4f}')#SSNR: {test_ssnr:.4f} PESQ: {test_pesq:.4f}')

torch.Size([128, 1, 192000]) torch.Size([128, 192000])


RuntimeError: Given groups=1, weight of size [64, 1, 4, 4], expected input[1, 128, 1, 192000] to have 1 channels, but got 128 channels instead