# Extending Variational Boosting to Multiple Competing Variational Families

This jupyter notebook contains all of the code used to generate the plots in the report.

In [657]:
using Plots, LinearAlgebra, ReverseDiff, ForwardDiff, SpecialFunctions, Distributions, Statistics, Random

## ADVI

We begin by implementing the ADVI algorithm that can be used to obtain the first component in the mixture approximation.

In [939]:
function ADVI(logπ, grad, γ, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig)
    μs = zeros(d, num_iter)
    Σs = zeros(d, d, num_iter)
    obj = zeros(num_iter)

    for k=1:num_iter
        # transform ψ to mean and covariance
        μ, L = transform(ψ0, diag_sig, d)
        # store parameter
        μs[:,k] = μ
        Σs[:,:,k] = L * transpose(L)
        # sample from κ
        y = rand(MvNormal(d, 1), Ngrad)
        # compute the gradient
        ∇ψ = grad(y, ψ0, logπ, Ngrad, diag_sig, d)
        # take a step
        ψ0 = ψ0 .- γ(k) .* ∇ψ
        # estimate the objective function
        μ, L = transform(ψ0, diag_sig, d)
        if d > 1
            xobj = rand(MvNormal(μ, L * transpose(L)), Nobj)
        else
            xobj = rand(Normal(μ[1], L[1,1]^2), Nobj)
            println(μ[1])
            println(L[1,1]^2)
        end
        for i=1:size(xobj)[1]
            obj[k] += logpdf(MvNormal(μ, L * transpose(L)), xobj[:,i]) - logπ(xobj[:,i])
        end
        obj[k] /= Nobj
    end

    return μs, Σs, obj
end

ADVI (generic function with 1 method)

In [940]:
function transform(ψ, diag_sig, d)
    μ = ψ[1:d]
    if diag_sig
        L = Diagonal(exp.(ψ[d+1:d+d]./2))
    else
        L = reshape(ψ[d+1:end], (d, d))
    end
    return μ, L
end

function logπψ(y, ψ0, logπ, diag_sig, d)
    μ, L = transform(ψ0, diag_sig, d)
    x = vec(μ .+ L * y)
    return -(logπ(x) .+ log(abs(det(L))))
end

function grad(ys, ψ0, logπ, Ngrad, diag_sig, d)
    ret = zeros(size(ψ0)[1])
    for i in 1:Ngrad
        y = vec(ys[:,i])    
        ret = ret .+ ReverseDiff.gradient(ψψ -> logπψ(y, ψψ, logπ, diag_sig, d), ψ0)
    end
    return (ret./Ngrad)
end

grad (generic function with 1 method)

## Variational Boosting

We now implement the variational boosting algorithm. First we implement the weighted EM method for initializing new components.

In [941]:
# weighted EM helper
function sample_from_current_mixture(μs, Σs, λs, Ngrad, d)
    xs = zeros(d, Ngrad)
    num_comp = size(μs)[2]
    for i in 1:Ngrad
        # identify component
        u = rand(Uniform(0,1))
        comp = 0
        for j in 1:num_comp
            if u <= sum(λs[1:j])
                comp = j
                break
            end
        end
        # sample from component
        xs[:,i] = rand(MvNormal(vec(μs[:,comp]), Σs[:,:,comp]))
    end
    return xs
end

sample_from_current_mixture (generic function with 1 method)

In [942]:
# weighted EM helper
function log_mixture(μs, Σs, λs, x)
    mix = 0.
    num_comp = size(μs)[2]
    for i in 1:num_comp
        mix = mix + λs[i] * pdf(MvNormal(vec(μs[:,i]), Σs[:,:,i]), vec(x))
    end
    return log(mix)
end

log_mixture (generic function with 2 methods)

In [943]:
# weighted EM helper
function compute_EM_weights(μs, Σs, λs, logπ, xs)
    N = size(xs)[2]
    logws = zeros(N)
    for i in 1:N
        x = xs[:,i]
        logws[i] = logπ(x) - log_mixture(μs, Σs, λs, x)
    end
    return exp.(logws)
end

compute_EM_weights (generic function with 1 method)

In [944]:
# weighted EM helper
function weighted_EM_initialization(μs, Σs, λs, weights, xs, diag_sig)
    num_comp = size(μs)[2] + 1
    N = size(xs)[2]
    d = size(μs)[1]
    λ = 1. / num_comp
    μ = vec(zeros(d))
    Σ = Diagonal(ones(d))
    # EM loop
    keep = true
    while(keep)
        # E step
        resp = zeros(N, num_comp)
        λs_curr = (1 - λ) .* λs
        for i in 1:N
            x = vec(xs[:,i])
            for j in 1:num_comp
                if j < num_comp
                    resp[i,j] = pdf(MvNormal(vec(μs[:,j]), Σs[:,:,j]), x) * λs_curr[j]
                else
                    resp[i,j] = pdf(MvNormal(vec(μ),Σ), x) * λ
                end
            end
        end
        # normalize by dividing row sum
        rowsum = sum(resp, dims=2)
        resp = resp ./ rowsum
        # M step
        Nk = sum(resp, dims=1)[num_comp]
        newλ = Nk / N
        newμ = vec(zeros(d))
        newΣ = zeros(d,d)
        # update μ
        for i in 1:N
            newμ = newμ .+ (1. / Nk) .* resp[i, num_comp] .* weights[i] .* vec(xs[:,i])
        end
        # update Σ
        for i in 1:N
            diff = vec(xs[:,i]) .- newμ
            newΣ = newΣ .+ (1. / Nk) .* resp[i, num_comp] .* weights[i] .* (diff) * transpose(diff)
        end
        # ensure Σ is Hermitian
        for j in 1:d
            for k in 1:j
                newΣ[j,k] = newΣ[k,j]
            end
        end
        # check convergence
        if (newλ - λ)^2 <= 1e-3
            println("finished initializing ψ and λ")
            keep = false
        end
        # update params
        λ = newλ
        μ = newμ
        Σ = newΣ
    end
    # modify Σ to follow specification of diag_sig
    if diag_sig
        Σ = Diagonal(Σ)
    end

    return λ, μ, Σ
end

weighted_EM_initialization (generic function with 3 methods)

The following functions compute the gradient for component optimization.

In [947]:
# component optimization gradient helper
function log_mix(x, μs, Σs, λs, μ_opt, L_opt, λ)
    ret = 0.
    for i in 1:size(μs)[2]
        ret = ret + (1 - λ) * λs[i] * pdf(MvNormal(vec(μs[:,i]), Σs[:,:,i]), vec(x))
    end
    ret = ret + λ * pdf(MvNormal(vec(μ_opt), L_opt * transpose(L_opt)), vec(x))
    return log(ret)
end

log_mix (generic function with 1 method)

In [948]:
# component optimization gradient helper
function obj_comp(y, ψ, λ, logπ, diag_sig, d, μs, Σs, λs)
    μ_opt, L_opt = transform(ψ, diag_sig, d)
    num_comp = size(μs)[2] + 1
    ret = 0.
    for i in 1:num_comp
        if i < num_comp
            if diag_sig
                L = Diagonal(sqrt.(Σs[:,:,i]))
            else
                C = cholesky(Σs[:,:,i])
                L = convert(Array{Float64,2}, C.L)
            end
            x = vec(μs[:,i] .+ L * y)
            # ln (mixture approx) - ln (π)
            ret = ret - (1-λ) * λs[i] * logπ(x)
            ret = ret + (1-λ) * λs[i] * log_mix(x, μs, Σs, λs, μ_opt, L_opt, λ)
        else
            x = vec(μ_opt .+ L_opt * y)
            # ln (mixture approx) - ln (π)
            ret = ret - λ * logπ(x)
            ret = ret + λ * log_mix(x, μs, Σs, λs, μ_opt, L_opt, λ)
        end
    end
    
    return ret
end

obj_comp (generic function with 1 method)

In [949]:
# component optimization gradient helper
function grad_comp(ys, ψ, λ, logπ, Ngrad, diag_sig, d, μs, Σs, λs)
    retψ = zeros(size(ψ)[1])
    retλ = 0.
    for i in 1:Ngrad
        y = vec(ys[:,i])
        retψ = retψ .+ ForwardDiff.gradient(ψψ -> obj_comp(y, ψψ, λ, logπ, diag_sig, d, μs, Σs, λs), ψ)
        retλ = retλ + ForwardDiff.derivative(λλ -> obj_comp(y, ψ, λλ, logπ, diag_sig, d, μs, Σs, λs), λ)
    end
    return (retψ./Ngrad), (retλ/Ngrad)
end

grad_comp (generic function with 1 method)

The following function performs the component optimization.

In [967]:
function optimize_component(μs, Σs, λs, λ_k, μ_k, Σ_k, logπ, γ, Ngrad, d, num_iter, transform, diag_sig)
    # turn μ_k and Σ_k to the format of ψ
    ψ0 = copy(μ_k)
    if diag_sig
        append!(ψ0, log.([Σ_k[i,i] for i in 1:d]))
    else
        C = cholesky(Σ_k)
        append!(ψ0, vec(C.L))
    end
    
    ψ = vec(copy(ψ0))
    println(ψ)
    println(typeof(ψ))

    # start optimization
    λks = zeros(2*num_iter)
    μks = zeros(d, num_iter)
    Σks = zeros(d, d, num_iter)
    λ = λ_k
    
    for k=1:(2*num_iter)
        if k <= num_iter
            # transform ψ to mean and covariance
            μ, L = transform(ψ, diag_sig, d)
            # store parameter
            μks[:,k] = μ
            Σks[:,:,k] = L * transpose(L)
        end

        λks[k] = λ
        
        # sample from κ
        y = rand(MvNormal(d, 1), Ngrad)
        
        # compute the gradient
        ∇ψ, ∇λ = grad_comp(y, ψ, λ, logπ, Ngrad, diag_sig, d, μs, Σs, λs)
        # take a step
        if k <= num_iter
            ψ = ψ .- γ(k) .* ∇ψ
        end
        λ = λ - (γ(k)/200) * ∇λ
        # project λ to feasible region
        if λ < 0.
            λ = 0.
        elseif λ > 1.
            λ = 1.
        end
    end

    println("-----")
    y = rand(MvNormal(d, 1), Ngrad)
    ∇ψ, ∇λ = grad_comp(y, ψ, λ, logπ, Ngrad, diag_sig, d, μs, Σs, λs)
    println("ψ gradient: ", ∇ψ)
    println("λ gradient: ", ∇λ)
    println("-----")
    
    return λks[2*num_iter], μks[:,num_iter], Σks[:,:,num_iter]
end

optimize_component (generic function with 1 method)

The following function performs variational boosting, where the first component is optimized using ADVI.

In [985]:
function VB(logπ, grad, γ, Ncomp, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig)
    μs = zeros(d, Ncomp)
    Σs = zeros(d, d, Ncomp)
    λs = [1.]
    objs = zeros(Ncomp)

    # fitting first component
    println("optimizing component 1 / ", Ncomp)
    mus, sigmas, obj = ADVI(logπ, grad, γ, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig)
    μs[:,1] = mus[:,num_iter]
    Σs[:,:,1] = sigmas[:,:,num_iter]
    println("component 1")
    println("mean: ", μs[:,1])
    println("variance: ", Σs[:,:,1])
    println("-----")

    # fitting remaining components
    for k in 2:Ncomp
        println("optimizing component ", k, " / ", Ncomp)
        # init new component
        xs = sample_from_current_mixture(μs[:,1:k-1], Σs[:,:,1:k-1], λs[1:k-1], 40, d)
        weights = compute_EM_weights(μs[:,1:k-1], Σs[:,:,1:k-1], λs[1:k-1], logπ, xs)
        λ_k, μ_k, Σ_k = weighted_EM_initialization(μs[:,1:k-1], Σs[:,:,1:k-1], λs[1:k-1], weights, xs, diag_sig)
        println("initialization of weight, mean, and variance")
        println(λ_k)
        println(μ_k)
        println(Σ_k)
        println("-----")

        # optimize new component
        λ_k, μ_k, Σ_k = optimize_component(μs[:,1:k-1], Σs[:,:,1:k-1], λs[1:k-1], λ_k, μ_k, Σ_k,
                                           logπ, γ, Ngrad, d, num_iter, transform, diag_sig)
        # update parameters
        λs = (1. - λ_k) .* λs
        push!(λs, λ_k)
        μs[:,k] .= μ_k
        Σs[:,:,k] .= Σ_k

        println("component ", k)
        println("mean: ", μs[:,k])
        println("variance: ", Σs[:,:,k])
        println("weight: ", λ_k)
        println("-----")
    end

    return μs, Σs, λs
end

VB (generic function with 1 method)

We first demonstrate a case where the weighted EM algorithm fails.

In [None]:
μ = vec([2. 2.])
Σ = [1 0.; 0. 1]
logπ = x -> log(0.5 * pdf(MvNormal(μ, Σ), x) + 0.5 * pdf(MvNormal(-μ, Σ), x))
γ = k -> 0.1 /sqrt(k)
Ngrad = 200
Nobj = 200
ψ0 = vec([1.5 1.5 1. 1.])
diag_sig = true
d = 2
num_iter = 1000
Ncomp = 2

Random.seed!(1)
μs, Σs, λs = VB(logπ, grad, γ, Ncomp, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig);

In [None]:
for i in 1:Ncomp
    println("mu: ", μs[:,i])
    println("sigma: ", Σs[:,:,i])
    println("lambda: ", λs[i])
end

In [None]:
function πtrue(x,y)
    return 0.5 * pdf(MvNormal(μ, Σ), vec([x y])) + 0.5 * pdf(MvNormal(-μ, Σ), vec([x y]))
end

function πapprox(x,y)
    ret = 0.
    for i in 1:Ncomp
        ret = ret + λs[i] * pdf(MvNormal(μs[:,i], Σs[:,:,i]), vec([x y]))
    end
    return ret
end

x1s = -5:0.1:5
x2s = -5:0.1:5

contour(x1s, x2s, (x, y) -> πtrue(x,y), xlabel="x1", ylabel="x2", lw=1, levels=30)
png("true_post_far")

In [931]:
contour(x1s, x2s, (x, y) -> πapprox(x,y), lw=1, levels=30, xlabel="x1", ylabel="x2")
png("approx_post_far")

We hence discard the weighted EM initialization and instead intialize the next component with a component centred at the origin with a large variance.

In [968]:
function VB(logπ, grad, γ, Ncomp, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig)
    μs = zeros(d, Ncomp)
    Σs = zeros(d, d, Ncomp)
    λs = [1.]
    objs = zeros(Ncomp)

    # fitting first component
    println("optimizing component 1 / ", Ncomp)
    mus, sigmas, obj = ADVI(logπ, grad, γ, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig)
    μs[:,1] = mus[:,num_iter]
    Σs[:,:,1] = sigmas[:,:,num_iter]
    println("component 1")
    println("mean: ", μs[:,1])
    println("variance: ", Σs[:,:,1])
    println("-----")
    
    # fitting remaining components
    for k in 2:Ncomp
        println("optimizing component ", k, " / ", Ncomp)
        # init new component
        λ_k = 1. / Ncomp
        Σ_k = Diagonal(ones(d))
        idx = Int64(floor(rand()*k))+1
        μ_k = -vec(ones(d))

        # optimize new component
        λ_k, μ_k, Σ_k = optimize_component(μs[:,1:k-1], Σs[:,:,1:k-1], λs[1:k-1], λ_k, μ_k, Σ_k,
                                           logπ, γ, Ngrad, d, num_iter, transform, diag_sig)
        # update parameters
        λs = (1. - λ_k) .* λs
        push!(λs, λ_k)
        μs[:,k] .= μ_k
        Σs[:,:,k] .= Σ_k

        println("component ", k)
        println("mean: ", μs[:,k])
        println("variance: ", Σs[:,:,k])
        println("weight: ", λ_k)
        println("-----")
    end

    return μs, Σs, λs
end

VB (generic function with 1 method)

In [979]:
μ = vec([3. 3.])
Σ = [1 0.1; 0.1 1]
logπ = x -> log(0.5 * pdf(MvNormal(μ, Σ), x) + 0.5 * pdf(MvNormal(-μ, Σ), x))
γ = k -> 0.1 /sqrt(k)
Ngrad = 200
Nobj = 200
ψ0 = vec([1.5 1.5 1. 1.])
diag_sig = true
d = 2
num_iter = 1000
Ncomp = 2

Random.seed!(1)
μs, Σs, λs = VB(logπ, grad, γ, Ncomp, Ngrad, Nobj, ψ0, d, num_iter, transform, diag_sig);

optimizing component 1 / 2
component 1
mean: [2.9941526171455397, 2.9928674816328695]
variance: [1.0257877522063494 0.0; 0.0 1.0299760983657085]
-----
optimizing component 2 / 2
[-1.0, -1.0, 0.0, 0.0]
Array{Float64,1}
-----
ψ gradient: [0.06557776016618942, 0.05759889421271219, 0.06281868025584292, 0.002957233010158445]
λ gradient: -0.059942449696490326
-----
component 2
mean: [-2.85627062883778, -2.8546278764693818]
variance: [1.016294174541848 0.0; 0.0 1.0195371494456904]
weight: 0.4808757132944286
-----


In [980]:
for i in 1:Ncomp
    println("mu: ", μs[:,i])
    println("sigma: ", Σs[:,:,i])
    println("lambda: ", λs[i])
end

mu: [2.9941526171455397, 2.9928674816328695]
sigma: [1.0257877522063494 0.0; 0.0 1.0299760983657085]
lambda: 0.5191242867055714
mu: [-2.85627062883778, -2.8546278764693818]
sigma: [1.016294174541848 0.0; 0.0 1.0195371494456904]
lambda: 0.4808757132944286


In [983]:
function πtrue(x,y)
    return 0.5 * pdf(MvNormal(μ, Σ), vec([x y])) + 0.5 * pdf(MvNormal(-μ, Σ), vec([x y]))
end

function πapprox(x,y)
    ret = 0.
    for i in 1:Ncomp
        ret = ret + λs[i] * pdf(MvNormal(μs[:,i], Σs[:,:,i]), vec([x y]))
    end
    return ret
end

x1s = -6:0.1:6
x2s = -6:0.1:6

contour(x1s, x2s, (x, y) -> πtrue(x,y), xlabel="x1", ylabel="x2", lw=1, levels=30)
png("true_post_close")

In [984]:
contour(x1s, x2s, (x, y) -> πapprox(x,y), lw=1, levels=30, xlabel="x1", ylabel="x2")
png("approx_post_close")

## Extending to Multiple Variational Families