In [None]:
%reset -f

# Configure matlotlib under the `notebook` setting and initialize figure size
import ipympl
import matplotlib
from matplotlib import cm
import seaborn

from matplotlib import pyplot as plt
#%matplotlib notebook
matplotlib.rcParams['figure.figsize'] = (14,7)
matplotlib.rcParams['figure.max_open_warning'] = -1
matplotlib.rcParams['figure.titlesize'] = 0
matplotlib.rcParams['toolbar'] = 'none'
plt.close('all')

# Do the imports
import ipywidgets
from src.beads import *

# Initialize the canvas
canvas_small = Canvas(-20,20,200, -10, 10, 100)
canvas_big = Canvas(-40,40,200, -20, 20, 200)
canvas_square = Canvas(-30,30,200, -30, 30, 200)

# Initialize the number of sources
cmaps = [getattr(cm, c) for c in ['Reds', 'Blues', 'Purples']]

ns = len(cmaps)
C = 10

from IPython.core.display import HTML
HTML("""
<style>
.output_wrapper button.btn.btn-default,
.output_wrapper .ui-dialog-titlebar {
  display: none;
}
</style>""")

# Audio source separation with magnitude priors: the BEADS model 

## Antoine Liutkus$^1$, Christian Rohlfing$^2$, Antoine Deleforge$^3$

$^1$ Zenith team, Inria, University of Montpellier, France<p>
$^2$ RWTH, Aachen University, Germany<p>
$^3$ Inria Rennes - Bretagne Atlantique, France<p>

<div class="inline-block">
    <img src="figures/zenith.jpg" style="height:3em; margin-top:5em">
</div>
<div class="inline-block">
    <img src ="figures/inria.png" style="height:3em">
</div>
<div class="inline-block">
    <img src="figures/rwth.svg" style="height:3em">
</div>
<div class="inline-block">
    <img src="figures/anr.png" style="height:3em">
</div>
</div>


# Context

## Separation of complex random variables

# The source separation problem 
For each Time-Frequency bin, the mixture is the sum of sources $x=\sum_j s_j$

In [None]:
fig1, ax1 = canvas_small.fig()

global radius, true_sources, sigmas, colors, mix
true_sources = []
radius = []
sigmas = []
mix=0

def get_sources(event):
    global radius, true_sources, sigmas, colors, mix, sources_lgm, mix_lgm, sources_donut
    if len(radius) == ns:
        canvas_big.clear(ax1)
        current = 0
        handles = []
        for j, (true_source, color) in enumerate(zip(true_sources, cmaps)):
            handles += [Canvas.arrow(ax1, current, true_source,  linewidth=3, color=color(200))]
            current += true_source
        handles += [Canvas.arrow(ax1, 0, current, linewidth=4, color='black')]
        ax1.legend(handles, ['$s_%d$'%j for j in range(ns)]+['mix'],fontsize=12)
        
        # defining the LGM and donut model according to the selected configuration
        sources_lgm = [Beads(0,0, b**2*2/np.pi,1) for b in radius]
        mix_lgm = GMM.product(sources_lgm)
        sources_donut = [Donut(0,r,sigma) for (r,sigma) in zip(radius,sigmas)]

        return
    canvas_small.clear(ax1)
    x = event.xdata+1j*event.ydata
    true_sources += [x]
    mix += x
    radius += [np.abs(x)]
    sigmas += [max(np.abs(x)/10,1.5)**2]#2*np.pi*np.abs(x)/C/2)**2]
    handles = []
    for j,(true_source, color) in enumerate(zip(true_sources, cmaps)):
        handles += [Canvas.arrow(ax1, 0, true_source, linewidth=3, color=color(200))]
    ax1.legend(handles, ['$s_%d$'%j for j in range(ns)], fontsize=12)
    
cid = fig1.canvas.mpl_connect('button_press_event', get_sources)

# Typical separation pipeline

<img src="figures/source_separation_pipeline.svg" style="height:10em">

## In this talk
* __Filtering__ from magnitude estimates $b_j>0$ to separated signals $s_j\in\mathbb{C}$ 
* Tractable model for __complex variables $s_j$ with (approximately) known magnitude $b_j$__

## In the paper
* The multichannel case
* Evaluation for audio coding

The classical Gaussian model $s_j\sim\mathcal{N}\left(0, \frac{2}{\pi}b_j^2\right)$ matches the prior $\mathbb{E}\left[\left|s_j\right|\right]=b_j$

In [None]:
fig_lgm, ax_lgm = canvas_big.fig()
sources_lgm[0].contour(canvas=canvas_big, ax=ax_lgm, nlines=50, cmap=cmaps[0])
arrow_h=Canvas.arrow(ax_lgm, 0, true_sources[0],  linewidth=3, color=cmaps[0](200), zorder=100)

$\Rightarrow$ Highest probability mass on 0

The mixture is Gaussian $x\sim\mathcal{N}\left(0,\sum_j b_j^2\right)$, sources are recovered as: $s\mid x\sim \mathcal{N}\left(\frac{b^2_j}{\sum b^2} x, \sigma_j^2\left(1 - \frac{b_j^2}{\sum b^2}\right)\right)$

In [None]:
fig_lgmdemo, ax_lgmdemo = canvas_big.fig()

x = mix
for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
    Canvas.circle(ax_lgmdemo, 0, rad,  color=cmap(200), linewidth=3, alpha=0.2, fill=False,linestyle="--")   
def lgm_demo(event):
    canvas_big.clear(ax_lgmdemo)
    if event.inaxes is None:
        x = mix
        for (true_source, cmap) in zip(true_sources, cmaps):
            Canvas.arrow(ax_lgmdemo, 0, true_source, color=cmap(200), alpha=0.2, linewidth=3)
    else:
        x = event.xdata+1j*event.ydata

    sources_post = [s.post(mix_lgm,x) for s in sources_lgm]
    
    for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
        Canvas.circle(ax_lgmdemo, 0, rad,  color=cmap(200), linewidth=3, alpha=0.2, fill=False,linestyle="--")   
    if event.inaxes is not None or event.button==3:
        for (spost, cmap) in zip(sources_post, cmaps):
            spost.plot(canvas_big, ax_lgmdemo, cmap(200))
    Canvas.arrow(ax_lgmdemo, 0, x, linewidth=3, facecolor="black")

cid = fig_lgmdemo.canvas.mpl_connect('button_press_event', lgm_demo)


$\Rightarrow$ Aligned estimated sources, magnitudes inconsistent with prior<p>
$\Rightarrow$ Over-estimate the strongest source, under-estimate the others<p>
$\Rightarrow$ Uncertainty independent of the mixture

Another classical solution: magnitude ratios:
$\hat{s}_j=\frac{b_j}{\sum b}x$

In [None]:
fig_magdemo, ax_magdemo = canvas_big.fig()
x = mix
for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
    Canvas.circle(ax_magdemo, 0, rad,  color=cmap(200), linewidth=3, alpha=0.2, fill=False, linestyle="--")   

def mag_demo(event):
    canvas_big.clear(ax_magdemo)
    if event.inaxes is None:
        x = mix
        for (true_source, cmap) in zip(true_sources, cmaps):
            Canvas.arrow(ax_magdemo, 0, true_source, color=cmap(200), alpha=0.2, linewidth=3)
    else:
        x = event.xdata+1j*event.ydata

    gains = radius/np.sum(radius)
    sources_post = [Bead(gain*x, None) for gain in gains]
    
    for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
        Canvas.circle(ax_magdemo, 0, rad,  color=cmap(200), linewidth=3, fill=False, alpha=0.2, linestyle="--")   
    if event.inaxes is not None or event.button==3:
        for (spost, cmap) in zip(sources_post, cmaps):
            spost.plot(canvas_big, ax_magdemo, cmap(200))
    Canvas.arrow(ax_magdemo, 0, x, linewidth=3, facecolor="black")

cid = fig_magdemo.canvas.mpl_connect('button_press_event', mag_demo)


$\Rightarrow$ More balanced source estimates<p>
$\Rightarrow$ Still estimating aligned sources rather than complying with the magnitude prior<p>
$\Rightarrow$ No tractable uncertainty 

# An ideal model

## The donut-shaped distribution

## Objective
What do we want of a probabilistic model for a complex random variable with (approximately) known magnitude?

In [None]:
fig_donut, ax_donut = canvas_big.fig()
arrow_h=Canvas.arrow(ax_donut, 0, radius[0],  linewidth=3, color=cmaps[0](200))
points = []

def donut_intro_callback(event):
    global points
    plot_donut = False
    if event.inaxes is None:
        # reinitializes points
        points = []
    else:
        if event.button == 1:
            n = 1 if len(points)<5 else 10
            points = np.concatenate((points, sources_donut[0].draw(n)))
        else:
            plot_donut = True
            
    canvas_big.clear(ax_donut)
    ax_donut.plot(np.real(points), np.imag(points),'o', color=cmaps[0](200), markersize=8)

    if plot_donut:
        sources_donut[0].contour(canvas_big, ax_donut, cmap=cmaps[0], nlines=10)
    else:
        Canvas.arrow(ax_donut, 0, radius[0],  linewidth=3, color=cmaps[0](200))


cid = fig_donut.canvas.mpl_connect('button_press_event', donut_intro_callback)


## The Donut distribution for modeling the sources

In [None]:
fig_sourcesdonut, ax_sourcesdonut = canvas_big.fig()
canvas_big.clear(ax_sourcesdonut)
for (sdonut,cmap) in zip(sources_donut, cmaps):
    sdonut.contour(canvas_big, ax_sourcesdonut, nlines=20, cmap=cmap)
handles = []
for (true_source, cmap) in zip(true_sources, cmaps):
    handles += [Canvas.arrow(ax_sourcesdonut, 0, true_source, facecolor=cmap(200), linewidth=3, zorder=1000, edgecolor='black')]
ax_sourcesdonut.legend(handles, ['$s_%d$'%j for j in range(ns)], fontsize=12)
plt.show()

$\Rightarrow$ No model for the sum of donut variables<p>
$\Rightarrow$ No easy way for separation: $\mathbb{P}\left[s\mid x\right]$ non tractable

## **BEADS** Bayesian Expansion to Approximate the Donut Shape

In [None]:
fig_beadsintro, ax_beadsintro = canvas_small.fig()
len_canvas = canvas_small.maxx - canvas_small.minx
mu_beadsintro = canvas_small.minx+len_canvas*3/4
sigma_beadsintro = len_canvas*0.05
rad_beadsintro = len_canvas*0.22
donutintro = Donut(canvas_small.minx+len_canvas/4, rad_beadsintro, sigma_beadsintro)

cintro = 3
cmap = cm.get_cmap('Reds')

donutintro.contour(canvas_small, ax_beadsintro, cmap=cmap, nlines=10)
def beads_intro_callback(event):
    global cintro
    if event.inaxes is None:
        cintro = 3
    else:
        cintro += 1
    
    if cintro == 3:
        beadsintro = Bead(mu_beadsintro, rad_beadsintro**2)
    else:
        beadsintro = Beads(mu_beadsintro, rad_beadsintro, sigma_beadsintro,cintro)
        
    canvas_small.clear(ax_beadsintro)
    donutintro.contour(canvas_small, ax_beadsintro, cmap=cmap, nlines=10)
    beadsintro.contour(canvas_small, ax_beadsintro, cmap=cmap, nlines=10)

cid = fig_beadsintro.canvas.mpl_connect('button_press_event', beads_intro_callback)

Sources distribution as a Gaussian Mixture Model: $P\left[s_j\right] = \sum_c \pi[c] \mathcal{N}\left(s_j\mid b_j \omega^c, \sigma_j\right)$<p>
$\Rightarrow$ Only two parameters: $b_j$ and $\sigma_j$

## Summing beads random variables
BEADS model for the sources $\Rightarrow$ Gaussian Mixture Model for the mixture

In [None]:
sources_beads = None
mix_beads = None

fig_beadssources, ax_beadssources = canvas_square.fig(2)
for (sdonut,cmap) in zip(sources_donut, cmaps):
    sdonut.contour(canvas_square, ax_beadssources[0], nlines=20, cmap=cmap)
C = 10
ax_beadssources[0].set_title('Sources distributions')
ax_beadssources[1].set_title('Mixture distribution')

def beads_intro_callback(event):
    global C, sources_beads, mix_beads
    if event.inaxes is None:
        sources_beads = None
        mix_beads = None
        C = 12
    else:
        r_sum = np.sum(radius)
        Cs = [max(1,int(C*r/r_sum)) for r in radius]
        sigmas = [(2*np.pi*r/c/2)**2 for (r,c) in zip(radius, Cs)]
        sources_beads = [Beads(0,b,sigma,c) if c>1 else Beads(0, 0, b**2,1) for (b,sigma,c) in zip(radius,sigmas, Cs)]
        mix_beads = GMM.product(sources_beads)
        C += 3
      
    canvas_square.clear(ax_beadssources)
    if sources_beads is None:
        for (sdonut,cmap) in zip(sources_donut, cmaps):
            sdonut.contour(canvas_square, ax_beadssources[0], nlines=20, cmap=cmap)
    else:
        for (sbeads,cmap) in zip(sources_beads, cmaps):
            sbeads.contour(canvas_square, ax_beadssources[0], nlines=20, cmap=cmap)
        mix_beads.contour(canvas_square, ax_beadssources[1], nlines=50, cmap=getattr(cm, 'Greens'))
    ax_beadssources[0].set_title('Sources distributions')
    ax_beadssources[1].set_title('Mixture distribution')
        
cid = fig_beadssources.canvas.mpl_connect('button_press_event', beads_intro_callback)

The sources are estimated through Bayes theorem as $\mathbb{P}\left[s\mid x\right]=\sum_c \pi(c\mid x)\mathcal{N}(s\mid \mu_{c\mid x}, \sigma_{\mid x})$

In [None]:
fig_beadsdemo, ax_beadsdemo = canvas_big.fig()

x = mix
for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
    Canvas.arrow(ax_beadsdemo, 0, true_source, color=cmap(200), alpha=0.8, linewidth=3)
    Canvas.circle(ax_beadsdemo, 0, rad,  color=cmap(200), linewidth=3, alpha=0.2, fill=False, linestyle="--")   
Canvas.arrow(ax_beadsdemo, 0, mix, linewidth=3, facecolor="black")

def beads_demo(event):
    canvas_big.clear(ax_beadsdemo)
    if event.inaxes is None:
        x = mix
        for (true_source, cmap) in zip(true_sources, cmaps):
            Canvas.arrow(ax_beadsdemo, 0, true_source, color=cmap(200), alpha=0.8, linewidth=3, zorder=10000)
    else:
        x = event.xdata+1j*event.ydata

    sources_post = [s.post(mix_beads,x) for s in sources_beads]
    
    for (true_source, rad, sig, cmap) in zip(true_sources, radius, sigmas, cmaps):
        Canvas.circle(ax_beadsdemo, 0, rad,  color=cmap(200), linewidth=3, alpha=0.2, fill=False, linestyle="--")   

    for (spost, cmap) in zip(sources_post, cmaps):
        spost.contour(canvas_big, ax_beadsdemo, nlines=10, cmap=cmap)
    Canvas.arrow(ax_beadsdemo, 0, x, linewidth=3, facecolor="black")

cid = fig_beadsdemo.canvas.mpl_connect('button_press_event', beads_demo)

* Estimates consistent with the magnitude prior
* Uncertainty is mix-dependent
* Posterior is tractable

# Conclusion: The beads model

## Core advantages
* Complex random variables with approximately known magnitudes
* Sums of beads sources is a GMM
* Separation is easy as GMM inference

## To go further
* Generalizes easily to multichannel
* Shared variances for the beads $\Rightarrow$ computational savings

## Source code for this presentation
[https://github.com/aliutkus/beads-presentation](https://github.com/aliutkus/beads-presentation)