# Variational inference via Wasserstein gradient flows

In [2]:
import numpy as np 
import matplotlib.pyplot as plt
import scipy.stats as stats
from distances import W2, KL_divergence

## VARIATIONAL INFERENCE WITH GAUSSIANS

### Bures–JKO scheme

Let $p_0 = \mathcal{N} (m_0, \sigma _0) $ and $p_1 = \mathcal{N} (m_1, \sigma _1)$ be two Gaussian distributions. 

Then, the Wasserstein distance $W _2 ^2 (p_0, p_1)$ admits a close form  : $W _2 ^2 (p_0, p_1) = || m_0 - m_1 || ^2 + \mathcal{B} (\sigma _0, \sigma _1)$, where $B(.,.)$ is the squared Bures metric. 

Now, given a target density $ \pi $ (which will be Gaussian in our case) we define the iterates of the proximal point algorithm : $ p_{k+1, h} = \text{argmin }_p \left[ \text{ KL}(p || \pi) + \frac{1}{2h} W _2 ^2 (p, p_{k,h}) \right] $.

In [None]:
N = 10 ** 4 
d = 2 
target_distribution = stats.multivariate_normal(mean = [3, 2], cov = [[6, 1], [1, 3]])

def Bures_JKO_scheme(target_distribution, N, d, MaxIter = 10 * 3):
    """ 
    Bures JKO scheme for sampling from a target distribution.

    Parameters:
    -----------
    target_distribution: scipy.stats distribution
        The target distribution from which we want to sample from.

    N: int
        The number of samples to generate.
    
    d: int
        The dimension of the target distribution.

    Returns:
    --------
    samples: np.array
        The samples generated from the target
    """
    h = 1 / N    # step size
    pi = target_distribution.sample(N)    # samples from the target distribution
    p_0_h = stats.multivariate_normal(mean = [0] * d, cov = np.eye(d)).sample(N)    # initial distribution
    for _ in range(MaxIter):
        # Here we have a problem : how to compute an argmin when we are working with probability distributions ?
    