In [1]:
import numpy as np
data = np.load("mcs_hw2_p3_data.npy")
x = data[:, :2]
y = data[:, 2]

In [109]:
import scipy.stats

def get_gradient_mu_t(beta, mu, sigma2, v):
    if mu[0] < beta[0]:
        loss_0 = -1
    else:
        loss_0 = 1
    if mu[1] < beta[1]:
        loss_1 = -1
    else:
        loss_1 = 1
    res = - 1 / np.sqrt(sigma2) * [loss_0, loss_1]
    return res

def get_gradient_logsigma2_t(beta, mu, sigma2, v):
    res_1 = - 1 / sigma2
    norm = np.sum(np.abs(beta - mu))
    res_2 = 0.5 * norm * np.power(sigma2, -1.5)
    return (res_1 + res_2) * sigma2

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def get_log_p(x, y, beta):
    eps = 1e-10
    res = 0.0
    res += np.sum(y * np.log(eps + sigmoid(np.dot(x, beta))) + (1.0 - y) * np.log(eps + 1 - sigmoid(np.dot(x, beta))))
    res += np.sum(scipy.stats.norm.logpdf(beta, np.zeros(2), np.ones(2)))
    return res

def get_log_q(mu, sigma2, beta, v):
    eps = 1e-10
    norm = np.sum(np.abs(beta - mu))
    res = - np.log(sigma2) - norm / np.sqrt(sigma2) - np.log(4)
    return res

import concurrent.futures



def elbo(x, y, mu, sigma2, dof):
    res = 0.0
    sample_size = 1024
    sample_beta = np.zeros(shape=(2, sample_size))
    sample_beta[0] = np.random.laplace(mu[0], np.sqrt(sigma2), sample_size)
    sample_beta[1] = np.random.laplace(mu[1], np.sqrt(sigma2), sample_size)
    sample_beta = sample_beta.T
    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        future_list = [executor.submit(get_log_p, x, y, beta) for beta in sample_beta]
        for future in concurrent.futures.as_completed(future_list):
            res += future.result()

    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        future_list = [executor.submit(get_log_q, mu, sigma2, beta, dof) for beta in sample_beta]
        for future in concurrent.futures.as_completed(future_list):
            res -= future.result()
            
    return res / sample_size

In [114]:
def bbvi_cv(x, y, mu, sigma2, lr, n_iter, m, v, dof):
    sample_size = 8
    sample_beta = np.zeros(shape=(2, sample_size))
    sample_beta[0] = np.random.laplace(mu[0], np.sqrt(sigma2), sample_size)
    sample_beta[1] = np.random.laplace(mu[1], np.sqrt(sigma2), sample_size)
    sample_beta = sample_beta.T
    # update mu
    loss_mu = np.zeros(shape=[sample_size, mu.shape[0]])
    loss_logsigma2 = np.zeros(shape=[sample_size, sigma2.shape[0]])
    cv_mu = np.zeros(shape=[sample_size, mu.shape[0]])
    cv_sigma2 = np.zeros(shape=[sample_size, sigma2.shape[0]])
    for i in range(sample_size):
        loss_mu[i] = cv_mu[i] = get_gradient_mu_t(sample_beta[i], mu, sigma2, dof)
        loss_logsigma2[i] = cv_sigma2[i] = get_gradient_logsigma2_t(sample_beta[i], mu, sigma2, dof)
        log_p = get_log_p(x, y, sample_beta[i])
        log_q = get_log_q(mu, sigma2, sample_beta[i], dof)
        loss_mu[i] *= (log_p - log_q)
        loss_logsigma2[i] *= (log_p - log_q)
        
    cov_mu0 = np.cov(np.stack((cv_mu.T[0], loss_mu.T[0]), axis=0))
    a_mu0 = cov_mu0[0][1] / cov_mu0[0][0]
    cov_mu1 = np.cov(np.stack((cv_mu.T[1], loss_mu.T[1]), axis=0))
    a_mu1 = cov_mu1[0][1] / cov_mu1[0][0]
    cov_logsigma2 = np.cov(np.stack((cv_sigma2.T[0], loss_logsigma2.T[0]), axis=0))
    a_logsigma2 = cov_logsigma2[0][1] / cov_logsigma2[0][0]
    
    update_mu = np.mean(loss_mu, axis=0)
    update_logsigma2 = np.mean(loss_logsigma2, axis=0)
    update_h_mu = np.mean(cv_mu, axis=0) * [a_mu0, a_mu1]
    update_h_logsigma2 = np.mean(cv_sigma2, axis=0) * a_logsigma2
    
    var_mu = np.var(loss_mu - cv_mu * [a_mu0, a_mu1], axis=0)
    var_sigma = np.var(loss_logsigma2 - cv_sigma2 * a_logsigma2, axis=0)
    
    grad = np.concatenate([update_mu - update_h_mu, update_logsigma2 - update_h_logsigma2])

    m = 0.9 * m + 0.1 * grad
    v = 0.999 * v + 0.001 * np.power(grad, 2)
    
    m_hat = m / (1 - np.power(0.9, n_iter))
    v_hat = v / (1 - np.power(0.999, n_iter))
    
    update = m_hat / (np.sqrt(v_hat) + 1e-10)
    
    mu += lr * update[:2]
    sigma2 = np.exp(np.log(sigma2) + lr * update[2])
    print(mu, sigma2)
    return mu, sigma2, m, v, var_mu, var_sigma

In [105]:
def train_bbvi_cv(x, y, n_iter, dof):
    mu_list = []
    sigma2_list = []
    var_list = []
    mu = np.random.normal(size=2)
    sigma2 = np.power(np.random.normal(size=1), 2)
    lr = 0.1
    m = np.zeros(shape=3)
    v = np.zeros(shape=3)
    for i in range(n_iter):
        mu, sigma2, m, v, var_mu, var_sigma = bbvi_cv(x, y, mu, sigma2, lr, i + 1, m, v, dof)
        mu_list.append(mu.copy())
        sigma2_list.append(sigma2.copy())
        var_list.append([var_mu.copy(), var_sigma.copy()])
    return mu_list, sigma2_list, var_list

In [5]:
res_list = []
for dof in range(1, 100):
    mu_bbvi_cv, sigma2_bbvi_cv, var_bbvi_cv = train_bbvi_cv(x, y, 400, dof)
    elbo_list = [elbo(x, y, mu_bbvi_cv[i], sigma2_bbvi_cv[i], dof) for i in range(390, 400)]
    res_list.append(np.mean(np.array(elbo_list)))
    print(dof , ": ", res_list[dof - 1])

1 :  -4446.445408558689
2 :  -4446.259170712138
3 :  -4445.953974288847
4 :  -4446.1872659429255
5 :  -4446.159092701291
6 :  -4446.2525093520135
7 :  -4446.20928070885
8 :  -4446.082215539985
9 :  -4446.152149318978
10 :  -4446.229557747922
11 :  -4446.161786911572
12 :  -4446.235500797165
13 :  -4446.2013525643915
14 :  -4445.732661728964
15 :  -4446.2323358609765
16 :  -4446.146108915905
17 :  -4446.116916871005
18 :  -4446.23180681882
19 :  -4446.225302927462
20 :  -4446.172475618041
21 :  -4446.281815778396
22 :  -4446.4085013841095
23 :  -4446.210226006423
24 :  -4446.3843183067065
25 :  -4446.061518992862
26 :  -4446.247041152078
27 :  -4446.530916147807
28 :  -4446.2977863624565
29 :  -4446.218619161301
30 :  -4446.3242331217425
31 :  -4446.258963262054
32 :  -4446.367395590993
33 :  -4446.521292384042
34 :  -4446.143007993945
35 :  -4446.484053950106
36 :  -4446.402975377282
37 :  -4446.180957403305
38 :  -4445.958189992486
39 :  -4446.212293279376
40 :  -4446.490106564879
41 

In [6]:
import matplotlib.pyplot as plt
plt.plot(res_list)
plt.savefig("t_dist")

In [115]:
mu_bbvi_cv, sigma2_bbvi_cv, var_bbvi_cv = train_bbvi_cv(x, y, 1000, 1)

[ 1.03233136 -1.1760971 ] [0.19445286]
[ 1.12440038 -1.08802428] [0.21115546]
[ 1.20857615 -0.99589585] [0.22227032]
[ 1.29746925 -0.90067917] [0.22319599]
[ 1.39016904 -0.80519037] [0.2195912]
[ 1.48391564 -0.71091703] [0.21745392]
[ 1.5410731  -0.61944126] [0.21510753]
[ 1.61052425 -0.5269049 ] [0.20588562]
[ 1.67275599 -0.44233664] [0.19636594]
[ 1.73288264 -0.36234577] [0.18752326]
[ 1.77449637 -0.28546026] [0.18087461]
[ 1.81956918 -0.21004346] [0.17556751]
[ 1.85878094 -0.13675132] [0.17069225]
[ 1.89776381 -0.05985045] [0.1646847]
[1.91806876 0.0186223 ] [0.15848996]
[1.93021372 0.09170181] [0.15249051]
[1.93236311 0.16521376] [0.14680902]
[1.93482239 0.23547024] [0.14144557]
[1.9390892  0.30077352] [0.13681954]
[1.94462942 0.3634279 ] [0.13239533]
[1.94675632 0.42453033] [0.12798859]
[1.94435549 0.48594572] [0.12386013]
[1.94677289 0.54730915] [0.11978283]
[1.94402666 0.60641463] [0.11613293]
[1.94373667 0.65995571] [0.11274423]
[1.94812021        nan] [0.10924846]
[nan nan] [n

  return (self.a <= x) & (x <= self.b)
  return (self.a <= x) & (x <= self.b)


[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan

[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan] [nan]
[nan nan

In [111]:
elbo(x, y, mu_bbvi_cv[-1], sigma2_bbvi_cv[-1], 1)

array([-4466.28226808])

In [13]:
mu_bbvi_cv

[array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
