In [1]:
import numpy as np
from scipy.special import digamma as dga
from scipy.special import gamma as ga
from scipy.special import loggamma as lga

In [2]:
eps=1e-10
def log(x):
    return np.log(x + eps)

def digamma(x):
    return dga(x + eps)

def loggamma(x):
    return lga(x + eps)

In [3]:
def init(data):
    vocab = np.array([i for i in range(100)])

    num_doc = data.shape[0]
    num_vocab = vocab.shape[0]
    len_doc = data.shape[1]
    num_topic = 10

    w = np.zeros([num_doc, len_doc, num_vocab])
    for d in range(num_doc):
        for n in range(len_doc):
            w[d, n, data[d, n]] = 1

    alpha = np.ones(shape=num_topic)
    eta = np.ones(shape=num_vocab)

    phi = np.random.rand(num_doc, len_doc, num_topic)
    for d in range(num_doc):
        for n in range(len_doc):
            phi[d, n] /= np.sum(phi[d, n])

    gam = np.random.rand(num_doc, num_topic)
    gam /= np.sum(gam, axis=1)[:, np.newaxis]

    lam = np.random.rand(num_topic, num_vocab)
    lam /= np.sum(lam, axis=1)[:, np.newaxis]
    return lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta

In [4]:
def one_step(lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta):
    #print(num_doc, num_topic, num_vocab)
    for k in range(num_topic):
        lam[k] = eta
        for d in range(num_doc):
            for n in range(len_doc):
                lam[k] += phi[d, n, k] * w[d, n]
    #lam /= np.sum(lam, axis=1)[:, np.newaxis]
    
    gam = alpha + np.sum(phi, axis=1)
    #gam /= np.sum(gam, axis=1)[:, np.newaxis]
    
    def get_single_doc(lam, gam, phi, w, d):
        for n in range(len_doc):
            #phi[d, n, :] = np.exp(digamma(gam[d, :]) + digamma(lam[:, data[d, n]]) - digamma(np.sum(lam, axis=1)))
            for k in range(num_topic):
                phi[d, n, k] = np.exp(digamma(lam[k, data[d, n]]) - digamma(np.sum(lam[k])) + digamma(gam[d, k]) - digamma(np.sum(gam[d])))
            phi[d, n, :] /= np.sum(phi[d, n, :])
        return phi[d], d
            
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        future_list = [executor.submit(get_single_doc, lam, gam, phi, w, d) for d in range(num_doc)]
        for future in concurrent.futures.as_completed(future_list):
            phi_d, d = future.result()
            phi[d] = phi_d
    
    return lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta

In [5]:
import concurrent.futures

def get_res1(lam, gam, phi, w):
    res_1 = 0.0
    res_1 += num_topic * loggamma(np.sum(eta))
    res_1 -= num_topic * np.sum(loggamma(eta))
    '''
    for k in range(num_topic):
        for i in range(num_vocab):
            res_1 += (eta[i] - 1) * (digamma(lam[k, i]) - digamma(np.sum(lam[k])))
    '''
    return res_1


def get_res2(lam, gam, phi, w):          
    res_2 = 0.0
    for n in range(len_doc):
        for k in range(num_topic):
            res_2 += phi[:, n, k] * (digamma(gam[:, k]) - digamma(np.sum(gam, axis=1)))
    #res_2 -= digamma(np.sum(gam, axis=1))
    res_2 = np.sum(res_2)
    return res_2

    
def get_res3(lam, gam, phi, w):
    res_3 = 0.0
    res_3 += loggamma(np.sum(alpha))
    res_3 -= np.sum(loggamma(alpha))
    '''
    for k in range(num_topic):
        res_3 += (alpha[k] - 1) * (digamma(gam[:, k] - digamma(np.sum(gam[:, k]))))
    '''
    res_3 = np.sum(res_3)
    return res_3

def get_res4(lam, gam, phi, w):
    res_4 = 0.0
    def get_res4_single_loc(lam, gam, phi, w, n):
        res_loc = 0.0
        for k in range(num_topic):
            sum_lam_k = np.sum(lam[k])
            for i in range(num_vocab):
                res_loc += phi[:, n, k] * w[:, n, i] * (digamma(lam[k, i]) - digamma(sum_lam_k))
        res_loc = np.sum(res_loc)
        return res_loc
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        future_list = [executor.submit(get_res4_single_loc, lam, gam, phi, w, n) for n in range(len_doc)]
        for future in concurrent.futures.as_completed(future_list):
            res_4 += future.result()
    return res_4

def get_res5(lam, gam, phi, w):
    res_5 = 0.0
    for k in range(num_topic):
        res_5 += loggamma(np.sum(lam[k])) - np.sum(loggamma(lam[k]))
    for k in range(num_topic):
        sum_lam_k = np.sum(lam[k])
        #'''
        res_5 += np.sum((lam[k] - 1) * (digamma(lam[k]) - digamma(sum_lam_k)))
        '''
        for i in range(num_vocab):
            res_5 += (lam[k, i] - 1) * (digamma(lam[k, i]) - digamma(sum_lam_k))
        '''
    return -res_5
    
def get_res6(lam, gam, phi, w):
    res_6 = 0.0
    res_6 += np.sum(phi * log(phi))
    return -res_6

def get_res7(lam, gam, phi, w):
    res_7 = 0.0
    res_7 += loggamma(np.sum(gam, axis=1)) - np.sum(loggamma(gam), axis=1)
    #print(res_7)
    res_7 = np.sum(res_7)
    for d in range(num_doc):
        res_7 += np.sum((gam[d] - 1) * (digamma(gam[d]) - digamma(np.sum(gam[d]))))
    return -res_7

def elbo(lam, gam, phi, w):
    res = 0.0
    func_list = [get_res1, get_res2, get_res3, get_res4, get_res5, get_res6, get_res7]
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        future_list = [executor.submit(func, lam, gam, phi, w) for func in func_list]
        for future in concurrent.futures.as_completed(future_list):
            res += future.result()
    
    return res

In [None]:
elbo_list = []
#for dsize in range(10, 100, 10):
    #data = np.load("mcs_hw4_p1_lda.npy")[:dsize]
if True:
    data = np.load("mcs_hw4_p1_lda.npy")
    lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta = init(data)
    for i in range(100):
        lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta = one_step(lam, gam, phi, w, num_doc, num_topic, num_vocab, len_doc, alpha, eta)
        #print("iteration " + str(i) + " done")
        elbo_per_point = elbo(lam, gam, phi, w)
        elbo_list.append(elbo_per_point)
        print(elbo_per_point)

-9420546.988047626
-9420544.301276501
-9420541.213457767
-9420536.958097842
-9420529.763685262
-9420515.427796645
-9420483.928336091
-9420411.301455073
-9420240.115280135
-9419832.384197503
-9418855.715408616
-9416508.0438344
-9410859.329313288
-9397347.196271101
-9365847.900881317
-9297962.379420219
-9177502.401186656
-9030714.512037896
-8925304.483196799
-8879642.069215579
-8864554.334923856
-8859170.001810363
-8856460.607429592
-8854597.808926111
-8853093.300412925
-8851780.430030977
-8850582.36969959
-8849455.926321492
-8848374.356185053
-8847320.3119146
-8846282.23654943
-8845252.272003207
-8844224.94527465
-8843196.286206545
-8842163.197681094
-8841122.986028587
-8840073.00975797
-8839010.43488207
-8837932.101945505
-8836834.516131341
-8835713.968557987
-8834566.784790095
-8833389.67726437
-8832180.155435916
-8830936.92763444
-8829660.220890114
-8828351.957382761
-8827015.759896986
-8825656.803540893
-8824281.567533396
-8822897.55108735


In [None]:
import pickle
with open("lda_elbo", "wb") as f:
    pickle.dump(elbo_list, f)

In [2]:
import numpy as np
data = np.load("mcs_hw4_p1_lda.npy")

In [5]:
data[0]

array([20, 91,  0, 71, 60, 43, 21, 47, 20, 21, 81, 74, 11, 50,  8, 81, 38,
       81, 50, 21, 41,  0, 50, 44,  1, 78, 90, 81, 68, 12, 81, 51, 71, 23,
       54, 49, 97, 42, 79, 92, 12, 70, 40, 81, 31, 51, 80, 53, 91, 50, 41,
       81, 51, 81, 91, 69, 62, 94, 41, 70, 57,  1, 66, 14, 12, 24, 32, 28,
       37, 72, 62,  2, 67, 40, 12, 50, 22, 25, 90, 41, 36, 62, 81, 12,  0,
       11, 55, 62, 96, 50, 61,  3, 37,  2,  0, 31, 81, 81,  9, 39, 60, 41,
       13, 96, 42, 71, 70,  1, 31, 72, 81,  9, 62, 79, 80, 92, 41, 38, 54,
       79, 90, 15, 28, 89, 77, 20, 21, 48, 71, 82, 21, 16, 53, 93, 35, 68,
       77, 51, 91, 14, 61, 84, 12, 31, 10, 90, 55, 34, 53, 56, 11,  8, 55,
       93, 42, 31, 52, 67, 41, 27, 74, 54, 52, 75, 97, 51, 96, 41, 81, 71,
       31, 42, 91, 71, 15, 61, 27,  2, 18, 68, 20, 80, 77, 71, 18, 71, 16,
       50, 36, 81,  0, 81, 91,  1, 55,  0, 71, 70, 80, 46], dtype=int64)