In [1]:
import numpy as np
data = np.load("mcs_hw2_p3_data.npy")
x = data[:, :2]
y = data[:, 2]

In [20]:
import scipy.stats

def get_gradient_mu_t(beta, mu, sigma2, v):
    return (beta - mu) / sigma2

def get_gradient_logsigma2_t(beta, mu, sigma2, v):
    norm = beta - mu
    return (- 1 / sigma2 + norm * norm / (2 * sigma2 * sigma2)) * sigma2

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def get_log_p(x, y, beta):
    eps = 0.0
    res = 0.0
    res += np.sum(y * np.log(eps + sigmoid(np.dot(x, beta))) + (1.0 - y) * np.log(eps + 1 - sigmoid(np.dot(x, beta))))
    res += np.sum(scipy.stats.norm.logpdf(beta, np.zeros(2), np.ones(2)))
    return res

def get_log_q(mu, sigma2, beta, v):
    res = np.sum(scipy.stats.norm.logpdf(beta, mu, np.sqrt(sigma2)))
    return res

import concurrent.futures



def elbo(x, y, mu, sigma2, dof):
    res = 0.0
    sample_size = 1024
    sample_beta = np.zeros(shape=(2, sample_size))
    sample_beta = np.random.normal(mu, np.sqrt(sigma2), size=[sample_size, mu.shape[0]])
    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        future_list = [executor.submit(get_log_p, x, y, beta) for beta in sample_beta]
        for future in concurrent.futures.as_completed(future_list):
            res += future.result()

    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        future_list = [executor.submit(get_log_q, mu, sigma2, beta, dof) for beta in sample_beta]
        for future in concurrent.futures.as_completed(future_list):
            res += future.result()
            
    return res / sample_size

In [16]:
def bbvi_cv(x, y, mu, sigma2, lr, n_iter, m, v, dof):
    sample_size = 64
    sample_beta = np.zeros(shape=(2, sample_size))
    sample_beta = np.random.normal(mu, np.sqrt(sigma2), size=[sample_size, mu.shape[0]])
    # update mu
    loss_mu = np.zeros(shape=[sample_size, mu.shape[0]])
    loss_logsigma2 = np.zeros(shape=[sample_size, sigma2.shape[0]])
    cv_mu = np.zeros(shape=[sample_size, mu.shape[0]])
    cv_sigma2 = np.zeros(shape=[sample_size, sigma2.shape[0]])
    for i in range(sample_size):
        loss_mu[i] = cv_mu[i] = get_gradient_mu_t(sample_beta[i], mu, sigma2, dof)
        loss_logsigma2[i] = cv_sigma2[i] = get_gradient_logsigma2_t(sample_beta[i], mu, sigma2, dof)
        log_p = get_log_p(x, y, sample_beta[i])
        log_q = get_log_q(mu, sigma2, sample_beta[i], dof)
        loss_mu[i] *= (log_p - log_q)
        loss_logsigma2[i] *= (log_p - log_q)
        
    cov_mu0 = np.cov(np.stack((cv_mu.T[0], loss_mu.T[0]), axis=0))
    a_mu0 = cov_mu0[0][1] / cov_mu0[0][0]
    cov_mu1 = np.cov(np.stack((cv_mu.T[1], loss_mu.T[1]), axis=0))
    a_mu1 = cov_mu1[0][1] / cov_mu1[0][0]
    cov_logsigma2 = np.cov(np.stack((cv_sigma2.T[0], loss_logsigma2.T[0]), axis=0))
    a_logsigma2 = cov_logsigma2[0][1] / cov_logsigma2[0][0]
    
    update_mu = np.mean(loss_mu, axis=0)
    update_logsigma2 = np.mean(loss_logsigma2, axis=0)
    update_h_mu = np.mean(cv_mu, axis=0) * [a_mu0, a_mu1]
    update_h_logsigma2 = np.mean(cv_sigma2, axis=0) * a_logsigma2
    
    var_mu = np.var(loss_mu - cv_mu * [a_mu0, a_mu1], axis=0)
    var_sigma = np.var(loss_logsigma2 - cv_sigma2 * a_logsigma2, axis=0)
    
    grad = np.concatenate([update_mu - update_h_mu, update_logsigma2 - update_h_logsigma2])

    m = 0.9 * m + 0.1 * grad
    v = 0.999 * v + 0.001 * np.power(grad, 2)
    
    m_hat = m / (1 - np.power(0.9, n_iter))
    v_hat = v / (1 - np.power(0.999, n_iter))
    
    update = m_hat / (np.sqrt(v_hat))
    
    mu += lr * update[:2]
    sigma2 = np.exp(np.log(sigma2) + lr * update[2:])
    print(mu, sigma2)
    return mu, sigma2, m, v, var_mu, var_sigma

In [22]:
def train_bbvi_cv(x, y, n_iter, dof):
    mu_list = []
    sigma2_list = []
    var_list = []
    mu = np.random.normal(size=2)
    sigma2 = np.power(np.random.normal(size=2), 2)
    lr = 0.01
    m = np.zeros(shape=4)
    v = np.zeros(shape=4)
    for i in range(n_iter):
        mu, sigma2, m, v, var_mu, var_sigma = bbvi_cv(x, y, mu, sigma2, lr, i + 1, m, v, dof)
        mu_list.append(mu.copy())
        sigma2_list.append(sigma2.copy())
        var_list.append([var_mu.copy(), var_sigma.copy()])
    return mu_list, sigma2_list, var_list

In [23]:
'''
res_list = []
for dof in range(1, 100):
    mu_bbvi_cv, sigma2_bbvi_cv, var_bbvi_cv = train_bbvi_cv(x, y, 400, dof)
    elbo_list = [elbo(x, y, mu_bbvi_cv[i], sigma2_bbvi_cv[i], dof) for i in range(390, 400)]
    res_list.append(np.mean(np.array(elbo_list)))
    print(dof , ": ", res_list[dof - 1])
'''

'\nres_list = []\nfor dof in range(1, 100):\n    mu_bbvi_cv, sigma2_bbvi_cv, var_bbvi_cv = train_bbvi_cv(x, y, 400, dof)\n    elbo_list = [elbo(x, y, mu_bbvi_cv[i], sigma2_bbvi_cv[i], dof) for i in range(390, 400)]\n    res_list.append(np.mean(np.array(elbo_list)))\n    print(dof , ": ", res_list[dof - 1])\n'

In [24]:
mu_bbvi_cv, sigma2_bbvi_cv, var_bbvi_cv = train_bbvi_cv(x, y, 1000, 1)

[-2.34157193  1.14070749] [0.12040692 1.9326584 ]
[-2.33172741  1.14888661] [0.11942849 1.91442149]
[-2.32210357  1.15781139] [0.11925732 1.897884  ]
[-2.31246109  1.1627762 ] [0.1189421  1.88054191]
[-2.30275873  1.16926978] [0.11892357 1.86284904]
[-2.29304888  1.17543199] [0.11858116 1.845018  ]
[-2.283222    1.18216791] [0.11836959 1.82841465]
[-2.2732408   1.18868748] [0.11783873 1.81147843]
[-2.26327409  1.19295322] [0.11726351 1.79477624]
[-2.2534669   1.19662976] [0.11674394 1.77932028]
[-2.24374992  1.20094882] [0.11608334 1.76378546]
[-2.23395782  1.20453137] [0.11563529 1.75006186]
[-2.22430777  1.20857194] [0.11512982 1.7362491 ]
[-2.21449753  1.2108031 ] [0.11458154 1.72228109]
[-2.20478703  1.21282007] [0.11403412 1.70848408]
[-2.19485138  1.21455949] [0.11353899 1.69517201]
[-2.18501192  1.21716176] [0.11299737 1.68181854]
[-2.17511881  1.21858369] [0.11233136 1.66786894]
[-2.16538072  1.22070297] [0.11193574 1.65647937]
[-2.15574801  1.22155929] [0.11142187 1.64447921]


[-0.80890229  0.9232583 ] [0.08201538 0.61043706]
[-0.80054955  0.92101719] [0.08179466 0.60711948]
[-0.79194076  0.91699493] [0.08160935 0.60400509]
[-0.78369492  0.91422832] [0.08146104 0.60108743]
[-0.77530161  0.91192174] [0.08139384 0.59863656]
[-0.76704382  0.91066341] [0.08137078 0.59647166]
[-0.75869182  0.91008777] [0.08133401 0.59415665]
[-0.7505019  0.910265 ] [0.08134787 0.59181859]
[-0.74231054  0.91006455] [0.08125816 0.58903854]
[-0.73384985  0.90885253] [0.08114307 0.58619999]
[-0.72539278  0.90817317] [0.08110645 0.5836857 ]
[-0.71710908  0.90728631] [0.08080134 0.57986622]
[-0.70889081  0.9066275 ] [0.08044397 0.57587728]
[-0.70071493  0.9065429 ] [0.08017846 0.57235238]
[-0.69257465  0.90540019] [0.07993572 0.56870589]
[-0.68442503  0.90293557] [0.07967789 0.56508202]
[-0.67630685  0.89882919] [0.07945035 0.56179577]
[-0.66792873  0.8933946 ] [0.07925021 0.55877476]
[-0.65963559  0.88913984] [0.07899206 0.5555173 ]
[-0.65145625  0.88644482] [0.0787716  0.55245117]
[-

[0.34028876 0.72635828] [0.05711309 0.27223363]
[0.3453791  0.72598302] [0.05704659 0.27131195]
[0.35042098 0.72531562] [0.056984   0.27044544]
[0.35525702 0.72451294] [0.05693603 0.26963878]
[0.36013289 0.72356038] [0.05684778 0.26870737]
[0.36499299 0.72315978] [0.05675589 0.26773554]
[0.36983916 0.72203684] [0.05666058 0.26675269]
[0.37465525 0.72161418] [0.05657496 0.26580945]
[0.3794412 0.7217301] [0.05647757 0.26485097]
[0.3842536  0.72011894] [0.05637118 0.26385787]
[0.38898176 0.71861309] [0.05631471 0.26292067]
[0.39368707 0.71729914] [0.05628686 0.2619823 ]
[0.39823399 0.71702607] [0.05627699 0.26110486]
[0.40275497 0.71675049] [0.05624551 0.26021123]
[0.4072575  0.71607426] [0.05618887 0.25927613]
[0.41185149 0.71611792] [0.05613333 0.25836385]
[0.41634021 0.71601362] [0.0560388  0.25736871]
[0.42096239 0.71577812] [0.05593873 0.25641405]
[0.42556367 0.7153537 ] [0.05586369 0.25545541]
[0.43005306 0.71526668] [0.05580163 0.25455673]
[0.43452611 0.71474478] [0.05570429 0.2535

[0.99745177 0.77426047] [0.04595994 0.16110513]
[1.00053272 0.77700647] [0.04587862 0.16050662]
[1.00358697 0.77902788] [0.04579525 0.15993489]
[1.00662457 0.78061228] [0.04570575 0.15936503]
[1.00970069 0.78239039] [0.04562238 0.15881829]
[1.01283262 0.78349766] [0.04553913 0.15828217]
[1.01582836 0.78373265] [0.04547119 0.15779383]
[1.01875909 0.78407846] [0.04541582 0.15734198]
[1.02166403 0.78337602] [0.04536237 0.15690244]
[1.02452036 0.78255973] [0.0453109  0.15647363]
[1.02729006 0.78170872] [0.0452747 0.1561012]
[1.03007641 0.78077868] [0.04523288 0.15572667]
[1.03287383 0.77907701] [0.04518624 0.15536168]
[1.03560207 0.77688365] [0.04515258 0.15500406]
[1.03823527 0.77423606] [0.04513013 0.15468376]
[1.04082791 0.77205543] [0.04510272 0.15435618]
[1.04333947 0.76976274] [0.04506674 0.15402402]
[1.04580309 0.76788921] [0.04503349 0.15370146]
[1.04830952 0.76582859] [0.0450064  0.15334682]
[1.0508237  0.76415715] [0.04497604 0.15298652]
[1.05336792 0.76289242] [0.04493459 0.1526

[1.3859781  0.81888436] [0.03831624 0.10993902]
[1.38773415 0.8189267 ] [0.03828911 0.10973585]
[1.38941681 0.81924909] [0.03826556 0.10954771]
[1.39111266 0.81941705] [0.03824534 0.10937158]
[1.39271189 0.81969631] [0.03822675 0.10920889]
[1.39436665 0.820041  ] [0.03820323 0.10903886]
[1.39600296 0.82075765] [0.03818106 0.10887452]
[1.39766482 0.82153125] [0.03815399 0.10870399]
[1.39924967 0.82224601] [0.03812897 0.10854504]
[1.40083803 0.82260073] [0.03810726 0.10838992]
[1.40233406 0.82343581] [0.03808953 0.10823569]
[1.40377896 0.82468706] [0.03807107 0.10808441]
[1.40519737 0.82559496] [0.0380537  0.10793981]
[1.40660933 0.82647219] [0.03803707 0.10779987]
[1.40814067 0.82880775] [0.03801929 0.10764044]
[1.40974568 0.83133137] [0.03800121 0.10748017]
[1.411339   0.83307389] [0.0379768  0.10731222]
[1.41289728 0.83462089] [0.03795768 0.10713537]
[1.41437451 0.83627572] [0.03794335 0.10696702]
[1.41581689 0.83744698] [0.0379293  0.10680655]
[1.41721869 0.83772844] [0.03791061 0.10

[1.61053786 0.86512529] [0.03481475 0.08463067]
[1.61160801 0.86686578] [0.03478847 0.0844971 ]
[1.61269919 0.86802617] [0.03476312 0.08436333]
[1.61376588 0.86932596] [0.03474131 0.08423589]
[1.61486808 0.87016486] [0.03471617 0.08410453]
[1.61599746 0.87072522] [0.03468882 0.08397394]
[1.61709153 0.8710915 ] [0.0346562  0.08383947]
[1.61821005 0.87132869] [0.03462647 0.08371158]
[1.61929333 0.87205484] [0.03459988 0.0835814 ]
[1.6203779 0.8726926] [0.03457325 0.0834515 ]
[1.62145878 0.8730831 ] [0.03454646 0.08332817]
[1.62250378 0.87302538] [0.03451858 0.0831864 ]
[1.62348362 0.87332243] [0.03449059 0.08304308]
[1.62442327 0.87381942] [0.03446049 0.08289966]
[1.62529921 0.87450362] [0.03443513 0.0827643 ]
[1.6261705  0.87537264] [0.0344072  0.08262696]
[1.62706273 0.87531076] [0.03437711 0.08249005]
[1.62796078 0.87503394] [0.03434493 0.08235592]
[1.62887392 0.87569332] [0.03431291 0.08221869]
[1.62976454 0.8768932 ] [0.03428219 0.08208444]
[1.63067725 0.87735666] [0.03425231 0.0819

In [27]:
elbo(x, y, mu_bbvi_cv[-1], sigma2_bbvi_cv[-1], 1)

AttributeError: 'float' object has no attribute 'shape'

In [13]:
mu_bbvi_cv

[array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
 array([nan, nan]),
