In [1]:
import os

import librosa
from torch.utils import data


class Dataset(data.Dataset):
    def __init__(self,
                 noisy_dataset,
                 limit,
                 offset,
                 sr,
                 ):
        """
        Args:
            noisy_dataset (str): noisy dir (wav format files) or noisy filenames list
        """
        noisy_dataset = os.path.abspath(os.path.expanduser(noisy_dataset))

        if os.path.isfile(noisy_dataset):
            noisy_wav_files = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(noisy_dataset)), "r")]
            if offset:
                noisy_wav_files = noisy_wav_files[offset:]
            if limit:
                noisy_wav_files = noisy_wav_files[:limit]
        elif os.path.isdir(noisy_dataset):
            noisy_wav_files = librosa.util.find_files(noisy_dataset, ext="wav", limit=limit, offset=offset)
        else:
            raise FileNotFoundError(f"Please Check {noisy_dataset}")

        print(f"Number of noisy files in the dir {noisy_dataset}: {len(noisy_wav_files)}")

        self.length = len(noisy_wav_files)
        self.noisy_wav_files = noisy_wav_files
        self.sr = sr

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        noisy_path = self.noisy_wav_files[item]
        name = os.path.splitext(os.path.basename(noisy_path))[0]
        noisy, _ = librosa.load(noisy_path, sr=self.sr)
        return noisy, name

In [2]:
import os

import librosa
from torch.utils import data


class Dataset(data.Dataset):
    def __init__(
            self,
            dataset_list,
            limit,
            offset,
            sr,
            n_fft,
            hop_length,
            train
    ):
        """
        dataset_list(*.txt):
            <noisy_path> <clean_path>\n
        e.g:
            noisy_1.wav clean_1.wav
            noisy_2.wav clean_2.wav
            ...
            noisy_n.wav clean_n.wav
        """
        super(Dataset, self).__init__()
        self.sr = sr
        self.train = train

        dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset_list)), "r")]
        dataset_list = dataset_list[offset:]
        if limit:
            dataset_list = dataset_list[:limit]

        self.dataset_list = dataset_list
        self.length = len(self.dataset_list)
        self.n_fft = n_fft
        self.hop_length = hop_length

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        noisy_path, clean_path = self.dataset_list[item].split(" ")
        name = os.path.splitext(os.path.basename(noisy_path))[0]
        noisy, _ = librosa.load(os.path.abspath(os.path.expanduser(noisy_path)), sr=self.sr)
        clean, _ = librosa.load(os.path.abspath(os.path.expanduser(clean_path)), sr=self.sr)

        if self.train:
            noisy_mag, _ = librosa.magphase(librosa.stft(noisy, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft))
            clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.n_fft))
            return noisy_mag, clean_mag, noisy_mag.shape[-1], name
        else:
            return noisy, clean, name

In [3]:
import os

import librosa
import numpy as np
import soundfile as sf
from torch.utils.data import Dataset


class WavDataset(Dataset):
    """
    Define train dataset
    """

    def __init__(self,
                 mixture_dataset,
                 clean_dataset,
                 limit=None,
                 offset=0,
                 ):
        """
        Construct train dataset
        Args:
            mixture_dataset (str): mixture dir (wav format files)
            clean_dataset (str): clean dir (wav format files)
            limit (int): the limit of the dataset
            offset (int): the offset of the dataset
        """
        assert os.path.exists(mixture_dataset) and os.path.exists(clean_dataset)

        print("Search datasets...")
        mixture_wav_files = librosa.util.find_files(mixture_dataset, ext="wav", limit=limit, offset=offset)
        clean_wav_files = librosa.util.find_files(clean_dataset, ext="wav", limit=limit, offset=offset)

        assert len(mixture_wav_files) == len(clean_wav_files)
        print(f"\t Original length: {len(mixture_wav_files)}")

        self.length = len(mixture_wav_files)
        self.mixture_wav_files = mixture_wav_files
        self.clean_wav_files = clean_wav_files

        print(f"\t Offset: {offset}")
        print(f"\t Limit: {limit}")
        print(f"\t Final length: {self.length}")

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        mixture_path = self.mixture_wav_files[item]
        clean_path = self.clean_wav_files[item]
        name = os.path.splitext(os.path.basename(clean_path))[0]

        mixture, sr = sf.read(mixture_path, dtype="float32")
        clean, sr = sf.read(clean_path, dtype="float32")
        assert sr == 16000
        assert mixture.shape == clean.shape

        n_frames = (len(mixture) - 320) // 160 + 1

        return mixture, clean, n_frames, name

In [4]:
import os

import librosa
import numpy as np
import soundfile as sf
from torch.utils.data import Dataset


class WavDataset(Dataset):
    """
    Define train dataset
    """

    def __init__(self,
                 mixture_dataset,
                 clean_dataset,
                 limit=None,
                 offset=0,
                 ):
        """
        Construct train dataset
        Args:
            mixture_dataset (str): mixture dir (wav format files)
            clean_dataset (str): clean dir (wav format files)
            limit (int): the limit of the dataset
            offset (int): the offset of the dataset
        """
        mixture_dataset = os.path.abspath(os.path.expanduser(mixture_dataset))
        clean_dataset = os.path.abspath(os.path.expanduser(clean_dataset))
        assert os.path.exists(mixture_dataset) and os.path.exists(clean_dataset)

        print("Search datasets...")
        mixture_wav_files = librosa.util.find_files(mixture_dataset, ext="wav", limit=limit, offset=offset)
        clean_wav_files = librosa.util.find_files(clean_dataset, ext="wav", limit=limit, offset=offset)

        assert len(mixture_wav_files) == len(clean_wav_files)
        print(f"\t Original length: {len(mixture_wav_files)}")

        self.length = len(mixture_wav_files)
        self.mixture_wav_files = mixture_wav_files
        self.clean_wav_files = clean_wav_files

        print(f"\t Offset: {offset}")
        print(f"\t Limit: {limit}")
        print(f"\t Final length: {self.length}")

    def __len__(self):
        return self.length

    def __getitem__(self, item):
        mixture_path = self.mixture_wav_files[item]
        clean_path = self.clean_wav_files[item]
        name = os.path.splitext(os.path.basename(clean_path))[0]

        mixture, sr = sf.read(mixture_path, dtype="float32")
        clean, sr = sf.read(clean_path, dtype="float32")
        assert sr == 16000
        assert mixture.shape == clean.shape

        n_frames = (len(mixture) - 320) // 160 + 1

        return mixture, clean, n_frames, name

In [5]:
import torch
import torch.nn as nn


class CausalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 2),
            stride=(2, 1),
            padding=(0, 1)
        )
        self.norm = nn.BatchNorm2d(num_features=out_channels)
        self.activation = nn.ELU()

    def forward(self, x):
        """
        2D Causal convolution.
        Args:
            x: [B, C, F, T]
        Returns:
            [B, C, F, T]
        """
        x = self.conv(x)
        x = x[:, :, :, :-1]  # chomp size
        x = self.norm(x)
        x = self.activation(x)
        return x


class CausalTransConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)):
        super().__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 2),
            stride=(2, 1),
            output_padding=output_padding
        )
        self.norm = nn.BatchNorm2d(num_features=out_channels)
        if is_last:
            self.activation = nn.ReLU()
        else:
            self.activation = nn.ELU()

    def forward(self, x):
        """
        2D Causal convolution.
        Args:
            x: [B, C, F, T]
        Returns:
            [B, C, F, T]
        """
        x = self.conv(x)
        x = x[:, :, :, :-1]  # chomp size
        x = self.norm(x)
        x = self.activation(x)
        return x


class CRN(nn.Module):
    """
    Input: [batch size, channels=1, T, n_fft]
    Output: [batch size, T, n_fft]
    """

    def __init__(self):
        super(CRN, self).__init__()
        # Encoder
        self.conv_block_1 = CausalConvBlock(1, 16)
        self.conv_block_2 = CausalConvBlock(16, 32)
        self.conv_block_3 = CausalConvBlock(32, 64)
        self.conv_block_4 = CausalConvBlock(64, 128)
        self.conv_block_5 = CausalConvBlock(128, 256)

        # LSTM
        self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True)

        self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128)
        self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64)
        self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32)
        self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0))
        self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True)

    def forward(self, x):
        self.lstm_layer.flatten_parameters()

        e_1 = self.conv_block_1(x)
        e_2 = self.conv_block_2(e_1)
        e_3 = self.conv_block_3(e_2)
        e_4 = self.conv_block_4(e_3)
        e_5 = self.conv_block_5(e_4)  # [2, 256, 4, 200]

        batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape

        # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024]
        lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1)
        lstm_out, _ = self.lstm_layer(lstm_in)  # [2, 200, 1024]
        lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size)  # [2, 256, 4, 200]

        d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1))
        d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1))
        d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1))
        d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1))
        d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1))

        return d_5


if __name__ == '__main__':
    layer = CRN()
    a = torch.rand(2, 1, 161, 200)
    print(layer(a).shape)

torch.Size([2, 1, 161, 200])


In [6]:
from torch.utils.tensorboard import SummaryWriter


def writer(logs_dir):
    return SummaryWriter(log_dir=logs_dir, max_queue=5, flush_secs=30)

In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class CRNN(nn.Module):
    """
    Input: [batch size, channels=1, T, n_fft]
    Output: [batch size, T, n_fft]
    """
    def __init__(self):
        super(CRNN, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(1, 3), stride=(1, 2))
        self.bn1 = nn.BatchNorm2d(num_features=16)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, 3), stride=(1, 2))
        self.bn2 = nn.BatchNorm2d(num_features=32)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, 3), stride=(1, 2))
        self.bn3 = nn.BatchNorm2d(num_features=64)
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, 3), stride=(1, 2))
        self.bn4 = nn.BatchNorm2d(num_features=128)
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1, 3), stride=(1, 2))
        self.bn5 = nn.BatchNorm2d(num_features=256)

        # LSTM
        self.LSTM1 = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True)

        # Decoder
        self.convT1 = nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=(1, 3), stride=(1, 2))
        self.bnT1 = nn.BatchNorm2d(num_features=128)
        self.convT2 = nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=(1, 3), stride=(1, 2))
        self.bnT2 = nn.BatchNorm2d(num_features=64)
        self.convT3 = nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(1, 3), stride=(1, 2))
        self.bnT3 = nn.BatchNorm2d(num_features=32)
        # output_padding为1，不然算出来是79
        self.convT4 = nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(1, 3), stride=(1, 2), output_padding=(0, 1))
        self.bnT4 = nn.BatchNorm2d(num_features=16)
        self.convT5 = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(1, 3), stride=(1, 2))
        self.bnT5 = nn.BatchNorm2d(num_features=1)

    def forward(self, x):
        # conv
        # (B, in_c, T, F)
        x.unsqueeze_(1)
        x1 = F.elu(self.bn1(self.conv1(x)))
        x2 = F.elu(self.bn2(self.conv2(x1)))
        x3 = F.elu(self.bn3(self.conv3(x2)))
        x4 = F.elu(self.bn4(self.conv4(x3)))
        x5 = F.elu(self.bn5(self.conv5(x4)))
        # reshape
        out5 = x5.permute(0, 2, 1, 3)
        out5 = out5.reshape(out5.size()[0], out5.size()[1], -1)
        # lstm

        lstm, (hn, cn) = self.LSTM1(out5)
        # reshape
        output = lstm.reshape(lstm.size()[0], lstm.size()[1], 256, -1)
        output = output.permute(0, 2, 1, 3)
        # ConvTrans
        res = torch.cat((output, x5), 1)
        res1 = F.elu(self.bnT1(self.convT1(res)))
        res1 = torch.cat((res1, x4), 1)
        res2 = F.elu(self.bnT2(self.convT2(res1)))
        res2 = torch.cat((res2, x3), 1)
        res3 = F.elu(self.bnT3(self.convT3(res2)))
        res3 = torch.cat((res3, x2), 1)
        res4 = F.elu(self.bnT4(self.convT4(res3)))
        res4 = torch.cat((res4, x1), 1)
        # (B, o_c, T. F)
        res5 = F.relu(self.bnT5(self.convT5(res4)))
        return res5.squeeze()


In [8]:
import torch
from torch.nn.utils.rnn import pad_sequence


def mse_loss_for_variable_length_data():
    def loss_function(target, ipt, n_frames_list, device):
        """
        Calculate the MSE loss for variable length dataset.
        ipt: [B, F, T]
        target: [B, F, T]
        """
        if target.shape[0] == 1:
            return torch.nn.functional.mse_loss(target, ipt)

        E = 1e-8
        with torch.no_grad():
            masks = []
            for n_frames in n_frames_list:
                masks.append(torch.ones(n_frames, target.size(1), dtype=torch.float32))  # the shape is (T_real, F)

            binary_mask = pad_sequence(masks, batch_first=True).to(device).permute(0, 2, 1)  # ([T1, F], [T2, F]) => [B, T, F] => [B, F, T]

        masked_ipt = ipt * binary_mask  # [B, F, T]
        masked_target = target * binary_mask
        return ((masked_ipt - masked_target) ** 2).sum() / (binary_mask.sum() + E)  # 不算 pad 部分的贡献，仅计算有效值

    return loss_function


In [9]:
import numpy as np
from pesq import pesq
from pystoi.stoi import stoi


def SI_SDR(reference, estimation):
    """
    Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
    Args:
        reference: numpy.ndarray, [..., T]
        estimation: numpy.ndarray, [..., T]
    Returns:
        SI-SDR
    [1] SDR– Half- Baked or Well Done?
    http://www.merl.com/publications/docs/TR2019-013.pdf
    """
    estimation, reference = np.broadcast_arrays(estimation, reference)
    reference_energy = np.sum(reference ** 2, axis=-1, keepdims=True)

    # This is $\alpha$ after Equation (3) in [1].
    optimal_scaling = np.sum(reference * estimation, axis=-1, keepdims=True) \
                      / reference_energy

    # This is $e_{\text{target}}$ in Equation (4) in [1].
    projection = optimal_scaling * reference

    # This is $e_{\text{res}}$ in Equation (4) in [1].
    noise = estimation - projection

    ratio = np.sum(projection ** 2, axis=-1) / np.sum(noise ** 2, axis=-1)
    return 10 * np.log10(ratio)

def STOI(ref, est, sr=16000):
    return stoi(ref, est, sr, extended=False)


def PESQ(ref, est, sr=16000):
    return pesq(sr, ref, est, "wb")

In [10]:
import importlib
import os
import time

import numpy as np
import torch
from pesq import pesq
from pystoi.stoi import stoi


def load_checkpoint(checkpoint_path, device):
    _, ext = os.path.splitext(os.path.basename(checkpoint_path))
    assert ext in (".pth", ".tar"), "Only support ext and tar extensions of model checkpoint."
    model_checkpoint = torch.load(os.path.abspath(os.path.expanduser(checkpoint_path)), map_location=device)

    if ext == ".pth":
        print(f"Loading {checkpoint_path}.")
        return model_checkpoint
    else:  # tar
        print(f"Loading {checkpoint_path}, epoch = {model_checkpoint['epoch']}.")
        return model_checkpoint["model"]


def get_sub_band_bound(idx, n_bins, n_neighbor):
    """
    根据索引来获取上下界限
    Args:
        idx: 当前索引
        n_bins: 总共的频带数量
        n_neighbor: 每侧拓展的频率带数量
    Returns:
        (子带上界的索引，子带下界的索引)
    """
    # 随机取子带区间
    n_bins_bottom = np.min([(n_bins - 1) - idx, n_neighbor])
    n_bins_top = np.min([idx, n_neighbor])

    # 补齐上边或者下边的长度
    if n_bins_bottom < n_neighbor:
        n_bins_top += n_neighbor - n_bins_bottom
    elif n_bins_top < n_neighbor:
        n_bins_bottom += n_neighbor - n_bins_top
    else:
        pass

    idx_bottom_bound = idx + n_bins_bottom
    idx_top_bound = idx - n_bins_top
    return idx_top_bound, idx_bottom_bound


def overlap_cat(chunk_list, dim=-1):
    """
    按照 50% 的 overlap 沿着最后一个维度对 chunk_list 进行拼接
    Args:
        dim: 需要拼接的维度
        chunk_list(list): [[B, T], [B, T], ...]
    Returns:
        overlap 拼接后
    """
    overlap_output = []
    for i, chunk in enumerate(chunk_list):
        first_half, last_half = torch.split(chunk, chunk.size(-1) // 2, dim=dim)
        if i == 0:
            overlap_output += [first_half, last_half]
        else:
            overlap_output[-1] = (overlap_output[-1] + first_half) / 2
            overlap_output.append(last_half)

    overlap_output = torch.cat(overlap_output, dim=dim)
    return overlap_output


def prepare_empty_dir(dirs, resume=False):
    """
    if resume experiment, assert the dirs exist,
    if not resume experiment, make dirs.
    Args:
        dirs (list): directors list
        resume (bool): whether to resume experiment, default is False
    """
    for dir_path in dirs:
        if resume:
            assert dir_path.exists()
        else:
            dir_path.mkdir(parents=True, exist_ok=True)


class ExecutionTime:
    """
    Usage:
        timer = ExecutionTime()
        <Something...>
        print(f'Finished in {timer.duration()} seconds.')
    """

    def __init__(self):
        self.start_time = time.time()

    def duration(self):
        return int(time.time() - self.start_time)


def initialize_config(module_cfg, pass_args=True):
    """
    According to config items, load specific module dynamically with params.
    eg，config items as follow：
        module_cfg = {
            "module": "model.model",
            "main": "Model",
            "args": {...}
        }
    1. Load the module corresponding to the "module" param.
    2. Call function (or instantiate class) corresponding to the "main" param.
    3. Send the param (in "args") into the function (or class) when calling ( or instantiating)
    """
    module = importlib.import_module(module_cfg["module"])

    if pass_args:
        return getattr(module, module_cfg["main"])(**module_cfg["args"])
    else:
        return getattr(module, module_cfg["main"])


def compute_PESQ(clean_signal, noisy_signal, sr=16000):
    return pesq(sr, clean_signal, noisy_signal, "wb")


def z_score(m):
    mean = np.mean(m)
    std_var = np.std(m)
    return (m - mean) / std_var, mean, std_var


def reverse_z_score(m, mean, std_var):
    return m * std_var + mean


def min_max(m):
    m_max = np.max(m)
    m_min = np.min(m)

    return (m - m_min) / (m_max - m_min), m_max, m_min


def reverse_min_max(m, m_max, m_min):
    return m * (m_max - m_min) + m_min


def sample_fixed_length_data_aligned(data_a, data_b, sample_length):
    """
    从某个随机位置开始，从两个样本中取出固定长度的片段
    """
    assert data_a.shape == data_b.shape, "Inconsistent dataset size."
    dim = np.ndim(data_a)
    assert dim == 1 or dim == 2, "Only support 1D or 2D."

    if data_a.shape[-1] > sample_length:
        frames_total = data_a.shape[-1]
        start = np.random.randint(frames_total - sample_length + 1)
        end = start + sample_length
        if dim == 1:
            return data_a[start:end], data_b[start:end]
        else:
            return data_a[:, start:end], data_b[:, start:end]
    elif data_a.shape[-1] == sample_length:
        return data_a, data_b
    else:
        frames_total = data_a.shape[-1]
        if dim == 1:
            return np.append(
                data_a,
                np.zeros(sample_length - frames_total, dtype=np.float32)
            ), np.append(
                data_b,
                np.zeros(sample_length - frames_total, dtype=np.float32)
            )
        else:
            return np.append(
                data_a,
                np.zeros(shape=(data_a.shape[0], sample_length - frames_total), dtype=np.float32),
                axis=-1
            ), np.append(
                data_b,
                np.zeros(shape=(data_a.shape[0], sample_length - frames_total), dtype=np.float32),
                axis=-1
            )


def compute_STOI(clean_signal, noisy_signal, sr=16000):
    return stoi(clean_signal, noisy_signal, sr, extended=False)


def print_tensor_info(tensor, flag="Tensor"):
    floor_tensor = lambda float_tensor: int(float(float_tensor) * 1000) / 1000
    print(flag)
    print(
        f"\tmax: {floor_tensor(torch.max(tensor))}, min: {float(torch.min(tensor))}, mean: {floor_tensor(torch.mean(tensor))}, std: {floor_tensor(torch.std(tensor))}")


def set_requires_grad(nets, requires_grad=False):
    """
    Args:
        nets: list of networks
        requires_grad
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


def prepare_device(n_gpu: int, cudnn_deterministic=False):
    """Choose to use CPU or GPU depend on "n_gpu".
    Args:
        n_gpu(int): the number of GPUs used in the experiment.
            if n_gpu is 0, use CPU;
            if n_gpu > 1, use GPU.
        cudnn_deterministic (bool): repeatability
            cudnn.benchmark will find algorithms to optimize training. if we need to consider the repeatability of experiment, set use_cudnn_deterministic to True
    """
    if n_gpu == 0:
        print("Using CPU in the experiment.")
        device = torch.device("cpu")
    else:
        if cudnn_deterministic:
            print("Using CuDNN deterministic mode in the experiment.")
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        device = torch.device("cuda:0")

    return device

In [12]:
import matplotlib.pyplot as plt
import time
from pathlib import Path

import json5
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import torch

from util import visualization
from util.metrics import STOI, PESQ, SI_SDR
from util.utils import prepare_empty_dir, ExecutionTime, prepare_device

plt.switch_backend('agg')


class BaseTrainer:
    def __init__(self, config, resume: bool, model, loss_function, optimizer):
        self.n_gpu = torch.cuda.device_count()
        self.device = prepare_device(self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"])

        self.optimizer = optimizer
        self.loss_function = loss_function

        self.model = model.to(self.device)

        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.n_gpu)))

        # Trainer
        self.epochs = config["trainer"]["epochs"]
        self.save_checkpoint_interval = config["trainer"]["save_checkpoint_interval"]
        self.validation_config = config["trainer"]["validation"]
        self.train_config = config["trainer"].get("train", {})
        self.validation_interval = self.validation_config["interval"]
        self.find_max = self.validation_config["find_max"]
        self.validation_custom_config = self.validation_config["custom"]
        self.train_custom_config = self.train_config.get("custom", {})

        # The following args is not in the config file, We will update it if resume is True in later.
        self.start_epoch = 1
        self.best_score = -np.inf if self.find_max else np.inf
        self.root_dir = Path(config["root_dir"]).expanduser().absolute() / config["experiment_name"]
        self.checkpoints_dir = self.root_dir / "checkpoints"
        self.logs_dir = self.root_dir / "logs"
        prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)

        self.writer = visualization.writer(self.logs_dir.as_posix())
        self.writer.add_text(
            tag="Configuration",
            text_string=f"<pre>  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n</pre>",
            global_step=1
        )

        if resume: self._resume_checkpoint()
        if config["preloaded_model_path"]: self._preload_model(Path(config["preloaded_model_path"]))

        print("Configurations are as follows: ")
        print(json5.dumps(config, indent=2, sort_keys=False))

        with open((self.root_dir / f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(), "w") as handle:
            json5.dump(config, handle, indent=2, sort_keys=False)

        self._print_networks([self.model])

    def _preload_model(self, model_path):
        """
        Preload *.pth file of the model at the start of the current experiment.
        Args:
            model_path(Path): the path of the *.pth file
        """
        model_path = model_path.expanduser().absolute()
        assert model_path.exists(), f"Preloaded *.pth file is not exist. Please check the file path: {model_path.as_posix()}"
        model_checkpoint = torch.load(model_path.as_posix(), map_location=self.device)

        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.load_state_dict(model_checkpoint, strict=False)
        else:
            self.model.load_state_dict(model_checkpoint, strict=False)

        print(f"Model preloaded successfully from {model_path.as_posix()}.")

    def _resume_checkpoint(self):
        """Resume experiment from latest checkpoint.
        Notes:
            To be careful at Loading model. if model is an instance of DataParallel, we need to set model.module.*
        """
        latest_model_path = self.checkpoints_dir.expanduser().absolute() / "latest_model.tar"
        assert latest_model_path.exists(), f"{latest_model_path} does not exist, can not load latest checkpoint."

        checkpoint = torch.load(latest_model_path.as_posix(), map_location=self.device)

        self.start_epoch = checkpoint["epoch"] + 1
        self.best_score = checkpoint["best_score"]
        self.optimizer.load_state_dict(checkpoint["optimizer"])

        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.load_state_dict(checkpoint["model"])
        else:
            self.model.load_state_dict(checkpoint["model"])

        print(f"Model checkpoint loaded. Training will begin in {self.start_epoch} epoch.")

    def _save_checkpoint(self, epoch, is_best=False):
        """Save checkpoint to <root_dir>/checkpoints directory, which contains:
            - current epoch
            - best score in history
            - optimizer parameters
            - model parameters
        Args:
            is_best(bool): if current checkpoint got the best score, it also will be saved in <root_dir>/checkpoints/best_model.tar.
        """
        print(f"\t Saving {epoch} epoch model checkpoint...")

        # Construct checkpoint tar package
        state_dict = {
            "epoch": epoch,
            "best_score": self.best_score,
            "optimizer": self.optimizer.state_dict()
        }

        if isinstance(self.model, torch.nn.DataParallel):  # Parallel
            state_dict["model"] = self.model.module.cpu().state_dict()
        else:
            state_dict["model"] = self.model.cpu().state_dict()

        """
        Notes:
            - latest_model.tar:
                Contains all checkpoint information, including optimizer parameters, model parameters, etc. New checkpoint will overwrite old one.
            - model_<epoch>.pth: 
                The parameters of the model. Follow-up we can specify epoch to inference.
            - best_model.tar:
                Like latest_model, but only saved when <is_best> is True.
        """
        torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix())
        torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix())
        if is_best:
            print(f"\t Found best score in {epoch} epoch, saving...")
            torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix())

        # Use model.cpu() or model.to("cpu") will migrate the model to CPU, at which point we need re-migrate model back.
        # No matter tensor.cuda() or tensor.to("cuda"), if tensor in CPU, the tensor will not be migrated to GPU, but the model will.
        self.model.to(self.device)

    def _is_best(self, score, find_max=True):
        """Check if the current model is the best model
        """
        if find_max and score >= self.best_score:
            self.best_score = score
            return True
        elif not find_max and score <= self.best_score:
            self.best_score = score
            return True
        else:
            return False

    @staticmethod
    def _transform_pesq_range(pesq_score):
        """transform [-0.5 ~ 4.5] to [0 ~ 1]
        """
        return (pesq_score + 0.5) / 5

    @staticmethod
    def _print_networks(nets: list):
        print(f"This project contains {len(nets)} networks, the number of the parameters: ")
        params_of_all_networks = 0
        for i, net in enumerate(nets, start=1):
            params_of_network = 0
            for param in net.parameters():
                params_of_network += param.numel()

            print(f"\tNetwork {i}: {params_of_network / 1e6} million.")
            params_of_all_networks += params_of_network

        print(f"The amount of parameters in the project is {params_of_all_networks / 1e6} million.")

    def _set_models_to_train_mode(self):
        self.model.train()

    def _set_models_to_eval_mode(self):
        self.model.eval()

    def spec_audio_visualization(self, noisy, enhanced, clean, name, epoch):
        # Visualize audio
        self.writer.add_audio(f"Speech/{name}_Noisy", noisy, epoch, sample_rate=16000)
        self.writer.add_audio(f"Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000)
        self.writer.add_audio(f"Speech/{name}_Clean", clean, epoch, sample_rate=16000)

        # # Visualize waveform
        # fig, ax = plt.subplots(3, 1)
        # for j, y in enumerate([noisy, enhanced, clean]):
        #     ax[j].set_title("mean: {:.3f}, std: {:.3f}, max: {:.3f}, min: {:.3f}".format(
        #         np.mean(y),
        #         np.std(y),
        #         np.max(y),
        #         np.min(y)
        #     ))
        #     librosa.display.waveplot(y, sr=16000, ax=ax[j])
        # plt.tight_layout()
        # self.writer.add_figure(f"Waveform/{name}", fig, epoch)

        # Visualize spectrogram
        noisy_mag, _ = librosa.magphase(librosa.stft(noisy, n_fft=320, hop_length=160, win_length=320))
        enhanced_mag, _ = librosa.magphase(librosa.stft(enhanced, n_fft=320, hop_length=160, win_length=320))
        clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=320, hop_length=160, win_length=320))

        fig, axes = plt.subplots(3, 1, figsize=(6, 6))
        for k, mag in enumerate([
            noisy_mag,
            enhanced_mag,
            clean_mag,
        ]):
            axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
                              f"std: {np.std(mag):.3f}, "
                              f"max: {np.max(mag):.3f}, "
                              f"min: {np.min(mag):.3f}")
            librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000)
        plt.tight_layout()
        self.writer.add_figure(f"Spectrogram/{name}", fig, epoch)

    def metrics_visualization(self, noisy_list, clean_list, enhanced_list, epoch):
        stoi_clean_noisy = []  # Clean and noisy
        stoi_clean_denoise = []  # Clean and denoisy

        pesq_clean_noisy = []
        pesq_clean_denoise = []

        sisdr_clean_noisy = []
        sisdr_clean_denoise = []

        for noisy, enhanced, clean in zip(noisy_list, enhanced_list, clean_list):
            stoi_clean_noisy.append(STOI(clean, noisy, sr=16000))
            stoi_clean_denoise.append(STOI(clean, enhanced, sr=16000))

            pesq_clean_noisy.append(PESQ(clean, noisy, sr=16000))
            pesq_clean_denoise.append(PESQ(clean, enhanced, sr=16000))

            sisdr_clean_noisy.append(SI_SDR(clean, noisy))
            sisdr_clean_denoise.append(SI_SDR(clean, enhanced))

        self.writer.add_scalars(f"Validation/STOI", {
            "clean and noisy": np.mean(stoi_clean_noisy),
            "clean and enhanced": np.mean(stoi_clean_denoise)
        }, epoch)
        self.writer.add_scalars(f"Validation/PESQ", {
            "clean and noisy": np.mean(pesq_clean_noisy),
            "clean and enhanced": np.mean(pesq_clean_denoise)
        }, epoch)
        self.writer.add_scalars(f"Validation/SI-SDR", {
            "clean and noisy": np.mean(sisdr_clean_noisy),
            "clean and enhanced": np.mean(sisdr_clean_denoise)
        }, epoch)

        return (self._transform_pesq_range(np.mean(pesq_clean_denoise)) + np.mean(stoi_clean_denoise)) / 2

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            print(f"============== {epoch} epoch ==============")
            print("[0 seconds] Begin training...")
            timer = ExecutionTime()

            self._set_models_to_train_mode()
            self._train_epoch(epoch)

            if self.save_checkpoint_interval != 0 and (epoch % self.save_checkpoint_interval == 0):
                self._save_checkpoint(epoch)

            if self.validation_interval != 0 and epoch % self.validation_interval == 0:
                print(f"[{timer.duration()} seconds] Training is over, Validation is in progress...")

                self._set_models_to_eval_mode()
                score = self._validation_epoch(epoch)

                if self._is_best(score, find_max=self.find_max):
                    self._save_checkpoint(epoch, is_best=True)

            print(f"[{timer.duration()} seconds] End this epoch.")

    def _train_epoch(self, epoch):
        raise NotImplementedError

    def _validation_epoch(self, epoch):
        raise NotImplementedError

In [13]:
import librosa
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

from trainer.base_trainer import BaseTrainer

plt.switch_backend('agg')


class Trainer(BaseTrainer):
    def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
        super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader

    def _train_epoch(self, epoch):
        loss_total = 0.0

        for noisy, clean, n_frames_list, _ in self.train_dataloader:
            self.optimizer.zero_grad()

            noisy = noisy.to(self.device).unsqueeze(1)  # [B, F, T] => [B, 1, F, T]
            clean = clean.to(self.device).unsqueeze(1)  # [B, F, T] => [B, 1, F, T]

            enhanced = self.model(noisy)  # [B, 1, F, T]

            loss = self.loss_function(enhanced, clean, n_frames_list, self.device)
            loss.backward()
            self.optimizer.step()
            loss_total += loss.item()

        self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch)

    @torch.no_grad()
    def _validation_epoch(self, epoch):
        noisy_list = []
        clean_list = []
        enhanced_list = []

        loss_total = 0.0

        visualization_limit = self.validation_custom_config["visualization_limit"]
        n_fft = self.validation_custom_config["n_fft"]
        hop_length = self.validation_custom_config["hop_length"]
        win_length = self.validation_custom_config["win_length"]

        for i, (noisy, clean, name) in tqdm(enumerate(self.validation_dataloader), desc="Inference"):
            assert len(name) == 1, "The batch size of inference stage must 1."
            name = name[0]

            noisy = noisy.numpy().reshape(-1)
            clean = clean.numpy().reshape(-1)

            noisy_mag, noisy_phase = librosa.magphase(librosa.stft(noisy, n_fft=n_fft, hop_length=hop_length, win_length=win_length))  # [T], [F, T]
            clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=n_fft, hop_length=hop_length, win_length=win_length))  # [T] => [F, T]

            noisy_mag = torch.tensor(noisy_mag, device=self.device)[None, None, :, :]  # [F, T] => [1, 1, F, T]
            clean_mag = torch.tensor(clean_mag, device=self.device)[None, None, :, :]

            enhanced_mag = self.model(noisy_mag)

            loss_total += self.loss_function(clean_mag, enhanced_mag, [clean_mag.shape[-1], ], self.device).item()

            enhanced_mag = enhanced_mag.squeeze(0).squeeze(0).detach().cpu().numpy()  # [1, 1, F, T] => [F, T]
            enhanced = librosa.istft(enhanced_mag * noisy_phase, hop_length=hop_length, win_length=win_length, length=len(noisy))

            assert len(noisy) == len(clean) == len(enhanced)

            if i <= np.min([visualization_limit, len(self.validation_dataloader)]):
                """======= 可视化第 i 个结果 ======="""
                self.spec_audio_visualization(noisy, enhanced, clean, name, epoch)

            noisy_list.append(noisy)
            clean_list.append(clean)
            enhanced_list.append(enhanced)

        self.writer.add_scalar(f"Loss/Validation", loss_total / len(self.validation_dataloader), epoch)
        return self.metrics_visualization(noisy_list, clean_list, enhanced_list, epoch)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from inferencer.inferencer import inference_wrapper
from trainer.base_trainer import BaseTrainer

plt.switch_backend('agg')


class Trainer(BaseTrainer):
    def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
        super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader

    def _train_epoch(self, epoch):
        loss_total = 0.0

        for noisy_mag, clean_mag, _ in self.train_dataloader:
            noisy_mag = noisy_mag.to(self.device)
            clean_mag = clean_mag.to(self.device)

            self.optimizer.zero_grad()

            enhanced_mag = self.model(noisy_mag)

            loss = self.loss_function(clean_mag, enhanced_mag)
            loss.backward()
            self.optimizer.step()

            loss_total += loss.item()

        self.writer.add_scalar(f"Train/Loss", loss_total / len(self.train_dataloader), epoch)

    @torch.no_grad()
    def _validation_epoch(self, epoch):
        noisy_list, enhanced_list, clean_list, name_list, loss = inference_wrapper(
            dataloader=self.validation_dataloader,
            model=self.model,
            loss_function=self.loss_function,
            device=self.device,
            inference_args=self.validation_custom_config,
            enhanced_dir=None
        )

        self.writer.add_scalar(f"Validation/Loss", loss, epoch)

        for i in range(np.min([self.validation_custom_config["visualization_limit"], len(self.validation_dataloader)])):
            self.spec_audio_visualization(
                noisy_list[i],
                enhanced_list[i],
                clean_list[i],
                name_list[i],
                epoch
            )

        score = self.metrics_visualization(noisy_list, clean_list, enhanced_list, epoch)
        return score

In [14]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as functional
import torchaudio
from tqdm import tqdm

from trainer.base_trainer import BaseTrainer
from util.utils import overlap_cat

plt.switch_backend('agg')


class Trainer(BaseTrainer):
    def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
        super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
        self.train_dataloader = train_dataloader
        self.validation_dataloader = validation_dataloader

    def _train_epoch(self, epoch):
        loss_total = 0.0

        for noisy, clean, _ in self.train_dataloader:
            self.optimizer.zero_grad()

            noisy = noisy.to(self.device)  # [B, T]
            clean = clean.to(self.device)  # [B, T]

            noisy_d = torch.stft(
                noisy,
                n_fft=320,
                hop_length=160,
                win_length=320,
                window=torch.hann_window(320).to(self.device))  # [B, F, T, 2]

            noisy_mag, noisy_phase = torchaudio.functional.magphase(noisy_d)  # [B, F, T], [B, F, T]

            enhanced_mag = self.model(noisy_mag)

            enhanced_d = torch.cat([
                (enhanced_mag * torch.cos(noisy_phase)).unsqueeze(-1),
                (enhanced_mag * torch.sin(noisy_phase)).unsqueeze(-1)
            ], dim=-1)  # [B, F, T, 2]

            enhanced = torchaudio.functional.istft(
                enhanced_d,
                n_fft=320,
                hop_length=160,
                win_length=320,
                window=torch.hann_window(320).to(self.device),
                length=noisy.shape[1])

            loss = self.loss_function(clean, enhanced)
            loss.backward()
            self.optimizer.step()
            loss_total += loss.item()

        self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch)

    @torch.no_grad()
    def _validation_epoch(self, epoch):
        noisy_list = []
        clean_list = []
        enhanced_list = []

        loss_total = 0.0

        visualization_limit = self.validation_custom_config["visualization_limit"]
        n_fft = self.validation_custom_config["n_fft"]
        hop_length = self.validation_custom_config["hop_length"]
        win_length = self.validation_custom_config["win_length"]
        batch_size = self.validation_custom_config["batch_size"]
        unfold_size = self.validation_custom_config["unfold_size"]

        for i, (noisy, clean, name) in tqdm(enumerate(self.validation_dataloader), desc="Inference"):
            assert len(name) == 1, "The batch size of inference stage must 1."
            name = name[0]
            padded_length = 0  # 用于后续的 pad

            noisy = noisy.to(self.device)  # [1, T]
            clean = clean.to(self.device)  # [1, T]

            noisy_d = torch.stft(
                noisy,
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=torch.hann_window(win_length).to(self.device))  # [B, F, T, 2]
            noisy_mag, noisy_phase = torchaudio.functional.magphase(noisy_d)  # [1, F, T]

            """=== === === start overlap enhancement === === ==="""
            noisy_mag = noisy_mag[None, :, :, :]  # [1, F, T] => [1, 1, F, T]，多一个维度是为了 unfold

            if noisy_mag.size(-1) % unfold_size != 0:
                padded_length = unfold_size - (noisy_mag.size(-1) % unfold_size)
                noisy_mag = torch.cat([noisy_mag, torch.zeros(1, 1, noisy_mag.size(2), padded_length, device=self.device)], dim=-1)  # [1, 1, F, T]

            noisy_mag = functional.unfold(noisy_mag, kernel_size=(n_fft // 2 + 1, unfold_size), stride=unfold_size // 2)
            # [1, F, T, N] => [N, 1, F, T] => [N, F, T], where is #chunks.
            noisy_mag = noisy_mag.reshape(1, n_fft // 2 + 1, unfold_size, -1).permute(3, 0, 1, 2).squeeze(1)
            noisy_mag_chunks = torch.split(noisy_mag, batch_size, dim=0)  # [N, F, T] => ([B, F, T], ...), where the number is N // batch_size + 1

            enhanced_mag_chunks = []
            for noisy_mag_chunk in noisy_mag_chunks:
                enhanced_mag_chunk = self.model(noisy_mag_chunk)  # [1, F, T]
                enhanced_mag_chunks += torch.split(enhanced_mag_chunk, 1, dim=0)  # [B, F, T] => ([1, F, T], [1, F, T], ...)

            enhanced_mag = overlap_cat(enhanced_mag_chunks)  # ([1, F, T], [1, F, T], ...) => [1, F, T]
            enhanced_mag = enhanced_mag if padded_length == 0 else enhanced_mag[:, :, :-padded_length]  # [1, F, T]
            """=== === === end overlap enhancement === === ==="""

            enhanced_d = torch.cat([
                (enhanced_mag * torch.cos(noisy_phase)).unsqueeze(-1),
                (enhanced_mag * torch.sin(noisy_phase)).unsqueeze(-1)
            ], dim=-1)  # [B, F, T, 2]

            enhanced = torchaudio.functional.istft(
                enhanced_d,
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=torch.hann_window(win_length).to(self.device),
                length=noisy.shape[1])  # [1, T]

            loss_total += self.loss_function(clean, enhanced).item()

            noisy = noisy.detach().squeeze(0).cpu().numpy()
            clean = clean.detach().squeeze(0).cpu().numpy()
            enhanced = enhanced.detach().squeeze(0).cpu().numpy()

            assert len(noisy) == len(clean) == len(enhanced)

            if i <= np.min([visualization_limit, len(self.validation_dataloader)]):
                """======= 可视化第 i 个结果 ======="""
                self.spec_audio_visualization(noisy, enhanced, clean, name, epoch)

            noisy_list.append(noisy)
            clean_list.append(clean)
            enhanced_list.append(enhanced)

        self.writer.add_scalar(f"Loss/Validation", loss_total / len(self.validation_dataloader), epoch)
        return self.metrics_visualization(noisy_list, clean_list, enhanced_list, epoch)
