In [1]:
import sys
import torch
import torch.nn as nn
import torchaudio
import torch.nn.functional as F
import librosa as libr
from IPython.display import Audio

In [2]:
class BaseUNetModel(nn.Module):
    def __init__(self):
        super(BaseUNetModel,self).__init__()
        self.down_conv_layer_1 = nn.Sequential(
            nn.Conv2d(1, 64 , kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )        
 
        self.down_conv_layer_2 = nn.Sequential(
            nn.Conv2d(64, 128 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            
        ) 
    
        self.down_conv_layer_3 = nn.Sequential(
            nn.Conv2d(128, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            
        ) 
        
        self.down_conv_layer_4 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        )

        self.down_conv_layer_5 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        ) 

        self.down_conv_layer_6 = nn.Sequential(
            nn.Conv2d(256, 256 , kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),   
        ) 
        
        self.up_conv_layer_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )
        
        self.up_conv_layer_2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )
        
        self.up_conv_layer_3 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=(2,3), stride=2, padding=0),
            nn.InstanceNorm2d(256),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )

        self.up_conv_layer_4 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, kernel_size=(4), stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(), 
            nn.Dropout(0.5), 
        )

        self.up_conv_layer_5 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=(4), stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(), 
        )
                
        self.upsample_layer = nn.Upsample(scale_factor=2)
        self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0))
        self.conv_layer_final = nn.Conv2d(128, 1, kernel_size=4, padding=1)

    def forward(self,x):
        # print("x", x.shape)
        enc1 = self.down_conv_layer_1(x)
        # print("enc1", enc1.shape)
        enc2 = self.down_conv_layer_2(enc1) 
        # print("enc2", enc2.shape)
        enc3 = self.down_conv_layer_3(enc2) 
        # print("enc3", enc3.shape)
        enc4 = self.down_conv_layer_4(enc3)
        # print("enc4", enc4.shape)
        enc5 = self.down_conv_layer_5(enc4)
        # print("enc5", enc5.shape)
        enc6 = self.down_conv_layer_6(enc5)
        # print("enc6", enc6.shape)

        dec1 = self.up_conv_layer_1(enc6)
        # print("dec1", dec1.shape)
        dec15 = torch.cat((dec1, enc5), 1)
        # print("dec15", dec15.shape)

        dec2 = self.up_conv_layer_2(dec15)
        # print("dec2", dec2.shape)
        dec24 = torch.cat((dec2, enc4), 1)
        # print("dec24", dec24.shape)

        dec3 = self.up_conv_layer_3(dec24)
        # print("dec3", dec3.shape)
        dec33 = torch.cat((dec3, enc3), 1)
        # print("dec33", dec33.shape)

        dec4 = self.up_conv_layer_4(dec33)
        # print("dec4", dec4.shape)
        dec42 = torch.cat((dec4, enc2), 1)
        # print("dec42", dec42.shape)

        dec5 = self.up_conv_layer_5(dec42)
        # print("dec5", dec5.shape)
        dec51 = torch.cat((dec5, enc1), 1)
        # print("dec51", dec51.shape)

        final = self.upsample_layer(dec51)
        # print("up", final.shape)
        final = self.zero_pad(final)
        # print("zero", final.shape)
        final = self.conv_layer_final(final)
        # print("final", final.shape)
        return final

In [3]:
PATH = './model'
model = BaseUNetModel()
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [4]:
sample_path = './noisy_testset_wav/p232_001.wav'
noisy_sample, sr = torchaudio.load(sample_path)
Audio(noisy_sample,rate=48000)

In [9]:
sample_path = './clean_testset_wav/p232_001.wav'
clean_sample, sr = torchaudio.load(sample_path)
Audio(clean_sample,rate=48000)

In [5]:
config = {
    'sample_rate':48000,
    'max_duration':4,
    'n_fft':1024,
    'hop_length':512,
    'n_mels':64,
    'num_samples' : 48000*4

}

In [6]:
mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=config['sample_rate'],
                                                      n_fft=config['n_fft'],
                                                      hop_length=config['hop_length'],
                                                      n_mels=config['n_mels'])

In [7]:
def test_sample(sample,model=model,transform=None):
    #pre-process audio
    # noisy_waveform, sr = torchaudio.load(self.noisy_data[idx])
    noisy_waveform = torch.tensor(sample.numpy().reshape(-1))
    if noisy_waveform.shape[0] < config['num_samples']:
        num_missing_samples = config['num_samples'] - noisy_waveform.shape[0]
        noisy_waveform = F.pad(noisy_waveform, (0, num_missing_samples))
    noisy_waveform = torch.tensor(noisy_waveform[:config['num_samples']])
    if transform:
        noisy_waveform = transform(noisy_waveform)
    noisy_waveform = noisy_waveform.unsqueeze(dim=0).unsqueeze(dim=0)

    noisy_waveform = noisy_waveform.cuda()
    model = model.cuda()
    
    model.eval()
    pred = model(noisy_waveform).squeeze(1).squeeze(0)
    mel = pred.cpu().data.numpy()
    S = libr.feature.inverse.mel_to_stft(mel, sr=config['sample_rate'])
    pred_audio = libr.griffinlim(S)
    return pred_audio

In [8]:
denoise_audio = test_sample(sample=noisy_sample,model=model,transform=mel_spectrogram)
Audio(denoise_audio,rate=48000)

  
