# Imports

In [1]:
import os
PATH = r'/content/drive/My Drive/Saarthi.ai Assignment/'
os.chdir(PATH)
import librosa
import time
import numpy as np


import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn.modules import Module
from torch.utils.data import DataLoader
from torch.utils import data
from torch.autograd import Variable
from torch import optim

# !pip install soundfile

import soundfile
from scipy import signal
from scipy.io import wavfile

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



# Audio Pre-processing

In [2]:
# Audio inputted is sliced into segments of length 8192, with 50% overlap for train set, 0% overlap for inference. Padded 0s for packing the last 8192 samples.

def slice_signal(filepath, window_size, stride, sample_rate = 8000):
  wav, sr = librosa.load(filepath, sr = sample_rate)
  if sr != 8000: print(f'Sampling rate error. Current rate is {sr}')
  n_samples = wav.shape[0]  # contains simple amplitudes
  hop = int(window_size * stride)
  slices = []
  wav = np.append(wav, [0] * (8192 - (len(wav) % 8192)))
  for end_idx in range(window_size, len(wav) + 1, hop):
      start_idx = end_idx - window_size
      slice_sig = wav[start_idx:end_idx]
      slices.append(slice_sig)
  return slices

In [None]:
# Split audio each file's segments into single batches (.npy files) for easy loading during training.

def serialize():

  start_time = time.time()
  window_size = 2 ** 13 # 8192 samples
  sample_rate = 8000
  stride = 0.5

  clean_data_path = f'{PATH}/clean_trainset_28spk_wav/'
  noisy_data_path = f'{PATH}/noisy_trainset_28spk_wav/'

  for file in os.listdir(clean_data_path):

    clean_filepath = f'{clean_data_path}{file}'
    noisy_filepath = f'{noisy_data_path}{file}'

    clean_sliced = slice_signal(clean_filepath, window_size, stride, sample_rate)
    noisy_sliced = slice_signal(noisy_filepath, window_size, stride, sample_rate)

    for idx, slice_tuple in enumerate(zip(clean_sliced, noisy_sliced)):
      pair = np.array([slice_tuple[0], slice_tuple[1]])
      np.save(os.path.join(f'{PATH}/serialized', '{}_{}'.format(file, idx)), arr=pair)

  end_time = time.time()
  print('Total elapsed time for preprocessing : {}'.format(end_time - start_time))

In [None]:
serialize()

Total elapsed time for preprocessing : 14714.036752700806


In [10]:
# Pre-emphasis filter

pre_emphasis = lambda batch: signal.lfilter([1, -0.95], [1], batch)

In [11]:
de_emphasis = lambda batch: signal.lfilter([1], [1, -0.95], batch)

# Dataloader

In [None]:
# Dataloader from https://github.com/dansuh17/segan-pytorch

class AudioSampleGenerator(data.Dataset):
  """
  Audio sample reader.
  Used alongside with DataLoader class to generate batches.
  see: http://pytorch.org/docs/master/data.html#torch.utils.data.Dataset
  """
  SAMPLE_LENGTH = 8192

  def __init__(self, data_folder_path: str):
    if not os.path.exists(data_folder_path):
      raise FileNotFoundError

    # store full paths - not the actual files.
    # all files cannot be loaded up to memory due to its large size.
    # insted, we read from files upon fetching batches (see __getitem__() implementation)
    self.filepaths = [os.path.join(data_folder_path, filename) for filename in os.listdir(data_folder_path)]
    self.num_data = len(self.filepaths)

  def reference_batch(self, batch_size: int):
    """
    Randomly selects a reference batch from dataset.
    Reference batch is used for calculating statistics for virtual batch normalization operation.
    Args:
        batch_size(int): batch size
    Returns:
        ref_batch: reference batch
    """
    ref_filenames = np.random.choice(self.filepaths, batch_size)
    ref_batch = torch.from_numpy(np.stack([np.load(f) for f in ref_filenames]))
    return ref_batch

  def fixed_test_audio(self, num_test_audio: int):
    """
    Randomly chosen batch for testing generated results.
    Args:
        num_test_audio(int): number of test audio.
            Must be same as batch size of training,
            otherwise it cannot go through the forward step of generator.
    """
    test_filenames = np.random.choice(self.filepaths, num_test_audio)
    # stack the data for all test audios
    test_audios = np.stack([np.load(f) for f in test_filenames])
    test_clean_set = test_audios[:, 0].reshape((num_test_audio, 1, self.SAMPLE_LENGTH))
    test_noisy_set = test_audios[:, 1].reshape((num_test_audio, 1, self.SAMPLE_LENGTH))
    # file names of test samples
    test_basenames = [os.path.basename(fpath) for fpath in test_filenames]
    return test_basenames, test_clean_set, test_noisy_set

  def __getitem__(self, idx):
    # get item for specified index
    pair = np.load(self.filepaths[idx])
    return pair

  def __len__(self):
    return self.num_data

# Virtual Batch Normalization

In [3]:
# Virtual Batch Norm from https://github.com/dansuh17/segan-pytorch

class VirtualBatchNorm1d(Module):
    """
    Module for Virtual Batch Normalization.
    Implementation borrowed and modified from Rafael_Valle's code + help of SimonW from this discussion thread:
    https://discuss.pytorch.org/t/parameter-grad-of-conv-weight-is-none-after-virtual-batch-normalization/9036
    """
    def __init__(self, num_features: int, eps: float=1e-5):
        super().__init__()
        # batch statistics
        self.num_features = num_features
        self.eps = eps  # epsilon
        self.ref_mean = self.register_parameter('ref_mean', None)
        self.ref_mean_sq = self.register_parameter('ref_mean_sq', None)

        # define gamma and beta parameters
        gamma = torch.normal(mean=torch.ones(1, num_features, 1), std=0.02)
        self.gamma = Parameter(gamma.float().to(device))
        self.beta = Parameter(torch.FloatTensor(1, num_features, 1).fill_(0)).to(device)

    def get_stats(self, x):
        """
        Calculates mean and mean square for given batch x.
        Args:
            x: tensor containing batch of activations
        Returns:
            mean: mean tensor over features
            mean_sq: squared mean tensor over features
        """
        mean = x.mean(2, keepdim=True).mean(0, keepdim=True)
        mean_sq = (x ** 2).mean(2, keepdim=True).mean(0, keepdim=True)
        return mean, mean_sq

    def forward(self, x, ref_mean: None, ref_mean_sq: None):
        """
        Forward pass of virtual batch normalization.
        Virtual batch normalization require two forward passes
        for reference batch and train batch, respectively.
        The input parameter is_reference should indicate whether it is a forward pass
        for reference batch or not.
        Args:
            x: input tensor
            is_reference(bool): True if forwarding for reference batch
        Result:
            x: normalized batch tensor
        """
        mean, mean_sq = self.get_stats(x)
        if ref_mean is None or ref_mean_sq is None:
            # reference mode - works just like batch norm
            mean = mean.clone().detach()
            mean_sq = mean_sq.clone().detach()
            out = self._normalize(x, mean, mean_sq)
        else:
            # calculate new mean and mean_sq
            batch_size = x.size(0)
            new_coeff = 1. / (batch_size + 1.)
            old_coeff = 1. - new_coeff
            mean = new_coeff * mean + old_coeff * ref_mean
            mean_sq = new_coeff * mean_sq + old_coeff * ref_mean_sq
            out = self._normalize(x, mean, mean_sq)
        return out, mean, mean_sq

    def _normalize(self, x, mean, mean_sq):
        """
        Normalize tensor x given the statistics.
        Args:
            x: input tensor
            mean: mean over features. it has size [1:num_features:]
            mean_sq: squared means over features.
        Result:
            x: normalized batch tensor
        """
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3  # specific for 1d VBN
        if mean.size(1) != self.num_features:
            raise Exception(
                    'Mean size not equal to number of featuers : given {}, expected {}'
                    .format(mean.size(1), self.num_features))
        if mean_sq.size(1) != self.num_features:
            raise Exception(
                    'Squared mean tensor size not equal to number of features : given {}, expected {}'
                    .format(mean_sq.size(1), self.num_features))

        std = torch.sqrt(self.eps + mean_sq - mean**2)
        x = x - mean
        x = x / std
        x = x * self.gamma
        x = x + self.beta
        return x

    def __repr__(self):
        return ('{name}(num_features={num_features}, eps={eps}'
                .format(name=self.__class__.__name__, **self.__dict__))

# Discriminator

In [6]:
# Discriminator modified to take an 8KHz audio input

class Discriminator(nn.Module):
    """D"""
    def __init__(self, dropout_drop=0.5):
        super().__init__()
        # Define convolution operations.
        # (#input channel, #output channel, kernel_size, stride, padding)
        # in : 16384 x 2
        negative_slope = 0.03
        self.conv1 = nn.Conv1d(in_channels=2, out_channels=32, kernel_size=31, stride=2, padding=15)   # out : 8192 x 32, 4096 x 32
        self.vbn1 = VirtualBatchNorm1d(32)
        self.lrelu1 = nn.LeakyReLU(negative_slope)
        self.conv2 = nn.Conv1d(32, 64, 31, 2, 15)  # 4096 x 64, 2048, 64
        self.vbn2 = VirtualBatchNorm1d(64)
        self.lrelu2 = nn.LeakyReLU(negative_slope)
        # self.conv3 = nn.Conv1d(64, 64, 31, 2, 15)  # 2048 x 64, # Removed to adjust for 8192 samples
        # self.dropout1 = nn.Dropout(dropout_drop)
        # self.vbn3 = VirtualBatchNorm1d(64)
        # self.lrelu3 = nn.LeakyReLU(negative_slope)
        self.conv4 = nn.Conv1d(64, 128, 31, 2, 15)  # 1024 x 128, 1024 x 128
        self.vbn4 = VirtualBatchNorm1d(128)
        self.lrelu4 = nn.LeakyReLU(negative_slope)
        self.conv5 = nn.Conv1d(128, 128, 31, 2, 15)  # 512 x 128
        self.vbn5 = VirtualBatchNorm1d(128)
        self.lrelu5 = nn.LeakyReLU(negative_slope)
        self.conv6 = nn.Conv1d(128, 256, 31, 2, 15)  # 256 x 256
        self.dropout2 = nn.Dropout(dropout_drop)
        self.vbn6 = VirtualBatchNorm1d(256)
        self.lrelu6 = nn.LeakyReLU(negative_slope)
        self.conv7 = nn.Conv1d(256, 256, 31, 2, 15)  # 128 x 256
        self.vbn7 = VirtualBatchNorm1d(256)
        self.lrelu7 = nn.LeakyReLU(negative_slope)
        self.conv8 = nn.Conv1d(256, 512, 31, 2, 15)  # 64 x 512
        self.vbn8 = VirtualBatchNorm1d(512)
        self.lrelu8 = nn.LeakyReLU(negative_slope)
        self.conv9 = nn.Conv1d(512, 512, 31, 2, 15)  # 32 x 512
        self.dropout3 = nn.Dropout(dropout_drop)
        self.vbn9 = VirtualBatchNorm1d(512)
        self.lrelu9 = nn.LeakyReLU(negative_slope)
        self.conv10 = nn.Conv1d(512, 1024, 31, 2, 15)  # 16 x 1024
        self.vbn10 = VirtualBatchNorm1d(1024)
        self.lrelu10 = nn.LeakyReLU(negative_slope)
        self.conv11 = nn.Conv1d(1024, 2048, 31, 2, 15)  # 8 x 1024
        self.vbn11 = VirtualBatchNorm1d(2048)
        self.lrelu11 = nn.LeakyReLU(negative_slope)
        # 1x1 size kernel for dimension and parameter reduction
        self.conv_final = nn.Conv1d(2048, 1, kernel_size=1, stride=1)  # 8 x 1
        self.lrelu_final = nn.LeakyReLU(negative_slope)
        self.fully_connected = nn.Linear(in_features=8, out_features=1)  # 1
        self.sigmoid = nn.Sigmoid()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.xavier_normal_(m.weight.data)

    def forward(self, x, ref_x):
        """
        Forward pass of discriminator.
        Args:
            x: batch
            ref_x: reference batch for virtual batch norm
        """
        # reference pass
        ref_x = self.conv1(ref_x)
        ref_x, mean1, meansq1 = self.vbn1(ref_x, None, None)
        ref_x = self.lrelu1(ref_x)
        ref_x = self.conv2(ref_x)
        ref_x, mean2, meansq2 = self.vbn2(ref_x, None, None)
        ref_x = self.lrelu2(ref_x)
        # ref_x = self.conv3(ref_x)        # Removed to adjust for 8192 samples
        # ref_x = self.dropout1(ref_x)
        # ref_x, mean3, meansq3 = self.vbn3(ref_x, None, None)
        # ref_x = self.lrelu3(ref_x)
        ref_x = self.conv4(ref_x)
        ref_x, mean4, meansq4 = self.vbn4(ref_x, None, None)
        ref_x = self.lrelu4(ref_x)
        ref_x = self.conv5(ref_x)
        ref_x, mean5, meansq5 = self.vbn5(ref_x, None, None)
        ref_x = self.lrelu5(ref_x)
        ref_x = self.conv6(ref_x)
        ref_x = self.dropout2(ref_x)
        ref_x, mean6, meansq6 = self.vbn6(ref_x, None, None)
        ref_x = self.lrelu6(ref_x)
        ref_x = self.conv7(ref_x)
        ref_x, mean7, meansq7 = self.vbn7(ref_x, None, None)
        ref_x = self.lrelu7(ref_x)
        ref_x = self.conv8(ref_x)
        ref_x, mean8, meansq8 = self.vbn8(ref_x, None, None)
        ref_x = self.lrelu8(ref_x)
        ref_x = self.conv9(ref_x)
        ref_x = self.dropout3(ref_x)
        ref_x, mean9, meansq9 = self.vbn9(ref_x, None, None)
        ref_x = self.lrelu9(ref_x)
        ref_x = self.conv10(ref_x)
        ref_x, mean10, meansq10 = self.vbn10(ref_x, None, None)
        ref_x = self.lrelu10(ref_x)
        ref_x = self.conv11(ref_x)
        ref_x, mean11, meansq11 = self.vbn11(ref_x, None, None)
        # further pass no longer needed

        # train pass
        x = self.conv1(x)
        x, _, _ = self.vbn1(x, mean1, meansq1)
        x = self.lrelu1(x)
        x = self.conv2(x)
        x, _, _ = self.vbn2(x, mean2, meansq2)
        x = self.lrelu2(x)
        # x = self.conv3(x)        # Removed to adjust for 8192 samples
        # x = self.dropout1(x)
        # x, _, _ = self.vbn3(x, mean3, meansq3)
        # x = self.lrelu3(x)
        x = self.conv4(x)
        x, _, _ = self.vbn4(x, mean4, meansq4)
        x = self.lrelu4(x)
        x = self.conv5(x)
        x, _, _ = self.vbn5(x, mean5, meansq5)
        x = self.lrelu5(x)
        x = self.conv6(x)
        x = self.dropout2(x)
        x, _, _ = self.vbn6(x, mean6, meansq6)
        x = self.lrelu6(x)
        x = self.conv7(x)
        x, _, _ = self.vbn7(x, mean7, meansq7)
        x = self.lrelu7(x)
        x = self.conv8(x)
        x, _, _ = self.vbn8(x, mean8, meansq8)
        x = self.lrelu8(x)
        x = self.conv9(x)
        x = self.dropout3(x)
        x, _, _ = self.vbn9(x, mean9, meansq9)
        x = self.lrelu9(x)
        x = self.conv10(x)
        x, _, _ = self.vbn10(x, mean10, meansq10)
        x = self.lrelu10(x)
        x = self.conv11(x)
        x, _, _ = self.vbn11(x, mean11, meansq11)
        x = self.lrelu11(x)
        x = self.conv_final(x)
        x = self.lrelu_final(x)
        # reduce down to a scalar value
        x = torch.squeeze(x)
        x = self.fully_connected(x)
        # return self.sigmoid(x)
        return x

# Generator

In [7]:
# Generator modified to take an 8KHz audio input

class Generator(nn.Module):
    """G"""
    def __init__(self):
        super().__init__()
        # size notations = [batch_size x feature_maps x width] (height omitted - 1D convolutions)
        # encoder gets a noisy signal as input
        self.enc1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=32, stride=2, padding=15)   # out : [B x 16 x 8192], [B x 16 x 4096]
        self.enc1_nl = nn.PReLU()  # non-linear transformation after encoder layer 1
        # self.enc2 = nn.Conv1d(16, 32, 32, 2, 15)  # [B x 32 x 4096] # Removed
        # self.enc2_nl = nn.PReLU()
        self.enc3 = nn.Conv1d(16, 32, 32, 2, 15)  # [B x 32 x 2048], [B x 32 x 2048]
        self.enc3_nl = nn.PReLU()
        self.enc4 = nn.Conv1d(32, 64, 32, 2, 15)  # [B x 64 x 1024] , cont as normal
        self.enc4_nl = nn.PReLU()
        self.enc5 = nn.Conv1d(64, 64, 32, 2, 15)  # [B x 64 x 512]
        self.enc5_nl = nn.PReLU()
        self.enc6 = nn.Conv1d(64, 128, 32, 2, 15)  # [B x 128 x 256]
        self.enc6_nl = nn.PReLU()
        self.enc7 = nn.Conv1d(128, 128, 32, 2, 15)  # [B x 128 x 128]
        self.enc7_nl = nn.PReLU()
        self.enc8 = nn.Conv1d(128, 256, 32, 2, 15)  # [B x 256 x 64]
        self.enc8_nl = nn.PReLU()
        self.enc9 = nn.Conv1d(256, 256, 32, 2, 15)  # [B x 256 x 32]
        self.enc9_nl = nn.PReLU()
        self.enc10 = nn.Conv1d(256, 512, 32, 2, 15)  # [B x 512 x 16]
        self.enc10_nl = nn.PReLU()
        self.enc11 = nn.Conv1d(512, 1024, 32, 2, 15)  # output : [B x 1024 x 8]
        self.enc11_nl = nn.PReLU()

        # decoder generates an enhanced signal
        # each decoder output are concatenated with homolgous encoder output,
        # so the feature map sizes are doubled
        self.dec10 = nn.ConvTranspose1d(in_channels=2048, out_channels=512, kernel_size=32, stride=2, padding=15)
        self.dec10_nl = nn.PReLU()  # out : [B x 512 x 16] -> (concat) [B x 1024 x 16]
        self.dec9 = nn.ConvTranspose1d(1024, 256, 32, 2, 15)  # [B x 256 x 32]
        self.dec9_nl = nn.PReLU()
        self.dec8 = nn.ConvTranspose1d(512, 256, 32, 2, 15)  # [B x 256 x 64]
        self.dec8_nl = nn.PReLU()
        self.dec7 = nn.ConvTranspose1d(512, 128, 32, 2, 15)  # [B x 128 x 128]
        self.dec7_nl = nn.PReLU()
        self.dec6 = nn.ConvTranspose1d(256, 128, 32, 2, 15)  # [B x 128 x 256]
        self.dec6_nl = nn.PReLU()
        self.dec5 = nn.ConvTranspose1d(256, 64, 32, 2, 15)  # [B x 64 x 512]
        self.dec5_nl = nn.PReLU()
        self.dec4 = nn.ConvTranspose1d(128, 64, 32, 2, 15)  # [B x 64 x 1024]
        self.dec4_nl = nn.PReLU()
        self.dec3 = nn.ConvTranspose1d(128, 32, 32, 2, 15)  # [B x 32 x 2048]
        self.dec3_nl = nn.PReLU()
        # self.dec2 = nn.ConvTranspose1d(64, 32, 32, 2, 15)  # [B x 32 x 4096] # Same layer rmoved from encoder as well
        # self.dec2_nl = nn.PReLU()
        self.dec1 = nn.ConvTranspose1d(64, 16, 32, 2, 15)  # [B x 16 x 8192], [B x 16 x 4096]
        self.dec1_nl = nn.PReLU()
        self.dec_final = nn.ConvTranspose1d(32, 1, 32, 2, 15)  # [B x 1 x 16384], [B x 1 x 8192]
        self.dec_tanh = nn.Tanh()

        # initialize weights
        self.init_weights()

    def init_weights(self):
        """
        Initialize weights for convolution layers using Xavier initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
                nn.init.xavier_normal_(m.weight.data)

    def forward(self, x, z):
        """
        Forward pass of generator.
        Args:
            x: input batch (signal)
            z: latent vector
        """
        ### encoding step
        e1 = self.enc1(x)
        # e2 = self.enc2(self.enc1_nl(e1))
        e3 = self.enc3(self.enc1_nl(e1)) # e2 -> e1, enc2 -> enc1
        e4 = self.enc4(self.enc3_nl(e3))
        e5 = self.enc5(self.enc4_nl(e4))
        e6 = self.enc6(self.enc5_nl(e5))
        e7 = self.enc7(self.enc6_nl(e6))
        e8 = self.enc8(self.enc7_nl(e7))
        e9 = self.enc9(self.enc8_nl(e8))
        e10 = self.enc10(self.enc9_nl(e9))
        e11 = self.enc11(self.enc10_nl(e10))
        # c = compressed feature, the 'thought vector'
        c = self.enc11_nl(e11)

        # concatenate the thought vector with latent variable
        encoded = torch.cat((c, z), dim=1)

        ### decoding step
        d10 = self.dec10(encoded)
        # dx_c : concatenated with skip-connected layer's output & passed nonlinear layer
        d10_c = self.dec10_nl(torch.cat((d10, e10), dim=1))
        d9 = self.dec9(d10_c)
        d9_c = self.dec9_nl(torch.cat((d9, e9), dim=1))
        d8 = self.dec8(d9_c)
        d8_c = self.dec8_nl(torch.cat((d8, e8), dim=1))
        d7 = self.dec7(d8_c)
        d7_c = self.dec7_nl(torch.cat((d7, e7), dim=1))
        d6 = self.dec6(d7_c)
        d6_c = self.dec6_nl(torch.cat((d6, e6), dim=1))
        d5 = self.dec5(d6_c)
        d5_c = self.dec5_nl(torch.cat((d5, e5), dim=1))
        d4 = self.dec4(d5_c)
        d4_c = self.dec4_nl(torch.cat((d4, e4), dim=1))
        d3 = self.dec3(d4_c)
        d3_c = self.dec3_nl(torch.cat((d3, e3), dim=1))
        # d2 = self.dec2(d3_c)
        # d2_c = self.dec2_nl(torch.cat((d2, e2), dim=1))
        d1 = self.dec1(d3_c) # d4_c -> d3_c
        d1_c = self.dec1_nl(torch.cat((d1, e1), dim=1))
        out = self.dec_tanh(self.dec_final(d1_c))
        return out

In [None]:
def split_pair_to_vars(sample_batch_pair):
    """
    Splits the generated batch data and creates combination of pairs.
    Input argument sample_batch_pair consists of a batch_size number of
    [clean_signal, noisy_signal] pairs.
    This function creates three pytorch Variables - a clean_signal, noisy_signal pair,
    clean signal only, and noisy signal only.
    It goes through preemphasis preprocessing before converted into variable.
    Args:
        sample_batch_pair(torch.Tensor): batch of [clean_signal, noisy_signal] pairs
    Returns:
        batch_pairs_var(Variable): batch of pairs containing clean signal and noisy signal
        clean_batch_var(Variable): clean signal batch
        noisy_batch_var(Varialbe): noisy signal batch
    """
    # pre-emphasis
    sample_batch_pair = pre_emphasis(sample_batch_pair.numpy())

    batch_pairs_var = torch.from_numpy(sample_batch_pair).type(torch.FloatTensor).to(device)  # [40 x 2 x 16384]
    clean_batch = np.stack([pair[0].reshape(1, -1) for pair in sample_batch_pair])
    clean_batch_var = torch.from_numpy(clean_batch).type(torch.FloatTensor).to(device)
    noisy_batch = np.stack([pair[1].reshape(1, -1) for pair in sample_batch_pair])
    noisy_batch_var = torch.from_numpy(noisy_batch).type(torch.FloatTensor).to(device)
    return batch_pairs_var, clean_batch_var, noisy_batch_var


def sample_latent():
    """
    Sample a latent vector - normal distribution
    Returns:
        z(torch.Tensor): random latent vector
    """
    return torch.randn((batch_size, 1024, 8)).to(device)

# Training

In [9]:
batch_size = 128 #(Trained at 256 batch_size at first, Colab failed and after epoch 7 had to use 128)
d_learning_rate = 0.0001
g_learning_rate = 0.0001
g_lambda = 100  # regularizer for generator
# use_devices = [0, 1, 2, 3]
sample_rate = 8000
num_epochs = 86 # Trained for only 13, before Colab revoked GPU access

In [None]:
sample_generator = AudioSampleGenerator(f'{PATH}serialized/')
random_data_loader = DataLoader(
        dataset=sample_generator,
        batch_size=batch_size,  # specified batch size here
        shuffle=True,
        drop_last=True,  # drop the last batch that cannot be divided by batch_size
        pin_memory=False)
print('DataLoader created')

DataLoader created


In [None]:
ref_batch_pairs = sample_generator.reference_batch(batch_size)
ref_batch_var, ref_clean_var, ref_noisy_var = split_pair_to_vars(ref_batch_pairs)

# optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=g_learning_rate, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=d_learning_rate, betas=(0.5, 0.999))

In [None]:
print('Starting Training...')
total_steps = 1
for epoch in range(13, num_epochs):
    for i, sample_batch_pairs in enumerate(random_data_loader):
        # using the sample batch pair, split into
        # batch of combined pairs, clean signals, and noisy signals
        batch_pairs_var, clean_batch_var, noisy_batch_var = split_pair_to_vars(sample_batch_pairs)

        # latent vector - normal distribution
        z = sample_latent()

        ##### TRAIN D #####
        # TRAIN D to recognize clean audio as clean
        # training batch pass
        outputs = discriminator(batch_pairs_var, ref_batch_var)  # out: [n_batch x 1]
        clean_loss = torch.mean((outputs - 1.0) ** 2)  # L2 loss - we want them all to be 1

        # TRAIN D to recognize generated audio as noisy
        generated_outputs = generator(noisy_batch_var, z)
        disc_in_pair = torch.cat((generated_outputs.detach(), noisy_batch_var), dim=1)
        outputs = discriminator(disc_in_pair, ref_batch_var)
        noisy_loss = torch.mean(outputs ** 2)  # L2 loss - we want them all to be 0
        d_loss = 0.5 * (clean_loss + noisy_loss)

        # back-propagate and update
        discriminator.zero_grad()
        d_loss.backward()
        d_optimizer.step()  # update parameters

        ##### TRAIN G #####
        # TRAIN G so that D recognizes G(z) as real
        z = sample_latent()
        generated_outputs = generator(noisy_batch_var, z)
        gen_noise_pair = torch.cat((generated_outputs, noisy_batch_var), dim=1)
        outputs = discriminator(gen_noise_pair, ref_batch_var)

        g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
        # L1 loss between generated output and clean sample
        l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(clean_batch_var)))
        g_cond_loss = g_lambda * torch.mean(l1_dist)  # conditional loss
        g_loss = g_loss_ + g_cond_loss

        # back-propagate and update
        generator.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # print message and store logs per 10 steps
        if (i + 1) % 20 == 0:
            print(
                'Epoch {}\t'
                'Step {}\t'
                'd_loss {:.5f}\t'
                'd_clean_loss {:.5f}\t'
                'd_noisy_loss {:.5f}\t'
                'g_loss {:.5f}\t'
                'g_loss_cond {:.5f}'
                .format(epoch + 1, i + 1, d_loss.item(), clean_loss.item(),
                        noisy_loss.item(), g_loss.item(), g_cond_loss.item()))

        total_steps += 1

    # save various states
    state_path = os.path.join(f'{PATH}/checkpoints', 'state-{}.pkl'.format(epoch + 1))
    state = {
        'discriminator': discriminator.state_dict(),
        'generator': generator.state_dict(),
        'g_optimizer': g_optimizer.state_dict(),
        'd_optimizer': d_optimizer.state_dict(),
    }
    torch.save(state, state_path)

    ### Can be loaded using, for example:
    # states = torch.load(state_path)
    # discriminator.load_state_dict(state['discriminator'])

print('Finished Training!')

Starting Training...
Epoch 14	Step 20	d_loss 0.28101	d_clean_loss 0.33827	d_noisy_loss 0.22375	g_loss 1.05236	g_loss_cond 0.88966
Epoch 14	Step 40	d_loss 0.26211	d_clean_loss 0.23245	d_noisy_loss 0.29176	g_loss 1.03236	g_loss_cond 0.91211
Epoch 14	Step 60	d_loss 0.25594	d_clean_loss 0.23028	d_noisy_loss 0.28161	g_loss 1.02954	g_loss_cond 0.92669
Epoch 14	Step 80	d_loss 0.24754	d_clean_loss 0.25868	d_noisy_loss 0.23641	g_loss 1.07919	g_loss_cond 0.92367
Epoch 14	Step 100	d_loss 0.25702	d_clean_loss 0.24930	d_noisy_loss 0.26473	g_loss 0.97049	g_loss_cond 0.84467
Epoch 14	Step 120	d_loss 0.24907	d_clean_loss 0.22422	d_noisy_loss 0.27393	g_loss 1.01501	g_loss_cond 0.90509
Epoch 14	Step 140	d_loss 0.24986	d_clean_loss 0.24379	d_noisy_loss 0.25593	g_loss 1.03557	g_loss_cond 0.91380


# Inference

In [26]:
# Inference function

def test(filename):

  # Initialise generator

  generator = nn.DataParallel(Generator())
  state = torch.load(f'{PATH}/checkpoints/state-13.pkl', map_location=device)
  generator.load_state_dict(state['generator'])
  generator.to(device)

  # Read and slice audio input
  noisy_slices = slice_signal(filename, 2**13, 1, 8000)
  enhanced_speech = []
  for noisy_slice in noisy_slices:
    noisy_slice = noisy_slice.reshape(1, 1, 8192)
    generator.eval()
    z = nn.init.normal(torch.Tensor(1, 1024, 8))
    noisy_slice = torch.from_numpy(pre_emphasis(noisy_slice)).type(torch.FloatTensor)
    z.to(device)
    noisy_slice.to(device)
    generated_speech = generator(noisy_slice, z).data.cpu().numpy()
    generated_speech = de_emphasis(generated_speech)
    generated_speech = generated_speech.reshape(-1)
    enhanced_speech.append(generated_speech)

  enhanced_speech = np.array(enhanced_speech).reshape(1, -1)
  name = filename.split('/')[-1]
  filename = f'{PATH}/output/enhanced_{name}'
  librosa.output.write_wav(filename, enhanced_speech.T, sr = 8000)

In [27]:
test(f'{PATH}/noisy_trainset_28spk_wav/p226_007.wav')

  # This is added back by InteractiveShellApp.init_path()


In [1]:
!pip list

Package                       Version        
----------------------------- ---------------
absl-py                       0.9.0          
alabaster                     0.7.12         
albumentations                0.1.12         
altair                        4.1.0          
argon2-cffi                   20.1.0         
asgiref                       3.2.10         
astor                         0.8.1          
astropy                       4.0.1.post1    
astunparse                    1.6.3          
atari-py                      0.2.6          
atomicwrites                  1.4.0          
attrs                         19.3.0         
audioread                     2.1.8          
autograd                      1.3            
Babel                         2.8.0          
backcall                      0.2.0          
beautifulsoup4                4.6.3          
bleach                        3.1.5          
blis                          0.4.1          
bokeh                         2.1.

In [None]:
torch 1.6.0+cu101 
numpy 1.18.5
scipy 1.4.1
SoundFile 0.10.3.post1 
librosa 0.63