In [1]:
import argparse
import logging
import os
import pathlib
from functools import partial
from typing import Dict, List, NoReturn

import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin
import torch

### YAML / LOGGING

In [2]:
import yaml

def read_yaml(config_yaml: str) -> Dict:
    """Read config file to dictionary.

    Args:
        config_yaml: str

    Returns:
        configs: Dict
    """
    with open(config_yaml, "r") as fr:
        configs = yaml.load(fr, Loader=yaml.FullLoader)

    return configs


In [3]:
def check_configs_gramma(configs: Dict) -> NoReturn:
    r"""Check if the gramma of the config dictionary for training is legal."""

    paired_input_target_data = configs['train']['paired_input_target_data']

    if paired_input_target_data is False:

        input_source_types = configs['train']['input_source_types']
        augmentation_types = configs['train']['augmentations'].keys()

        for augmentation_type in list(
            set(augmentation_types)
            & set(
                [
                    'mixaudio',
                    'pitch_shift',
                    'magnitude_scale',
                    'swap_channel',
                    'flip_axis',
                ]
            )
        ):

            augmentation_dict = configs['train']['augmentations'][augmentation_type]

            for source_type in augmentation_dict.keys():
                if source_type not in input_source_types:
                    error_msg = (
                        "The source type '{}'' in configs['train']['augmentations']['{}'] "
                        "must be one of input_source_types {}".format(
                            source_type, augmentation_type, input_source_types
                        )
                    )
                    raise Exception(error_msg)


In [4]:
import os
import logging

def create_logging(log_dir: str, filemode: str) -> logging:
    r"""Create logging to write out log files.

    Args:
        logs_dir, str, directory to write out logs
        filemode: str, e.g., "w"

    Returns:
        logging
    """
    os.makedirs(log_dir, exist_ok=True)
    i1 = 0

    while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
        i1 += 1

    log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
        filename=log_path,
        filemode=filemode,
    )

    # Print to console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
    console.setFormatter(formatter)
    logging.getLogger("").addHandler(console)

    return logging


In [5]:
def get_dirs(
    workspace: str,
    task_name: str,
    filename: str,
    config_yaml: str,
    gpus: int,
) -> List[str]:
    r"""Get directory paths.

    Args:
        workspace: str
        task_name, str, e.g., 'musdb18'
        filenmae: str
        config_yaml: str
        gpus: int, e.g., 0 for cpu and 8 for training with 8 gpu cards

    Returns:
        checkpoints_dir: str
        logs_dir: str
        logger: pl.loggers.TensorBoardLogger
        statistics_path: str
    """

    # save checkpoints dir
    checkpoints_dir = os.path.join(
        workspace,
        "checkpoints",
        task_name,
        filename,
        "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
    )
    os.makedirs(checkpoints_dir, exist_ok=True)

    # logs dir
    logs_dir = os.path.join(
        workspace,
        "logs",
        task_name,
        filename,
        "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
    )
    os.makedirs(logs_dir, exist_ok=True)

    # loggings
    create_logging(logs_dir, filemode='w')
    # logging.info(args)

    # tensorboard logs dir
    tb_logs_dir = os.path.join(workspace, "tensorboard_logs")
    os.makedirs(tb_logs_dir, exist_ok=True)

    experiment_name = os.path.join(task_name, filename, pathlib.Path(config_yaml).stem)
    logger = pl.loggers.TensorBoardLogger(save_dir=tb_logs_dir, name=experiment_name)

    # statistics path
    statistics_path = os.path.join(
        workspace,
        "statistics",
        task_name,
        filename,
        "config={},gpus={}".format(pathlib.Path(config_yaml).stem, gpus),
        "statistics.pkl",
    )
    os.makedirs(os.path.dirname(statistics_path), exist_ok=True)

    return checkpoints_dir, logs_dir, logger, statistics_path


### UTILS

In [6]:
import datetime
import logging
import os
import pickle
from typing import Dict, NoReturn

import librosa
import numpy as np
import yaml


def create_logging(log_dir: str, filemode: str) -> logging:
    r"""Create logging to write out log files.

    Args:
        logs_dir, str, directory to write out logs
        filemode: str, e.g., "w"

    Returns:
        logging
    """
    os.makedirs(log_dir, exist_ok=True)
    i1 = 0

    while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))):
        i1 += 1

    log_path = os.path.join(log_dir, "{:04d}.log".format(i1))
    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
        filename=log_path,
        filemode=filemode,
    )

    # Print to console
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
    console.setFormatter(formatter)
    logging.getLogger("").addHandler(console)

    return logging


def load_audio(
    audio_path: str,
    mono: bool,
    sample_rate: float,
    offset: float = 0.0,
    duration: float = None,
) -> np.array:
    r"""Load audio.

    Args:
        audio_path: str
        mono: bool
        sample_rate: float
    """
    audio, _ = librosa.core.load(
        audio_path, sr=sample_rate, mono=mono, offset=offset, duration=duration
    )
    # (audio_samples,) | (channels_num, audio_samples)

    if audio.ndim == 1:
        audio = audio[None, :]
        # (1, audio_samples,)

    return audio


def load_random_segment(
    audio_path: str,
    random_state: int,
    segment_seconds: float,
    mono: bool,
    sample_rate: int,
) -> np.array:
    r"""Randomly select an audio segment from a recording."""

    duration = librosa.get_duration(filename=audio_path)

    start_time = random_state.uniform(0.0, duration - segment_seconds)

    audio = load_audio(
        audio_path=audio_path,
        mono=mono,
        sample_rate=sample_rate,
        offset=start_time,
        duration=segment_seconds,
    )
    # (channels_num, audio_samples)

    return audio


def float32_to_int16(x: np.float32) -> np.int16:

    x = np.clip(x, a_min=-1, a_max=1)

    return (x * 32767.0).astype(np.int16)


def int16_to_float32(x: np.int16) -> np.float32:

    return (x / 32767.0).astype(np.float32)


def read_yaml(config_yaml: str) -> Dict:
    """Read config file to dictionary.

    Args:
        config_yaml: str

    Returns:
        configs: Dict
    """
    with open(config_yaml, "r") as fr:
        configs = yaml.load(fr, Loader=yaml.FullLoader)

    return configs


def check_configs_gramma(configs: Dict) -> NoReturn:
    r"""Check if the gramma of the config dictionary for training is legal."""

    paired_input_target_data = configs['train']['paired_input_target_data']

    if paired_input_target_data is False:

        input_source_types = configs['train']['input_source_types']
        augmentation_types = configs['train']['augmentations'].keys()

        for augmentation_type in list(
            set(augmentation_types)
            & set(
                [
                    'mixaudio',
                    'pitch_shift',
                    'magnitude_scale',
                    'swap_channel',
                    'flip_axis',
                ]
            )
        ):

            augmentation_dict = configs['train']['augmentations'][augmentation_type]

            for source_type in augmentation_dict.keys():
                if source_type not in input_source_types:
                    error_msg = (
                        "The source type '{}'' in configs['train']['augmentations']['{}'] "
                        "must be one of input_source_types {}".format(
                            source_type, augmentation_type, input_source_types
                        )
                    )
                    raise Exception(error_msg)


def magnitude_to_db(x: float) -> float:
    eps = 1e-10
    return 20.0 * np.log10(max(x, eps))


def db_to_magnitude(x: float) -> float:
    return 10.0 ** (x / 20)


def get_pitch_shift_factor(shift_pitch: float) -> float:
    r"""The factor of the audio length to be scaled."""
    return 2 ** (shift_pitch / 12)


class StatisticsContainer(object):
    def __init__(self, statistics_path):
        self.statistics_path = statistics_path

        self.backup_statistics_path = "{}_{}.pkl".format(
            os.path.splitext(self.statistics_path)[0],
            datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
        )

        self.statistics_dict = {"train": [], "test": []}

    def append(self, steps, statistics, split):
        statistics["steps"] = steps
        self.statistics_dict[split].append(statistics)

    def dump(self):
        pickle.dump(self.statistics_dict, open(self.statistics_path, "wb"))
        pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb"))
        logging.info("    Dump statistics to {}".format(self.statistics_path))
        logging.info("    Dump statistics to {}".format(self.backup_statistics_path))

    '''
    def load_state_dict(self, resume_steps):
        self.statistics_dict = pickle.load(open(self.statistics_path, "rb"))

        resume_statistics_dict = {"train": [], "test": []}

        for key in self.statistics_dict.keys():
            for statistics in self.statistics_dict[key]:
                if statistics["steps"] <= resume_steps:
                    resume_statistics_dict[key].append(statistics)

        self.statistics_dict = resume_statistics_dict
    '''


def calculate_sdr(ref: np.array, est: np.array) -> float:
    s_true = ref
    s_artif = est - ref
    sdr = 10.0 * (
        np.log10(np.clip(np.mean(s_true ** 2), 1e-8, np.inf))
        - np.log10(np.clip(np.mean(s_artif ** 2), 1e-8, np.inf))
    )
    return sdr


### Augmentor

In [7]:
from typing import Dict

import librosa
import numpy as np

# from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db


class Augmentor:
    def __init__(self, augmentations: Dict, random_seed=1234):
        r"""Augmentor for augmenting one segment.

        Args:
            augmentations: Dict, e.g, {
                'mixaudio': {'vocals': 2, 'accompaniment': 2}
                'pitch_shift': {'vocals': 4, 'accompaniment': 4},
                ...,
            }
            random_seed: int
        """
        self.augmentations = augmentations
        self.random_state = np.random.RandomState(random_seed)

    def __call__(self, waveform: np.array, source_type: str) -> np.array:
        r"""Augment a waveform.

        Args:
            waveform: (input_channels, audio_samples)
            source_type: str

        Returns:
            new_waveform: (input_channels, new_audio_samples)
        """
        if 'pitch_shift' in self.augmentations.keys():
            waveform = self.pitch_shift(waveform, source_type)

        if 'magnitude_scale' in self.augmentations.keys():
            waveform = self.magnitude_scale(waveform, source_type)

        if 'swap_channel' in self.augmentations.keys():
            waveform = self.swap_channel(waveform, source_type)

        if 'flip_axis' in self.augmentations.keys():
            waveform = self.flip_axis(waveform, source_type)

        return waveform

    def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
        r"""Shift the pitch of a waveform. We use resampling for fast pitch
        shifting, so the speed of the waveform will also be changed. The length
        of the returned waveform will be changed.

        Args:
            waveform: (input_channels, audio_samples)
            source_type: str

        Returns:
            new_waveform: (input_channels, new_audio_samples)
        """

        # maximum pitch shift in semitones
        max_pitch_shift = self.augmentations['pitch_shift'][source_type]

        if max_pitch_shift == 0:  # No pitch shift augmentations.
            return waveform

        # random pitch shift
        rand_pitch = self.random_state.uniform(
            low=-max_pitch_shift, high=max_pitch_shift
        )

        # We use librosa.resample instead of librosa.effects.pitch_shift
        # because it is 10x times faster.
        pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
        dummy_sample_rate = 10000  # Dummy constant.

        input_channels = waveform.shape[0]

        if input_channels == 1:
            waveform = np.squeeze(waveform)

        new_waveform = librosa.resample(
            y=waveform,
            orig_sr=dummy_sample_rate,
            target_sr=dummy_sample_rate / pitch_shift_factor,
            res_type='linear',
            axis=-1,
        )

        if input_channels == 1:
            new_waveform = new_waveform[None, :]

        return new_waveform

    def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
        r"""Scale the magnitude of a waveform.

        Args:
            waveform: (input_channels, audio_samples)
            source_type: str

        Returns:
            new_waveform: (input_channels, audio_samples)
        """
        lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
        higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']

        if lower_db == 0 and higher_db == 0:  # No magnitude scale augmentation.
            return waveform

        # The magnitude (in dB) of the sample with the maximum value.
        waveform_db = magnitude_to_db(np.max(np.abs(waveform)))

        new_waveform_db = self.random_state.uniform(
            waveform_db + lower_db, waveform_db + higher_db
        )

        relative_db = new_waveform_db - waveform_db

        relative_scale = db_to_magnitude(relative_db)

        new_waveform = waveform * relative_scale

        return new_waveform

    def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
        r"""Randomly swap channels.

        Args:
            waveform: (input_channels, audio_samples)
            source_type: str

        Returns:
            new_waveform: (input_channels, audio_samples)
        """
        ndim = waveform.shape[0]

        if ndim == 1:
            return waveform
        else:
            random_axes = self.random_state.permutation(ndim)
            return waveform[random_axes, :]

    def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
        r"""Randomly flip the waveform along x-axis.

        Args:
            waveform: (input_channels, audio_samples)
            source_type: str

        Returns:
            new_waveform: (input_channels, audio_samples)
        """
        ndim = waveform.shape[0]
        random_values = self.random_state.choice([-1, 1], size=ndim)

        return waveform * random_values[:, None]


### Sampler

In [8]:
import pickle
from typing import Dict, List, NoReturn

import numpy as np
import torch.distributed as dist


class SegmentSampler:
    def __init__(
        self,
        indexes_dict_path: str,
        input_source_types: List[str],
        target_source_types: List[str],
        segment_samples: int,
        remixing_sources: bool,
        mixaudio_dict: Dict,
        batch_size: int,
        steps_per_epoch: int,
        random_seed=1234,
    ):
        r"""Sample training indexes of sources.

        Args:
            indexes_path: str, path of indexes dict
            input_source_types: list of str, e.g., ['vocals', 'accompaniment']
            target_source_types: list of str, e.g., ['vocals']
            segment_samplers: int
            mixaudio_dict, dict, mix-audio data augmentation parameters,
                e.g., {'voclas': 2, 'accompaniment': 2}
            batch_size: int
            steps_per_epoch: int, #steps_per_epoch is called an `epoch`
            random_seed: int
        """
        self.segment_samples = segment_samples
        self.mixaudio_dict = mixaudio_dict
        self.remixing_sources = remixing_sources
        self.batch_size = batch_size
        self.steps_per_epoch = steps_per_epoch

        self.meta_dict = pickle.load(open(indexes_dict_path, "rb"))
        # E.g., {
        #     'vocals': [
        #         {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}，
        #         {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
        #         ... (e.g., 225752 dicts)
        #     ],
        #     'accompaniment': [
        #         {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 0, 'end_sample': 132300}，
        #         {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4410, 'end_sample': 445410},
        #         ... (e.g., 225752 dicts)
        #     ]
        # }

        self.source_types = list(set(input_source_types) | set(target_source_types))
        # E.g., ['vocals', 'accompaniment']

        self.pointers_dict = {source_type: 0 for source_type in self.source_types}
        # E.g., {'vocals': 0, 'accompaniment': 0}

        self.indexes_dict = {
            source_type: np.arange(len(self.meta_dict[source_type]))
            for source_type in self.source_types
        }
        # E.g. {
        #     'vocals': [0, 1, ..., 225751],
        #     'accompaniment': [0, 1, ..., 225751]
        # }

        random_state = np.random.RandomState(random_seed)
        self.random_state_dict = {}

        for source_type in self.source_types:

            if remixing_sources:
                # Use different seeds for different sources.
                source_random_seed = random_state.randint(low=0, high=10000)

            else:
                # Use same seeds for different sources.
                source_random_seed = random_seed

            self.random_state_dict[source_type] = np.random.RandomState(
                source_random_seed
            )

            self.random_state_dict[source_type].shuffle(self.indexes_dict[source_type])
            # E.g., [198036, 196736, ..., 103408]

            print("{}: {}".format(source_type, len(self.indexes_dict[source_type])))

    def __iter__(self) -> List[Dict]:
        r"""Yield a batch of meta info.

        Returns:
            batch_meta_list: (batch_size,) e.g., when mix-audio is 2, looks like [
                {'vocals': [
                    {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
                    {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
                'accompaniment': [
                    {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 14579460, 'end_sample': 14711760},
                    {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'accompaniment', 'begin_sample': 3995460, 'end_sample': 4127760}]
                },
                ...
            ]
        """
        batch_size = self.batch_size

        while True:
            batch_meta_dict = {source_type: [] for source_type in self.source_types}

            for source_type in self.source_types:
                # E.g., ['vocals', 'accompaniment']

                # Loop until get a mini-batch.
                while len(batch_meta_dict[source_type]) != batch_size:

                    if source_type in self.mixaudio_dict.keys():
                        mix_audios_num = self.mixaudio_dict[source_type]

                    else:
                        mix_audios_num = 1

                    largest_index = len(self.indexes_dict[source_type]) - mix_audios_num
                    # E.g., 225750 = 225752 - 2

                    if self.pointers_dict[source_type] > largest_index:

                        # Reset pointer, and shuffle indexes.
                        self.pointers_dict[source_type] = 0
                        self.random_state_dict[source_type].shuffle(
                            self.indexes_dict[source_type]
                        )

                    source_metas = []

                    for _ in range(mix_audios_num):

                        pointer = self.pointers_dict[source_type]
                        # E.g., 1

                        index = self.indexes_dict[source_type][pointer]
                        # E.g., 12231

                        self.pointers_dict[source_type] += 1

                        source_meta = self.meta_dict[source_type][index]
                        # E.g., {
                        #     'hdf5_path': 'xx/song_A.h5',
                        #     'key_in_hdf5': 'vocals',
                        #     'begin_sample': 13406400,
                        # }

                        # Re-assign the end_sample.
                        source_meta['end_sample'] = (
                            source_meta['begin_sample'] + self.segment_samples
                        )

                        source_metas.append(source_meta)

                    batch_meta_dict[source_type].append(source_metas)

            # When mix-audio is 2, batch_meta_dict looks like: {
            #     'vocals': [
            #         [{'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
            #          {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}
            #         ],
            #         ... (batch_size)
            #     ]
            #     'accompaniment': [
            #         [{'hdf5_path': 'songG.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 24232950, 'end_sample': 24365250},
            #          {'hdf5_path': 'songH.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 1569960, 'end_sample': 1702260}
            #         ],
            #         ... (batch_size)
            #     ]
            # }

            batch_meta_list = [
                {
                    source_type: batch_meta_dict[source_type][i]
                    for source_type in self.source_types
                }
                for i in range(batch_size)
            ]
            # When mix-audio is 2, batch_meta_list looks like: [
            #     {'vocals': [
            #         {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
            #         {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
            #      'accompaniment': [
            #         {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
            #         {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
            #     }
            #     ... (batch_size)
            # ]

            yield batch_meta_list

    def __len__(self) -> int:
        return self.steps_per_epoch

    def state_dict(self) -> Dict:
        state = {'pointers_dict': self.pointers_dict, 'indexes_dict': self.indexes_dict}
        return state

    def load_state_dict(self, state) -> NoReturn:
        self.pointers_dict = state['pointers_dict']
        self.indexes_dict = state['indexes_dict']


class DistributedSamplerWrapper:
    def __init__(self, sampler):
        r"""Distributed wrapper of sampler."""
        self.sampler = sampler

    def __iter__(self) -> List[Dict]:

        num_replicas = dist.get_world_size()  # number of GPUs.
        rank = dist.get_rank()  # rank of current GPU

        for batch_meta_list in self.sampler:

            # When mix-audio is 2, batch_meta_list looks like: [
            #     {'vocals': [
            #         {'hdf5_path': 'songA.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 13406400, 'end_sample': 13538700},
            #         {'hdf5_path': 'songB.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 4440870, 'end_sample': 4573170}]
            #      'accompaniment': [
            #         {'hdf5_path': 'songE.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 14579460, 'end_sample': 14711760},
            #         {'hdf5_path': 'songF.h5', 'key_in_hdf5': 'vocals', 'begin_sample': 3995460, 'end_sample': 4127760}]
            #     }
            #     ... (batch_size)
            # ]

            # Yield a subset of batch_meta_list on one GPU.
            yield batch_meta_list[rank::num_replicas]

    def __len__(self) -> int:
        return len(self.sampler)


### DATA MODULE

In [9]:
from typing import Dict, List, NoReturn, Optional

import h5py
import librosa
import numpy as np
import torch
from pytorch_lightning.core.datamodule import LightningDataModule


class DataModule(LightningDataModule):
    def __init__(
        self,
        train_sampler: object,
        train_dataset: object,
        num_workers: int,
        distributed: bool,
    ):
        r"""Data module.

        Args:
            train_sampler: Sampler object
            train_dataset: Dataset object
            num_workers: int
            distributed: bool
        """
        super().__init__()
        self._train_sampler = train_sampler
        self.train_dataset = train_dataset
        self.num_workers = num_workers
        self.distributed = distributed

    def setup(self, stage: Optional[str] = None) -> NoReturn:
        r"""called on every device."""

        # SegmentSampler is used for sampling segment indexes for training.
        # On multiple devices, each SegmentSampler samples a part of mini-batch
        # data.

        if self.distributed:
            self.train_sampler = DistributedSamplerWrapper(self._train_sampler)

        else:
            self.train_sampler = self._train_sampler

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        r"""Get train loader."""
        train_loader = torch.utils.data.DataLoader(
            dataset=self.train_dataset,
            batch_sampler=self.train_sampler,
            collate_fn=collate_fn,
            num_workers=self.num_workers,
            pin_memory=True,
        )

        return train_loader


class Dataset:
    def __init__(
        self,
        input_source_types: List[str],
        target_source_types: List[str],
        paired_input_target_data: bool,
        input_channels: int,
        augmentor: Augmentor,
        segment_samples: int,
    ):
        r"""Used for getting data according to a meta.

        Args:
            input_source_types: list of str, e.g., ['vocals', 'accompaniment']
            target_source_types: list of str, e.g., ['vocals']
            input_channels: int
            augmentor: Augmentor
            segment_samples: int
        """
        self.input_source_types = input_source_types
        self.paired_input_target_data = paired_input_target_data
        self.input_channels = input_channels
        self.augmentor = augmentor
        self.segment_samples = segment_samples

        if paired_input_target_data:
            self.source_types = list(set(input_source_types) | set(target_source_types))

        else:
            self.source_types = input_source_types

    def __getitem__(self, meta: Dict) -> Dict:
        r"""Return data according to a meta. E.g., an input meta looks like: {
            'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
            'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
        }

        Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
        Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
        Finally, mixture is created by summing vocals and accompaniment.

        Args:
            meta: dict, e.g., {
                'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
                'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
            }

        Returns:
            data_dict: dict, e.g., {
                'vocals': (channels, segments_num),
                'accompaniment': (channels, segments_num),
                'mixture': (channels, segments_num),
            }
        """
        data_dict = {}

        for source_type in self.source_types:
            # E.g., ['vocals', 'accompaniment']

            waveforms = []  # Audio segments to be mix-audio augmented.

            for m in meta[source_type]:
                # E.g., {
                #     'hdf5_path': '.../song_A.h5',
                #     'key_in_hdf5': 'vocals',
                #     'begin_sample': '13406400',
                #     'end_sample': 13538700,
                # }

                hdf5_path = m['hdf5_path']
                key_in_hdf5 = m['key_in_hdf5']
                bgn_sample = m['begin_sample']
                end_sample = m['end_sample']

                with h5py.File(hdf5_path, 'r') as hf:

                    if source_type == 'audioset':
                        index_in_hdf5 = m['index_in_hdf5']
                        waveform = int16_to_float32(
                            hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
                        )
                        waveform = waveform[None, :]
                    else:
                        waveform = int16_to_float32(
                            hf[key_in_hdf5][:, bgn_sample:end_sample]
                        )

                if self.paired_input_target_data:
                    # TODO
                    pass

                else:
                    if self.augmentor:
                        waveform = self.augmentor(waveform, source_type)

                if source_type in self.input_source_types:
                    waveform = self.match_waveform_to_input_channels(
                        waveform=waveform, input_channels=self.input_channels
                    )
                    # (input_channels, segments_num)

                waveform = librosa.util.fix_length(
                    waveform, size=self.segment_samples, axis=1
                )

                waveforms.append(waveform)
            # E.g., waveforms: [(input_channels, audio_samples), (input_channels, audio_samples)]

            # mix-audio augmentation
            data_dict[source_type] = np.sum(waveforms, axis=0)
            # data_dict[source_type]: (input_channels, audio_samples)

        # data_dict looks like: {
        #     'voclas': (input_channels, audio_samples),
        #     'accompaniment': (input_channels, audio_samples)
        # }

        return data_dict

    def match_waveform_to_input_channels(
        self,
        waveform: np.array,
        input_channels: int,
    ) -> np.array:
        r"""Match waveform to channels num.

        Args:
            waveform: (input_channels, segments_num)
            input_channels: int

        Outputs:
            output: (new_input_channels, segments_num)
        """
        waveform_channels = waveform.shape[0]

        if waveform_channels == input_channels:
            return waveform

        elif waveform_channels < input_channels:
            assert waveform_channels == 1
            return np.tile(waveform, (input_channels, 1))

        else:
            assert input_channels == 1
            return np.mean(waveform, axis=0)[None, :]


def collate_fn(list_data_dict: List[Dict]) -> Dict:
    r"""Collate mini-batch data to inputs and targets for training.

    Args:
        list_data_dict: e.g., [
            {'vocals': (input_channels, segment_samples),
             'accompaniment': (input_channels, segment_samples),
             'mixture': (input_channels, segment_samples)
            },
            {'vocals': (input_channels, segment_samples),
             'accompaniment': (input_channels, segment_samples),
             'mixture': (input_channels, segment_samples)
            },
            ...]

    Returns:
        data_dict: e.g. {
            'vocals': (batch_size, input_channels, segment_samples),
            'accompaniment': (batch_size, input_channels, segment_samples),
            'mixture': (batch_size, input_channels, segment_samples)
            }
    """
    data_dict = {}

    for key in list_data_dict[0].keys():
        data_dict[key] = torch.Tensor(
            np.array([data_dict[key] for data_dict in list_data_dict])
        )

    return data_dict


In [10]:
def get_pitch_shifted_segment_samples(segment_samples: int, augmentations: Dict) -> int:
    r"""Get new segment samples depending on maximum pitch shift.

    Args:
        segment_samples: int
        augmentations: Dict

    Returns:
        ex_segment_samples: int
    """

    if 'pitch_shift' not in augmentations.keys():
        return segment_samples

    else:
        pitch_shift_dict = augmentations['pitch_shift']
        source_types = pitch_shift_dict.keys()

    max_pitch_shift = max(
        [pitch_shift_dict[source_type] for source_type in source_types]
    )

    ex_segment_samples = int(segment_samples * get_pitch_shift_factor(max_pitch_shift))

    return ex_segment_samples

In [11]:
def get_data_module(
    workspace: str,
    config_yaml: str,
    num_workers: int,
    distributed: bool,
) -> DataModule:
    r"""Create data_module. Here is an example to fetch a mini-batch:

    code-block:: python

        data_module.setup()
        for batch_data_dict in data_module.train_dataloader():
            print(batch_data_dict.keys())
            break

    Args:
        workspace: str
        config_yaml: str
        num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores
            for preparing data in parallel
        distributed: bool

    Returns:
        data_module: DataModule
    """
    configs = read_yaml(config_yaml)
    input_source_types = configs['train']['input_source_types']
    target_source_types = configs['train']['target_source_types']
    paired_input_target_data = configs['train']['paired_input_target_data']
    indexes_dict_path = os.path.join(workspace, configs['train']['indexes_dict_path'])
    sample_rate = configs['train']['sample_rate']
    input_channels = configs['train']['input_channels']
    segment_seconds = configs['train']['segment_seconds']
    augmentations = configs['train']['augmentations']
    batch_size = configs['train']['batch_size']
    steps_per_epoch = configs['train']['steps_per_epoch']

    segment_samples = int(segment_seconds * sample_rate)

    if paired_input_target_data:
        assert (
            augmentations['remixing_sources'] is False
        ), "Must set remixing_sources to False if input and target data are paired."

    ex_segment_samples = get_pitch_shifted_segment_samples(
        segment_samples=segment_samples,
        augmentations=augmentations,
    )

    # sampler
    train_sampler = SegmentSampler(
        indexes_dict_path=indexes_dict_path,
        input_source_types=input_source_types,
        target_source_types=target_source_types,
        segment_samples=ex_segment_samples,
        remixing_sources=augmentations['remixing_sources'],
        mixaudio_dict=augmentations['mixaudio'],
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
    )

    # augmentor
    augmentor = Augmentor(augmentations=augmentations)

    # dataset
    train_dataset = Dataset(
        input_source_types=input_source_types,
        target_source_types=target_source_types,
        paired_input_target_data=paired_input_target_data,
        input_channels=input_channels,
        augmentor=augmentor,
        segment_samples=segment_samples,
    )

    # data module
    data_module = DataModule(
        train_sampler=train_sampler,
        train_dataset=train_dataset,
        num_workers=num_workers,
        distributed=False,
    )

    return data_module

### Batch Data Preprocessor

In [12]:
from typing import Dict, List

import torch
import torch.nn as nn


class MixtureTargetBatchDataPreprocessor(nn.Module):
    def __init__(self, input_source_types: List[str], target_source_types: List[str]):
        r"""Batch data preprocessor. Used for preparing mixtures and targets for
        training. If there are multiple target source types, the waveforms of
        those sources will be stacked along the channel dimension.

        Args:
            input_source_types: List[str], e.g., ['vocals', 'bass', ...]
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
        """
        super(MixtureTargetBatchDataPreprocessor, self).__init__()

        self.input_source_types = input_source_types
        self.target_source_types = target_source_types

    def __call__(self, batch_data_dict: Dict) -> List[Dict]:
        r"""Format waveforms and targets for training.

        Args:
            batch_data_dict: dict, e.g., {
                'mixture': (batch_size, input_channels, segment_samples),
                'vocals': (batch_size, input_channels, segment_samples),
                'bass': (batch_size, input_channels, segment_samples),
                ...,
            }

        Returns:
            input_dict: dict, e.g., {
                'waveform': (batch_size, input_channels, segment_samples),
            }
            output_dict: dict, e.g., {
                'waveform': (batch_size, target_sources_num * output_channels, segment_samples)
            }
        """
        # Get mixture. Sum waveforms all sources.
        stacked_sources = torch.stack(
            [batch_data_dict[source_type] for source_type in self.input_source_types],
            dim=1,
        )
        # input_waveforms: (batch_size, input_sources, input_channels, segment_samples)

        input_waveforms = torch.sum(stacked_sources, dim=1)
        # input_waveforms: (batch_size, input_channels, segment_samples)

        # Concatenate waveforms of multiple targets along the channel axis.
        target_waveforms = torch.cat(
            [batch_data_dict[source_type] for source_type in self.target_source_types],
            dim=1,
        )
        # target_waveform: (batch_size, target_sources_num * output_channels, segment_samples)

        input_dict = {'waveform': input_waveforms}
        target_dict = {'waveform': target_waveforms}

        return input_dict, target_dict


class MixtureTargetConditionalBatchDataPreprocessor:
    def __init__(self, input_source_types: List[str], target_source_types: List[str]):
        r"""Conditional single input single output (SISO) batch data
        preprocessor. Select one target source from several target sources as
        training target and prepare the corresponding conditional vector.

        Args:
            input_source_types: List[str], e.g., ['vocals', 'bass', ...]
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
        """
        self.input_source_types = input_source_types
        self.target_source_types = target_source_types

        self.target_sources_num = len(self.target_source_types)

    def __call__(self, batch_data_dict: Dict) -> List[Dict]:
        r"""Format waveforms and targets for training.

        Args:
            batch_data_dict: dict, e.g., {
                'mixture': (batch_size, input_channels, segment_samples),
                'vocals': (batch_size, input_channels, segment_samples),
                'bass': (batch_size, input_channels, segment_samples),
                ...,
            }

        Returns:
            input_dict: dict, e.g., {
                'waveform': (batch_size, input_channels, segment_samples),
                'condition': (batch_size, target_sources_num),
            }
            output_dict: dict, e.g., {
                'waveform': (batch_size, output_channels, segment_samples)
            }
        """
        first_source_type = list(batch_data_dict.keys())[0]
        batch_size = batch_data_dict[first_source_type].shape[0]

        assert (
            batch_size % self.target_sources_num == 0
        ), "Batch size should be \
            evenly divided by target sources number."

        # Get mixture. Sum waveforms all sources.
        stacked_sources = torch.stack(
            [batch_data_dict[source_type] for source_type in self.input_source_types],
            dim=1,
        )
        # input_waveforms: (batch_size, input_sources, input_channels, segment_samples)

        input_waveforms = torch.sum(stacked_sources, dim=1)
        # input_waveforms: (batch_size, input_channels, segment_samples)

        conditions = torch.zeros(batch_size, self.target_sources_num).to(
            input_waveforms.device
        )
        # conditions: (batch_size, target_sources_num)

        target_waveforms = []

        for n in range(batch_size):

            k = n % self.target_sources_num  # source class index
            source_type = self.target_source_types[k]

            target_waveforms.append(batch_data_dict[source_type][n])

            conditions[n, k] = 1

        # conditions will looks like:
        # [[1, 0, 0, 0],
        #  [0, 1, 0, 0],
        #  [0, 0, 1, 0],
        #  [0, 0, 0, 1],
        #  [1, 0, 0, 0],
        #  [0, 1, 0, 0],
        #  ...,
        # ]

        target_waveforms = torch.stack(target_waveforms, dim=0)
        # targets: (batch_size, output_channels, segment_samples)

        input_dict = {
            'waveform': input_waveforms,
            'condition': conditions,
        }

        target_dict = {'waveform': target_waveforms}

        return input_dict, target_dict


class AmbisonicBinauralBatchDataPreprocessor(nn.Module):
    def __init__(self, input_source_types: List[str], target_source_types: List[str]):
        r"""Batch data preprocessor. Used for preparing mixtures and targets for
        training. If there are multiple target source types, the waveforms of
        those sources will be stacked along the channel dimension.

        Args:
            input_source_types: List[str], e.g., ['ambisonic']
            target_source_types: List[str], e.g., ['binaural']
        """
        super(AmbisonicBinauralBatchDataPreprocessor, self).__init__()

        self.input_source_types = input_source_types
        self.target_source_types = target_source_types

    def __call__(self, batch_data_dict: Dict) -> List[Dict]:
        r"""Format waveforms and targets for training.

        Args:
            batch_data_dict: dict, e.g., {
                'ambisonic': (batch_size, input_channels, segment_samples),
                'binaural': (batch_size, output_channels, segment_samples),
            }

        Returns:
            input_dict: dict, e.g., {
                'waveform': (batch_size, input_channels, segment_samples),
            }
            output_dict: dict, e.g., {
                'waveform': (batch_size, output_channels, segment_samples)
            }
        """
        input_dict = {'waveform': batch_data_dict['ambisonic']}
        target_dict = {'waveform': batch_data_dict['binaural']}

        return input_dict, target_dict


def get_batch_data_preprocessor_class(batch_data_preprocessor_type: str) -> nn.Module:
    r"""Get batch data preprocessor class.

    Args:
        batch_data_preprocessor_type: str

    Returns:
        nn.Module
    """
    if batch_data_preprocessor_type == 'MixtureTarget':
        return MixtureTargetBatchDataPreprocessor

    elif batch_data_preprocessor_type == 'MixtureTargetConditional':
        return MixtureTargetConditionalBatchDataPreprocessor

    elif batch_data_preprocessor_type == 'AmbisonicBinaural':
        return AmbisonicBinauralBatchDataPreprocessor

    else:
        raise NotImplementedError


### Pytorch modules

In [13]:
from typing import List, NoReturn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def init_embedding(layer: nn.Module) -> NoReturn:
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.uniform_(layer.weight, -1.0, 1.0)

    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)


def init_layer(layer: nn.Module) -> NoReturn:
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)


def init_bn(bn: nn.Module) -> NoReturn:
    r"""Initialize a Batchnorm layer."""
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)
    bn.running_mean.data.fill_(0.0)
    bn.running_var.data.fill_(1.0)


def act(x: torch.Tensor, activation: str) -> torch.Tensor:

    if activation == "relu":
        return F.relu_(x)

    elif activation == "leaky_relu":
        return F.leaky_relu_(x, negative_slope=0.01)

    elif activation == "swish":
        return x * torch.sigmoid(x)

    else:
        raise Exception("Incorrect activation!")


class Base:
    def __init__(self):
        r"""Base function for extracting spectrogram, cos, and sin, etc."""
        pass

    def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
        r"""Calculate spectrogram.

        Args:
            input: (batch_size, segments_num)
            eps: float

        Returns:
            spectrogram: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5

    def spectrogram_phase(
        self, input: torch.Tensor, eps: float = 0.0
    ) -> List[torch.Tensor]:
        r"""Calculate the magnitude, cos, and sin of the STFT of input.

        Args:
            input: (batch_size, segments_num)
            eps: float

        Returns:
            mag: (batch_size, time_steps, freq_bins)
            cos: (batch_size, time_steps, freq_bins)
            sin: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
        cos = real / mag
        sin = imag / mag
        return mag, cos, sin

    def wav_to_spectrogram_phase(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:
        r"""Convert waveforms to magnitude, cos, and sin of STFT.

        Args:
            input: (batch_size, channels_num, segment_samples)
            eps: float

        Outputs:
            mag: (batch_size, channels_num, time_steps, freq_bins)
            cos: (batch_size, channels_num, time_steps, freq_bins)
            sin: (batch_size, channels_num, time_steps, freq_bins)
        """
        batch_size, channels_num, segment_samples = input.shape

        # Reshape input with shapes of (n, segments_num) to meet the
        # requirements of the stft function.
        x = input.reshape(batch_size * channels_num, segment_samples)

        mag, cos, sin = self.spectrogram_phase(x, eps=eps)
        # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)

        _, _, time_steps, freq_bins = mag.shape
        mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
        cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
        sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)

        return mag, cos, sin

    def wav_to_spectrogram(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:

        mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
        return mag


class Subband:
    def __init__(self, subbands_num: int):
        r"""Warning!! This class is not used!!

        This class does not work as good as [1] which split subbands in the
        time-domain. Please refer to [1] for formal implementation.

        [1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
        accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).

        Args:
            subbands_num: int, e.g., 4
        """
        self.subbands_num = subbands_num

    def analysis(self, x: torch.Tensor) -> torch.Tensor:
        r"""Analysis time-frequency representation into subbands. Stack the
        subbands along the channel axis.

        Args:
            x: (batch_size, channels_num, time_steps, freq_bins)

        Returns:
            output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
        """
        batch_size, channels_num, time_steps, freq_bins = x.shape

        x = x.reshape(
            batch_size,
            channels_num,
            time_steps,
            self.subbands_num,
            freq_bins // self.subbands_num,
        )
        # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)

        x = x.transpose(2, 3)

        output = x.reshape(
            batch_size,
            channels_num * self.subbands_num,
            time_steps,
            freq_bins // self.subbands_num,
        )
        # output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)

        return output

    def synthesis(self, x: torch.Tensor) -> torch.Tensor:
        r"""Synthesis subband time-frequency representations into original
        time-frequency representation.

        Args:
            x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)

        Returns:
            output: (batch_size, channels_num, time_steps, freq_bins)
        """
        batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape

        channels_num = subband_channels_num // self.subbands_num
        freq_bins = subband_freq_bins * self.subbands_num

        x = x.reshape(
            batch_size,
            channels_num,
            self.subbands_num,
            time_steps,
            subband_freq_bins,
        )
        # x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)

        x = x.transpose(2, 3)
        # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)

        output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
        # x: (batch_size, channels_num, time_steps, freq_bins)

        return output


### Model

In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from inplace_abn.abn import InPlaceABNSync
from torchlibrosa.stft import ISTFT, STFT, magphase

# from bytesep.models.pytorch_modules import Base, init_bn, init_layer


class ConvBlockRes(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
        r"""Residual block."""
        super(ConvBlockRes, self).__init__()

        self.activation = activation
        padding = [kernel_size[0] // 2, kernel_size[1] // 2]

        # ABN is not used for bn1 because we found using abn1 will degrade performance.
        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)

        self.abn2 = InPlaceABNSync(
            num_features=out_channels, momentum=momentum, activation='leaky_relu'
        )

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
            )
            self.is_shortcut = True
        else:
            self.is_shortcut = False

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.is_shortcut:
            init_layer(self.shortcut)

    def forward(self, x):
        origin = x
        x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
        x = self.conv2(self.abn2(x))

        if self.is_shortcut:
            return self.shortcut(origin) + x
        else:
            return origin + x


class EncoderBlockRes4B(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, downsample, activation, momentum
    ):
        r"""Encoder block, contains 8 convolutional layers."""
        super(EncoderBlockRes4B, self).__init__()

        self.conv_block1 = ConvBlockRes(
            in_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block2 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.downsample = downsample

    def forward(self, x):
        encoder = self.conv_block1(x)
        encoder = self.conv_block2(encoder)
        encoder = self.conv_block3(encoder)
        encoder = self.conv_block4(encoder)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder


class DecoderBlockRes4B(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, upsample, activation, momentum
    ):
        r"""Decoder block, contains 1 transpose convolutional and 8 convolutional layers."""
        super(DecoderBlockRes4B, self).__init__()
        self.kernel_size = kernel_size
        self.stride = upsample
        self.activation = activation

        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.stride,
            stride=self.stride,
            padding=(0, 0),
            bias=False,
            dilation=(1, 1),
        )

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.conv_block2 = ConvBlockRes(
            out_channels * 2, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block5 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn1)
        init_layer(self.conv1)

    def forward(self, input_tensor, concat_tensor):
        x = self.conv1(F.relu_(self.bn1(input_tensor)))
        x = torch.cat((x, concat_tensor), dim=1)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        return x


class ResUNet143_DecouplePlusInplaceABN_ISMIR2021(nn.Module, Base):
    def __init__(self, input_channels, target_sources_num):
        super(ResUNet143_DecouplePlusInplaceABN_ISMIR2021, self).__init__()

        self.input_channels = input_channels
        self.target_sources_num = target_sources_num

        window_size = 2048
        hop_size = 441
        center = True
        pad_mode = 'reflect'
        window = 'hann'
        activation = 'leaky_relu'
        momentum = 0.01

        self.subbands_num = 1

        assert (
            self.subbands_num == 1
        ), "Using subbands_num > 1 on spectrogram \
            will lead to unexpected performance sometimes. Suggest to use \
            subband method on waveform."

        # Downsample rate along the time axis.
        self.K = 4  # outputs: |M|, cos∠M, sin∠M, Q
        self.time_downsample_ratio = 2 ** 5  # This number equals 2^{#encoder_blcoks}

        self.stft = STFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

        self.istft = ISTFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

        self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)

        self.encoder_block1 = EncoderBlockRes4B(
            in_channels=input_channels * self.subbands_num,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block2 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block3 = EncoderBlockRes4B(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block4 = EncoderBlockRes4B(
            in_channels=128,
            out_channels=256,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block5 = EncoderBlockRes4B(
            in_channels=256,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block6 = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7a = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7b = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7c = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7d = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block1 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block2 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block3 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=256,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block4 = DecoderBlockRes4B(
            in_channels=256,
            out_channels=128,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block5 = DecoderBlockRes4B(
            in_channels=128,
            out_channels=64,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block6 = DecoderBlockRes4B(
            in_channels=64,
            out_channels=32,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv_block1 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=target_sources_num
            * input_channels
            * self.K
            * self.subbands_num,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

        self.init_weights()

    def init_weights(self):
        init_bn(self.bn0)
        init_layer(self.after_conv2)

    def feature_maps_to_wav(
        self,
        input_tensor: torch.Tensor,
        sp: torch.Tensor,
        sin_in: torch.Tensor,
        cos_in: torch.Tensor,
        audio_length: int,
    ) -> torch.Tensor:
        r"""Convert feature maps to waveform.

        Args:
            input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
            sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
            sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
            cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)

        Outputs:
            waveform: (batch_size, target_sources_num * input_channels, segment_samples)
        """
        batch_size, _, time_steps, freq_bins = input_tensor.shape

        x = input_tensor.reshape(
            batch_size,
            self.target_sources_num,
            self.input_channels,
            self.K,
            time_steps,
            freq_bins,
        )
        # x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)

        mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
        _mask_real = torch.tanh(x[:, :, :, 1, :, :])
        _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
        _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
        linear_mag = x[:, :, :, 3, :, :]
        # mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Y = |Y|cos∠Y + j|Y|sin∠Y
        #   = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
        #   = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
        out_cos = (
            cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
        )
        out_sin = (
            sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
        )
        # out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
        # out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Calculate |Y|.
        out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
        # out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Calculate Y_{real} and Y_{imag} for ISTFT.
        out_real = out_mag * out_cos
        out_imag = out_mag * out_sin
        # out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)

        # Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
        shape = (
            batch_size * self.target_sources_num * self.input_channels,
            1,
            time_steps,
            freq_bins,
        )
        out_real = out_real.reshape(shape)
        out_imag = out_imag.reshape(shape)

        # ISTFT.
        x = self.istft(out_real, out_imag, audio_length)
        # (batch_size * target_sources_num * input_channels, segments_num)

        # Reshape.
        waveform = x.reshape(
            batch_size, self.target_sources_num * self.input_channels, audio_length
        )
        # (batch_size, target_sources_num * input_channels, segments_num)

        return waveform

    def forward(self, input_dict):
        r"""Forward data into the module.

        Args:
            input_dict: dict, e.g., {
                waveform: (batch_size, input_channels, segment_samples),
                ...,
            }

        Outputs:
            output_dict: dict, e.g., {
                'waveform': (batch_size, input_channels, segment_samples),
                ...,
            }
        """
        mixtures = input_dict['waveform']
        # (batch_size, input_channels, segment_samples)

        mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
        # mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)

        # Batch normalize on individual frequency bins.
        x = mag.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        # x: (batch_size, input_channels, time_steps, freq_bins)

        # Pad spectrogram to be evenly divided by downsample ratio.
        origin_len = x.shape[2]
        pad_len = (
            int(np.ceil(x.shape[2] / self.time_downsample_ratio))
            * self.time_downsample_ratio
            - origin_len
        )
        x = F.pad(x, pad=(0, 0, 0, pad_len))
        # (batch_size, channels, padded_time_steps, freq_bins)

        # Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024.
        x = x[..., 0 : x.shape[-1] - 1]  # (bs, channels, T, F)

        if self.subbands_num > 1:
            x = self.subband.analysis(x)
            # (bs, input_channels, T, F'), where F' = F // subbands_num

        # UNet
        (x1_pool, x1) = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
        (x2_pool, x2) = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
        (x3_pool, x3) = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
        (x4_pool, x4) = self.encoder_block4(
            x3_pool
        )  # x4_pool: (bs, 256, T / 16, F / 16)
        (x5_pool, x5) = self.encoder_block5(
            x4_pool
        )  # x5_pool: (bs, 384, T / 32, F / 32)
        (x6_pool, x6) = self.encoder_block6(
            x5_pool
        )  # x6_pool: (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7a(x6_pool)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
        (x_center, _) = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
        x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
        x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
        x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
        x10 = self.decoder_block4(x9, x3)  # (bs, 128, T / 4, F / 4)
        x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
        x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
        (x, _) = self.after_conv_block1(x12)  # (bs, 32, T, F)

        x = self.after_conv2(x)  # (bs, channels * 3, T, F)
        # (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')

        if self.subbands_num > 1:
            x = self.subband.synthesis(x)
            # (batch_size, target_sources_num * input_channles * self.K, T, F)

        # Recover shape
        x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 1024 -> 1025.

        x = x[:, :, 0:origin_len, :]
        # (batch_size, target_sources_num * input_channles * self.K, T, F)

        audio_length = mixtures.shape[2]

        separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
        # separated_audio: (batch_size, target_sources_num * input_channels, segments_num)

        output_dict = {'waveform': separated_audio}

        return output_dict


In [15]:
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class UNet(nn.Module):
    """
    Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation
    Link - https://arxiv.org/abs/1505.04597

    >>> UNet(num_classes=2, num_layers=3)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    UNet(
      (layers): ModuleList(
        (0): DoubleConv(...)
        (1): Down(...)
        (2): Down(...)
        (3): Up(...)
        (4): Up(...)
        (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
      )
    )
    """

    def __init__(
        self,
        num_classes: int = 19,
        num_layers: int = 5,
        features_start: int = 64,
        bilinear: bool = False,
    ):
        """
        Args:
            num_classes: Number of output classes required (default 19 for KITTI dataset)
            num_layers: Number of layers in each side of U-net
            features_start: Number of features in first layer
            bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling.
        """
        super().__init__()
        self.num_layers = num_layers

        layers = [DoubleConv(3, features_start)]

        feats = features_start
        for _ in range(num_layers - 1):
            layers.append(Down(feats, feats * 2))
            feats *= 2

        for _ in range(num_layers - 1):
            layers.append(Up(feats, feats // 2, bilinear))
            feats //= 2

        layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        xi = [self.layers[0](x)]
        # Down path
        for layer in self.layers[1:self.num_layers]:
            xi.append(layer(xi[-1]))
        # Up path
        for i, layer in enumerate(self.layers[self.num_layers:-1]):
            xi[-1] = layer(xi[-1], xi[-2 - i])
        return self.layers[-1](xi[-1])


class DoubleConv(nn.Module):
    """
    Double Convolution and BN and ReLU
    (3x3 conv -> BN -> ReLU) ** 2

    >>> DoubleConv(4, 4)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    DoubleConv(
      (net): Sequential(...)
    )
    """

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    """
    Combination of MaxPool2d and DoubleConv in series

    >>> Down(4, 8)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    Down(
      (net): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (net): Sequential(...)
        )
      )
    )
    """

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))

    def forward(self, x):
        return self.net(x)


class Up(nn.Module):
    """
    Upsampling (by either bilinear interpolation or transpose convolutions)
    followed by concatenation of feature map from contracting path,
    followed by double 3x3 convolution.

    >>> Up(8, 4)  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    Up(
      (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2))
      (conv): DoubleConv(
        (net): Sequential(...)
      )
    )
    """

    def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
        super().__init__()
        self.upsample = None
        if bilinear:
            self.upsample = nn.Sequential(
                nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
                nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
            )
        else:
            self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.upsample(x1)

        # Pad x1 to the size of x2
        diff_h = x2.shape[2] - x1.shape[2]
        diff_w = x2.shape[3] - x1.shape[3]

        x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])

        # Concatenate along the channels axis
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [16]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import os.path as op
import pathlib
import os
from scipy.io import loadmat


def load_mat2numpy(fname=""):
    '''
    Args:
        fname: pth to mat
        type:
    Returns: dic object
    '''
    if len(fname) == 0:
        return None
    else:
        return loadmat(fname)


class PQMF(nn.Module):
    def __init__(self, N, M, project_root):
        super().__init__()
        self.N = N  # nsubband
        self.M = M  # nfilter
        try:
            assert (N, M) in [(8, 64), (4, 64), (2, 64)]
        except:
            print("Warning:", N, "subbandand ", M, " filter is not supported")
        self.pad_samples = 64
        self.name = str(N) + "_" + str(M) + ".mat"
        self.ana_conv_filter = nn.Conv1d(
            1, out_channels=N, kernel_size=M, stride=N, bias=False
        )

        filters_dir = '{}/bytesep_data/filters'.format(str(pathlib.Path.home()))

        for _name in ['f_4_64.mat', 'h_4_64.mat']:

            _path = os.path.join(filters_dir, _name)

            if not os.path.isfile(_path):
                os.makedirs(os.path.dirname(_path), exist_ok=True)
                remote_path = (
                    "https://zenodo.org/record/5513378/files/{}?download=1".format(
                        _name
                    )
                )
                command_str = 'wget -O "{}" "{}"'.format(_path, remote_path)
                os.system(command_str)

        data = load_mat2numpy(op.join(filters_dir, "f_" + self.name))
        data = data['f'].astype(np.float32) / N
        data = np.flipud(data.T).T
        data = np.reshape(data, (N, 1, M)).copy()
        dict_new = self.ana_conv_filter.state_dict().copy()
        dict_new['weight'] = torch.from_numpy(data)
        self.ana_pad = nn.ConstantPad1d((M - N, 0), 0)
        self.ana_conv_filter.load_state_dict(dict_new)

        self.syn_pad = nn.ConstantPad1d((0, M // N - 1), 0)
        self.syn_conv_filter = nn.Conv1d(
            N, out_channels=N, kernel_size=M // N, stride=1, bias=False
        )
        gk = load_mat2numpy(op.join(filters_dir, "h_" + self.name))
        gk = gk['h'].astype(np.float32)
        gk = np.transpose(np.reshape(gk, (N, M // N, N)), (1, 0, 2)) * N
        gk = np.transpose(gk[::-1, :, :], (2, 1, 0)).copy()
        dict_new = self.syn_conv_filter.state_dict().copy()
        dict_new['weight'] = torch.from_numpy(gk)
        self.syn_conv_filter.load_state_dict(dict_new)

        for param in self.parameters():
            param.requires_grad = False

    def __analysis_channel(self, inputs):
        return self.ana_conv_filter(self.ana_pad(inputs))

    def __systhesis_channel(self, inputs):
        ret = self.syn_conv_filter(self.syn_pad(inputs)).permute(0, 2, 1)
        return torch.reshape(ret, (ret.shape[0], 1, -1))

    def analysis(self, inputs):
        '''
        :param inputs: [batchsize,channel,raw_wav],value:[0,1]
        :return:
        '''
        inputs = F.pad(inputs, ((0, self.pad_samples)))
        ret = None
        for i in range(inputs.size()[1]):  # channels
            if ret is None:
                ret = self.__analysis_channel(inputs[:, i : i + 1, :])
            else:
                ret = torch.cat(
                    (ret, self.__analysis_channel(inputs[:, i : i + 1, :])), dim=1
                )
        return ret

    def synthesis(self, data):
        '''
        :param data: [batchsize,self.N*K,raw_wav_sub],value:[0,1]
        :return:
        '''
        ret = None
        # data = F.pad(data,((0,self.pad_samples//self.N)))
        for i in range(data.size()[1]):  # channels
            if i % self.N == 0:
                if ret is None:
                    ret = self.__systhesis_channel(data[:, i : i + self.N, :])
                else:
                    new = self.__systhesis_channel(data[:, i : i + self.N, :])
                    ret = torch.cat((ret, new), dim=1)
        ret = ret[..., : -self.pad_samples]
        return ret

    def forward(self, inputs):
        return self.ana_conv_filter(self.ana_pad(inputs))

In [17]:
from typing import Dict, List, NoReturn, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import ISTFT, STFT, magphase

# from bytesep.models.pytorch_modules import Base, init_bn, init_layer
# from bytesep.models.subband_tools.pqmf import PQMF


class ConvBlockRes(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple,
        activation: str,
        momentum: float,
    ):
        r"""Residual block."""
        super(ConvBlockRes, self).__init__()

        self.activation = activation
        padding = [kernel_size[0] // 2, kernel_size[1] // 2]

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=(1, 1),
            dilation=(1, 1),
            padding=padding,
            bias=False,
        )

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
            )
            self.is_shortcut = True
        else:
            self.is_shortcut = False

        self.init_weights()

    def init_weights(self) -> NoReturn:
        r"""Initialize weights."""
        init_bn(self.bn1)
        init_bn(self.bn2)
        init_layer(self.conv1)
        init_layer(self.conv2)

        if self.is_shortcut:
            init_layer(self.shortcut)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        r"""Forward data into the module.

        Args:
            input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)

        Returns:
            output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
        """
        x = self.conv1(F.leaky_relu_(self.bn1(input_tensor), negative_slope=0.01))
        x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))

        if self.is_shortcut:
            return self.shortcut(input_tensor) + x
        else:
            return input_tensor + x


class EncoderBlockRes4B(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple,
        downsample: Tuple,
        activation: str,
        momentum: float,
    ):
        r"""Encoder block, contains 8 convolutional layers."""
        super(EncoderBlockRes4B, self).__init__()

        self.conv_block1 = ConvBlockRes(
            in_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block2 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.downsample = downsample

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        r"""Forward data into the module.

        Args:
            input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)

        Returns:
            encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins)
            encoder: (batch_size, output_feature_maps, time_steps, freq_bins)
        """
        encoder = self.conv_block1(input_tensor)
        encoder = self.conv_block2(encoder)
        encoder = self.conv_block3(encoder)
        encoder = self.conv_block4(encoder)
        encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
        return encoder_pool, encoder


class DecoderBlockRes4B(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple,
        upsample: Tuple,
        activation: str,
        momentum: float,
    ):
        r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers."""
        super(DecoderBlockRes4B, self).__init__()
        self.kernel_size = kernel_size
        self.stride = upsample
        self.activation = activation

        self.conv1 = torch.nn.ConvTranspose2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.stride,
            stride=self.stride,
            padding=(0, 0),
            bias=False,
            dilation=(1, 1),
        )

        self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
        self.conv_block2 = ConvBlockRes(
            out_channels * 2, out_channels, kernel_size, activation, momentum
        )
        self.conv_block3 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block4 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )
        self.conv_block5 = ConvBlockRes(
            out_channels, out_channels, kernel_size, activation, momentum
        )

        self.init_weights()

    def init_weights(self):
        r"""Initialize weights."""
        init_bn(self.bn1)
        init_layer(self.conv1)

    def forward(
        self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
    ) -> torch.Tensor:
        r"""Forward data into the module.

        Args:
            input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins)
            concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)

        Returns:
            output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
        """
        x = self.conv1(F.relu_(self.bn1(input_tensor)))
        # (batch_size, input_feature_maps, time_steps, freq_bins)

        x = torch.cat((x, concat_tensor), dim=1)
        # (batch_size, input_feature_maps * 2, time_steps, freq_bins)

        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        x = self.conv_block5(x)
        # output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)

        return x


class ResUNet143_Subbandtime(nn.Module, Base):
    def __init__(
        self, input_channels: int, output_channels: int, target_sources_num: int
    ):
        super(ResUNet143_Subbandtime, self).__init__()

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.target_sources_num = target_sources_num

        window_size = 512  # 2048 // 4
        hop_size = 110  # 441 // 4
        center = True
        pad_mode = "reflect"
        window = "hann"
        activation = "leaky_relu"
        momentum = 0.01

        self.subbands_num = 4
        self.K = 4  # outputs: |M|, cos∠M, sin∠M, Q

        self.time_downsample_ratio = 2 ** 5  # This number equals 2^{#encoder_blcoks}

        self.pqmf = PQMF(
            N=self.subbands_num,
            M=64,
            project_root='bytesep/models/subband_tools/filters',
        )

        self.stft = STFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

        self.istft = ISTFT(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

        self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)

        self.encoder_block1 = EncoderBlockRes4B(
            in_channels=self.input_channels * self.subbands_num,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block2 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=64,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block3 = EncoderBlockRes4B(
            in_channels=64,
            out_channels=128,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block4 = EncoderBlockRes4B(
            in_channels=128,
            out_channels=256,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block5 = EncoderBlockRes4B(
            in_channels=256,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.encoder_block6 = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7a = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7b = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7c = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.conv_block7d = EncoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block1 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(1, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block2 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=384,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block3 = DecoderBlockRes4B(
            in_channels=384,
            out_channels=256,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block4 = DecoderBlockRes4B(
            in_channels=256,
            out_channels=128,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block5 = DecoderBlockRes4B(
            in_channels=128,
            out_channels=64,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )
        self.decoder_block6 = DecoderBlockRes4B(
            in_channels=64,
            out_channels=32,
            kernel_size=(3, 3),
            upsample=(2, 2),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv_block1 = EncoderBlockRes4B(
            in_channels=32,
            out_channels=32,
            kernel_size=(3, 3),
            downsample=(1, 1),
            activation=activation,
            momentum=momentum,
        )

        self.after_conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=self.target_sources_num
            * self.output_channels
            * self.K
            * self.subbands_num,
            kernel_size=(1, 1),
            stride=(1, 1),
            padding=(0, 0),
            bias=True,
        )

        self.init_weights()

    def init_weights(self):
        r"""Initialize weights."""
        init_bn(self.bn0)
        init_layer(self.after_conv2)

    def feature_maps_to_wav(
        self,
        input_tensor: torch.Tensor,
        sp: torch.Tensor,
        sin_in: torch.Tensor,
        cos_in: torch.Tensor,
        audio_length: int,
    ) -> torch.Tensor:
        r"""Convert feature maps to waveform.

        Args:
            input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins)
            sp: (batch_size, input_channels, time_steps, freq_bins)
            sin_in: (batch_size, input_channels, time_steps, freq_bins)
            cos_in: (batch_size, input_channels, time_steps, freq_bins)

            (There is input_channels == output_channels for the source separation task.)

        Outputs:
            waveform: (batch_size, target_sources_num * output_channels, segment_samples)
        """
        batch_size, _, time_steps, freq_bins = input_tensor.shape

        x = input_tensor.reshape(
            batch_size,
            self.target_sources_num,
            self.output_channels,
            self.K,
            time_steps,
            freq_bins,
        )
        # x: (batch_size, target_sources_num, output_channels, self.K, time_steps, freq_bins)

        mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
        _mask_real = torch.tanh(x[:, :, :, 1, :, :])
        _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
        linear_mag = torch.tanh(x[:, :, :, 3, :, :])
        _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
        # mask_cos, mask_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)

        # Y = |Y|cos∠Y + j|Y|sin∠Y
        #   = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
        #   = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
        out_cos = (
            cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
        )
        out_sin = (
            sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
        )
        # out_cos: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
        # out_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)

        # Calculate |Y|.
        out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
        # out_mag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)

        # Calculate Y_{real} and Y_{imag} for ISTFT.
        out_real = out_mag * out_cos
        out_imag = out_mag * out_sin
        # out_real, out_imag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)

        # Reformat shape to (N, 1, time_steps, freq_bins) for ISTFT where
        # N = batch_size * target_sources_num * output_channels
        shape = (
            batch_size * self.target_sources_num * self.output_channels,
            1,
            time_steps,
            freq_bins,
        )
        out_real = out_real.reshape(shape)
        out_imag = out_imag.reshape(shape)

        # ISTFT.
        x = self.istft(out_real, out_imag, audio_length)
        # (batch_size * target_sources_num * output_channels, segments_num)

        # Reshape.
        waveform = x.reshape(
            batch_size, self.target_sources_num * self.output_channels, audio_length
        )
        # (batch_size, target_sources_num * output_channels, segments_num)

        return waveform

    def forward(self, input_dict):
        r"""Forward data into the module.

        Args:
            input_dict: dict, e.g., {
                waveform: (batch_size, input_channels, segment_samples),
                ...,
            }

        Outputs:
            output_dict: dict, e.g., {
                'waveform': (batch_size, output_channels, segment_samples),
                ...,
            }
        """
        mixtures = input_dict['waveform']
        # (batch_size, input_channels, segment_samples)

        subband_x = self.pqmf.analysis(mixtures)
        # subband_x: (batch_size, input_channels * subbands_num, segment_samples)

        mag, cos_in, sin_in = self.wav_to_spectrogram_phase(subband_x)
        # mag, cos_in, sin_in: (batch_size, input_channels * subbands_num, time_steps, freq_bins)

        # Batch normalize on individual frequency bins.
        x = mag.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)
        # (batch_size, input_channels * subbands_num, time_steps, freq_bins)

        # Pad spectrogram to be evenly divided by downsample ratio.
        origin_len = x.shape[2]
        pad_len = (
            int(np.ceil(x.shape[2] / self.time_downsample_ratio))
            * self.time_downsample_ratio
            - origin_len
        )
        x = F.pad(x, pad=(0, 0, 0, pad_len))
        # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)

        # Let frequency bins be evenly divided by 2, e.g., 257 -> 256
        x = x[..., 0 : x.shape[-1] - 1]  # (bs, input_channels, T, F)
        # x: (batch_size, input_channels * subbands_num, padded_time_steps, freq_bins)

        # UNet
        x1_pool, x1 = self.encoder_block1(x)  # x1_pool: (bs, 32, T / 2, F / 2)
        x2_pool, x2 = self.encoder_block2(x1_pool)  # x2_pool: (bs, 64, T / 4, F / 4)
        x3_pool, x3 = self.encoder_block3(x2_pool)  # x3_pool: (bs, 128, T / 8, F / 8)
        x4_pool, x4 = self.encoder_block4(x3_pool)  # x4_pool: (bs, 256, T / 16, F / 16)
        x5_pool, x5 = self.encoder_block5(x4_pool)  # x5_pool: (bs, 384, T / 32, F / 32)
        x6_pool, x6 = self.encoder_block6(x5_pool)  # x6_pool: (bs, 384, T / 32, F / 64)
        x_center, _ = self.conv_block7a(x6_pool)  # (bs, 384, T / 32, F / 64)
        x_center, _ = self.conv_block7b(x_center)  # (bs, 384, T / 32, F / 64)
        x_center, _ = self.conv_block7c(x_center)  # (bs, 384, T / 32, F / 64)
        x_center, _ = self.conv_block7d(x_center)  # (bs, 384, T / 32, F / 64)
        x7 = self.decoder_block1(x_center, x6)  # (bs, 384, T / 32, F / 32)
        x8 = self.decoder_block2(x7, x5)  # (bs, 384, T / 16, F / 16)
        x9 = self.decoder_block3(x8, x4)  # (bs, 256, T / 8, F / 8)
        x10 = self.decoder_block4(x9, x3)  # (bs, 128, T / 4, F / 4)
        x11 = self.decoder_block5(x10, x2)  # (bs, 64, T / 2, F / 2)
        x12 = self.decoder_block6(x11, x1)  # (bs, 32, T, F)
        x, _ = self.after_conv_block1(x12)  # (bs, 32, T, F)

        x = self.after_conv2(x)
        # (batch_size, target_sources_num * output_channels * self.K * subbands_num, T, F')

        # Recover shape
        x = F.pad(x, pad=(0, 1))  # Pad frequency, e.g., 256 -> 257.

        x = x[:, :, 0:origin_len, :]
        # (batch_size, target_sources_num * output_channels * self.K * subbands_num, T, F')

        audio_length = subband_x.shape[2]

        # Recover each subband spectrograms to subband waveforms. Then synthesis
        # the subband waveforms to a waveform.
        separated_subband_audio = torch.stack(
            [
                self.feature_maps_to_wav(
                    input_tensor=x[:, j :: self.subbands_num, :, :],
                    # input_tensor: (batch_size, target_sources_num * output_channels * self.K, T, F')
                    sp=mag[:, j :: self.subbands_num, :, :],
                    # sp: (batch_size, input_channels, T, F')
                    sin_in=sin_in[:, j :: self.subbands_num, :, :],
                    # sin_in: (batch_size, input_channels, T, F')
                    cos_in=cos_in[:, j :: self.subbands_num, :, :],
                    # cos_in: (batch_size, input_channels, T, F')
                    audio_length=audio_length,
                )
                # (batch_size, target_sources_num * output_channels, segments_num)
                for j in range(self.subbands_num)
            ],
            dim=2,
        )
        # （batch_size, target_sources_num * output_channels, subbands_num, segment_samples)

        # Format for synthesis.
        shape = (
            separated_subband_audio.shape[0],  # batch_size
            self.target_sources_num * self.output_channels * self.subbands_num,
            audio_length,
        )
        separated_subband_audio = separated_subband_audio.reshape(shape)
        # （batch_size, target_sources_num * output_channels * subbands_num, segment_samples)

        separated_audio = self.pqmf.synthesis(separated_subband_audio)
        # (batch_size, input_channles, segment_samples)

        output_dict = {'waveform': separated_audio}

        return output_dict


In [18]:
from typing import Any, Callable, Dict

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR


class LitSourceSeparation(pl.LightningModule):
    def __init__(
        self,
        batch_data_preprocessor,
        model: nn.Module,
        loss_function: Callable,
        optimizer_type: str,
        learning_rate: float,
        lr_lambda: Callable,
    ):
        r"""Pytorch Lightning wrapper of PyTorch model, including forward,
        optimization of model, etc.

        Args:
            batch_data_preprocessor: object, used for preparing inputs and
                targets for training. E.g., BasicBatchDataPreprocessor is used
                for preparing data in dictionary into tensor.
            model: nn.Module
            loss_function: function
            learning_rate: float
            lr_lambda: function
        """
        super().__init__()

        self.batch_data_preprocessor = batch_data_preprocessor
        self.model = model
        self.optimizer_type = optimizer_type
        self.loss_function = loss_function
        self.learning_rate = learning_rate
        self.lr_lambda = lr_lambda

    def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
        r"""Forward a mini-batch data to model, calculate loss function, and
        train for one step. A mini-batch data is evenly distributed to multiple
        devices (if there are) for parallel training.

        Args:
            batch_data_dict: e.g. {
                'vocals': (batch_size, channels_num, segment_samples),
                'accompaniment': (batch_size, channels_num, segment_samples),
                'mixture': (batch_size, channels_num, segment_samples)
            }
            batch_idx: int

        Returns:
            loss: float, loss function of this mini-batch
        """
        input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
        # input_dict: {
        #     'waveform': (batch_size, channels_num, segment_samples),
        #     (if_exist) 'condition': (batch_size, channels_num),
        # }
        # target_dict: {
        #     'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
        # }

        '''
        import numpy as np
        import librosa
        import matplotlib.pyplot as plt
        n = 1
        in_wav = input_dict['waveform'].data.cpu().numpy()[n]
        out_wav = target_dict['waveform'].data.cpu().numpy()[n]
        in_sp = librosa.feature.melspectrogram(in_wav[0], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        out_sp = librosa.feature.melspectrogram(out_wav[0], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        out_sp2 = librosa.feature.melspectrogram(out_wav[1], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        fig, axs = plt.subplots(3,1, sharex=True, figsize=(10, 8))
        vmax = np.max(np.log(in_sp))
        vmin = np.min(np.log(in_sp))
        axs[0].matshow(np.log(in_sp), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[1].matshow(np.log(out_sp), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[2].matshow(np.log(out_sp2), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[0].grid(linestyle='solid', linewidth=0.3)
        axs[1].grid(linestyle='solid', linewidth=0.3)
        axs[2].grid(linestyle='solid', linewidth=0.3)
        # axs[0].imshow(np.log(in_sp), interpolation='none')
        # axs[1].imshow(np.log(out_sp), interpolation='none')
        plt.savefig('_zz.pdf')
        import soundfile
        soundfile.write(file='_zz.wav', data=in_wav[0], samplerate=16000)
        soundfile.write(file='_zz2.wav', data=out_wav[0], samplerate=16000)
        from IPython import embed; embed(using=False); os._exit(0)
        '''

        # Forward.
        self.model.train()

        output_dict = self.model(input_dict)
        # output_dict: {
        #     'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
        # }

        outputs = output_dict['waveform']
        # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)

        # Calculate loss.
        loss = self.loss_function(
            output=outputs,
            target=target_dict['waveform'],
            mixture=input_dict['waveform'],
        )

        return loss

    def configure_optimizers(self) -> Any:
        r"""Configure optimizer."""

        if self.optimizer_type == "Adam":
            optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0.0,
                amsgrad=True,
            )

        elif self.optimizer_type == "AdamW":
            optimizer = optim.AdamW(
                self.model.parameters(),
                lr=self.learning_rate,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0.0,
                amsgrad=True,
            )

        else:
            raise NotImplementedError

        scheduler = {
            'scheduler': LambdaLR(optimizer, self.lr_lambda),
            'interval': 'step',
            'frequency': 1,
        }

        return [optimizer], [scheduler]


def get_model_class(model_type):
    r"""Get model.

    Args:
        model_type: str, e.g., 'ResUNet143_DecouplePlusInplaceABN'

    Returns:
        nn.Module
    """
    if model_type == 'ResUNet143_DecouplePlusInplaceABN_ISMIR2021':
        # from bytesep.models.resunet_ismir2021 import (
        #     ResUNet143_DecouplePlusInplaceABN_ISMIR2021,
        # )

        return ResUNet143_DecouplePlusInplaceABN_ISMIR2021

    elif model_type == 'UNet':
        # from bytesep.models.unet import UNet

        return UNet

    # TODO - add later
    # elif model_type == 'UNetSubbandTime':
    #     # from bytesep.models.unet_subbandtime import UNetSubbandTime

    #     return UNetSubbandTime

    elif model_type == 'ResUNet143_Subbandtime':
        # from bytesep.models.resunet_subbandtime import ResUNet143_Subbandtime

        return ResUNet143_Subbandtime

    # TODO - add later
    # elif model_type == 'MobileNet_Subbandtime':
    #     # from bytesep.models.mobilenet_subbandtime import MobileNet_Subbandtime

    #     return MobileNet_Subbandtime

    # elif model_type == 'MobileTiny_Subbandtime':
    #     from bytesep.models.mobiletiny_subbandtime import MobileTiny_Subbandtime

    #     return MobileTiny_Subbandtime

    # TODO - add later
    # elif model_type == 'ResUNet143_DecouplePlus':
    #     # from bytesep.models.resunet import ResUNet143_DecouplePlus

    #     return ResUNet143_Dec

    # TODO - add later
    # elif model_type == 'ConditionalUNet':
    #     # from bytesep.models.conditional_unet import ConditionalUNet

    #     return ConditionalUNet

    # elif model_type == 'LevelRNN':
    #     from bytesep.models.levelrnn import LevelRNN

    #     return LevelRNN

    # elif model_type == 'WavUNet':
    #     from bytesep.models.wavunet import WavUNet

    #     return WavUNet

    # elif model_type == 'WavUNetLevelRNN':
    #     from bytesep.models.wavunet_levelrnn import WavUNetLevelRNN

    #     return WavUNetLevelRNN

    # elif model_type == 'TTnet':
    #     from bytesep.models.ttnet import TTnet

    #     return TTnet

    # elif model_type == 'TTnetNoTransformer':
    #     from bytesep.models.ttnet_no_transformer import TTnetNoTransformer

    #     return TTnetNoTransformer

    # elif model_type == 'JiafengCNN':
    #     from bytesep.models.ttnet_jiafeng import JiafengCNN

    #     return JiafengCNN

    # elif model_type == 'JiafengTTNet':
    #     from bytesep.models.ttnet_jiafeng import JiafengTTNet

    #     return JiafengTTNet

    # elif model_type == 'ResUNet143FC_Subbandtime':
    #     from bytesep.models.resunet_subbandtime2 import ResUNet143FC_Subbandtime

    #     return ResUNet143FC_Subbandtime

    # elif model_type == 'AmbisonicToBinaural_UNetSubbandtimePhase':
    #     from bytesep.models.ambisonic_to_binaural import (
    #         AmbisonicToBinaural_UNetSubbandtimePhase,
    #     )

    #     return AmbisonicToBinaural_UNetSubbandtimePhase

    # elif model_type == 'AmbisonicToBinaural_ResUNetSubbandtimePhase':
    #     from bytesep.models.ambisonic_to_binaural import (
    #         AmbisonicToBinaural_ResUNetSubbandtimePhase,
    #     )

    #     return AmbisonicToBinaural_ResUNetSubbandtimePhase

    # TODO - add later
    # elif model_type == 'MobileNetSubbandTime':
    #     # from bytesep.models.mobilenet_subbandtime import MobileNetSubbandTime

    #     return MobileNetSubbandTime

    # elif model_type == 'WrapperDemucs':
    #     from bytesep.models.demucs.demucs import WrapperDemucs

    #     return WrapperDemucs

    # elif model_type == 'WrapperHDemucs':
    #     from bytesep.models.demucs.hdemucs import WrapperHDemucs

    #     return WrapperHDemucs

    else:
        raise NotImplementedError("{} not implemented!".format(model_type))


### Loss

In [19]:
import math
from typing import Callable

import torch
import torch.nn as nn
from torchlibrosa.stft import STFT

# from bytesep.models.pytorch_modules import Base


def l1(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
    r"""L1 loss.

    Args:
        output: torch.Tensor
        target: torch.Tensor

    Returns:
        loss: torch.float
    """
    return torch.mean(torch.abs(output - target))


def l1_wav(output: torch.Tensor, target: torch.Tensor, **kwargs) -> torch.Tensor:
    r"""L1 loss in the time-domain.

    Args:
        output: torch.Tensor
        target: torch.Tensor

    Returns:
        loss: torch.float
    """
    return l1(output, target)


class L1_Wav_L1_Sp(nn.Module, Base):
    def __init__(self):
        r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
        super(L1_Wav_L1_Sp, self).__init__()

        self.window_size = 2048
        hop_size = 441
        center = True
        pad_mode = "reflect"
        window = "hann"

        self.stft = STFT(
            n_fft=self.window_size,
            hop_length=hop_size,
            win_length=self.window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

    def __call__(
        self, output: torch.Tensor, target: torch.Tensor, **kwargs
    ) -> torch.Tensor:
        r"""L1 loss in the time-domain and on the spectrogram.

        Args:
            output: torch.Tensor
            target: torch.Tensor

        Returns:
            loss: torch.float
        """

        # L1 loss in the time-domain.
        wav_loss = l1_wav(output, target)

        # L1 loss on the spectrogram.
        sp_loss = l1(
            self.wav_to_spectrogram(output, eps=1e-8),
            self.wav_to_spectrogram(target, eps=1e-8),
        )

        # sp_loss /= math.sqrt(self.window_size)
        # sp_loss *= 1.

        # Total loss.
        return wav_loss + sp_loss

        return sp_loss


class L1_Wav_L1_CompressedSp(nn.Module, Base):
    def __init__(self):
        r"""L1 loss in the time-domain and L1 loss on the spectrogram."""
        super(L1_Wav_L1_CompressedSp, self).__init__()

        self.window_size = 2048
        hop_size = 441
        center = True
        pad_mode = "reflect"
        window = "hann"

        self.stft = STFT(
            n_fft=self.window_size,
            hop_length=hop_size,
            win_length=self.window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True,
        )

    def __call__(
        self, output: torch.Tensor, target: torch.Tensor, **kwargs
    ) -> torch.Tensor:
        r"""L1 loss in the time-domain and on the spectrogram.

        Args:
            output: torch.Tensor
            target: torch.Tensor

        Returns:
            loss: torch.float
        """

        # L1 loss in the time-domain.
        wav_loss = l1_wav(output, target)

        output_mag, output_cos, output_sin = self.wav_to_spectrogram_phase(
            output, eps=1e-8
        )
        target_mag, target_cos, target_sin = self.wav_to_spectrogram_phase(
            target, eps=1e-8
        )

        mag_loss = l1(output_mag ** 0.3, target_mag ** 0.3)
        real_loss = l1(output_mag ** 0.3 * output_cos, target_mag ** 0.3 * target_cos)
        imag_loss = l1(output_mag ** 0.3 * output_sin, target_mag ** 0.3 * target_sin)

        total_loss = wav_loss + mag_loss + real_loss + imag_loss

        return total_loss


def get_loss_function(loss_type: str) -> Callable:
    r"""Get loss function.

    Args:
        loss_type: str

    Returns:
        loss function: Callable
    """

    if loss_type == "l1_wav":
        return l1_wav

    elif loss_type == "l1_wav_l1_sp":
        return L1_Wav_L1_Sp()

    elif loss_type == "l1_wav_l1_compressed_sp":
        return L1_Wav_L1_CompressedSp()

    else:
        raise NotImplementedError


### Callbacks

In [20]:
import logging
import os
from typing import NoReturn

import pytorch_lightning as pl
import torch
import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_only


class SaveCheckpointsCallback(pl.Callback):
    def __init__(
        self,
        model: nn.Module,
        checkpoints_dir: str,
        save_step_frequency: int,
    ):
        r"""Callback to save checkpoints every #save_step_frequency steps.

        Args:
            model: nn.Module
            checkpoints_dir: str, directory to save checkpoints
            save_step_frequency: int
        """
        self.model = model
        self.checkpoints_dir = checkpoints_dir
        self.save_step_frequency = save_step_frequency
        os.makedirs(self.checkpoints_dir, exist_ok=True)

    @rank_zero_only
    def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
        r"""Save checkpoint."""
        global_step = trainer.global_step

        if global_step % self.save_step_frequency == 0:

            checkpoint_path = os.path.join(
                self.checkpoints_dir, "step={}.pth".format(global_step)
            )

            checkpoint = {'step': global_step, 'model': self.model.state_dict()}

            torch.save(checkpoint, checkpoint_path)
            logging.info("Save checkpoint to {}".format(checkpoint_path))


In [21]:
import numpy as np

def preprocess_audio(
    audio: np.array, mono: bool, origin_sr: float, sr: float, resample_type: str
) -> np.array:
    r"""Preprocess audio to mono / stereo, and resample.

    Args:
        audio: (channels_num, audio_samples), input audio
        mono: bool
        origin_sr: float, original sample rate
        sr: float, target sample rate
        resample_type: str, e.g., 'kaiser_fast'

    Returns:
        output: ndarray, output audio
    """
    if mono:
        audio = np.mean(audio, axis=0)
        # (audio_samples,)

    output = librosa.core.resample(
        audio, orig_sr=origin_sr, target_sr=sr, res_type=resample_type
    )
    # (audio_samples,) | (channels_num, audio_samples)

    if output.ndim == 1:
        output = output[None, :]
        # (1, audio_samples,)

    return output

In [22]:
from typing import Dict

import numpy as np
import torch
import torch.nn as nn


class Separator:
    def __init__(
        self, model: nn.Module, segment_samples: int, batch_size: int, device: str
    ):
        r"""Separate to separate an audio clip into a target source.

        Args:
            model: nn.Module, trained model
            segment_samples: int, length of segments to be input to a model, e.g., 44100*30
            batch_size, int, e.g., 12
            device: str, e.g., 'cuda'
        """
        self.model = model
        self.segment_samples = segment_samples
        self.batch_size = batch_size
        self.device = device

    def separate(self, input_dict: Dict) -> np.array:
        r"""Separate an audio clip into a target source.

        Args:
            input_dict: dict, e.g., {
                waveform: (channels_num, audio_samples),
                ...,
            }

        Returns:
            sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
        """
        audio = input_dict['waveform']

        audio_samples = audio.shape[-1]

        # Pad the audio with zero in the end so that the length of audio can be
        # evenly divided by segment_samples.
        audio = self.pad_audio(audio)

        # Enframe long audio into segments.
        segments = self.enframe(audio, self.segment_samples)
        # (segments_num, channels_num, segment_samples)

        segments_input_dict = {'waveform': segments}

        if 'condition' in input_dict.keys():
            segments_num = len(segments)
            segments_input_dict['condition'] = np.tile(
                input_dict['condition'][None, :], (segments_num, 1)
            )
            # (batch_size, segments_num)

        # Separate in mini-batches.
        sep_segments = self._forward_in_mini_batches(
            self.model, segments_input_dict, self.batch_size
        )['waveform']
        # (segments_num, channels_num, segment_samples)

        # Deframe segments into long audio.
        sep_audio = self.deframe(sep_segments)
        # (channels_num, padded_audio_samples)

        sep_audio = sep_audio[:, 0:audio_samples]
        # (channels_num, audio_samples)

        return sep_audio

    def pad_audio(self, audio: np.array) -> np.array:
        r"""Pad the audio with zero in the end so that the length of audio can
        be evenly divided by segment_samples.

        Args:
            audio: (channels_num, audio_samples)

        Returns:
            padded_audio: (channels_num, audio_samples)
        """
        channels_num, audio_samples = audio.shape

        # Number of segments
        segments_num = int(np.ceil(audio_samples / self.segment_samples))

        pad_samples = segments_num * self.segment_samples - audio_samples

        padded_audio = np.concatenate(
            (audio, np.zeros((channels_num, pad_samples))), axis=1
        )
        # (channels_num, padded_audio_samples)

        return padded_audio

    def enframe(self, audio: np.array, segment_samples: int) -> np.array:
        r"""Enframe long audio into segments.

        Args:
            audio: (channels_num, audio_samples)
            segment_samples: int

        Returns:
            segments: (segments_num, channels_num, segment_samples)
        """
        audio_samples = audio.shape[1]
        assert audio_samples % segment_samples == 0

        hop_samples = segment_samples // 2
        segments = []

        pointer = 0
        while pointer + segment_samples <= audio_samples:
            segments.append(audio[:, pointer : pointer + segment_samples])
            pointer += hop_samples

        segments = np.array(segments)

        return segments

    def deframe(self, segments: np.array) -> np.array:
        r"""Deframe segments into long audio.

        Args:
            segments: (segments_num, channels_num, segment_samples)

        Returns:
            output: (channels_num, audio_samples)
        """
        (segments_num, _, segment_samples) = segments.shape

        if segments_num == 1:
            return segments[0]

        assert self._is_integer(segment_samples * 0.25)
        assert self._is_integer(segment_samples * 0.75)

        output = []

        output.append(segments[0, :, 0 : int(segment_samples * 0.75)])

        for i in range(1, segments_num - 1):
            output.append(
                segments[
                    i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
                ]
            )

        output.append(segments[-1, :, int(segment_samples * 0.25) :])

        output = np.concatenate(output, axis=-1)

        return output

    def _is_integer(self, x: float) -> bool:
        if x - int(x) < 1e-10:
            return True
        else:
            return False

    def _forward_in_mini_batches(
        self, model: nn.Module, segments_input_dict: Dict, batch_size: int
    ) -> Dict:
        r"""Forward data to model in mini-batch.

        Args:
            model: nn.Module
            segments_input_dict: dict, e.g., {
                'waveform': (segments_num, channels_num, segment_samples),
                ...,
            }
            batch_size: int

        Returns:
            output_dict: dict, e.g. {
                'waveform': (segments_num, channels_num, segment_samples),
            }
        """
        output_dict = {}

        pointer = 0
        segments_num = len(segments_input_dict['waveform'])

        while True:
            if pointer >= segments_num:
                break

            batch_input_dict = {}

            for key in segments_input_dict.keys():
                batch_input_dict[key] = torch.Tensor(
                    segments_input_dict[key][pointer : pointer + batch_size]
                ).to(self.device)

            pointer += batch_size

            with torch.no_grad():
                model.eval()
                batch_output_dict = model(batch_input_dict)

            for key in batch_output_dict.keys():
                self._append_to_dict(
                    output_dict, key, batch_output_dict[key].data.cpu().numpy()
                )

        for key in output_dict.keys():
            output_dict[key] = np.concatenate(output_dict[key], axis=0)

        return output_dict

    def _append_to_dict(self, dict, key, value):
        if key in dict.keys():
            dict[key].append(value)
        else:
            dict[key] = [value]


In [23]:
import logging
import os
import time
from typing import Dict, List, NoReturn

import librosa
import musdb
import museval
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_only

# from bytesep.callbacks.base import SaveCheckpointsCallback
# from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio
# from bytesep.separate import Separator
# from bytesep.utils import StatisticsContainer, read_yaml


def get_musdb18_callbacks(
    config_yaml: str,
    workspace: str,
    checkpoints_dir: str,
    statistics_path: str,
    logger: pl.loggers.TensorBoardLogger,
    model: nn.Module,
    evaluate_device: str,
) -> List[pl.Callback]:
    r"""Get MUSDB18 callbacks of a config yaml.

    Args:
        config_yaml: str
        workspace: str
        checkpoints_dir: str, directory to save checkpoints
        statistics_dir: str, directory to save statistics
        logger: pl.loggers.TensorBoardLogger
        model: nn.Module
        evaluate_device: str

    Return:
        callbacks: List[pl.Callback]
    """
    configs = read_yaml(config_yaml)
    task_name = configs['task_name']
    evaluation_callback = configs['train']['evaluation_callback']
    target_source_types = configs['train']['target_source_types']
    input_channels = configs['train']['input_channels']
    evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name)
    test_segment_seconds = configs['evaluate']['segment_seconds']
    sample_rate = configs['train']['sample_rate']
    test_segment_samples = int(test_segment_seconds * sample_rate)
    test_batch_size = configs['evaluate']['batch_size']

    evaluate_step_frequency = configs['train']['evaluate_step_frequency']
    save_step_frequency = configs['train']['save_step_frequency']

    # save checkpoint callback
    save_checkpoints_callback = SaveCheckpointsCallback(
        model=model,
        checkpoints_dir=checkpoints_dir,
        save_step_frequency=save_step_frequency,
    )

    # evaluation callback
    EvaluationCallback = _get_evaluation_callback_class(evaluation_callback)

    # statistics container
    statistics_container = StatisticsContainer(statistics_path)

    # evaluation callback
    evaluate_train_callback = EvaluationCallback(
        dataset_dir=evaluation_audios_dir,
        split='train',
        model=model,
        target_source_types=target_source_types,
        sample_rate=sample_rate,
        input_channels=input_channels,
        segment_samples=test_segment_samples,
        batch_size=test_batch_size,
        device=evaluate_device,
        evaluate_step_frequency=evaluate_step_frequency,
        logger=logger,
        statistics_container=statistics_container,
    )

    evaluate_test_callback = EvaluationCallback(
        dataset_dir=evaluation_audios_dir,
        split='test',
        model=model,
        target_source_types=target_source_types,
        sample_rate=sample_rate,
        input_channels=input_channels,
        segment_samples=test_segment_samples,
        batch_size=test_batch_size,
        device=evaluate_device,
        evaluate_step_frequency=evaluate_step_frequency,
        logger=logger,
        statistics_container=statistics_container,
    )

    # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback]
    callbacks = [save_checkpoints_callback, evaluate_test_callback]

    return callbacks


def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback:
    r"""Get evaluation callback class."""
    if evaluation_callback == "Musdb18":
        return Musdb18EvaluationCallback

    if evaluation_callback == 'Musdb18Conditional':
        return Musdb18ConditionalEvaluationCallback

    else:
        raise NotImplementedError


class Musdb18EvaluationCallback(pl.Callback):
    def __init__(
        self,
        dataset_dir: str,
        split: str,
        model: nn.Module,
        target_source_types: str,
        sample_rate: int,
        input_channels: int,
        segment_samples: int,
        batch_size: int,
        device: str,
        evaluate_step_frequency: int,
        logger: pl.loggers.TensorBoardLogger,
        statistics_container: StatisticsContainer,
    ):
        r"""Callback to evaluate every #save_step_frequency steps.

        Args:
            dataset_dir: str
            model: nn.Module
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
            input_channels: int
            split: 'train' | 'test'
            sample_rate: int
            segment_samples: int, length of segments to be input to a model, e.g., 44100*30
            batch_size, int, e.g., 12
            device: str, e.g., 'cuda'
            evaluate_step_frequency: int, evaluate every #save_step_frequency steps
            logger: object
            statistics_container: StatisticsContainer
        """
        self.model = model
        self.target_source_types = target_source_types
        self.input_channels = input_channels
        self.sample_rate = sample_rate
        self.split = split
        self.segment_samples = segment_samples
        self.evaluate_step_frequency = evaluate_step_frequency
        self.logger = logger
        self.statistics_container = statistics_container
        self.mono = input_channels == 1
        self.resample_type = "kaiser_fast"

        self.mus = musdb.DB(root=dataset_dir, subsets=[split])

        error_msg = "The directory {} is empty!".format(dataset_dir)
        assert len(self.mus) > 0, error_msg

        # separator
        self.separator = Separator(model, self.segment_samples, batch_size, device)

    @rank_zero_only
    def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
        r"""Evaluate separation SDRs of audio recordings."""
        global_step = trainer.global_step

        if global_step % self.evaluate_step_frequency == 0:

            sdr_dict = {}

            logging.info("--- Step {} ---".format(global_step))
            logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))

            eval_time = time.time()

            for track in self.mus.tracks:

                audio_name = track.name

                # Get waveform of mixture.
                mixture = track.audio.T
                # (channels_num, audio_samples)

                mixture = preprocess_audio(
                    audio=mixture,
                    mono=self.mono,
                    origin_sr=track.rate,
                    sr=self.sample_rate,
                    resample_type=self.resample_type,
                )
                # (channels_num, audio_samples)

                target_dict = {}
                sdr_dict[audio_name] = {}

                # Get waveform of all target source types.
                for j, source_type in enumerate(self.target_source_types):
                    # E.g., ['vocals', 'bass', ...]

                    audio = track.targets[source_type].audio.T

                    audio = preprocess_audio(
                        audio=audio,
                        mono=self.mono,
                        origin_sr=track.rate,
                        sr=self.sample_rate,
                        resample_type=self.resample_type,
                    )
                    # (channels_num, audio_samples)

                    target_dict[source_type] = audio
                    # (channels_num, audio_samples)

                # Separate.
                input_dict = {'waveform': mixture}

                sep_wavs = self.separator.separate(input_dict)
                # sep_wavs: (target_sources_num * channels_num, audio_samples)

                # Post process separation results.
                sep_wavs = preprocess_audio(
                    audio=sep_wavs,
                    mono=self.mono,
                    origin_sr=self.sample_rate,
                    sr=track.rate,
                    resample_type=self.resample_type,
                )
                # sep_wavs: (target_sources_num * channels_num, audio_samples)

                sep_wavs = librosa.util.fix_length(
                    sep_wavs, size=mixture.shape[1], axis=1
                )
                # sep_wavs: (target_sources_num * channels_num, audio_samples)

                sep_wav_dict = get_separated_wavs_from_simo_output(
                    sep_wavs, self.input_channels, self.target_source_types
                )
                # output_dict: dict, e.g., {
                #     'vocals': (channels_num, audio_samples),
                #     'bass': (channels_num, audio_samples),
                #     ...,
                # }

                # Evaluate for all target source types.
                for source_type in self.target_source_types:
                    # E.g., ['vocals', 'bass', ...]

                    # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan).
                    (sdrs, _, _, _) = museval.evaluate(
                        [target_dict[source_type].T], [sep_wav_dict[source_type].T]
                    )

                    sdr = np.nanmedian(sdrs)
                    sdr_dict[audio_name][source_type] = sdr

                    logging.info(
                        "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
                    )

            logging.info("-----------------------------")
            median_sdr_dict = {}

            # Calculate median SDRs of all songs.
            for source_type in self.target_source_types:
                # E.g., ['vocals', 'bass', ...]

                median_sdr = np.median(
                    [
                        sdr_dict[audio_name][source_type]
                        for audio_name in sdr_dict.keys()
                    ]
                )

                median_sdr_dict[source_type] = median_sdr

                logging.info(
                    "Step: {}, {}, Median SDR: {:.3f}".format(
                        global_step, source_type, median_sdr
                    )
                )

            logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))

            statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
            self.statistics_container.append(global_step, statistics, self.split)
            self.statistics_container.dump()


def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict:
    r"""Get separated waveforms of target sources from a single input multiple
    output (SIMO) system.

    Args:
        x: (target_sources_num * channels_num, audio_samples)
        input_channels: int
        target_source_types: List[str], e.g., ['vocals', 'bass', ...]

    Returns:
        output_dict: dict, e.g., {
            'vocals': (channels_num, audio_samples),
            'bass': (channels_num, audio_samples),
            ...,
        }
    """
    output_dict = {}

    for j, source_type in enumerate(target_source_types):
        output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels]

    return output_dict


class Musdb18ConditionalEvaluationCallback(pl.Callback):
    def __init__(
        self,
        dataset_dir: str,
        split: str,
        model: nn.Module,
        target_source_types: str,
        sample_rate: int,
        input_channels: int,
        segment_samples: int,
        batch_size: int,
        device: str,
        evaluate_step_frequency: int,
        logger: pl.loggers.TensorBoardLogger,
        statistics_container: StatisticsContainer,
    ):
        r"""Callback to evaluate every #save_step_frequency steps.

        Args:
            dataset_dir: str
            split: 'train' | 'test'
            model: nn.Module
            target_source_types: List[str], e.g., ['vocals', 'bass', ...]
            sample_rate: int
            input_channels: int
            segment_samples: int, length of segments to be input to a model, e.g., 44100*30
            batch_size, int, e.g., 12
            device: str, e.g., 'cuda'
            evaluate_step_frequency: int, evaluate every #save_step_frequency steps
            logger: object
            statistics_container: StatisticsContainer
        """
        self.model = model
        self.target_source_types = target_source_types
        self.input_channels = input_channels
        self.sample_rate = sample_rate
        self.split = split
        self.segment_samples = segment_samples
        self.evaluate_step_frequency = evaluate_step_frequency
        self.logger = logger
        self.statistics_container = statistics_container
        self.mono = input_channels == 1
        self.resample_type = "kaiser_fast"

        self.mus = musdb.DB(root=dataset_dir, subsets=[split])

        error_msg = "The directory {} is empty!".format(dataset_dir)
        assert len(self.mus) > 0, error_msg

        # separator
        self.separator = Separator(model, self.segment_samples, batch_size, device)

    @rank_zero_only
    def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn:
        r"""Evaluate separation SDRs of audio recordings."""
        global_step = trainer.global_step

        if global_step % self.evaluate_step_frequency == 0:

            sdr_dict = {}

            logging.info("--- Step {} ---".format(global_step))
            logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks)))

            eval_time = time.time()

            for track in self.mus.tracks:

                audio_name = track.name

                # Get waveform of mixture.
                mixture = track.audio.T
                # (channels_num, audio_samples)

                mixture = preprocess_audio(
                    audio=mixture,
                    mono=self.mono,
                    origin_sr=track.rate,
                    sr=self.sample_rate,
                    resample_type=self.resample_type,
                )
                # (channels_num, audio_samples)

                target_dict = {}
                sdr_dict[audio_name] = {}

                # Get waveform of all target source types.
                for j, source_type in enumerate(self.target_source_types):
                    # E.g., ['vocals', 'bass', ...]

                    audio = track.targets[source_type].audio.T

                    audio = preprocess_audio(
                        audio=audio,
                        mono=self.mono,
                        origin_sr=track.rate,
                        sr=self.sample_rate,
                        resample_type=self.resample_type,
                    )
                    # (channels_num, audio_samples)

                    target_dict[source_type] = audio
                    # (channels_num, audio_samples)

                    condition = np.zeros(len(self.target_source_types))
                    condition[j] = 1

                    input_dict = {'waveform': mixture, 'condition': condition}

                    sep_wav = self.separator.separate(input_dict)
                    # sep_wav: (channels_num, audio_samples)

                    sep_wav = preprocess_audio(
                        audio=sep_wav,
                        mono=self.mono,
                        origin_sr=self.sample_rate,
                        sr=track.rate,
                        resample_type=self.resample_type,
                    )
                    # sep_wav: (channels_num, audio_samples)

                    sep_wav = librosa.util.fix_length(
                        sep_wav, size=mixture.shape[1], axis=1
                    )
                    # sep_wav: (target_sources_num * channels_num, audio_samples)

                    # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan)
                    (sdrs, _, _, _) = museval.evaluate(
                        [target_dict[source_type].T], [sep_wav.T]
                    )

                    sdr = np.nanmedian(sdrs)
                    sdr_dict[audio_name][source_type] = sdr

                    logging.info(
                        "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr)
                    )

            logging.info("-----------------------------")
            median_sdr_dict = {}

            # Calculate median SDRs of all songs.
            for source_type in self.target_source_types:

                median_sdr = np.median(
                    [
                        sdr_dict[audio_name][source_type]
                        for audio_name in sdr_dict.keys()
                    ]
                )

                median_sdr_dict[source_type] = median_sdr

                logging.info(
                    "Step: {}, {}, Median SDR: {:.3f}".format(
                        global_step, source_type, median_sdr
                    )
                )

            logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time))

            statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict}
            self.statistics_container.append(global_step, statistics, self.split)
            self.statistics_container.dump()


In [24]:
from typing import List

import pytorch_lightning as pl
import torch.nn as nn


def get_callbacks(
    task_name: str,
    config_yaml: str,
    workspace: str,
    checkpoints_dir: str,
    statistics_path: str,
    logger: pl.loggers.TensorBoardLogger,
    model: nn.Module,
    evaluate_device: str,
) -> List[pl.Callback]:
    r"""Get callbacks of a task and config yaml file.

    Args:
        task_name: str
        config_yaml: str
        dataset_dir: str
        workspace: str, containing useful files such as audios for evaluation
        checkpoints_dir: str, directory to save checkpoints
        statistics_dir: str, directory to save statistics
        logger: pl.loggers.TensorBoardLogger
        model: nn.Module
        evaluate_device: str

    Return:
        callbacks: List[pl.Callback]
    """
    if task_name == 'musdb18':

        # from bytesep.callbacks.musdb18 import get_musdb18_callbacks

        return get_musdb18_callbacks(
            config_yaml=config_yaml,
            workspace=workspace,
            checkpoints_dir=checkpoints_dir,
            statistics_path=statistics_path,
            logger=logger,
            model=model,
            evaluate_device=evaluate_device,
        )

    else:
        raise NotImplementedError


### LitSourceSeperation

In [25]:
from typing import Any, Callable, Dict

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR


class LitSourceSeparation(pl.LightningModule):
    def __init__(
        self,
        batch_data_preprocessor,
        model: nn.Module,
        loss_function: Callable,
        optimizer_type: str,
        learning_rate: float,
        lr_lambda: Callable,
    ):
        r"""Pytorch Lightning wrapper of PyTorch model, including forward,
        optimization of model, etc.

        Args:
            batch_data_preprocessor: object, used for preparing inputs and
                targets for training. E.g., BasicBatchDataPreprocessor is used
                for preparing data in dictionary into tensor.
            model: nn.Module
            loss_function: function
            learning_rate: float
            lr_lambda: function
        """
        super().__init__()

        self.batch_data_preprocessor = batch_data_preprocessor
        self.model = model
        self.optimizer_type = optimizer_type
        self.loss_function = loss_function
        self.learning_rate = learning_rate
        self.lr_lambda = lr_lambda

    def training_step(self, batch_data_dict: Dict, batch_idx: int) -> torch.float:
        r"""Forward a mini-batch data to model, calculate loss function, and
        train for one step. A mini-batch data is evenly distributed to multiple
        devices (if there are) for parallel training.

        Args:
            batch_data_dict: e.g. {
                'vocals': (batch_size, channels_num, segment_samples),
                'accompaniment': (batch_size, channels_num, segment_samples),
                'mixture': (batch_size, channels_num, segment_samples)
            }
            batch_idx: int

        Returns:
            loss: float, loss function of this mini-batch
        """
        input_dict, target_dict = self.batch_data_preprocessor(batch_data_dict)
        # input_dict: {
        #     'waveform': (batch_size, channels_num, segment_samples),
        #     (if_exist) 'condition': (batch_size, channels_num),
        # }
        # target_dict: {
        #     'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
        # }

        '''
        import numpy as np
        import librosa
        import matplotlib.pyplot as plt
        n = 1
        in_wav = input_dict['waveform'].data.cpu().numpy()[n]
        out_wav = target_dict['waveform'].data.cpu().numpy()[n]
        in_sp = librosa.feature.melspectrogram(in_wav[0], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        out_sp = librosa.feature.melspectrogram(out_wav[0], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        out_sp2 = librosa.feature.melspectrogram(out_wav[1], sr=16000, n_fft=512, hop_length=160, n_mels=80, fmin=30, fmax=8000)
        fig, axs = plt.subplots(3,1, sharex=True, figsize=(10, 8))
        vmax = np.max(np.log(in_sp))
        vmin = np.min(np.log(in_sp))
        axs[0].matshow(np.log(in_sp), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[1].matshow(np.log(out_sp), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[2].matshow(np.log(out_sp2), origin='lower', aspect='auto', cmap='jet', vmin=vmin, vmax=vmax)
        axs[0].grid(linestyle='solid', linewidth=0.3)
        axs[1].grid(linestyle='solid', linewidth=0.3)
        axs[2].grid(linestyle='solid', linewidth=0.3)
        # axs[0].imshow(np.log(in_sp), interpolation='none')
        # axs[1].imshow(np.log(out_sp), interpolation='none')
        plt.savefig('_zz.pdf')
        import soundfile
        soundfile.write(file='_zz.wav', data=in_wav[0], samplerate=16000)
        soundfile.write(file='_zz2.wav', data=out_wav[0], samplerate=16000)
        from IPython import embed; embed(using=False); os._exit(0)
        '''

        # Forward.
        self.model.train()

        output_dict = self.model(input_dict)
        # output_dict: {
        #     'waveform': (batch_size, target_sources_num * channels_num, segment_samples),
        # }

        outputs = output_dict['waveform']
        # outputs:, e.g, (batch_size, target_sources_num * channels_num, segment_samples)

        # Calculate loss.
        loss = self.loss_function(
            output=outputs,
            target=target_dict['waveform'],
            mixture=input_dict['waveform'],
        )

        return loss

    def configure_optimizers(self) -> Any:
        r"""Configure optimizer."""

        if self.optimizer_type == "Adam":
            optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.learning_rate,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0.0,
                amsgrad=True,
            )

        elif self.optimizer_type == "AdamW":
            optimizer = optim.AdamW(
                self.model.parameters(),
                lr=self.learning_rate,
                betas=(0.9, 0.999),
                eps=1e-08,
                weight_decay=0.0,
                amsgrad=True,
            )

        else:
            raise NotImplementedError

        scheduler = {
            'scheduler': LambdaLR(optimizer, self.lr_lambda),
            'interval': 'step',
            'frequency': 1,
        }

        return [optimizer], [scheduler]

### Optimizers

In [26]:
def get_lr_lambda(step, warm_up_steps: int, reduce_lr_steps: int):
    r"""Get lr_lambda for LambdaLR. E.g.,

    .. code-block: python
        lr_lambda = lambda step: get_lr_lambda(step, warm_up_steps=1000, reduce_lr_steps=10000)

        from torch.optim.lr_scheduler import LambdaLR
        LambdaLR(optimizer, lr_lambda)

    Args:
        warm_up_steps: int, steps for warm up
        reduce_lr_steps: int, reduce learning rate by 0.9 every #reduce_lr_steps steps

    Returns:
        learning rate: float
    """
    if step <= warm_up_steps:
        return step / warm_up_steps
    else:
        return 0.9 ** (step // reduce_lr_steps)


### Train

In [27]:
def train(workspace, gpus, config_yaml, filename) -> NoReturn:
    r"""Train & evaluate and save checkpoints.

    Args:
        workspace: str, directory of workspace
        gpus: int
        config_yaml: str, path of config file for training
    """

    # # arugments & parameters
    # workspace = args.workspace
    # gpus = args.gpus
    # config_yaml = args.config_yaml
    # filename = args.filename

    num_workers = 8
    distributed = True if gpus > 1 else False
    evaluate_device = "cuda" if gpus > 0 else "cpu"

    # Read config file.
    configs = read_yaml(config_yaml)
    check_configs_gramma(configs)
    task_name = configs['task_name']
    input_source_types = configs['train']['input_source_types']
    target_source_types = configs['train']['target_source_types']
    input_channels = configs['train']['input_channels']
    output_channels = configs['train']['output_channels']
    batch_data_preprocessor_type = configs['train']['batch_data_preprocessor']
    model_type = configs['train']['model_type']
    loss_type = configs['train']['loss_type']
    optimizer_type = configs['train']['optimizer_type']
    learning_rate = float(configs['train']['learning_rate'])
    precision = configs['train']['precision']
    early_stop_steps = configs['train']['early_stop_steps']
    warm_up_steps = configs['train']['warm_up_steps']
    reduce_lr_steps = configs['train']['reduce_lr_steps']
    resume_checkpoint_path = configs['train']['resume_checkpoint_path']

    target_sources_num = len(target_source_types)

    # paths
    checkpoints_dir, logs_dir, logger, statistics_path = get_dirs(
        workspace, task_name, filename, config_yaml, gpus
    )

    # training data module
    data_module = get_data_module(
        workspace=workspace,
        config_yaml=config_yaml,
        num_workers=num_workers,
        distributed=distributed,
    )

    # batch data preprocessor
    BatchDataPreprocessor = get_batch_data_preprocessor_class(
        batch_data_preprocessor_type=batch_data_preprocessor_type
    )

    batch_data_preprocessor = BatchDataPreprocessor(
        input_source_types=input_source_types, target_source_types=target_source_types
    )

    # model
    # print("model_type", model_type)
    # return
    Model = get_model_class(model_type=model_type)
    model = Model(
        input_channels=input_channels,
        output_channels=output_channels,
        target_sources_num=target_sources_num,
    )

    if resume_checkpoint_path:
        checkpoint = torch.load(resume_checkpoint_path)
        model.load_state_dict(checkpoint['model'])
        logging.info(
            "Load pretrained checkpoint from {}".format(resume_checkpoint_path)
        )

    # loss function
    loss_function = get_loss_function(loss_type=loss_type)

    # callbacks
    callbacks = get_callbacks(
        task_name=task_name,
        config_yaml=config_yaml,
        workspace=workspace,
        checkpoints_dir=checkpoints_dir,
        statistics_path=statistics_path,
        logger=logger,
        model=model,
        evaluate_device=evaluate_device,
    )
    # callbacks = []

    # learning rate reduce function
    lr_lambda = partial(
        get_lr_lambda, warm_up_steps=warm_up_steps, reduce_lr_steps=reduce_lr_steps
    )

    # pytorch-lightning model
    pl_model = LitSourceSeparation(
        batch_data_preprocessor=batch_data_preprocessor,
        model=model,
        optimizer_type=optimizer_type,
        loss_function=loss_function,
        learning_rate=learning_rate,
        lr_lambda=lr_lambda,
    )

    # trainer
    trainer = pl.Trainer(
        checkpoint_callback=False,
        gpus=gpus,
        callbacks=callbacks,
        max_steps=early_stop_steps,
        accelerator="ddp",
        sync_batchnorm=True,
        precision=precision,
        replace_sampler_ddp=False,
        plugins=[DDPPlugin(find_unused_parameters=False)],
        profiler='simple',
        limit_train_batches=3,
        limit_val_batches=3,
    )

    # Fit, evaluate, and save checkpoints.
    trainer.fit(pl_model, data_module)


# if __name__ == "__main__":

#     parser = argparse.ArgumentParser(description="")
#     subparsers = parser.add_subparsers(dest="mode")

#     parser_train = subparsers.add_parser("train")
#     parser_train.add_argument(
#         "--workspace", type=str, required=True, help="Directory of workspace."
#     )
#     parser_train.add_argument("--gpus", type=int, required=True)
#     parser_train.add_argument(
#         "--config_yaml",
#         type=str,
#         required=True,
#         help="Path of config file for training.",
#     )

#     args = parser.parse_args()
#     args.filename = pathlib.Path(__file__).stem

#     if args.mode == "train":
#         train(args)

#     else:
#         raise Exception("Error argument!")


In [28]:
import pathlib
import os

workspace = "."
gpus=1
config_yaml="../scripts/4_train/musdb18/configs/vocals-accompaniment,resunet_subbandtime.yaml"
filename = pathlib.Path(os.getcwd()).stem
train(
    workspace=workspace,
    gpus=gpus,
    config_yaml=config_yaml,
    filename=filename
    
)

accompaniment: 225752
vocals: 225752


GPU available: True, used: True
lightning   : INFO     GPU available: True, used: True
TPU available: None, using: 0 TPU cores
lightning   : INFO     TPU available: None, using: 0 TPU cores
pytorch_lightning.accelerators.gpu: INFO     LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
lightning   : INFO     initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1

  | Name                    | Type                               | Params
-------------------------------------------------------------------------------
0 | batch_data_preprocessor | MixtureTargetBatchDataPreprocessor | 0     
1 | model                   | ResUNet143_Subbandtime             | 103 M 
-------------------------------------------------------------------------------
102 M     Trainable params
787 K     Non-trainable params
103 M     Total params
413.442   Total estimated model params size (MB)
lightning   : INFO     
  | Name                    | Type                               | Para

Epoch 0:   0%|          | 0/3 [00:03<?, ?it/s] 




Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  5.4644         	|  100 %          	|
------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  3.5489         	|1              	|  3.5489         	|  64.945         	|
get_train_batch                    	|  1.7855         	|1              	|  1.7855         	|  32.676         	|
run_training_batch                 	|  1.763          	|1              	|  1.763          	|  32.264         	|
optimizer_step_and_closure_0       	|  1.7601         	|1              	|  1.7601         	|  32.21          	|
training_step_and_backward         	|  1.

OutOfMemoryError: CUDA out of memory. Tried to allocate 160.00 MiB. GPU 0 has a total capacity of 3.71 GiB of which 83.88 MiB is free. Including non-PyTorch memory, this process has 3.61 GiB memory in use. Of the allocated memory 3.30 GiB is allocated by PyTorch, and 94.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)