In [None]:
import zipfile
import time

import torch
from torch import nn
from torch.utils.data import Dataset
import torchaudio
import torchaudio.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np

import scipy

path = ""

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [2]:
class NoisySpeech(Dataset):
    def __init__(self, path, device='cpu'):
        self.path = path
        self.device = device
        self.__clean_zip__ = 'clean_trainset_28spk_wav.zip'
        self.__noisy_zip__ = 'noisy_trainset_28spk_wav.zip'
        with zipfile.ZipFile(self.path+self.__clean_zip__, 'r') as clean_zip:
            cleanlist = clean_zip.namelist()
        self.__clean_wav_list__ = [s for s in cleanlist if s[-4:] == '.wav']
        with zipfile.ZipFile(self.path+self.__noisy_zip__, 'r') as noisy_zip:
            noisylist = noisy_zip.namelist()
        self.__noisy_wav_list__ = [s for s in noisylist if s[-4:] == '.wav']

    def __len__(self):
        return len(self.__noisy_wav_list__)

    def __getitem__(self, idx):
        with zipfile.ZipFile(self.path+self.__clean_zip__, 'r') as clean_zip:
            with clean_zip.open(self.__clean_wav_list__[idx]) as clean_wav_file:                
                sr, np_clean_audio = scipy.io.wavfile.read(clean_wav_file)        
        with zipfile.ZipFile(self.path+self.__noisy_zip__, 'r') as noisy_zip:
            with noisy_zip.open(self.__noisy_wav_list__[idx]) as noisy_wav_file:                
                sr, np_noisy_audio = scipy.io.wavfile.read(noisy_wav_file)            
        return torch.tensor(np_noisy_audio), torch.tensor(np_clean_audio), torch.tensor(sr)

def CollateNoisySpeech(itemlist):
    buffer_len = 6*48000 # Maximum length is 60 sec at 48kHz
    sample_len = buffer_len
    noisy_batch, clean_batch = torch.Tensor(0), torch.Tensor(0)
    
    for noisy, clean, sr in itemlist:
        sample_len = min(len(noisy),sample_len)
        noisy_padded, clean_padded = torch.zeros(buffer_len), torch.zeros(buffer_len)
        noisy_padded[0:sample_len], clean_padded[0:sample_len] = noisy[0:sample_len], clean[0:sample_len]
        noisy_batch = torch.cat((noisy_batch, noisy_padded.unsqueeze(0)))
        clean_batch = torch.cat((clean_batch, clean_padded.unsqueeze(0)))
   
    return noisy_batch[:,0:sample_len], clean_batch[:,0:sample_len], sr

In [3]:
# Load data, create dataloaders, set parameters
dataset = NoisySpeech(path,device=device)
_, _, input_samplerate = dataset.__getitem__(0)
resample_samplerate = 16000
window_length_ms = 30
batch_size = 10
n_fft = (2*window_length_ms * resample_samplerate) // 2000
hop_length = n_fft // 2

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=CollateNoisySpeech)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=CollateNoisySpeech)

In [4]:
class PreProcessing(torch.nn.Module):
    def __init__(
        self,
        input_samplerate    = 16000,
        resample_samplerate = 16000,
        window_length_ms    = 30
    ):
        super().__init__()
        self.resample = torchaudio.transforms.Resample(orig_freq=input_samplerate, new_freq=resample_samplerate)
        n_fft = (2*window_length_ms * resample_samplerate) // 2000
        hop_length = n_fft // 2
        self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft,power=None,hop_length=hop_length)
        self.output_size = (n_fft+2)//2
        

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        
        # Resample the input
        # Convert to power spectrogram
        resampled = self.resample(waveform)
        spec = self.spec(resampled)
        
        return spec

        

class PostProcessing(torch.nn.Module):
    def __init__(
        self,
        output_samplerate   = 16000,
        resample_samplerate = 16000,
        window_length_ms    = 30
    ):
        super().__init__()
        self.resample = torchaudio.transforms.Resample(orig_freq=resample_samplerate, new_freq=output_samplerate)
        n_fft = (2*window_length_ms * resample_samplerate) // 2000
        hop_length = n_fft // 2
        self.invspec = torchaudio.transforms.InverseSpectrogram(n_fft=n_fft,hop_length=hop_length)
       
    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        # Convert to power spectrogram
        # Resample the output
        waveform = self.invspec(spec)
        resampled = self.resample(waveform)
        
        return resampled

In [6]:
# Complex 2d conv (code from: https://github.com/pheepa/DCUnet/tree/master)
class CConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.real_conv = nn.Conv2d(in_channels=self.in_channels, 
                                   out_channels=self.out_channels, 
                                   kernel_size=self.kernel_size, 
                                   padding=self.padding, 
                                   stride=self.stride)
        
        self.im_conv = nn.Conv2d(in_channels=self.in_channels, 
                                 out_channels=self.out_channels, 
                                 kernel_size=self.kernel_size, 
                                 padding=self.padding, 
                                 stride=self.stride)
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_conv.weight)
        nn.init.xavier_uniform_(self.im_conv.weight)
        
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        c_real = self.real_conv(x_real) - self.im_conv(x_im)
        c_im = self.im_conv(x_real) + self.real_conv(x_im)
        
        output = torch.stack([c_real, c_im], dim=-1)
        return output

In [7]:
# Complex transpose 2d conv (code from: https://github.com/pheepa/DCUnet/tree/master), modified
class CConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding=0, padding=0):
        super().__init__()
        
        self.in_channels = in_channels

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.output_padding = output_padding
        self.padding = padding
        self.stride = stride
        
        self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding,
                                            padding=self.padding,
                                            stride=self.stride)
        
        self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                            out_channels=self.out_channels, 
                                            kernel_size=self.kernel_size, 
                                            output_padding=self.output_padding, 
                                            padding=self.padding,
                                            stride=self.stride)
        
        
        # Glorot initialization.
        nn.init.xavier_uniform_(self.real_convt.weight)
        nn.init.xavier_uniform_(self.im_convt.weight)
        
        
    def forward(self, x, output_size):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        ct_real = self.real_convt(x_real, output_size) - self.im_convt(x_im, output_size)
        ct_im = self.im_convt(x_real, output_size) + self.real_convt(x_im, output_size)
        
        output = torch.stack([ct_real, ct_im], dim=-1)
        return output

In [8]:
# Complex 2d batch norm (code from: https://github.com/pheepa/DCUnet/tree/master)
class CBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        
        self.real_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                      affine=self.affine, track_running_stats=self.track_running_stats)
        self.im_b = nn.BatchNorm2d(num_features=self.num_features, eps=self.eps, momentum=self.momentum,
                                    affine=self.affine, track_running_stats=self.track_running_stats) 
        
    def forward(self, x):
        x_real = x[..., 0]
        x_im = x[..., 1]
        
        n_real = self.real_b(x_real)
        n_im = self.im_b(x_im)  
        
        output = torch.stack([n_real, n_im], dim=-1)
        return output

In [9]:
# Encoder block (code from: https://github.com/pheepa/DCUnet/tree/master)
class Encoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45, padding=(0,0)):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.padding = padding

        self.cconv = CConv2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x):
        conved = self.cconv(x)
        normed = self.cbn(conved)
        acted = self.leaky_relu(normed)
        
        return acted

In [10]:
# Decoder block (code from: https://github.com/pheepa/DCUnet/tree/master), modified
class Decoder(nn.Module):
    def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45,
                 output_padding=(0,0), padding=(0,0), last_layer=False):
        super().__init__()
        
        self.filter_size = filter_size
        self.stride_size = stride_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_padding = output_padding
        self.padding = padding
        
        self.last_layer = last_layer
        
        self.cconvt = CConvTranspose2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                             kernel_size=self.filter_size, stride=self.stride_size, output_padding=self.output_padding, padding=self.padding)
        
        self.cbn = CBatchNorm2d(num_features=self.out_channels) 
        
        self.leaky_relu = nn.LeakyReLU()
            
    def forward(self, x, output_size):
        
        conved = self.cconvt(x, output_size)
        
        if not self.last_layer:
            normed = self.cbn(conved)
            output = self.leaky_relu(normed)
        else:
            m_phase = conved / (torch.abs(conved) + 1e-8)
            m_mag = torch.tanh(torch.abs(conved))
            output = m_phase * m_mag
            
        return output

In [11]:
#  Deep Complex U-Net (code from: https://github.com/pheepa/DCUnet/tree/master), modified
class DCUnet10(nn.Module):
    def __init__(self):
        super().__init__()    
        # downsampling/encoding
        self.downsample0 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=1, out_channels=45)
        self.downsample1 = Encoder(filter_size=(7,5), stride_size=(2,2), in_channels=45, out_channels=90)
        self.downsample2 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample3 = Encoder(filter_size=(5,3), stride_size=(2,2), in_channels=90, out_channels=90)
        self.downsample4 = Encoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90)
        
        # upsampling/decoding
        self.upsample0 = Decoder(filter_size=(5,3), stride_size=(2,1), in_channels=90, out_channels=90)
        self.upsample1 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample2 = Decoder(filter_size=(5,3), stride_size=(2,2), in_channels=180, out_channels=90)
        self.upsample3 = Decoder(filter_size=(7,5), stride_size=(2,2), in_channels=180, out_channels=45)
        self.upsample4 = Decoder(filter_size=(7,5), stride_size=(2,2), in_channels=90, output_padding=(0,1),
                                 out_channels=1, last_layer=True)
        
        
    def forward(self, x):
        x = torch.view_as_real(x.unsqueeze(1))
        # downsampling/encoding
        d0 = self.downsample0(x)
        d1 = self.downsample1(d0) 
        d2 = self.downsample2(d1)       
        d3 = self.downsample3(d2)       
        d4 = self.downsample4(d3)
        
        # upsampling/decoding 
        u0 = self.upsample0(d4, output_size=d3[..., 0].size())
        # skip-connection
        c0 = torch.cat((u0, d3), dim=1)
        u1 = self.upsample1(c0, output_size=d2[..., 0].size())
        c1 = torch.cat((u1, d2), dim=1)
        u2 = self.upsample2(c1, output_size=d1[..., 0].size())
        c2 = torch.cat((u2, d1), dim=1)
        u3 = self.upsample3(c2, output_size=d0[..., 0].size())
        c3 = torch.cat((u3, d0), dim=1)
        
        gains = self.upsample4(c3, output_size=x[..., 0].size())
        
        # u4 - the mask
        estimated_spec = gains * x
        
        return torch.view_as_complex(estimated_spec).squeeze(1), torch.view_as_complex(gains).squeeze(1)

In [12]:
def train(dataset, dataloader, model, preprocessor, loss_fn, optimizer, epochs=1):
    size = len(dataset)
    model.train()
    start_time = time.perf_counter()

    for epoch in range(epochs):
        for batch, (noisy_batch, clean_batch, _) in enumerate(dataloader):
            noisy_spec = preprocessor(noisy_batch).to(device)
            clean_spec = preprocessor(clean_batch).to(device)
            batch_size = noisy_batch.shape[0]

            est_clean_spec, _ = model(noisy_spec)

            loss = loss_fn(noisy_spec, est_clean_spec, clean_spec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if (batch+1) % 10 == 0:
                torch.save(model, "DCUnet10.pt")
                curr_time = time.perf_counter()
                loss, current = loss.item(), 1 + (batch)*batch_size + epoch*size
                print(f"loss: {loss:>7f} [{current:>5d}/{size*epochs:>5d}] at {curr_time-start_time:>5f} sec")
                start_time = curr_time

In [14]:
# Loss function (code from: https://github.com/pheepa/DCUnet/tree/master), modified
def wsdr_fn(x_, y_pred_, y_true_, eps=1e-8):
    # to time-domain waveform
    y_true = torch.istft(y_true_, n_fft=n_fft, hop_length=hop_length)
    x = torch.istft(x_, n_fft=n_fft, hop_length=hop_length)
    y_pred = torch.istft(y_pred_, n_fft=n_fft, hop_length=hop_length)

    def sdr_fn(true, pred, eps=1e-8):
        num = torch.sum(true * pred, dim=1)
        den = torch.norm(true, p=2, dim=1) * torch.norm(pred, p=2, dim=1)
        return -(num / (den + eps))

    # true and estimated noise
    z_true = x - y_true
    z_pred = x - y_pred

    a = torch.sum(y_true**2, dim=1) / (torch.sum(y_true**2, dim=1) + torch.sum(z_true**2, dim=1) + eps)
    wSDR = a * sdr_fn(y_true, y_pred) + (1 - a) * sdr_fn(z_true, z_pred)
    return torch.mean(wSDR)

In [None]:
# preprocessing stft not correct padding to max length atm
enhancer = DCUnet10().to(device)

loss_fn = wsdr_fn
epochs = 5
optimizer = torch.optim.Adam(enhancer.parameters(), lr=0.001)
preprocessor = PreProcessing(input_samplerate=input_samplerate, resample_samplerate=resample_samplerate, window_length_ms=window_length_ms)

train(dataset, train_dataloader, enhancer, preprocessor, loss_fn, optimizer=optimizer, epochs=5)

In [16]:
torch.save(enhancer, "DCUnet10.pt")

In [17]:
noisy_batch, clean_batch, sr = next(iter(test_dataloader))
noisy_spec = preprocessor(noisy_batch).to(device)
clean_spec = preprocessor(clean_batch).to(device)

postprocessor = PostProcessing(output_samplerate=input_samplerate)

enhancer.eval()
with torch.no_grad():
    enhanced_spec, gains = enhancer(noisy_spec)

enhanced_batch = postprocessor(enhanced_spec.to('cpu'))
clean_audio = postprocessor(clean_spec.to('cpu'))
noisy_audio = postprocessor(noisy_spec.to('cpu'))

In [None]:
idx = np.random.randint(batch_size)
print(idx)
plt.figure(figsize=(8,3))
plt.subplot(131)
plt.imshow(noisy_spec[idx,:,:].to('cpu').abs().log().mT.numpy(),origin='lower', aspect="auto")
plt.subplot(132)
plt.imshow(enhanced_spec[idx,:,:].to('cpu').abs().log().mT.detach().numpy(),origin='lower', aspect="auto")
plt.subplot(133)
plt.imshow(clean_spec[idx,:,:].to('cpu').abs().log().mT.numpy(),origin='lower', aspect="auto")
plt.show()

import IPython
IPython.display.display(IPython.display.Audio(noisy_batch[idx,:].detach().numpy(),rate=int(sr)))
IPython.display.display(IPython.display.Audio(enhanced_batch[idx,:].detach().numpy(),rate=int(sr)))
IPython.display.display(IPython.display.Audio(clean_batch[idx,:],rate=int(sr)))