In [4]:
%load_ext autoreload
%autoreload 2
import sys 
if '/Users/ericliu/Desktop/Latent-Dirichilet-Allocation' not in sys.path: 
    sys.path.append('/Users/ericliu/Desktop/Latent-Dirichilet-Allocation')
import torch as tr 
import numpy as np 
import pandas as pd 
from collections import defaultdict
from pprint import pprint
from scipy.special import psi, polygamma, gammaln, loggamma
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
from src.lda_model import LDASmoothed 
from src.generator import doc_generator 
from src.utils import (
    get_vocab_from_docs, 
    get_np_wct, 
    data_loader,
    text_pipeline, 
    process_documents,
) 
from src.text_pre_processor import (
    remove_accented_chars, 
    remove_special_characters, 
    remove_punctuation,
    remove_extra_whitespace_tabs,
    remove_stopwords,
)
from pprint import pprint 
import copy 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
gen = doc_generator(
    M = 3,
    L = 20, 
    topic_prior = tr.tensor([1,1,1,1,1], dtype=tr.double)
)

docs = gen.generate_doc()

Document: 0 | word: 0 -> topic: science -> word: scientst
Document: 0 | word: 1 -> topic: sport -> word: immunology
Document: 0 | word: 2 -> topic: sport -> word: athletics
Document: 0 | word: 3 -> topic: law -> word: court
Document: 0 | word: 4 -> topic: science -> word: form
Document: 0 | word: 5 -> topic: science -> word: infection
Document: 0 | word: 6 -> topic: science -> word: electricity
Document: 0 | word: 7 -> topic: art -> word: form
Document: 0 | word: 8 -> topic: sport -> word: Technique
Document: 0 | word: 9 -> topic: science -> word: astrophysics
Document 0: scientst immunology athletics court form infection electricity form Technique astrophysics

Document: 1 | word: 0 -> topic: health -> word: allergy
Document: 1 | word: 1 -> topic: health -> word: bruise
Document: 1 | word: 2 -> topic: sport -> word: recreation
Document: 1 | word: 3 -> topic: science -> word: genetics
Document: 1 | word: 4 -> topic: science -> word: attorney
Document: 1 | word: 5 -> topic: sport -> wor

In [6]:
docs

{0: 'scientst immunology athletics court form infection electricity form Technique astrophysics',
 1: 'allergy bruise recreation genetics attorney FIFA electricity asymmetrical quantum Technique',
 2: 'evidence fever contagious exercise scientst copyright picture decongestant content court'}

In [7]:
result = process_documents(docs, sample=True) 

There are 3 documents in the dataset after processing
On average estimated document length is 10.0 words per document after processing
There are 25 unique vocab in the corpus after processing


In [8]:
import warnings
def init_lda(docs, vocab, n_topic, gibbs=False, random_state=42):
    if gibbs:
        global V, k, N, M, alpha, eta, n_iw, n_di
    else:
        global V, k, N, M, alpha, beta, gamma, phi
        
    np.random.seed(random_state)

    V = len(vocab)
    k = n_topic  # number of topics
    N = np.array([doc.shape[0] for doc in docs])
    M = len(docs)

    print(f"V: {V}\nk: {k}\nN: {N[:10]}...\nM: {M}")

    # initialize α, β
    if gibbs:
        alpha = np.random.gamma(shape=100, scale=0.01, size=1)  # one for all k
        eta = np.random.gamma(shape=100, scale=0.01, size=1)  # one for all V
        print(f"α: {alpha}\nη: {eta}")
        
        n_iw = np.zeros((k, V), dtype=int)
        n_di = np.zeros((M, k), dtype=int)
        print(f"n_iw: dim {n_iw.shape}\nn_di: dim {n_di.shape}")
    else:
        alpha = np.random.gamma(shape=100, scale=0.01, size=k) #np.random.rand(k)
        beta = np.random.dirichlet(np.ones(V), k)
        print(f"α: dim {alpha.shape}\nβ: dim {beta.shape}")

        # initialize ϕ, γ
        ## ϕ: (M x max(N) x k) arrays with zero paddings on the right
        gamma = alpha + np.ones((M, k)) * V / k

        phi = np.ones((M, max(N), k)) / k
        for m, N_d in enumerate(N):
            phi[m, N_d:, :] = 0  # zero padding for vectorized operations

        print(f"γ: dim {gamma.shape}\nϕ: dim ({len(phi)}, N_d, {phi[0].shape[1]})")

def E_step(docs, phi, gamma, alpha, beta):
    """
    Minorize the joint likelihood function via variational inference.
    This is the E-step of variational EM algorithm for LDA.
    """
    # optimize phi
    for m in range(M):
        #print(N[m], docs[m])
        phi[m, :N[m], :] = (beta[:, docs[m]] * np.exp(
            psi(gamma[m, :]) - psi(gamma[m, :].sum())
        ).reshape(-1, 1)).T

        # Normalize phi
        phi[m, :N[m]] /= phi[m, :N[m]].sum(axis=1).reshape(-1, 1)
        if np.any(np.isnan(phi)):
            raise ValueError("phi nan")
        
        

    # optimize gamma
    gamma = alpha + phi.sum(axis=1)

    

    return phi, gamma


def M_step(docs, phi, gamma, alpha, beta, M):
    """
    maximize the lower bound of the likelihood.
    This is the M-step of variational EM algorithm for (smoothed) LDA.
    
    update of alpha follows from appendix A.2 of Blei et al., 2003.
    """
    # update alpha
    alpha = _update(alpha, gamma, M)
    
    # update beta
    for j in range(V):
        beta[:, j] = np.array([_phi_dot_w(docs, phi, m, j) for m in range(M)]).sum(axis=0)
    beta /= beta.sum(axis=1).reshape(-1, 1)

    return alpha, beta

def _update(var, vi_var, const, max_iter=10000, tol=1e-6):
    """
    From appendix A.2 of Blei et al., 2003.
    For hessian with shape `H = diag(h) + 1z1'`
    
    To update alpha, input var=alpha and vi_var=gamma, const=M.
    To update eta, input var=eta and vi_var=lambda, const=k.
    """
    for _ in range(max_iter):
        # store old value
        var0 = var.copy()
        
        # g: gradient 
        psi_sum = psi(vi_var.sum(axis=1)).reshape(-1, 1)
        g = const * (psi(var.sum()) - psi(var)) \
            + (psi(vi_var) - psi_sum).sum(axis=0)

        # H = diag(h) + 1z1'
        z = const * polygamma(1, var.sum())  # z: Hessian constant component
        h = -const * polygamma(1, var)       # h: Hessian diagonal component
        c = (g / h).sum() / (1./z + (1./h).sum())

        # update var
        var -= (g - c) / h
        print(f"{vi_var.sum()}|{var0} -> {var}")
        
        # check convergence
        err = np.sqrt(np.mean((var - var0) ** 2))
        crit = err < tol
        if crit:
            break
    else:
        warnings.warn(f"max_iter={max_iter} reached: values might not be optimal.")
    
    #print(err)
    return var

def _phi_dot_w(docs, phi, d, j):
    """
    \sum_{n=1}^{N_d} ϕ_{dni} w_{dn}^j
    """
    # doc = np.zeros(docs[m].shape[0] * V, dtype=int)
    # doc[np.arange(0, docs[m].shape[0] * V, V) + docs[m]] = 1
    # doc = doc.reshape(-1, V)
    # lam += phi[m, :N[m], :].T @ doc
    return (docs[d] == j) @ phi[d, :N[d], :]

def dg(gamma, d, i):
    """
    E[log θ_t] where θ_t ~ Dir(gamma)
    """
    return psi(gamma[d, i]) - psi(np.sum(gamma[d, :]))


def dl(lam, i, w_n):
    """
    E[log β_t] where β_t ~ Dir(lam)
    """
    return psi(lam[i, w_n]) - psi(np.sum(lam[i, :]))

def vlb(docs, phi, gamma, alpha, beta, M, N, k):
    """
    Average variational lower bound for joint log likelihood.
    """
    lb = 0
    for d in range(M):
        lb += (
            gammaln(np.sum(alpha))
            - np.sum(gammaln(alpha))
            + np.sum([(alpha[i] - 1) * dg(gamma, d, i) for i in range(k)])
        )

        lb -= (
            gammaln(np.sum(gamma[d, :]))
            - np.sum(gammaln(gamma[d, :]))
            + np.sum([(gamma[d, i] - 1) * dg(gamma, d, i) for i in range(k)])
        )

        for n in range(N[d]):
            w_n = int(docs[d][n])

            lb += np.sum([phi[d][n, i] * dg(gamma, d, i) for i in range(k)])
            lb += np.sum([phi[d][n, i] * np.log(beta[i, w_n]) for i in range(k)])
            lb -= np.sum([phi[d][n, i] * np.log(phi[d][n, i]) for i in range(k)])

    return lb / M

In [9]:
result['vocab_to_idx']

{'scientst': 0,
 'immunology': 1,
 'athletics': 2,
 'court': 3,
 'form': 4,
 'infection': 5,
 'electricity': 6,
 'Technique': 7,
 'astrophysics': 8,
 'allergy': 9,
 'bruise': 10,
 'recreation': 11,
 'genetics': 12,
 'attorney': 13,
 'FIFA': 14,
 'asymmetrical': 15,
 'quantum': 16,
 'evidence': 17,
 'fever': 18,
 'contagious': 19,
 'exercise': 20,
 'copyright': 21,
 'picture': 22,
 'decongestant': 23,
 'content': 24}

In [10]:
docs

{0: 'scientst immunology athletics court form infection electricity form Technique astrophysics',
 1: 'allergy bruise recreation genetics attorney FIFA electricity asymmetrical quantum Technique',
 2: 'evidence fever contagious exercise scientst copyright picture decongestant content court'}

In [11]:
docs_np = []
for doc in result['documents']: 

    doc_idx = []
    for n in range(len(doc)): 

        doc_idx.append(result['vocab_to_idx'][doc[n]])

    
    docs_np.append(np.array(doc_idx))
docs_np

[array([0, 1, 2, 3, 4, 5, 6, 4, 7, 8]),
 array([ 9, 10, 11, 12, 13, 14,  6, 15, 16,  7]),
 array([17, 18, 19, 20,  0, 21, 22, 23, 24,  3])]

In [12]:
docs = np.array(docs_np)
docs 

array([[ 0,  1,  2,  3,  4,  5,  6,  4,  7,  8],
       [ 9, 10, 11, 12, 13, 14,  6, 15, 16,  7],
       [17, 18, 19, 20,  0, 21, 22, 23, 24,  3]])

In [13]:
init_lda(docs, set(result['vocab_to_idx'].keys()), n_topic=5)

V: 25
k: 5
N: [10 10 10]...
M: 3
α: dim (5,)
β: dim (5, 25)
γ: dim (3, 5)
ϕ: dim (3, N_d, 5)


In [14]:
alpha

array([1.04708219, 0.98292693, 0.97347267, 0.97347428, 1.16278368])

In [15]:
gamma

array([[6.04708219, 5.98292693, 5.97347267, 5.97347428, 6.16278368],
       [6.04708219, 5.98292693, 5.97347267, 5.97347428, 6.16278368],
       [6.04708219, 5.98292693, 5.97347267, 5.97347428, 6.16278368]])

In [16]:
phi

array([[[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2]],

       [[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2]],

       [[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.

In [17]:
print(len(result['vocab_to_idx']))

25


In [18]:
phi, gamma = E_step(docs, phi, gamma, alpha, beta)
print(phi)
print()
print()
print(gamma)

[[[0.59987533 0.05989332 0.05591898 0.25764781 0.02666457]
  [0.24724437 0.01369662 0.25195747 0.1616097  0.32549185]
  [0.04131649 0.19213417 0.0787964  0.43461685 0.2531361 ]
  [0.07367909 0.20507565 0.12492449 0.2699581  0.32636267]
  [0.06538056 0.04044009 0.26047636 0.04728988 0.58641311]
  [0.08941611 0.16223635 0.03862844 0.35425703 0.35546207]
  [0.17389084 0.00787611 0.39078812 0.38468867 0.04275626]
  [0.06538056 0.04044009 0.26047636 0.04728988 0.58641311]
  [0.09616628 0.39302956 0.01359186 0.16118991 0.33602239]
  [0.04597935 0.03851539 0.59727913 0.22651461 0.09171152]]

 [[0.16426484 0.18154784 0.26493801 0.13596289 0.25328643]
  [0.04178727 0.10004484 0.06357349 0.2367149  0.5578795 ]
  [0.16945999 0.34668658 0.00280193 0.31477295 0.16627854]
  [0.14744797 0.24619831 0.5634054  0.00957322 0.0333751 ]
  [0.25036064 0.08091476 0.52050471 0.05400647 0.09421342]
  [0.22688931 0.49616496 0.19872735 0.00542042 0.07279796]
  [0.17389084 0.00787611 0.39078812 0.38468867 0.04275

In [19]:
np.random.seed(42)

_alpha_ = np.random.gamma(shape=100, scale=0.01, size=5)
_gamma_ = _alpha_ + np.ones((M, k)) * 21 / k

print(_alpha_)
print(_gamma_)

[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368]
[[5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
 [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
 [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]]


In [20]:
_update(_alpha_, _gamma_, M)

78.4192192526763|[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368] -> [1.79562569 1.71883122 1.70721122 1.70721321 1.92565069]
78.4192192526763|[1.79562569 1.71883122 1.70721122 1.70721321 1.92565069] -> [2.89706016 2.81739023 2.80519713 2.80519922 3.0290319 ]
78.4192192526763|[2.89706016 2.81739023 2.80519713 2.80519922 3.0290319 ] -> [4.15126141 4.07750791 4.06632234 4.06632425 4.27649846]
78.4192192526763|[4.15126141 4.07750791 4.06632234 4.06632425 4.27649846] -> [5.00796502 4.94131996 4.93140917 4.93141086 5.12604242]
78.4192192526763|[5.00796502 4.94131996 4.93140917 4.93141086 5.12604242] -> [5.23567369 5.17138297 5.16190351 5.16190513 5.35149999]
78.4192192526763|[5.23567369 5.17138297 5.16190351 5.16190513 5.35149999] -> [5.24705619 5.18290058 5.17344626 5.17344787 5.36275799]
78.4192192526763|[5.24705619 5.18290058 5.17344626 5.17344787 5.36275799] -> [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
78.4192192526763|[5.24708219 5.18292693 5.17347267 5.17347428 

array([5.24708219, 5.18292693, 5.17347267, 5.17347428, 5.36278368])

In [21]:
V = 19 
print(k)
print(V)

5
19


In [22]:
_eta_ = 1
np.random.seed(42)
_lambda_ = np.random.gamma(shape=100, scale=0.01, size=(k, V))

print(_eta_)
print(_lambda_)

1
[[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368 1.07526208
  0.95052839 1.05181933 1.02101821 0.81760009 0.89893338 1.0283693
  1.15026429 0.97429619 0.94330106 1.00778149 0.9378975  0.96782864
  0.9953198 ]
 [0.89475728 1.08105987 0.87968664 0.86986295 1.01644942 1.07222286
  1.01387299 0.98516565 0.96690772 0.95138329 1.10597141 1.02937237
  0.95871485 0.93060329 1.05898791 1.1031785  1.09256024 0.91521059
  0.96611462]
 [0.94959045 0.9782464  1.08000428 1.13828828 1.03320764 0.93363961
  0.99309432 1.16117251 1.00538214 0.96711283 0.97489658 1.03274514
  0.91811143 0.94740917 1.09087585 1.02984852 1.00638968 1.09653088
  0.92820856]
 [0.96431168 1.02652255 1.02295647 0.86193808 0.95525935 0.98065155
  1.03755099 1.02260049 0.98925295 0.81711981 0.99402204 1.00269174
  1.26335968 0.97758574 1.02707522 1.11516717 1.07363499 1.14337395
  0.86316381]
 [0.90101275 0.9411936  0.84973584 1.00352721 0.90766658 1.15954736
  0.92049911 0.96485874 1.0801087  0.87876646 0.84464572 1.

In [23]:
_lambda_[0]

array([1.04708219, 0.98292693, 0.97347267, 0.97347428, 1.16278368,
       1.07526208, 0.95052839, 1.05181933, 1.02101821, 0.81760009,
       0.89893338, 1.0283693 , 1.15026429, 0.97429619, 0.94330106,
       1.00778149, 0.9378975 , 0.96782864, 0.9953198 ])

In [24]:
def _update_eta(var, vi_var, const, max_iter=10000, tol=1e-9):
    """
    From appendix A.2 of Blei et al., 2003.
    For hessian with shape `H = diag(h) + 1z1'`
    
    To update alpha, input var=alpha and vi_var=gamma, const=M.
    To update eta, input var=eta and vi_var=lambda, const=k.
    """
    for _ in range(max_iter):
        # store old value
        var0 = var
        
        # g: gradient 
        psi_sum = psi(vi_var.sum(axis=1)).reshape(-1, 1)
        g = const * (V*psi(V*var) - V*psi(var)) + np.sum(psi(vi_var)) - np.sum(V*(psi_sum))

        h = const * (V**2 * polygamma(1, V*var) - V * polygamma(1, var))

        # # update var
        var -= g/h
        print(f"grad:{g}, hessian:{h}, eta:old{var0} -> {var}")

        if var == np.inf or var == -np.inf: 
            raise ValueError(f"Grad -> {g}, Hessian -> {h}, overflow")
        
        # check convergence
        err = np.sqrt(np.mean((var - var0) ** 2))
        crit = err < tol
        if crit:
            break
    else:
        warnings.warn(f"max_iter={max_iter} reached: values might not be optimal.")
    
    #print(err)
    return var

In [25]:
_update_eta(1, _lambda_, k)

grad:-0.8517960407552891, hessian:-58.72490095253988, eta:old1 -> 0.985495147255443
grad:0.013704776840540944, hessian:-60.62945883574141, eta:old0.985495147255443 -> 0.9857211888052723
grad:3.433071356084838e-06, hessian:-60.599086980475064, eta:old0.9857211888052723 -> 0.9857212454574683
grad:2.2737367544323206e-13, hessian:-60.59907937125736, eta:old0.9857212454574683 -> 0.9857212454574721


0.9857212454574721

In [26]:
_update(alpha, gamma, const=M)

45.4192192526763|[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368] -> [1.60984251 1.50609863 1.54390368 1.58949613 1.77290523]
45.4192192526763|[1.60984251 1.50609863 1.54390368 1.58949613 1.77290523] -> [2.18336381 2.03750007 2.14067556 2.2532118  2.39239093]
45.4192192526763|[2.18336381 2.03750007 2.14067556 2.2532118  2.39239093] -> [2.51267568 2.34201721 2.48875269 2.64805    2.74756143]
45.4192192526763|[2.51267568 2.34201721 2.48875269 2.64805    2.74756143] -> [2.57809687 2.40243168 2.55861498 2.7284538  2.81807722]
45.4192192526763|[2.57809687 2.40243168 2.55861498 2.7284538  2.81807722] -> [2.58006301 2.40424528 2.56073231 2.7309235  2.82019604]
45.4192192526763|[2.58006301 2.40424528 2.56073231 2.7309235  2.82019604] -> [2.58006471 2.40424684 2.56073416 2.73092568 2.82019787]
45.4192192526763|[2.58006471 2.40424684 2.56073416 2.73092568 2.82019787] -> [2.58006471 2.40424684 2.56073416 2.73092568 2.82019787]


array([2.58006471, 2.40424684, 2.56073416, 2.73092568, 2.82019787])

In [27]:
%%time
N_EPOCH = 1000
TOL = 0.1

verbose = True
lb = -np.inf

for epoch in range(N_EPOCH): 
    # store old value
    lb_old = lb 
    
    # Variational EM
    phi, gamma = E_step(docs, phi, gamma, alpha, beta)
    alpha, beta = M_step(docs, phi, gamma, alpha, beta, M)
    
    # check anomaly
    if np.any(np.isnan(alpha)):
        print("NaN detected: alpha")
        break
    
    # check convergence
    lb = vlb(docs, phi, gamma, alpha, beta, M, N, k)
    err = abs(lb - lb_old)
    
    # check anomaly
    if np.isnan(lb):
        print("NaN detected: lb")
        break
        
    if verbose:
        print(f"{epoch: 04}:  variational_lb: {lb: .3f},  error: {err: .3f}")
    
    if err < TOL:
        break
else:
    warnings.warn(f"max_iter reached: values might not be optimal.")

print(" ========== TRAINING FINISHED ==========")

69.28850775900855|[2.58006471 2.40424684 2.56073416 2.73092568 2.82019787] -> [3.2802924  3.01731455 3.27298312 3.61080008 3.66333863]
69.28850775900855|[3.2802924  3.01731455 3.27298312 3.61080008 3.66333863] -> [3.56332256 3.26440639 3.56133202 3.97438339 4.00719189]
69.28850775900855|[3.56332256 3.26440639 3.56133202 3.97438339 4.00719189] -> [3.59329858 3.29053606 3.59189913 4.01349309 4.04381849]
69.28850775900855|[3.59329858 3.29053606 3.59189913 4.01349309 4.04381849] -> [3.59358325 3.29078395 3.5921896  4.01386984 4.04416798]
69.28850775900855|[3.59358325 3.29078395 3.5921896  4.01386984 4.04416798] -> [3.59358327 3.29078397 3.59218963 4.01386987 4.04416801]
 000:  variational_lb: -34.138,  error:  inf
85.60378428511153|[3.59358327 3.29078397 3.59218963 4.01386987 4.04416801] -> [4.30466117 3.79451801 4.19480609 4.7055656  4.67247706]
85.60378428511153|[4.30466117 3.79451801 4.19480609 4.7055656  4.67247706] -> [4.46325416 3.90457606 4.32658029 4.85703393 4.8100705 ]
85.6037842

  lb += np.sum([phi[d][n, i] * np.log(beta[i, w_n]) for i in range(k)])
  lb += np.sum([phi[d][n, i] * np.log(beta[i, w_n]) for i in range(k)])
  lb -= np.sum([phi[d][n, i] * np.log(phi[d][n, i]) for i in range(k)])
  lb -= np.sum([phi[d][n, i] * np.log(phi[d][n, i]) for i in range(k)])


In [28]:
alpha 

array([0.08531124, 0.08008666, 0.08052729, 0.07825668, 0.08192304])

In [29]:
beta

array([[2.50000158e-001, 3.73312601e-266, 7.32838132e-267,
        2.50000158e-001, 1.13237083e-266, 1.19013162e-266,
        5.75243540e-249, 6.34074407e-249, 2.13708622e-266,
        2.08630140e-272, 1.49136563e-272, 2.99707612e-272,
        1.01805620e-272, 2.27339414e-272, 1.93547137e-272,
        4.31694042e-273, 1.60010657e-272, 2.49999842e-001,
        2.49999842e-001, 1.77042842e-049, 3.53886271e-050,
        1.27319332e-050, 5.62869190e-049, 6.38003752e-049,
        3.12758634e-049],
       [4.14461028e-294, 1.02895608e-303, 1.69561163e-302,
        1.87534635e-293, 3.48489090e-303, 1.07439616e-302,
        5.30010883e-005, 8.09463677e-002, 8.90698206e-303,
        7.72087741e-002, 1.19557825e-001, 2.05310310e-001,
        5.69195132e-002, 2.46025178e-002, 1.41723529e-001,
        9.32491789e-002, 2.00428983e-001, 0.00000000e+000,
        0.00000000e+000, 7.05318388e-056, 2.56337538e-057,
        6.03484255e-057, 1.28042161e-057, 1.08873290e-056,
        1.36151144e-056],
    

In [30]:
phi 

array([[[1.47170412e-006, 1.45912181e-293, 7.01819230e-275,
         9.98209203e-001, 1.78932491e-003],
        [1.49325135e-265, 4.96540397e-303, 2.64532574e-285,
         3.84785019e-002, 9.61521498e-001],
        [2.93135438e-266, 8.18246460e-302, 9.71844389e-286,
         1.21561412e-001, 8.78438588e-001],
        [1.47149933e-006, 6.60196154e-293, 3.45628768e-274,
         9.95397358e-001, 4.60117074e-003],
        [2.26474309e-266, 8.40846923e-303, 1.56849819e-285,
         6.45777278e-003, 9.93542227e-001],
        [4.76052947e-266, 5.18468284e-302, 3.57513200e-286,
         7.43536295e-002, 9.25646370e-001],
        [1.16711753e-248, 1.44225496e-010, 5.58214553e-007,
         8.67112614e-001, 1.32886828e-001],
        [2.26474309e-266, 8.40846923e-303, 1.56849819e-285,
         6.45777278e-003, 9.93542227e-001],
        [1.26895427e-248, 2.14264323e-007, 3.30975060e-007,
         3.64662886e-001, 6.35336569e-001],
        [8.54835028e-266, 4.29821689e-302, 1.93037745e-284,
    

In [31]:
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.datasets import make_multilabel_classification

In [52]:
docs

array([[ 0,  1,  2,  3,  4,  5,  6,  4,  7,  8],
       [ 9, 10, 11, 12, 13, 14,  6, 15, 16,  7],
       [17, 18, 19, 20,  0, 21, 22, 23, 24,  3]])

In [56]:
lda = LatentDirichletAllocation(
    n_components=5,
    random_state=0,
    doc_topic_prior= 1,
    topic_word_prior= 1,
)
lda.fit(docs_np)

In [57]:
lda._approx_bound(
    docs, 
    doc_topic_distr=np.array(
        [
            [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368],
            [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368],
            [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368],
        ]
    ), # this term is var inf parameter, gamma 
    sub_sampling=False
)

-848.1362078657901

In [55]:
lda._approx_bound(
    docs, 
    doc_topic_distr=np.array(
    [
        [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368],
        [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368],
        [4.84708219, 4.78292693, 4.77347267, 4.77347428, 4.96278368]
    ]
    ), 
    sub_sampling=True
)

-353568206.5163521