<h1> Mixture Density Network </h1>

This notebook is essentially a reproduction of the "Simple Inverse Problem" presented in 
<a href="http://publications.aston.ac.uk/373/1/NCRG_94_004.pdf">Mixture Density Networks (Bishop, 1994)</a>

It shows the weakness a standard neural net has in predicting target data that has more than one mode. It motivates training the net to output a mixture model instead of a single conditioned average of the distribution.

<h2> Two data sets that are really the same data:</h2>

In [1]:
datafunction(x) = x + 0.3*sin(2*π*x) + rand()*0.2-0.1
x = zeros(1,1001)
x .= collect(0:0.001:1)'
t = datafunction.(x)

using Plots; gr(size=(800,400), leg=false)
p1=scatter(x',t', title="easy")
p2=scatter(t',x', title="hard")
plot(p1,p2)

<h2> Functions for implementing a Standard Neural Net and a Mixture Density Net</h2>

In [2]:
function fprop!(x, W,b,z,y)
    A_mul_B!(z[1], W[1], x); z[1].+=b[1]
    @. y[1] = tanh(z[1])
    A_mul_B!(z[2], W[2], y[1]); z[2].+=b[2]
    @. y[2] = 1*z[2]
end

function ∇MSE!(∇z, y,t)
    @. ∇z[2] = y[2]-t
end

function bprop!(∇W,∇z,∇y, W,x,y)
    A_mul_Bt!(∇W[2], ∇z[2], y[1])
    At_mul_B!(∇y[1], W[2], ∇z[2])    
    @. ∇z[1] = ∇y[1] * (1-y[1]^2)
    A_mul_Bt!(∇W[1], ∇z[1], x)
end

function adjust!(W,b, ∇W,∇z, α)
    for n=1:length(b)
        @. W[n] -= α*∇W[n]
        b[n] .-= α*sum(∇z[n],2)
    end
end

function preallocate(Lw, bsz)
    N = length(Lw)-1
        
    W = [randn(Lw[n+1],Lw[n])./sqrt(Lw[n]) for n=1:N]
    ∇W = deepcopy(W)
    b = [zeros(Lw[n+1],1) for n=1:N]
    ∇b = deepcopy(b)
    
    z = [zeros(Lw[n+1],bsz) for n=1:N]
    ∇z = deepcopy(z)
    y = [zeros(Lw[n+1],bsz) for n=1:N]
    ∇y = deepcopy(y)
    return W,b,z,y, ∇W,∇b,∇z,∇y
end

function softmax(x)
    m=maximum(x,1)
    p = exp.(x .- m)
    return p ./ sum(p,1)
end

function Φ(t,μ,σ,c)
    return 1/(σ*(2*π)^(c/2)) * exp(-norm(t-μ)^2/(2*σ^2))
end

function ∇MD(z,t, m,c)
    μ = z[1:m,:]
    σ = exp.(z[m+1:2*m,:])
    α = softmax(z[2*m+1:3*m,:])
    #p = α.*hcat([Φ.(t[:]',μ[i,:],σ[i,:], c) for i=1:m]...)'
    p = α.*hcat([Φ.(t[1,:],μ[i,:],σ[i,:], c) for i=1:m]...)'
    pp = p./sum(p,1) #posterior probability (π)
    
    #get bprop started
    ∇zμ = pp.*hcat([(μ[i,:].-t[1,:])./σ[i,:].^2 for i=1:m]...)'
    ∇zσ = -pp.*hcat([(t[1,:].-μ[i,:]).^2 ./ σ[i,:].^2 - c for i=1:m]...)'
    ∇zα = α.-pp
    return vcat(∇zμ,∇zσ,∇zα)
end

∇MD (generic function with 1 method)

<h2> Standard Neural Net</h2>

In [3]:
#normal feed forward net
function trainFFN(x, t, Lw, bsz, Mepoch, α)
    W,b,z,y, ∇W,∇b,∇z,∇y = preallocate(Lw,bsz)
    for epoch=1:Mepoch
        idx=randperm(size(x,2))
        for n=1:bsz:size(x,2)-bsz
            bidx=idx[n:n+bsz-1]
            xb=x[bidx]'
            fprop!(xb, W,b,z,y)
            ∇MSE!(∇z, y,t[bidx]')
            bprop!(∇W,∇z,∇y, W,xb,y)
            adjust!(W,b, ∇W,∇z, α)
        end
    end
    return W, b
end

trainFFN (generic function with 1 method)

<h2> Mixture Density Net</h2>

In [4]:
#mixture density net
function trainMDN(x, t, Lw, m,c, bsz, Mepoch, α)
    W,b,z,y, ∇W,∇b,∇z,∇y = preallocate(Lw,bsz)
    for epoch=1:Mepoch
        idx=randperm(size(x,2))
        for n=1:bsz:size(x,2)-bsz
            bidx=idx[n:n+bsz-1]
            xb=x[bidx]'
            fprop!(xb, W,b,z,y)
            ∇z[2] = ∇MD(z[2], t[bidx]', m,c)
            bprop!(∇W,∇z,∇y, W,xb,y)
            adjust!(W,b, ∇W,∇z, α)
        end
    end
    return W, b
end

trainMDN (generic function with 1 method)

<h2> Train standard net</h2>

In [5]:
c=1 #dimensionality of output
bsz=32
x_eval=collect(0:0.01:1)'

W_e, b_e = trainFFN(x,t, [1,20,c], bsz, 10000, 1e-2) #train a FFN on the easy mapping x -> t
predictions_easy = W_e[2]*(tanh.(W_e[1]*x_eval.+b_e[1])).+b_e[2]

W_h, b_h = trainFFN(t,x, [1,20,c], bsz,100000, 1e-2) #train a FFN on the hard mapping t -> x
predictions_hard = W_h[2]*(tanh.(W_h[1]*x_eval.+b_h[1])).+b_h[2];

<h2> This is what it learns</h2>
Its models the x->t mapping easily, but struggles with the t->x mapping

The red dots are what the net predicts, given a certain input

In [6]:
p1=scatter(x',t', title="easy")
scatter!(p1, x_eval', predictions_easy', ylim=(0,1), xlim=(0,1))
p2=scatter(t',x', title="hard")
scatter!(p2, x_eval', predictions_hard', ylim=(0,1), xlim=(0,1))
plot(p1,p2)

In [7]:
function preds2params(preds, m)
    μ = preds[1:m,:]
    σ = exp.(preds[m+1:2*m,:])
    α = softmax(preds[2*m+1:3*m,:])
    mp = findmax(α,1)[2]
    return μ,σ,α,mp
end

preds2params (generic function with 1 method)

<h2> Train Mixture Density Net</h2>

In [8]:
m=1 #nr of kernels
W_e, b_e = trainMDN(x,t, [1,20,(c+2)*m], m,c,bsz, 10000, 1e-5)  #train a MDN on the easy mapping x -> t
predictions_easy = W_e[2]*(tanh.(W_e[1]*x_eval.+b_e[1])).+b_e[2]
μ_e, σ_e, α_e, mp_e = preds2params(predictions_easy, m)

m=3  #nr of kernels
W_h, b_h = trainMDN(t,x, [1,20,(c+2)*m], m,c,bsz,100000, 1e-5)  #train a MDN on the hard mapping t -> x
predictions_hard = W_h[2]*(tanh.(W_h[1]*x_eval.+b_h[1])).+b_h[2]
μ_h, σ_h, α_h, mp_h = preds2params(predictions_hard, m);

<h2> This is what it learns</h2>
It also models the x->t mapping easily. And does a better job with the t->x mapping

The red dots are what the net predicts, given a certain input

In [9]:
p1 = scatter(x',t', title="easy")
scatter!(p1, x_eval', μ_e[mp_e]', ylim=(0,1), xlim=(0,1))
p2=scatter(t',x', title="hard")
scatter!(p2, x_eval', μ_h[mp_h]', ylim=(0,1), xlim=(0,1))
plot(p1,p2)

<h2> Not only that</h2>
We get means, variances and amplitudes of the mixture components. A heatmap visualizes more clearly what the mixture density net has learned. The plot above only shows the mean of the most probable kernel, but they all contribute to the posterior probabilities.

With an input value of 0.5 the trained Mixture Density Net will say that the output is
<li>about 0.2 with 25% probability,
<li>about 0.5 with 50% probability,
<li>about 0.8 with 25% probability.

This is something a standard neural net cannot do.

See figure below, the strength of color corresponds to probability that the output (given an input) should lie in certain regions.

In [10]:
gr(size=(800,400))
hm_easy = zeros(101,101)
hm_hard = zeros(101,101)
for X=1:101
    t_eval=0.0
    for Y=1:101
        t_eval+=0.01
        hm_easy[Y,X] = sum([α_e[i,X]*Φ(t_eval, μ_e[i,X],σ_e[i,X],1) for i=1:1])
        hm_hard[Y,X] = sum([α_h[i,X]*Φ(t_eval, μ_h[i,X],σ_h[i,X],1) for i=1:3])
    end
end
p1=heatmap(hm_easy, title="easy", color=:pu_or)
p2=heatmap(hm_hard, title="hard", color=:pu_or)
plot(p1,p2)