<h1> Mixture Density Network </h1>

This notebook is essentially a reproduction of the "Simple Inverse Problem" presented in 
<a href="http://publications.aston.ac.uk/373/1/NCRG_94_004.pdf">Mixture Density Networks (Bishop, 1994)</a>

It shows the weakness a standard neural net has in predicting target data that has more than one mode. It motivates training the net to output a mixture model instead of a single conditioned average of the distribution.

We create 2 datasets and compare 2 neural net architectures. In total, 4 neural nets are trained in this notebook.

<h2> Two data sets that are really the same data:</h2>

In [1]:
datafunction(x) = x + 0.3*sin(2*π*x) + rand()*0.2-0.1
x = zeros(1,1001)
x .= collect(0:0.001:1)'
t = datafunction.(x)

using Plots; gr(size=(800,400), ylim=(0,1), xlim=(0,1), leg=false)
p1=scatter(x',t', title="easy")
p2=scatter(t',x', title="hard")
plot(p1,p2)

<h2> Functions for implementing a Standard Neural Net and a Mixture Density Net</h2>

In [2]:
include("../toy_ffnet.jl")

function ∇MSE!(∇z, y,t)
    #get bprop started
    @. ∇z[2] = y[2]-t
end

Φ(t,μ,σ,c) = 1/(σ*(2*π)^(c/2)) * exp(-norm(t-μ)^2/(2*σ^2))

function ∇MD!(∇z, z,t, m,c)
    μ,σ,α = preds2params(z, m)
    p = α.*hcat([Φ.(t[1,:],μ[i,:],σ[i,:], c) for i=1:m]...)'
    pp = p./sum(p,1)
    
    #get bprop started
    ∇zμ = pp.*hcat([(μ[i,:].-t[1,:])./σ[i,:].^2 for i=1:m]...)'
    ∇zσ = -pp.*hcat([(t[1,:].-μ[i,:]).^2 ./ σ[i,:].^2 - c for i=1:m]...)'
    ∇zα = α.-pp
    ∇z[2] = vcat(∇zμ,∇zσ,∇zα)
end

function preds2params(z, m)
    μ = z[1:m,:]
    σ = exp.(z[m+1:2*m,:])
    α = nn.softmax(z[2*m+1:3*m,:])
    return μ,σ,α
end

preds2params (generic function with 1 method)

<h2> Standard Neural Net vs. Mixture Density Net</h2>

In [3]:
function trainFFN(x, t, Lw, bsz, Mepoch, α)
    W,b,z,y, ∇W,∇b,∇z,∇y = nn.preallocate(Lw,bsz)
    for epoch=1:Mepoch
        idx=randperm(size(x,2))
        for n=1:bsz:size(x,2)-bsz
            bidx=idx[n:n+bsz-1]
            xb=x[bidx]'
            nn.fprop!(xb, W,b,z,y)
            ∇MSE!(∇z, y,t[bidx]') # mean squared error output
            nn.bprop!(∇W,∇z,∇y, W,xb,y)
            nn.adjust!(W,b, ∇W,∇z, α)
        end
    end
    return W, b
end

function trainMDN(x, t, Lw, m,c, bsz, Mepoch, α)
    W,b,z,y, ∇W,∇b,∇z,∇y = nn.preallocate(Lw,bsz)
    for epoch=1:Mepoch
        idx=randperm(size(x,2))
        for n=1:bsz:size(x,2)-bsz
            bidx=idx[n:n+bsz-1]
            xb=x[bidx]'
            nn.fprop!(xb, W,b,z,y)
            ∇MD!(∇z, z[2], t[bidx]', m,c) #mixture model output
            nn.bprop!(∇W,∇z,∇y, W,xb,y)
            nn.adjust!(W,b, ∇W,∇z, α)
        end
    end
    return W, b
end

trainMDN (generic function with 1 method)

<h2>Train standard nets</h2>

In [4]:
Layerwidths=[1,20,1]
batchsize=16
learningrate=0.01
epochs=10000

We, be = trainFFN(x,t, Layerwidths, batchsize, epochs, learningrate) #easy data (x->t mapping)
Wh, bh = trainFFN(t,x, Layerwidths, batchsize, epochs, learningrate) #hard data (t->z mapping)

xeval=collect(0:0.01:1)'
preds_easy = nn.inference(xeval, We, be)
preds_hard = nn.inference(xeval, Wh, bh);

<h2> This is what they learn</h2>
Its models the x->t mapping easily, but struggles with the t->x mapping

The red dots are what the net predicts, given a certain input

In [5]:
p1=scatter(x',t', title="easy"); scatter!(p1, xeval', preds_easy')
p2=scatter(t',x', title="hard"); scatter!(p2, xeval', preds_hard')
plot(p1,p2)

<h2> Train Mixture Density Nets</h2>

In [6]:
learningrate=0.0001
c=1 #dimensionality of output
m1=1 #nr of components
m2=3
Layerwidhts1=[1,20,(2+c)*m1]
Layerwidhts2=[1,20,(2+c)*m2]

We, be = trainMDN(x,t, Layerwidhts1, m1,1,batchsize, epochs, learningrate) #easy data (x->t mapping)
Wh, bh = trainMDN(t,x, Layerwidhts2, m2,1,batchsize, epochs, learningrate) #hard data (t->z mapping)

preds_easy = nn.inference(xeval, We, be); μe, σe, αe = preds2params(preds_easy, m1)
preds_hard = nn.inference(xeval, Wh, bh); μh, σh, αh = preds2params(preds_hard, m2);

<h2> This is what they learn</h2>
It also models the x->t mapping easily. And does a better job with the t->x mapping

The red dots are what the net predicts, given a certain input

In [7]:
p1=scatter(x',t', title="easy"); scatter!(p1, xeval', μe[findmax(αe,1)[2]]')
p2=scatter(t',x', title="hard"); scatter!(p2, xeval', μh[findmax(αh,1)[2]]')
plot(p1,p2)

<h2> Not only that</h2>
We get means, variances and amplitudes of the mixture components. A heatmap visualizes more clearly what the mixture density net has learned. The plot above only shows the mean of the most probable kernel, but they all contribute to the posterior probabilities.

With an input value of 0.5 the trained Mixture Density Net will say that the output is
<li>about 0.2 with 25% probability,
<li>about 0.5 with 50% probability,
<li>about 0.8 with 25% probability.

This is something a standard neural net cannot do.

<h4>(see figure below) the strength of color corresponds to probability that the output (given an input) should lie in certain regions.</h4>

In [8]:
hm_easy = zeros(101,101)
hm_hard = zeros(101,101)
for X=1:101
    teval=0.0
    for Y=1:101
        teval+=0.01
        hm_easy[Y,X] = sum([αe[i,X]*Φ(teval, μe[i,X],σe[i,X],1) for i=1:1])
        hm_hard[Y,X] = sum([αh[i,X]*Φ(teval, μh[i,X],σh[i,X],1) for i=1:3])
    end
end
p1=heatmap(hm_easy, title="easy", color=:pu_or, xlim=(1,101), ylim=(1,101))
p2=heatmap(hm_hard, title="hard", color=:pu_or, xlim=(1,101), ylim=(1,101))
plot(p1,p2)