# Mixture Consistency Projection

$$
\hat{s_m} = \underline{s}_m + \frac{1}{M}(x - \sum_{m'}\underline{s}_{m'} )
$$
* $x$ is mixed input source
* $\underline{s}_m$ is the outputted seperate sources

In [1]:
import tensorflow as tf 
import numpy as np

In [33]:
def enforce_mixture_consistency_time_domain(mixture_waveforms,
                                            separated_waveforms):
    """Projection implementing mixture consistency in time domain.
        This projection makes the sum across sources of separated_waveforms equal
        mixture_waveforms and minimizes the unweighted mean-squared error between the
        sum across sources of separated_waveforms and mixture_waveforms. See
        https://arxiv.org/abs/1811.08521 for the derivation.
        Args:
            mixture_waveforms: Tensor of mixture waveforms in waveform format.
            separated_waveforms: Tensor of separated waveforms in source image format.
        Returns:
            Projected separated_waveforms as a Tensor in source image format.
    """
    # Modify the source estimates such that they sum up to the mixture, where
    # the mixture is defined as the sum across sources of the true source
    # targets. Uses the least-squares solution under the constraint that the
    # resulting source estimates add up to the mixture.
    
    num_sources = 4.0
    mix_estimate = tf.reduce_sum(separated_waveforms, axis=1, keepdims=True)
    #mix_weights = tf.reduce_mean(tf.square(separated_waveforms), axis=[1, 2],keepdims=True)
    #mix_weights /= tf.reduce_sum(mix_weights, axis=1, keepdims=True)
    mix_weights = (1.0 / num_sources)
    mix_weights = tf.cast(mix_weights, tf.float32)
    correction = mix_weights * (mixture_waveforms - mix_estimate)
    separated_waveforms = separated_waveforms + correction

    return separated_waveforms

In [34]:
mix_wave = tf.Variable([[1,2,3],[4,5,6]],dtype=tf.float32)
mix_wave = tf.expand_dims(mix_wave,axis=1)
print(mix_wave.shape)

sep_wave = tf.Variable([[[1,2,4],[3,5,7]],[[1,2,4],[3,5,7]]],dtype=tf.float32)
print(sep_wave.shape)

projected_sep = enforce_mixture_consistency_time_domain(mix_wave,sep_wave)
print(projected_sep)
print(projected_sep.shape)

(2, 1, 3)
(2, 2, 3)
tf.Tensor(
[[[0.25 0.75 2.  ]
  [2.25 3.75 5.  ]]

 [[1.   1.5  2.75]
  [3.   4.5  5.75]]], shape=(2, 2, 3), dtype=float32)
(2, 2, 3)
