# Visualize non-hierarchical case

In [13]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import logsumexp

In [14]:
def get_jsd(mu_1=0):
    # mu_1 = 0
    std_1 = 2
    mu_2 = 0
    std_2 = 2

    q_1 = lambda z: (np.exp(-(z - mu_1)**2 / (2 * std_1**2))
          / (std_1 * np.sqrt(2 * np.pi)))
    q_2 = lambda z: (np.exp(-(z - mu_2)**2 / (2 * std_2**2))
          / (std_2 * np.sqrt(2 * np.pi)))

    q_norm = lambda q_s, q_j, z: q_s(z) / (q_s(z) + q_j(z))
    E_log_q_s = lambda q_s, q_j, z: np.log(q_norm(q_s, q_j, z))


    phi_1 = (mu_1, std_1)
    phi_2 = (mu_2, std_2)
    jsd = np.log(2)
    for phi, (p, q) in zip([phi_1, phi_2], [(q_1, q_2), (q_2, q_1)]):
      z = np.random.normal(*phi, 50000)
      jsd += 0.5 * E_log_q_s(p, q, z).mean()
    return jsd

In [15]:
# the tri- and multimodal true models
def p_fun(z):
    mu_list = [-5, 0, 5, 10, 15, 20]
    likelihood = 0
    std = 0.5
    for mu in mu_list:
      likelihood += 1/len(mu_list) * (np.exp(-(z - mu)**2 / (2 * (2 * std)**2))
          / ((2 * std) * np.sqrt(2 * np.pi)))
    return likelihood

def p_fun_trimodal(z):
    mu_list = [0, 10, 20]
    likelihood = 0
    std = 1.1
    for mu in mu_list:
      likelihood += 1/len(mu_list) * (np.exp(-(z - mu)**2 / (2 * (2 * std)**2))
          / ((2 * std) * np.sqrt(2 * np.pi)))
    return likelihood

In [16]:
# computes $\Delta_L$
def get_difference(mu_1=0, L=0):
    std_1 = 2
    mu_2 = 0
    std_2 = 2

    q_1 = lambda z: (np.exp(-(z - mu_1)**2 / (2 * std_1**2))
          / (std_1 * np.sqrt(2 * np.pi)))
    q_2 = lambda z: (np.exp(-(z - mu_2)**2 / (2 * std_2**2))
          / (std_2 * np.sqrt(2 * np.pi)))


    phi_1 = (mu_1, std_1)
    phi_2 = (mu_2, std_2)

    p_1 = lambda z: (np.exp(-(z - 0)**2 / (2 * (1 * std_1)**2))
              / ((1 * std_1) * np.sqrt(2 * np.pi)))
    p_2 = lambda z: (np.exp(-(z - mu_1_list[10])**2 / (2 * (2 * std_2)**2))
          / ((2 * std_2) * np.sqrt(2 * np.pi)))

    p = lambda z: 0.5 * p_1(z) + 0.5 * p_2(z)
    # p = p_fun_trimodal
    # p = p_fun


    miselbo = 0
    avg_iwelbo = 0

    J = 10000
    for phi, (q_s, q_j) in zip([phi_1, phi_2], [(q_1, q_2), (q_2, q_1)]):
        z = np.random.normal(*phi, (J, L))

        log_p_z_l = np.log(p(z))

        q_s_l = q_s(z)
        q_j_l = q_j(z)
        log_q_mix_l = np.log(0.5 * q_s_l + 0.5 * q_j_l)

        miselbo += 0.5 * np.mean(logsumexp(log_p_z_l - log_q_mix_l, axis=1)) - np.log(L)
        avg_iwelbo += 0.5 * np.mean(logsumexp(log_p_z_l - np.log(q_s_l), axis=1)) - np.log(L)
    
    return miselbo - avg_iwelbo

In [None]:
import seaborn as sns
import pandas as pd
from matplotlib.patches import ConnectionPatch

sns.set()
sns.set_theme(style="ticks")
fig, axes = plt.subplots(1, 2, sharex=False, figsize=(20,5))

axes[0].axhline(y=np.log(2), label="log S", color="Black", linestyle=':', linewidth=4)

# mu_1_list = np.linspace(0, 20, 50)
mu_1_list = np.arange(0, 21)
for L in [1, 2, 10, 100, 1000]:
  gaps = []
  jsd = []
  for mu_1 in mu_1_list:
    jsd.append(get_jsd(mu_1))
    gaps.append(get_difference(mu_1, L))
  axes[0].plot(jsd, gaps, label=f'L {L}')
axes[0].set_xlabel('JSD($q_{\phi_1}(z)$, $q_{\phi_2}(z)$)', fontsize='xx-large')
secax_0 = axes[0].secondary_xaxis('top')
secax_1 = axes[1].secondary_xaxis('top')
secax_0.set_xlabel('$\mu$', fontsize='xx-large')
axes[1].set_xlabel('$z$', fontsize='xx-large')
axes[0].set_ylabel(r"$\Delta_L$", fontsize='xx-large')
axes[0].set_xticks(ticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
axes[0].set_xticklabels(labels=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize='xx-large')
axes[0].set_yticks(ticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
axes[0].set_yticklabels(labels=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize='xx-large')
axes[0].legend(loc='best', fontsize='x-large')

secax_1.set_xlabel('JSD($q_{\phi_1}(z)$, $q_{\phi_2}(z)$)\t', fontsize='xx-large')
# secax_1.xaxis.set_label_coords(0.475, 0)

mu_choices = np.array([5, 10, 15], dtype=int)

for i, m in zip(mu_choices, list('ov*')):
    mu_1 = mu_1_list[i]
    std_1 = 2

    q_1 = lambda z: (np.exp(-(z - mu_1)**2 / (2 * std_1**2))
          / (std_1 * np.sqrt(2 * np.pi)))

    
    z = np.linspace(0, 20, 50)
    axes[1].plot(z, q_1(z), marker=m, color='#4169E1', markersize=15, label=f'$\mu_1={mu_1}$')
    # axes[1].axvline(x=mu_1, color='r')

mu_2 = 0
std_2_q = 2
std_2 = 4
q_2 = lambda z: (np.exp(-(z - mu_2)**2 / (2 * std_2_q**2))
          / (std_2_q * np.sqrt(2 * np.pi)))
z = np.linspace(0, 20, 100)
q_2_z = q_2(z)

p_1 = lambda z: (np.exp(-(z - 0)**2 / (2 * (std_1)**2))
              / ((std_1) * np.sqrt(2 * np.pi)))
p_2 = lambda z: (np.exp(-(z - mu_1_list[10])**2 / (2 * (std_2)**2))
      / ((std_2) * np.sqrt(2 * np.pi)))

# decomment here **and** in get_difference to change model
p = lambda z: 0.5 * p_1(z) + 0.5 * p_2(z)
# p = p_fun_trimodal
# p = p_fun

# axes[1].plot(0, 0, c='#4169E1', linewidth=4, label='$q_{\phi_1}(z)$')
axes[1].plot(z, q_2_z, c='green', linewidth=4, label='$\mu_2=0$')
axes[1].plot(z, p(z), color='DarkRed', linewidth=4, label=r'$p(z)$')
axes[1].legend(loc='lower center', bbox_to_anchor=(0.95, 0.05), fontsize='xx-large')

idx = [0, 5, 15]
# axes[1].xaxis.tick_top()
secax_0.set_xticks(ticks=np.array(jsd)[np.array(idx)])
secax_0.set_xticklabels(labels=[round(mu_1_list[i]) for i in idx], fontsize='xx-large')
axes[1].set_xticks(ticks=[0]+list(mu_choices))
axes[1].set_xticklabels(labels=[0]+list(mu_choices), fontsize='xx-large')
secax_1.set_xticks(ticks=np.array(mu_1_list)[mu_choices])
secax_1.set_xticklabels(labels=[f"{round(jsd[i], 2)}" for i in mu_choices], fontsize='xx-large')  #[round(jsd[i], 2) for i in mu_choices]
# axes[0].set_xticks([0, 0.5, 0.6], np.array(jsd)[mu_choices])
axes[1].set_yticks([])
axes[1].set_ylabel('$q_{\phi}(z)$', fontsize='xx-large')
axes[0].grid(True)
axes[1].grid(True)
plt.tight_layout(h_pad=-1.08, w_pad=1)

plt.show()

# Visualize hierarchical case

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import logsumexp
from scipy.stats import norm, multivariate_normal
import seaborn as sns
import pandas as pd
from matplotlib.patches import ConnectionPatch




def get_jsd(logQ):
    S, L, S = logQ.shape
    jsd = np.log(S)
    for s in range(S):
        jsd += 1 / S * np.mean((logQ[s, :, s] - logsumexp(logQ[s, ...], axis=-1)), axis=-1)
    return jsd


def compute_miselbo(logP, logQ):
    # B is the size of the mini-batch (perhaps remove for vbpi)
    S, L = logP.shape
    # logsumexp over the S mixture components in the denominator of the log ratio
    logQ_mixtures = logsumexp(logQ, axis=-1) - np.log(S)
    # logsumexp over the L importance samples, results in S log ratios
    log_ratios = logsumexp(logP - logQ_mixtures, axis=-1) - np.log(L)
    # sum over both batch and S weights
    miselbo = np.sum(1/S * log_ratios)
    return miselbo


def compute_average_iwelbo(logP, logQ):
    S, L = logP.shape
    iwelbos = np.zeros(S)
    average_iwelbo = 0
    for s in range(S):
        iwelbo = logsumexp(logP[s] - logQ[s, :, s], axis=-1) - np.log(L)
        iwelbos[s] = iwelbo
        average_iwelbo += 1 / S * iwelbo
    return average_iwelbo, iwelbos


def log_p(z_s, mu_s):
    mu_list = [-5, 0, 5, 10, 15, 20]
    std = 0.5
    log_mix = np.zeros((len(mu_list), L))
    for i, mu in enumerate(mu_list):
      log_mix[i] = np.log(1/len(mu_list)) + norm(mu, std).logpdf(z_s)
    log_likelihood = logsumexp(log_mix, axis=0)
    log_likelihood += prior.logpdf(mu_s)
    return log_likelihood


fig = plt.figure(figsize=(10, 5))
ax = fig.gca()
sns.set()
sns.set_theme(style="ticks")
plt.grid(True)
ax.axhline(y=np.log(2), label="log S", color="Black", linestyle=':', linewidth=4)
# some params, distributions, etc.
prior = norm(10, 3)
q_1_mu = norm(10, 0.001)
for L in [1, 2, 10, 100, 1000]:
    JSD = []
    DELTA = []
    for std_1 in [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 1.5, 2]:
        q_2_mu = norm(10, std_1)
        q_mu = [q_1_mu, q_2_mu]
        stds = [2, 2]

        S = 2
        J = 1000

        jsd = 0
        miselbo = 0
        avg_iwelbo = 0

        np.random.seed(0)
        for _ in range(J):
            logQ = np.zeros((S, L, S))
            logP = np.zeros((S, L))

            for s, q_mu_s in enumerate(q_mu):
                # sample and get conditional approximations
                mu_s = q_mu_s.rvs(L)
                q_conds = []
                for j, std in enumerate(stds):
                    q_cond_s = norm(mu_s, std)
                    q_conds.append(q_cond_s)
                z_s = q_conds[s].rvs(L)
                # compute the log-likelihoods using the sampled z and mu
                for j in range(S):
                    logQ[s, :, j] += q_conds[j].logpdf(z_s)
                    logQ[s, :, j] += q_mu[j].logpdf(mu_s)
                logP[s] = log_p(z_s, mu_s)

            jsd += get_jsd(logQ) / J
            miselbo += compute_miselbo(logP, logQ) / J
            avg_iwelbo_, _ = compute_average_iwelbo(logP, logQ)
            avg_iwelbo += avg_iwelbo_ / J

        delta_L = miselbo - avg_iwelbo

        DELTA.append(delta_L)
        JSD.append(jsd)
    plt.plot(JSD, DELTA, label=f"L {L}")
secax_0 = ax.secondary_xaxis('top')
secax_0.set_xlabel(r'$\sigma_1^2$', fontsize='xx-large')
secax_0.set_xticks(np.array(JSD)[np.array([0, 2, 6])])
secax_0.set_xticklabels([r'$\sigma_2^2$', r'$10\sigma_2^2$', r'$100\sigma_2^2$'], fontsize='xx-large')
ax.set_ylabel(r"$\Delta_L$", fontsize='xx-large')
ax.set_xlabel(r'JSD($q_{\phi_1}(z, \mu)$, $q_{\phi_2}(z, \mu)$)', fontsize='xx-large')
plt.legend(loc='best', fontsize='x-large')
ax.set_xticks(ticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
ax.set_xticklabels(labels=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize='xx-large')
ax.set_yticks(ticks=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
ax.set_yticklabels(labels=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7], fontsize='xx-large')
plt.show()