## DCASE 2020 Task 4 Metrics and Loss
https://github.com/turpaultn/dcase20_task4

In [1]:
import numpy as np 
import tensorflow as tf

import warnings
import pprint
import functools
import itertools
import typing

## Metrics

In [2]:
def _resolve_permutation(loss_matrix):
    """Resolves permutation from an all-pairs loss_matrix input.

    Args:
        loss_matrix: tensor of shape [batch, source, source]
            axis 1 refers to the estimate.
            axis 2 refers to the reference.
    Returns:
    permutation: tensor of shape [batch, source, 2] such that
        tf.gather_nd(estimates, permutation, 1) returns the permuted estimates
        that achieves the lowest loss.
    """
    batch = loss_matrix.shape[0]
    source = loss_matrix.shape[1]

    # Compute permutations as vectors of indices into flattened loss matrix.
    # permutations will have shape [batch, source!, source, 1].
    permutations = tf.constant(list(itertools.permutations(range(source))))
    permutations = tf.expand_dims(tf.expand_dims(permutations, 0), 3)
    permutations = tf.tile(permutations, [batch, 1, 1, 1])

    # Expand loss dimensions for gather.
    # loss_matrix.shape will be (batch, source!, source, source)
    loss_matrix = tf.expand_dims(loss_matrix, 1)
    loss_matrix = tf.tile(loss_matrix, [1, permutations.shape[1], 1, 1])

    # Compute the total loss for each permutation.
    # permuted_loss.shape will be (batch, source!)
    permuted_loss = tf.gather_nd(loss_matrix, permutations, batch_dims=3)
    permuted_loss = tf.math.reduce_sum(permuted_loss, axis=2)

    # Get and return the permutation with the lowest total loss.
    # loss_argmin.shape will be (batch, 1)
    loss_argmin = tf.math.argmin(permuted_loss, axis=1)
    loss_argmin = tf.expand_dims(loss_argmin, 1)

    # permutation.shape will be (batch, source, 1)
    permutation = tf.gather_nd(permutations, loss_argmin, batch_dims=1)

    return permutation


def _apply(loss_fn: typing.Callable[..., tf.Tensor],
           reference: tf.Tensor,
           estimate: tf.Tensor,
           allow_repeated: bool,
           enable: bool) -> typing.Any:
    """Return permutation invariant loss.

    Note that loss_fn must in general handle an arbitrary number of sources, since
    this function may expand in that dimention to get losses on all
    reference-estimate pairs.

    Args:
        loss_fn: function with the following signature:
        Args
            reference [batch, source', ...] tensor
            estimate [batch, source', ...] tensor
    Returns
        A [batch, source'] tensor of dtype=tf.float32
        reference: [batch, source, ...] tensor.
        estimate: [batch, source, ...] tensor.
        allow_repeated: If true, allow the same estimate to be used to match
          multiple references.
        enable: If False, apply the loss function in fixed order and return its
          value and the unpermuted estimates.

    Returns:
        loss, A [batch, source] tensor of dtype=tf.float32
        permuted_estimate, A tensor like estimate.
    """
    reference = tf.convert_to_tensor(reference)
    estimate = tf.convert_to_tensor(estimate)

    if not enable:
        return loss_fn(reference, estimate), estimate

    assert reference.shape[:2] == estimate.shape[:2]
    batch = reference.shape[0]
    source = reference.shape[1]
    
    # Replicate estimate on axis 1
    # estimate.shape will be (batch, source * source, ...)
    multiples = np.ones_like(estimate.shape)
    multiples[1] = source
    estimate_tiled = tf.tile(estimate, multiples)

    # Replicate reference on new axis 2, then combine axes [1, 2].
    # reference.shape will be (batch, source * source, ...)
    reference_tiled = tf.expand_dims(reference, 2)
    multiples = np.ones_like(reference_tiled.shape)
    multiples[2] = source
    reference_tiled = tf.tile(reference_tiled, multiples)
    reference_tiled = tf.reshape(reference_tiled, estimate_tiled.shape)

    # Compute the loss matrix.
    # loss_matrix.shape will be (batch, source, source).
    # Axis 1 is the estimate.  Axis 2 is the reference.
    loss_matrix = tf.reshape(loss_fn(reference_tiled, estimate_tiled),
                           [batch, source, source])

    # Get the best permutation.
    # permutation.shape will be (batch, source, 1)
    if allow_repeated:
        permutation = tf.math.argmin(loss_matrix, axis=2, output_type=tf.int32)
        permutation = tf.expand_dims(permutation, 2)
    else:
        permutation = _resolve_permutation(loss_matrix)
    assert permutation.shape == (batch, source, 1), permutation.shape

    # Permute the estimates according to the best permutation.
    estimate_permuted = tf.gather_nd(estimate, permutation, batch_dims=1)
    loss_permuted = tf.gather_nd(loss_matrix, permutation, batch_dims=2)

    return loss_permuted, estimate_permuted


def wrap(loss_fn: typing.Callable[..., tf.Tensor],
         allow_repeated: bool = False,
         enable: bool = True) -> typing.Callable[..., typing.Any]:
    """Returns a permutation invariant version of loss_fn.

    Args:
        loss_fn: function with the following signature:
        Args
            reference [batch, source', ...] tensor
            estimate [batch, source', ...] tensor
            **args Any remaining arguments to loss_fn
    Returns
        A [batch, source'] tensor of dtype=tf.float32
    allow_repeated: If true, allow the same estimate to be used to match
      multiple references.
    enable: If False, return a fuction that applies the loss function in fixed
      order, returning its value and the (unpermuted) estimate.

    Returns:
        A function with same arguments as loss_fn returning loss, permuted_estimate
    """
    def wrapped_loss_fn(reference, estimate, **args):
        return _apply(functools.partial(loss_fn, **args),
                  reference,
                  estimate,
                  allow_repeated,
                  enable)
    return wrapped_loss_fn

In [3]:
def calculate_signal_to_noise_ratio_from_power(signal_power, noise_power, epsilon):
    """Computes the signal to noise ratio given signal_power and noise_power.

    Args:
    signal_power: A tensor of unknown shape and arbitrary rank.
    noise_power: A tensor matching the signal tensor.
    epsilon: An optional float for numerical stability, since silences
        can lead to divide-by-zero.

    Returns:
        A tensor of size [...] with SNR computed between matching slices of the
        input signal and noise tensors.
    """
    # Pre-multiplication and change of logarithm base.
    constant = tf.cast(10.0 / tf.math.log(10.0), signal_power.dtype)

    return constant * tf.math.log(tf.math.truediv(signal_power + epsilon, noise_power + epsilon))


def calculate_signal_to_noise_ratio(signal, noise, epsilon=1e-8):
    """Computes the signal to noise ratio given signal and noise.

    Args:
        signal: A [..., samples] tensor of unknown shape and arbitrary rank.
        noise: A tensor matching the signal tensor.
        epsilon: An optional float for numerical stability, since silences
          can lead to divide-by-zero.

    Returns:
          A tensor of size [...] with SNR computed between matching slices of the
        input signal and noise tensors.
    """
    def power(x):
        return tf.math.reduce_mean(tf.square(source_wave), axis=-1)
    

    return calculate_signal_to_noise_ratio_from_power(power(signal), power(noise), epsilon)


def signal_to_noise_ratio_gain_invariant(estimate, target, epsilon=1e-8):
    """Computes the signal to noise ratio in a gain invariant manner.

      This computes SNR assuming that the signal equals the target multiplied by an
      unknown gain, and that the noise is orthogonal to the target.

      This quantity is also known as SI-SDR [1, equation 5].

      This function estimates SNR using a formula given e.g. in equation 4.38 from
      [2], which gives accurate results on a wide range of inputs, and yields a
      monotonically decreasing value when target or estimate scales toward zero.

      [1] Jonathan Le Roux, Scott Wisdom, Hakan Erdogan, John R. Hershey,
      "SDR--half-baked or well done?",ICASSP 2019,
      https://arxiv.org/abs/1811.02508.
      [2] Magnus Borga, "Learning Multidimensional Signal Processing"
      https://www.diva-portal.org/smash/get/diva2:302872/FULLTEXT01.pdf

      Args:
        estimate: An estimate of the target of size [..., samples].
        target: A ground truth tensor, matching estimate above.
        epsilon: An optional float introduced for numerical stability in the
          projections only.

      Returns:
        A tensor of size [...] with SNR computed between matching slices of the
        input signal and noise tensors.
    """
    def normalize(x):
        power = tf.math.reduce_sum(tf.square(x), keepdims=True, axis=[-1])
        return tf.math.multiply(x, tf.math.rsqrt(tf.math.maximum(power, 1e-16)))

    normalized_estimate = normalize(estimate)
    normalized_target = normalize(target)
    
    cosine_similarity = tf.math.reduce_sum(tf.math.multiply(normalized_estimate, normalized_target),axis=[-1])
    squared_cosine_similarity = tf.math.square(cosine_similarity)
    normalized_signal_power = squared_cosine_similarity
    normalized_noise_power = 1. - squared_cosine_similarity

    # Computing normalized_noise_power as the difference between very close
    # floating-point numbers is not accurate enough for this case, so when
    # normalized_signal power is close to 0., we use an alternate formula.
    # Both formulas are accurate enough at the 'seam' in float32.
    normalized_noise_power_direct = tf.math.reduce_sum(
              tf.math.square(normalized_estimate -
                normalized_target * tf.expand_dims(cosine_similarity, -1)),axis=[-1])
    
    normalized_noise_power = tf.where(
        tf.greater_equal(normalized_noise_power, 0.01),
        normalized_noise_power,
        normalized_noise_power_direct)

    return calculate_signal_to_noise_ratio_from_power(
        normalized_signal_power, normalized_noise_power, epsilon)


def signal_to_noise_ratio_residual(estimate, target, epsilon=1e-8):
    """Computes the signal to noise ratio using residuals.

    This computes the SNR in a "statistical fashion" as the logarithm of the
    relative residuals. The signal is defined as the original target, and the
    noise is the residual between the estimate and the target. This is
    proportional to log(1 - 1/R^2).

    Args:
        estimate: An estimate of the target of size [..., samples].
        target: A ground truth tensor, matching estimate above.
        epsilon: An optional float for numerical stability, since silences
        can lead to divide-by-zero.

    Returns:
        A tensor of size [...] with SNR computed between matching slices of the
        input signal and noise tensors.
    """
    return calculate_signal_to_noise_ratio(target, target - estimate, epsilon=epsilon)

In [4]:
def _weights_for_nonzero_refs(source_waveforms):
    """Return shape (source,) weights for signals that are nonzero."""
    source_norms = tf.sqrt(tf.reduce_mean(tf.square(source_waveforms), axis=-1))
    return tf.greater(source_norms, 1e-8)

def _weights_for_active_seps(power_sources, power_separated):
    """Return (source,) weights for active separated signals."""
    min_power = tf.reduce_min(power_sources, axis=-1, keepdims=True)
    return tf.greater(power_separated, 0.01 * min_power)

In [5]:
def compute_final_metrics(source_waveforms, separated_waveforms, mixture_waveform):
    """Permutation-invariant SI-SNR, powers, and under/equal/over-separation."""
    perm_inv_loss = wrap(lambda tar, est: -signal_to_noise_ratio_gain_invariant(est, tar))
    _, separated_waveforms = perm_inv_loss(source_waveforms,separated_waveforms)
      
    # Compute separated and source powers.
    power_separated = tf.reduce_mean(separated_waveforms ** 2, axis=-1)
    power_sources = tf.reduce_mean(source_waveforms ** 2, axis=-1)
    
    # Compute weights for active (separated, source) pairs where source is nonzero
    # and separated power is above threshold of quietest source power - 20 dB.
    weights_active_refs = _weights_for_nonzero_refs(source_waveforms)
    weights_active_seps = _weights_for_active_seps(
        tf.boolean_mask(power_sources, weights_active_refs), power_separated)
    weights_active_pairs = tf.logical_and(weights_active_refs,
                                        weights_active_seps)
    
    # Compute SI-SNR.
    sisnr_separated = signal_to_noise_ratio_gain_invariant(separated_waveforms, source_waveforms)
    num_active_refs = tf.math.reduce_sum(tf.cast(weights_active_refs, tf.int32))
    num_active_seps = tf.math.reduce_sum(tf.cast(weights_active_seps, tf.int32))
    num_active_pairs = tf.math.reduce_sum(tf.cast(weights_active_pairs, tf.int32))
    sisnr_mixture = signal_to_noise_ratio_gain_invariant(
      tf.tile(mixture_waveform, (1,source_waveforms.shape[1], 1)),source_waveforms)
    
    # Compute under/equal/over separation.
    under_separation = tf.cast(tf.less(num_active_seps, num_active_refs),
                             tf.float32)
    equal_separation = tf.cast(tf.equal(num_active_seps, num_active_refs),
                             tf.float32)
    over_separation = tf.cast(tf.greater(num_active_seps, num_active_refs),
                            tf.float32)
    
    return {'sisnr_separated': sisnr_separated,
          'sisnr_mixture': sisnr_mixture,
          'sisnr_improvement': sisnr_separated - sisnr_mixture,
          'power_separated': power_separated,
          'power_sources': power_sources,
          'under_separation': under_separation,
          'equal_separation': equal_separation,
          'over_separation': over_separation,
          'weights_active_refs': weights_active_refs,
          'weights_active_seps': weights_active_seps,
          'weights_active_pairs': weights_active_pairs,
          'num_active_refs': num_active_refs,
          'num_active_seps': num_active_seps,
          'num_active_pairs': num_active_pairs}

In [6]:
def _report_score_stats(metric_per_source_count, label='', counts=None):
    """Report mean and std dev for specified counts."""
    values_all = []
    if counts is None:
        counts = metric_per_source_count.keys()
    for count in counts:
        values = metric_per_source_count[count]
        values_all.extend(list(values))
    return '%s for count(s) %s = %.1f +/- %.1f dB' % (label, counts, np.mean(values_all), np.std(values_all))

In [7]:
def getFinalMetricsFuss(sources,sep_wave,mix_wave,n_samples):
    i=1
    max_count = 4
    dict_per_source_count = lambda: {c: [] for c in range(1, max_count + 1)}
    sisnr_per_source_count = dict_per_source_count()
    sisnri_per_source_count = dict_per_source_count()
    under_seps = []
    equal_seps = []
    over_seps = []

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)

        for i in range(n_samples):
            metrics_dict = metrics = compute_final_metrics(source_wave[i,:,:][np.newaxis,:],\
                                sep_wave[i,:,:][np.newaxis,:],mix_wave[i,:,:][np.newaxis,:])

            metrics_dict = {k: v.numpy() for k, v in metrics_dict.items()}
            sisnr_sep = metrics_dict['sisnr_separated']
            sisnr_mix = metrics_dict['sisnr_mixture']
            sisnr_imp = metrics_dict['sisnr_improvement']
            weights_active_pairs = metrics_dict['weights_active_pairs']

            # Store metrics per source count and report results so far.
            under_seps.append(metrics_dict['under_separation'])
            equal_seps.append(metrics_dict['equal_separation'])
            over_seps.append(metrics_dict['over_separation'])
            sisnr_per_source_count[metrics_dict['num_active_refs']].extend(
            sisnr_sep[weights_active_pairs].tolist())
            sisnri_per_source_count[metrics_dict['num_active_refs']].extend(
                    sisnr_imp[weights_active_pairs].tolist())

            # Report mean statistics and save csv every so often.
            lines = [
                      'Metrics after %d examples:' % i,
                      _report_score_stats(sisnr_per_source_count, 'SI-SNR',
                                          counts=[1]),
                      _report_score_stats(sisnri_per_source_count, 'SI-SNRi',
                                          counts=[2]),
                      _report_score_stats(sisnri_per_source_count, 'SI-SNRi',
                                          counts=[3]),
                      _report_score_stats(sisnri_per_source_count, 'SI-SNRi',
                                          counts=[4]),
                      _report_score_stats(sisnri_per_source_count, 'SI-SNRi',
                                          counts=[2, 3, 4]),
                      'Under separation: %.2f' % np.mean(under_seps),
                      'Equal separation: %.2f' % np.mean(equal_seps),
                      'Over separation: %.2f' % np.mean(over_seps),
            ]

        print('')
        for line in lines:
            print(line)   

## Example of final metrics over fuss

In [8]:
# Batch Size is 2
# Length is 5
# Max source is 3

source_wave = np.array([[[1,5,7,3,5],[2,5,8,3,6],[0,0,0,0,0],[1,1,1,1,1]],[[.7,.7,.7,.7,.7],\
                        [1,4,6,4,3],[2,4,7,4,2],[2,2,2,2,2]],[[1,1,1,1,1],[0,0,0,0,0],\
                            [0,0,0,0,0],[0,0,0,0,0]]],dtype='float32')
print(source_wave.shape,source_wave.dtype)

sep_wave = np.array([[[1.4,5.4,7.3,3.1,5.2],[2.1,5.7,8.3,3.4,6.6],[.1,.1,.1,.1,.1],\
                      [1,1,1,1,1]],[[1,1,1,1,1],[1.5,4.3,6.6,4.3,3.4],[2.5,4.7,7.3,4.6,2.4],\
                      [0,0,0,0,0]],[[1,1,1,1,1],[0,0,0,0,0],\
                       [0,0,0,0,0],[0,0,0,0,0]]],dtype='float32')
print(sep_wave.shape,sep_wave.dtype)

mix_wave = np.sum(source_wave,axis=1,keepdims=True)
print(mix_wave.shape,mix_wave.dtype)

getFinalMetricsFuss(source_wave,sep_wave,mix_wave,3)

(3, 4, 5) float32
(3, 4, 5) float32
(3, 1, 5) float32

Metrics after 2 examples:
SI-SNR for count(s) [1] = 80.0 +/- 0.0 dB
SI-SNRi for count(s) [2] = nan +/- nan dB
SI-SNRi for count(s) [3] = 27.1 +/- 32.0 dB
SI-SNRi for count(s) [4] = 29.1 +/- 29.3 dB
SI-SNRi for count(s) [2, 3, 4] = 28.1 +/- 30.7 dB
Under separation: 0.33
Equal separation: 0.33
Over separation: 0.33


## Loss function

$$ 
L(s, \hat{s})= \min_{\pi \in \Pi} [ \sum_{m_a = 1}^{M_a} L_{SNR}(s_{m_{a}},\hat{s_{\pi (m_a)}}) + \sum_{m_o = M_a + 1}^{M} L_o (x, \hat{s_{\pi (m_o )}} )]
$$

$$
L_{SNR}(y, \hat{y}) = 10 \log_{10} ( ||y-\hat{y}||^2 + \tau ||y||^2 )
$$

$$
L_{o}(y, \hat{y}) = 10 \log_{10} ( ||y-\hat{y}||^2 + \tau ||x||^2 )
$$

In [None]:
def static_or_dynamic_dim_size(tensor, i):
    """Static size for dimension `i` if available, otherwise dynamic size."""
    static_shape = tensor.shape
    dyn_shape = tf.shape(tensor)
    return (static_shape[i].value if hasattr(static_shape[i], 'value')
          else static_shape[i]) or dyn_shape[i]


def smart_shape(tensor):
    """Shape of tensor with static and/or dynamic dimensions.

    Args:
        tensor: A tf.Tensor.

    Returns:
        A list containing static (type int) and dynamic (tf.Tensor) dim sizes.
    """
    dims = []
    for i in range(len(tensor.shape)):
        dims.append(static_or_dynamic_dim_size(tensor, i))
    return tuple(dims)

In [None]:
def _stabilized_log_base(x, base=10., stabilizer=1e-8):
    """Stabilized log with specified base."""
    logx = tf.math.log(x + stabilizer)
    logb = tf.math.log(tf.constant(base, dtype=logx.dtype))
    return logx / logb

def log_mse_loss(source, separated, max_snr=1e6, bias_ref_signal=None):
    """Negative log MSE loss, the negated log of SNR denominator."""
    err_pow = tf.math.reduce_sum(tf.math.square(source - separated), axis=-1)
    snrfactor = 10.**(-max_snr / 10.)
    if bias_ref_signal is None:
        ref_pow = tf.math.reduce_sum(tf.square(source), axis=-1)
    else:
        ref_pow = tf.math.reduce_sum(tf.math.square(bias_ref_signal), axis=-1)
    bias = snrfactor * ref_pow
    return 10. * _stabilized_log_base(bias + err_pow)

In [None]:
def groupwise_apply(loss_fns: typing.Dict[str, typing.Callable[..., typing.Any]],
          signal_names: typing.List[str],
          reference: tf.Tensor,
          estimate: tf.Tensor,
          permutation_invariant_losses: typing.List[str]):
    """Apply loss functions to the corresponding references and estimates.

    For each kind of signal, gather corresponding references and estimates, and
    apply the loss function.  Scatter-add the results into the loss.

    For elements of signals_names not in loss_fns, no loss will be applied.

    Args:
        loss_fns: dictionary of string -> loss_fn.
            Each string is a name to match elements of signal_names.
            Each loss_fn has the following signature:
        Args
            reference [batch, grouped_source, ...] tensor
            estimate [batch, grouped_source, ...] tensor
    Returns
        A [batch, grouped_source] tensor of dtype=tf.float32
        signal_names: list of names of each signal.
        reference: [batch, source, ...] tensor.
        estimate: [batch, source, ...] tensor.
        permutation_invariant_losses: List of losses to be permutation invariant.

    Returns:
        loss, A [batch, source] tensor of dtype=tf.float32
    """
    if reference.shape[:2] != estimate.shape[:2]:
        raise ValueError('First two axes (batch, source) of reference and estimate'
                     'must be equal, got {}, {}'.format(
                         reference.shape[:2], estimate.shape[:2]))
        
    batch = reference.shape[0]
    loss = tf.zeros(shape=reference.shape[:2], dtype=tf.float32)
    permuted_estimates = tf.zeros_like(reference)

    # For each kind of signal, e.g. 'speech', 'noise', gather subsets of reference
    # and estimate, apply loss function and scatter-add into the loss tensor.
    for name, loss_fn in loss_fns.items():
        print(name)

        idxs = [idx for idx, value in enumerate(signal_names) if value == name]
        idxs_0 = tf.tile(
            tf.expand_dims(tf.range(batch), 1),
            [1, len(idxs)])
        idxs_1 = tf.tile(
            tf.expand_dims(tf.constant(idxs, dtype=tf.int32), 0),
            [batch, 1])
        
        idxs_nd = tf.stack([idxs_0, idxs_1], axis=2)
        reference_key = tf.gather_nd(reference, idxs_nd)
        estimate_key = tf.gather_nd(estimate, idxs_nd)
        
        loss_fn = wrap(
            loss_fn,
            enable=name in permutation_invariant_losses)
        loss_key, permuted_estimates_key = loss_fn(reference_key, estimate_key)
        
        loss = tf.tensor_scatter_nd_add(loss, idxs_nd, loss_key)
        permuted_estimates = tf.tensor_scatter_nd_add(
            permuted_estimates, idxs_nd, permuted_estimates_key)
        
    return loss, permuted_estimates

In [None]:
def getFussLoss(mixture_waveforms,source_waveforms,separated_waveforms,batch_size):
    hparams_signal_types = ['source'] * 4
    unique_signal_types = list(set(hparams_signal_types))
    loss_fns = {signal_type: log_mse_loss for signal_type in unique_signal_types}

    _, separated_waveforms = groupwise_apply(loss_fns,
                                            hparams_signal_types,
                                            source_waveforms,
                                            separated_waveforms,
                                            unique_signal_types)

    # Build loss split between all-zero and nonzero reference signals.
    source_is_nonzero = _weights_for_nonzero_refs(source_waveforms)
    source_is_zero = tf.math.logical_not(source_is_nonzero)

    # Get batch size and (max) number of sources.
    num_sources = 4

    # Waveforms with nonzero references.
    source_waveforms_nonzero = tf.boolean_mask(
          source_waveforms, source_is_nonzero)[:, tf.newaxis]
    separated_waveforms_nonzero = tf.boolean_mask(
          separated_waveforms, source_is_nonzero)[:, tf.newaxis]

    # Waveforms with all-zero references.
    source_waveforms_zero = tf.boolean_mask(
          source_waveforms, source_is_zero)[:, tf.newaxis]
    separated_waveforms_zero = tf.boolean_mask(
          separated_waveforms, source_is_zero)[:, tf.newaxis]

    weight = 1. / tf.cast(batch_size * num_sources, tf.float32)

    mixture_waveforms_zero = tf.boolean_mask(
            tf.tile(mixture_waveforms[:, 0:1], (1, num_sources, 1)),
            source_is_zero)[:, tf.newaxis]
    loss = tf.math.reduce_sum(log_mse_loss(source_waveforms_zero,
                                          separated_waveforms_zero,
                                          max_snr=20,
                                          bias_ref_signal=mixture_waveforms_zero))
    loss_zero = tf.identity(1 * weight * loss, name='loss_ref_zero')

    # Loss for nonzero references.
    loss = tf.math.reduce_sum(log_mse_loss(source_waveforms_nonzero,
                                        separated_waveforms_nonzero,
                                        max_snr=30))
    loss_nonzero = tf.identity(weight * loss, name='loss_ref_nonzero')
    
    return loss_zero+loss_nonzero

## Example of loss

In [None]:
source_waveforms = np.array([[[1,5],[2,5],[0,0],[1,1]],
                            [[0,0],[2,5],[0,0],[1,1]]],dtype='float32')
print(source_waveforms.shape) 

separated_waveforms = np.array([[[2.1,5.7],[1.4,5.4],
                                 [1,1],[.1,.1]],[[0,0],[1.4,5.4],
                                 [1,1],[.1,.1]]],dtype='float32')
print(separated_waveforms.shape)

mixture_waveforms = np.sum(source_waveforms,keepdims=True,axis=1)
print(mixture_waveforms.shape)

In [None]:
loss = getFussLoss(mixture_waveforms,source_waveforms,separated_waveforms,2)
print(loss)