In [1]:
# standard libraries
from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal, norm
from scipy import linalg, stats
from particles import resampling as rs
import time
import importlib
import sys
sys.path.append('..')

import MultivariateGaussianAlgorithms as algo
import utils as u
importlib.reload(algo)
importlib.reload(u)

<module 'utils' from '/Users/francescacrucinio/Documents/PAPERS/SUBMITTED/SMC-WFR/Gaussian/../utils.py'>

In [2]:
def KL(mu0, Sigma0, mu1, Sigma1):
    """
    Computes KL(N0 || N1) for multivariate Gaussians.
    
    Parameters
    ----------
    mu0 : (k,) array_like
        Mean vector of distribution 0.
    Sigma0 : (k,k) array_like
        Covariance matrix of distribution 0.
    mu1 : (k,) array_like
        Mean vector of distribution 1.
    Sigma1 : (k,k) array_like
        Covariance matrix of distribution 1.
    
    Returns
    -------
    float
        KL divergence KL(N0 || N1).
    """
    mu0 = np.asarray(mu0)
    mu1 = np.asarray(mu1)
    Sigma0 = np.asarray(Sigma0)
    Sigma1 = np.asarray(Sigma1)
    
    k = mu0.shape[0]

    # Compute inverse and determinants
    invSigma1 = np.linalg.inv(Sigma1)
    detSigma0 = np.linalg.det(Sigma0)
    detSigma1 = np.linalg.det(Sigma1)

    # Mahalanobis term
    diff = mu1 - mu0
    mahal = diff.T @ invSigma1 @ diff

    # Trace term
    trace_term = np.trace(invSigma1 @ Sigma0)

    # KL divergence
    kl = 0.5 * (trace_term + mahal - k + np.log(detSigma1 / detSigma0))

    return kl

In [3]:
d = 20

In [4]:
ms = 20*np.ones(d)
Sigmas = 5*np.eye(d)
mu0 = np.zeros(d)
Sigma0 = np.eye(d)

In [5]:
Sigmas_inv = linalg.inv(Sigmas)

## Algorithms

In [6]:
N = 100
X0 = np.random.multivariate_normal(np.zeros(d), np.eye(d), size = N)
gamma = 0.001

In [7]:
Niter = 20000
Niter_ula = Niter
Niter_mala = Niter
Niter_fr = Niter
Niter_smcula = Niter
Niter_smcmala = Niter

In [8]:
T = Niter*gamma
kl_wfr_exact = np.zeros(Niter)
kl_w_exact = np.zeros(Niter)
kl_fr_exact = np.zeros(Niter)

In [9]:
for t in range(Niter):
    mt, Sigmat = algo.GaussmultiD_WFRexact(mu0, Sigma0, ms, Sigmas, t*gamma, d)
    kl_wfr_exact[t] = KL(mt, Sigmat, ms, Sigmas)
    
    mt, Sigmat = algo.GaussmultiD_FR(mu0, Sigma0, ms, Sigmas, t*gamma)
    kl_fr_exact[t] = KL(mt, Sigmat, ms, Sigmas)
    
    mt, Sigmat = algo.GaussmultiD_Wass(mu0, Sigma0, ms, Sigmas, t*gamma)
    kl_w_exact[t] = KL(mt, Sigmat, ms, Sigmas)

In [10]:
gamma_mala = 2.5
gamma_smcmala = 2.5

In [11]:
Nalgo = 6
Nrep = 1

ula_chain = np.zeros((Nrep, Niter_ula, d, N))
mala_chain = np.zeros((Nrep, Niter_mala, d, N))
fr_x = np.zeros((Nrep, Niter_fr, d, N))
fr_w = np.zeros((Nrep, Niter_fr, N))
wfr_x = np.zeros((Nrep, Niter, d, N))
wfr_w = np.zeros((Nrep, Niter, N))
smcula_x = np.zeros((Nrep, Niter_smcula, d, N))
smcula_w = np.zeros((Nrep, Niter_smcula, N))
smcmala_x = np.zeros((Nrep, Niter_smcmala, d, N))
smcmala_w = np.zeros((Nrep, Niter_smcmala, N))

runtime = np.zeros((Nrep, Nalgo))

In [12]:
for i in range(Nrep):
    ## ULA
    start = time.time()
    ula_chain[i,:] = algo.ParallelULA(gamma, Niter_ula, ms, Sigmas, Sigmas_inv, X0)
    end = time.time()
    runtime[i, 0] = end-start
    ### MALA
    start = time.time()
    mala_chain[i,:], accepted_mala = algo.ParallelMALA(gamma_mala, Niter_mala, ms, Sigmas, Sigmas_inv, X0)
    end = time.time()
    runtime[i, 1] = end-start
    ### SMC-FR
    start = time.time()
    fr_x[i,:], fr_w[i,:] = algo.SMC_UnitFR(gamma, Niter_fr, ms, Sigmas, Sigmas_inv, X0, 1)
    end = time.time()
    runtime[i, 2] = end-start
    ### SMC-WFR
    start = time.time()
    wfr_x[i,:], wfr_w[i,:] = algo.SMC_WFR(gamma, Niter, ms, Sigmas, Sigmas_inv, X0, 1)
    end = time.time()
    runtime[i, 3] = end-start
    ### SMC-ULA
    start = time.time()
    smcula_x[i,:], smcula_w[i,:] = algo.SMC_ULA(gamma, Niter_smcula, ms, Sigmas, Sigmas_inv, X0, 1)
    end = time.time()
    runtime[i, 4] = end-start
    ### SMC-MALA
    start = time.time()
    smcmala_x[i,:], smcmala_w[i,:], accepted_smcmala = algo.SMC_MALA(gamma_smcmala, Niter_smcmala, ms, Sigmas, Sigmas_inv, X0, 1)
    end = time.time()
    runtime[i, 5] = end-start

(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(

(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(20, 20)
(

KeyboardInterrupt: 

In [None]:
d1 = 0
d2 = 1
i = Nrep-1
plt.scatter(wfr_x[i,-1, d1, :], wfr_x[i,-1, d2, :], c = wfr_w[i,-1, :])
plt.scatter(ula_chain[i,-1, d1, :], ula_chain[i,-1, d2, :], c = 'red')

In [None]:
np.mean(np.sum(accepted_mala, axis = 0)/Niter_mala)

In [None]:
np.mean(np.sum(accepted_smcmala, axis = 0)/Niter_smcmala)

In [None]:
avg_runtime = np.mean(runtime, axis = 0)
avg_runtime

In [None]:
xx_ula = avg_runtime[0]/Niter_ula*np.arange(Niter_ula)
xx_mala = avg_runtime[1]/Niter_mala*np.arange(Niter_mala)
xx_fr = avg_runtime[2]/Niter_fr*np.arange(Niter_fr)
xx_wfr = avg_runtime[3]/Niter*np.arange(Niter)
xx_smcula = avg_runtime[4]/Niter_smcula*np.arange(Niter_smcula)
xx_smcmala = avg_runtime[5]/Niter_smcmala*np.arange(Niter_smcmala)

In [None]:
kl_ula = np.zeros((Nrep, Niter_ula))
kl_mala = np.zeros((Nrep, Niter_mala))
kl_fr = np.zeros((Nrep, Niter_fr))
kl_wfr = np.zeros((Nrep, Niter))
kl_smcula = np.zeros((Nrep, Niter_smcula))
kl_smcmala = np.zeros((Nrep, Niter_smcmala))

In [None]:
for j in range(Nrep):
    for i in range(Niter_ula):
        m = np.mean(ula_chain[j,i,:], axis = 1)
        C = np.cov(ula_chain[j,i,:])
        kl_ula[j, i] = KL(m, C, ms, Sigmas)
    for i in range(Niter_mala):
        m = np.mean(mala_chain[j,i,:], axis = 1)
        C = np.cov(mala_chain[j,i,:])
        kl_mala[j, i] = KL(m, C, ms, Sigmas)
    for i in range(Niter):
        m = np.sum(wfr_x[j,i,:]*wfr_w[j,i,:], axis = 1)
        C = np.cov(wfr_x[j,i,:], aweights = wfr_w[j,i,:], bias = True)
        kl_wfr[j, i] = KL(m, C, ms, Sigmas)
    for i in range(Niter_fr):
        m = np.sum(fr_x[j,i,:]*fr_w[j,i,:], axis = 1)
        C = np.cov(fr_x[j,i,:], aweights = fr_w[j,i,:], bias = True)
        kl_fr[j, i] = KL(m, C, ms, Sigmas)
    for i in range(Niter_smcula):
        m = np.sum(smcula_x[j,i,:]*smcula_w[j,i,:], axis = 1)
        C = np.cov(smcula_x[j,i,:], aweights = smcula_w[j,i,:], bias = True)
        kl_smcula[j, i] = KL(m, C, ms, Sigmas)
    for i in range(Niter_smcmala):
        m = np.sum(smcmala_x[j,i,:]*smcmala_w[j,i,:], axis = 1)
        C = np.cov(smcmala_x[j,i,:], aweights = smcmala_w[j,i,:], bias = True)
        kl_smcmala[j, i] = KL(m, C, ms, Sigmas)

In [None]:
plt.figure(figsize = (20,5))
plt.subplot(1, 2, 1)
plt.loglog(np.mean(kl_ula, axis = 0), label = 'ULA', lw = 2)
plt.plot(np.mean(kl_mala, axis = 0), label = 'MALA', lw = 2)
plt.plot(np.mean(kl_fr, axis = 0), label = 'SMC-tempering', lw = 2, linestyle = 'dotted')
plt.plot(np.mean(kl_wfr, axis = 0), label = 'SMC-WFR', lw = 2, linestyle = 'dashed')
plt.plot(np.mean(kl_smcula, axis = 0), label = 'SMC-ULA', lw = 2, linestyle = 'dashed')
plt.plot(np.mean(kl_smcmala, axis = 0), label = 'SMC-MALA', lw = 2, linestyle = 'dashed')
# plt.plot(kl_wfr_exact, c = 'black', lw = 3)
plt.xlabel('iterations', fontsize = 20)
plt.ylabel('KL', fontsize = 20, labelpad=-1)
plt.xticks(fontsize=15);
plt.yticks(fontsize=15);
# legend = plt.legend(loc = 'lower center', bbox_to_anchor=(1, -0.5), ncol = 6, fontsize = 20)
plt.subplot(1, 2, 2)
plt.loglog(xx_ula, np.mean(kl_ula, axis = 0), label = 'ULA', lw = 2)
plt.plot(xx_mala, np.mean(kl_mala, axis = 0), label = 'MALA', lw = 2)
plt.plot(xx_fr, np.mean(kl_fr, axis = 0), label = 'SMC-tempering', lw = 2, linestyle = 'dotted')
plt.plot(xx_wfr, np.mean(kl_wfr, axis = 0), label = 'SMC-WFR', lw = 2, linestyle = 'dashed')
plt.plot(xx_smcula, np.mean(kl_smcula, axis = 0), label = 'SMC-ULA', lw = 2, linestyle = 'dashed')
plt.plot(xx_smcmala, np.mean(kl_smcmala, axis = 0), label = 'SMC-MALA', lw = 2, linestyle = 'dashed')
plt.xlabel('runtime', fontsize = 20)
plt.ylabel('KL', fontsize = 20, labelpad=-1)
plt.xticks(fontsize=15);
plt.yticks(fontsize=15);
# plt.savefig('conv_iter_gm20D.pdf', bbox_inches="tight")