In [6]:
import os

import librosa
from torch.utils import data


class Dataset(data.Dataset):
    def __init__(self,
                 noisy_dataset,
                 limit,
                 offset,
                 sr,
                 ):
        """
        Args:
            noisy_dataset (str): noisy dir (wav format files) or noisy filenames list
        """
        noisy_dataset = os.path.abspath(os.path.expanduser(noisy_dataset))

        if os.path.isfile(noisy_dataset):
            noisy_wav_files = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(noisy_dataset)), "r")]
            if offset:
                noisy_wav_files = noisy_wav_files[offset:]
            if limit:
                noisy_wav_files = noisy_wav_files[:limit]
        elif os.path.isdir(noisy_dataset):
            noisy_wav_files = librosa.util.find_files(noisy_dataset, ext="wav", limit=limit, offset=offset)
        else:
            raise FileNotFoundError(f"Please Check {noisy_dataset}")

        print(f"Number of noisy files in the dir {noisy_dataset}: {len(noisy_wav_files)}")

        self.length = len(noisy_wav_files)
        self.noisy_wav_files = noisy_wav_files
        self.sr = sr

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        noisy_path = self.noisy_wav_files[item]
        name = os.path.splitext(os.path.basename(noisy_path))[0]
        noisy, _ = librosa.load(noisy_path, sr=self.sr)
        return noisy, name

In [7]:
import os

import librosa
from torch.utils import data


class Dataset(data.Dataset):
    def __init__(
            self,
            dataset_list,
            limit,
            offset,
            sr,
            n_fft,
            hop_length,
            train
    ):
        """
        dataset_list(*.txt):
            <noisy_path> <clean_path>\n
        e.g:
            noisy_1.wav clean_1.wav
            noisy_2.wav clean_2.wav
            ...
            noisy_n.wav clean_n.wav
        """
        super(Dataset, self).__init__()
        self.sr = sr
        self.train = train

        dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset_list)), "r")]
        dataset_list = dataset_list[offset:]
        if limit:
            dataset_list = dataset_list[:limit]

        self.dataset_list = dataset_list
        self.length = len(self.dataset_list)
        self.n_fft = n_fft
        self.hop_length = hop_length

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        noisy_path, clean_path = self.dataset_list[item].split(" ")
        name = os.path.splitext(os.path.basename(noisy_path))[0]
        noisy, _ = librosa.load(os.path.abspath(os.path.expanduser(noisy_path)), sr=self.sr)
        clean, _ = librosa.load(os.path.abspath(os.path.expanduser(clean_path)), sr=self.sr)

        if self.train:
            noisy_mag, _ = librosa.magphase(librosa.stft(noisy, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft))
            clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft))
            return noisy_mag, clean_mag, noisy_mag.shape[-1], name
        else:
            return noisy, clean, name

In [8]:
import os

import librosa
import numpy as np
import soundfile as sf
from torch.utils.data import Dataset


class WavDataset(Dataset):
    """
    Define train dataset
    """

    def __init__(self,
                 mixture_dataset,
                 clean_dataset,
                 limit=None,
                 offset=0,
                 ):
        """
        Construct train dataset
        Args:
            mixture_dataset (str): mixture dir (wav format files)
            clean_dataset (str): clean dir (wav format files)
            limit (int): the limit of the dataset
            offset (int): the offset of the dataset
        """
        assert os.path.exists(mixture_dataset) and os.path.exists(clean_dataset)

        print("Search datasets...")
        mixture_wav_files = librosa.util.find_files(mixture_dataset, ext="wav", limit=limit, offset=offset)
        clean_wav_files = librosa.util.find_files(clean_dataset, ext="wav", limit=limit, offset=offset)

        assert len(mixture_wav_files) == len(clean_wav_files)
        print(f"\t Original length: {len(mixture_wav_files)}")

        self.length = len(mixture_wav_files)
        self.mixture_wav_files = mixture_wav_files
        self.clean_wav_files = clean_wav_files

        print(f"\t Offset: {offset}")
        print(f"\t Limit: {limit}")
        print(f"\t Final length: {self.length}")

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        mixture_path = self.mixture_wav_files[item]
        clean_path = self.clean_wav_files[item]
        name = os.path.splitext(os.path.basename(clean_path))[0]

        mixture, sr = sf.read(mixture_path, dtype="float32")
        clean, sr = sf.read(clean_path, dtype="float32")
        assert sr == 16000
        assert mixture.shape == clean.shape

        n_frames = (len(mixture) - 320) // 160 + 1

        return mixture, clean, n_frames, name

In [9]:
import os

import librosa
import numpy as np
import soundfile as sf
from torch.utils.data import Dataset


class WavDataset(Dataset):
    """
    Define train dataset
    """

    def __init__(self,
                 mixture_dataset,
                 clean_dataset,
                 limit=None,
                 offset=0,
                 ):
        """
        Construct train dataset
        Args:
            mixture_dataset (str): mixture dir (wav format files)
            clean_dataset (str): clean dir (wav format files)
            limit (int): the limit of the dataset
            offset (int): the offset of the dataset
        """
        mixture_dataset = os.path.abspath(os.path.expanduser(mixture_dataset))
        clean_dataset = os.path.abspath(os.path.expanduser(clean_dataset))
        assert os.path.exists(mixture_dataset) and os.path.exists(clean_dataset)

        print("Search datasets...")
        mixture_wav_files = librosa.util.find_files(mixture_dataset, ext="wav", limit=limit, offset=offset)
        clean_wav_files = librosa.util.find_files(clean_dataset, ext="wav", limit=limit, offset=offset)

        assert len(mixture_wav_files) == len(clean_wav_files)
        print(f"\t Original length: {len(mixture_wav_files)}")

        self.length = len(mixture_wav_files)
        self.mixture_wav_files = mixture_wav_files
        self.clean_wav_files = clean_wav_files

        print(f"\t Offset: {offset}")
        print(f"\t Limit: {limit}")
        print(f"\t Final length: {self.length}")

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        mixture_path = self.mixture_wav_files[item]
        clean_path = self.clean_wav_files[item]
        name = os.path.splitext(os.path.basename(clean_path))[0]

        mixture, sr = sf.read(mixture_path, dtype="float32")
        clean, sr = sf.read(clean_path, dtype="float32")
        assert sr == 16000
        assert mixture.shape == clean.shape

        n_frames = (len(mixture) - 320) // 160 + 1

        return mixture, clean, n_frames, name

In [10]:
import torch
import torch.nn as nn


class CausalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 2),
            stride=(2, 1),
            padding=(0, 1)
        )
        self.norm = nn.BatchNorm2d(num_features=out_channels)
        self.activation = nn.ELU()

    def forward(self, x):
        """
        2D Causal convolution.
        Args:
            x: [B, C, F, T]
        Returns:
            [B, C, F, T]
        """
        x = self.conv(x)
        x = x[:, :, :, :-1]  # chomp size
        x = self.norm(x)
        x = self.activation(x)
        return x


class CausalTransConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)):
        super().__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 2),
            stride=(2, 1),
            output_padding=output_padding
        )
        self.norm = nn.BatchNorm2d(num_features=out_channels)
        if is_last:
            self.activation = nn.ReLU()
        else:
            self.activation = nn.ELU()

    def forward(self, x):
        """
        2D Causal convolution.
        Args:
            x: [B, C, F, T]
        Returns:
            [B, C, F, T]
        """
        x = self.conv(x)
        x = x[:, :, :, :-1]  # chomp size
        x = self.norm(x)
        x = self.activation(x)
        return x


class CRN(nn.Module):
    """
    Input: [batch size, channels=1, T, n_fft]
    Output: [batch size, T, n_fft]
    """

    def __init__(self):
        super(CRN, self).__init__()
        # Encoder
        self.conv_block_1 = CausalConvBlock(1, 16)
        self.conv_block_2 = CausalConvBlock(16, 32)
        self.conv_block_3 = CausalConvBlock(32, 64)
        self.conv_block_4 = CausalConvBlock(64, 128)
        self.conv_block_5 = CausalConvBlock(128, 256)

        # LSTM
        self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True)

        self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128)
        self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64)
        self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32)
        self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0))
        self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True)

    def forward(self, x):
        self.lstm_layer.flatten_parameters()

        e_1 = self.conv_block_1(x)
        e_2 = self.conv_block_2(e_1)
        e_3 = self.conv_block_3(e_2)
        e_4 = self.conv_block_4(e_3)
        e_5 = self.conv_block_5(e_4)  # [2, 256, 4, 200]

        batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape

        # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024]
        lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1)
        lstm_out, _ = self.lstm_layer(lstm_in)  # [2, 200, 1024]
        lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size)  # [2, 256, 4, 200]

        d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1))
        d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1))
        d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1))
        d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1))
        d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1))

        return d_5


if __name__ == '__main__':
    layer = CRN()
    a = torch.rand(2, 1, 161, 200)
    print(layer(a).shape)

torch.Size([2, 1, 161, 200])


In [12]:
from torch.utils.tensorboard import SummaryWriter


def writer(logs_dir):
    return SummaryWriter(log_dir=logs_dir, max_queue=5, flush_secs=30)

ImportError: TensorBoard logging requires TensorBoard version 1.15 or above

In [13]:
import tensorflow as tf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [16]:
tf.__version__

import tensorflow as tf
print (tf.VERSION)


1.14.0
