# Metric Gaussian Variational Inference (MGVI), Julia and AutoDiff

## Background

### MGVI

* Bayesian inference scheme by Jakob Knollmüller
* Variational approximation of true posterior with a Gaussian in standardized coordinates
* Minimizes Kullback-Leibler divergence between true posterior and approximation
* Appoximates covariance structures via samples drawn according to the Fisher metric

### Julia

* High-performance programming language
* Just-in-time (JIT) compiled
* Multiple dispatch
* Native AutoDiff

### AutoDiff

* "Differentiation is mechanics, integration is art."
* Forward mode (jacobian-vector-product, pushforward)
    * Memory scales independently from the depth of the computational graph
    * Very efficient for "tall" jacobian matrices
* Reverse mode (vector-jacobian-product, pullback)
    * Memory scales with the depth of the computational graph
    * Very efficient for "wide" jacobain matrices (e.g. gradients)

## Imlementing MGVI in Julia

### Exemplary problem

#### Problem definition

Let the data generating equation be
$$d = R s(\xi) + n $$
with data $d$, signal model $s$, response $R$, Gaussian noise $n$ and parameters $\xi$.

The hierarchical signal model with standardized priors for $\xi$ is
$$ s(\xi) = \exp{\left[ \mathrm{HT}^{-1} \circ P \circ \xi \right]} $$
with harmonic transform operator $\mathrm{HT}$ and amplitude operator $P$.
The signal model contains both non-linearieties as well global linear functions.

Then the negative log-likelihood respectively information hamiltonian of the likelihood reads
$$\mathcal{H}(d|\xi) = \mathrm{nll}(\xi) = \frac{1}{2} (d - Rs(\xi))^\dagger N^{-1} (d - Rs(\xi))$$
with $N$ the noise covariance.

Thus, the overall potential, i.e. joint information hamiltonina of data and signal is
$$\mathcal{H}(d,\xi) = \mathrm{potential}(\xi) = \frac{1}{2} (d - Rs(\xi))^\dagger N^{-1} (d - Rs(\xi)) + \frac{1}{2} \xi^\dagger \xi.$$

#### Solution with MGVI

1. First, draw an initially random starting position $\bar{\xi}$.
1. Construct the Fisher information metric
$$\mathrm{Fisher}(\xi) = J_\bar{\xi}^\dagger N^{-1} J_\bar{\xi} + \mathbb{1}$$
with metric $M = N^{-1}$ and $J_\bar{\xi}$ the jacobian of the signal response at $\bar{\xi}$.
1. Draw samples $\Delta\xi_\cdot$ with the covariance structure defined by the $\mathrm{Fisher}$ metric.
1. Calculate the approximate Kullback-Leiblach divergence
$$\mathcal{D}_\mathrm{KL}(\xi) = {\left\langle \mathcal{H}(d,\xi) \right\rangle}_{\mathcal{G}(\xi-\bar{\xi},\mathrm{Fisher}(\bar{\xi}))} = \frac{1}{\text{#samples}} \sum_{i=0}^{\text{#samples}} \mathcal{H}(d, \xi + \Delta\xi_i).$$
1. Minimize the KL via its natural gradient

Thereby MGVI is a perfect example of why forward ($J_\bar{\xi}$) and reverse mode ($J_\bar{\xi}^\dagger$) differentiation is extremely helpful in a programming language.

In [None]:
import IterativeSolvers: cg
import Random: randn,seed!
import ForwardDiff
import FFTW: plan_r2r, DHT
import Base: *, +, ∘, adjoint
using ForwardDiff  # Rudimentary for autodiff
using Zygote  # Advanced reverse autodiff
using LinearAlgebra
using LinearMaps  # Treate (linear) functions as if they were matrices
using Statistics: mean
using Plots

using Optim


VecOrNum = Union{Number,Vector{<:Number}}

seed!(42)

In [None]:
dims = (1024)
k = [i < dims / 2 ? i :  dims-i for i = 0:dims-1]

# Define the harmonic transform operator as a matrix-like object
ht = plan_r2r(zeros(dims), DHT)
# Unfortunately neither Zygote nor ForwardDiff support planned Hartley
# transformations. While Zygote does not support AbstractFFTs.ScaledPlan,
# ForwardDiff does not overload the appropriate methods from AbstractFFTs.
function *(trafo::typeof(ht), u::Vector{ForwardDiff.Dual{T,V,P}}) where {T,V,P}
    # Unpack AoS -> SoA
    vs = ForwardDiff.value.(u)
    ps = mapreduce(ForwardDiff.partials, vcat, u)
    # Actual computation
    val = trafo * vs
    jvp = trafo * ps
    # Pack SoA -> AoS (depending on jvp, might need `eachrow`)
    return map((v, p) -> ForwardDiff.Dual{T}(v, p...), val, jvp)
end
Zygote.@adjoint function *(trafo::typeof(inv(ht)), xs::T) where T
    return trafo * xs, Δ -> (nothing, trafo * Δ)
end
Zygote.@adjoint function inv(trafo::typeof(ht))
    inv_t = inv(trafo)
    return inv_t, function (Δ)
        adj_inv_t = adjoint(inv_t)
        return (- adj_inv_t * Δ * adj_inv_t, )
    end
end

# ξ := latent variables
ξ_truth = randn(dims)

loglogslope = 2.0 + 0.5 * randn()
P = @. 50 / (k^loglogslope + 1)
function correlated_field(ξ::V) where V<:VecOrNum
    return inv(ht) * (P .* ξ)
end
function signal(ξ::V) where V<:VecOrNum
    return exp.(correlated_field(ξ))
end

N = Diagonal(0.01^2 * ones(dims))
R = ones(dims)
#R[100:200] .= 0
R = Diagonal(R)

function signal_response(ξ::V) where V<:VecOrNum
    return R * signal(ξ)
end

In [None]:
# Generate synthetic signal and data
ss = signal(ξ_truth)
d = R * ss .+ R * sqrt(N) * randn(dims)
plot(ss, color=:red, label="ground truth", linewidt=5)
plot!(d, seriestype=:scatter, marker=:x, color=:black)

In [None]:
@doc """Return a mapping function to translate a vector at a given
vector-valued position to a combined vector of dual numbers"""
function to_dual_at(ξ::V) where V<:VecOrNum
    return function to_dual(δ::V)
        return map((v, p) -> ForwardDiff.Dual(v, p...), ξ, δ)
    end
end

@doc """Retrieve the jacobian of f at xi as implicit matrix"""
function jacobian(f::F, ξ::V) where {F<:Function, V<:VecOrNum}
    to_dual = to_dual_at(ξ)
    # HERE is where the magic happens!
    jvp(δ::V) = mapreduce(ForwardDiff.partials, vcat, f(to_dual(δ)))

    function vjp(δ::T) where T<:VecOrNum
        # HERE is more magic!
        return first(Zygote.pullback(f, ξ)[2](δ))
    end

    return LinearMap{eltype(ξ)}(jvp, vjp, first(size(ξ)))
end

inv_noise_cov = inv(N)

function nll(ξ::L) where L
    res = d .- signal_response(ξ)
    return 0.5 * transpose(res) * inv_noise_cov * res
end

ham(ξ::L where L) = nll(ξ) + 0.5 * (ξ ⋅ ξ)

In [None]:
function covariance_sample(cov_inv::T, jac::N, metric::M) where {
    T<:Union{AbstractMatrix,LinearMap{E}},
    N<:Union{AbstractMatrix,LinearMap{E}},
    M<:Union{AbstractMatrix,LinearMap{E}}
} where E
    ξ_new::Vector{E} = randn(first(size(cov_inv)))
    d_new::Vector{E} = jac * ξ_new .+ sqrt(inv(metric)) * randn(dims)
    j_new::Vector{E} = adjoint(jac) * metric * d_new
    m_new::Vector{E} = cg(cov_inv, j_new, log=true)[1]
    return ξ_new .- m_new
end

function metric_gaussian_kl!(
    pos::P,
    n_samples::C;
    mirror_samples::Bool=false,
    n_grad_steps::C=10,
    nat_grad_scl::F=1.0
) where {T, P, C<:Int, F<:Number}
    jac = jacobian(signal_response, pos)
    fisher = adjoint(jac) * inv_noise_cov * jac + I

    samples = [covariance_sample(fisher, jac, inv_noise_cov) for i = 1 : n_samples]
    samples = mirror_samples ? vcat(samples, -samples) : samples
    
    kl(ξ::P) = reduce(+, ham(ξ + s) for s in samples) / length(samples)

    # Take the metric of the KL itself as curvature
    nll_fisher_by_s = mapreduce(+, samples) do s
        jac_s = jacobian(signal_response, pos + s)
        return adjoint(jac_s) * inv_noise_cov * jac_s
    end
    avg_fisher = nll_fisher_by_s / length(samples) + I

    for _ in 1 : n_grad_steps
        grad::P = first(gradient(kl, pos))
        Δξ = cg(avg_fisher, grad, log=true)[1]
        pos .-= nat_grad_scl * Δξ
    end
    return pos, samples
end



init_pos = 0.1 * randn(dims)

pos = deepcopy(init_pos)
n_samples = 3

In [None]:
pos, samples = metric_gaussian_kl!(pos, n_samples; mirror_samples=true, n_grad_steps=15, nat_grad_scl=0.1)
plot!(signal(pos), label="MGVI it. 1")

In [None]:
pos, samples = metric_gaussian_kl!(pos, n_samples; mirror_samples=true, n_grad_steps=15, nat_grad_scl=0.5)
for (i, s) in enumerate(samples)
    plot!(signal(pos + s), label="Post. Sample " * string(i), color=:gray)
end
plot!(signal(pos), label="Post. mean")

### Why Julia

* No error prone self-built Jacobians thanks to AutoDiff (here, Zygote and ForwardDiff)
* Potentially faster and deployable to a wider infrastructure e.g. GPUs

## What to do next

### Minor

* Introduce a line search algorithm
* Advance Zygote's work on `pushforward`, i.e. forward mode differentiation for vector-valued parameters thereby enabling true GPU and XLA support

### Major

* Handling more than one parameter vector
    * Ideally it should treat the parameters as a dictionary of vectors
    * Parameter dictionaries require major changes and re-inventing LinearMap
* Wrap code for use in python and benchmark against the implementation in Numerical Information Field Theory (NIFTy)
* Integrate into Bayesian Analysis Toolkit (BAT)