# 2次元ガウス分布に対する変分推論

参考：https://machine-learning.hatenablog.com/entry/2016/01/31/172500

In [23]:
using Plots
using LinearAlgebra

In [51]:
# パラメータの初期化
function init_params(D)
    μ_history = randn(D)
    Λ_history = zeros(D, D)    
    return μ_history, Λ_history
end

init_params (generic function with 1 method)

In [52]:
# パラメータの更新即
function update_params(μ, Λ, μ_hisotry, Λ_hisotry)    
    μ_history[1] = μ[1] - inv(Λ[1, 1]) * Λ[1, 2] * (μ_history[2] - μ[2])
    Λ_history[1, 1] = Λ[1, 1]
    μ_history[2] = μ[2] - inv(Λ[2, 2]) * Λ[2, 1] * (μ_history[1] - μ[1])
    Λ_history[2, 2] = Λ[2, 2]
    return μ_hisotry, Λ_hisotry
end

update_params (generic function with 1 method)

In [79]:
# KL divergenceの計算
function calc_KL(μ, Λ, μ̂, Λ̂)
    D = size(μ, 1)
    qform = tr(((μ - μ̂) * (μ - μ̂)' + inv(Λ̂)) * Λ)
    KL = 0.5 * (qform + logdet(Λ) - logdet(Λ̂) - D)
    return KL
end

calc_KL (generic function with 1 method)

In [84]:
# 変分推論
function learn_VI(D, μ, Λ, maxiter)
    μ_history, Λ_history = init_params(D)
    result = Vector()
    KL = Vector{Float64}()
    for i in 1:maxiter
        μ_hisotory, Λ_hisotory = update_params(μ, Λ, μ_hisotry, Λ_hisotry)
        push!(result, (deepcopy(μ_history), deepcopy(inv(Λ_hisotory))))
        push!(KL, calc_KL(μ, Λ, μ_hisotory, Λ_hisotory))
    end
    return result, KL
end

learn_VI (generic function with 2 methods)

In [86]:
# 次元数を設定
D = 2

# パラメータの初期値を設定
μ = [
    0.0
    0.0
]

θ = 2.0 *　π / 12
A = [
    cos(θ) -sin(θ)
    sin(θ) cos(θ)
]

Λ = inv(A * inv([1.0 0.0; 0.0 10]) * A')

# 推論
maxiter = 10
result, KL = learn_VI(D, μ, Λ, maxiter);