In [1]:
%load_ext autoreload
%autoreload 2
import sys 
if '/Users/ericliu/Desktop/Latent-Dirichilet-Allocation' not in sys.path: 
    sys.path.append('/Users/ericliu/Desktop/Latent-Dirichilet-Allocation')
import torch as tr 
import numpy as np 
import pandas as pd 
from collections import defaultdict
from pprint import pprint
from scipy.special import psi, polygamma, gammaln, loggamma
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
from src.lda_model import LDASmoothed 
from src.generator import doc_generator 
from src.utils import (
    get_vocab_from_docs, 
    get_np_wct, 
    data_loader,
    text_pipeline, 
    process_documents,
) 
from src.text_pre_processor import (
    remove_accented_chars, 
    remove_special_characters, 
    remove_punctuation,
    remove_extra_whitespace_tabs,
    remove_stopwords,
)
from pprint import pprint 
import copy 

  from .autonotebook import tqdm as notebook_tqdm


In [36]:
5*gammaln(22*2) - 5 * 22 * gammaln(2)

607.6654075771932

In [38]:
5 * 22 * gammaln(2) 

0.0

In [40]:
22 * gammaln(1) 

0.0

In [37]:
5*gammaln(22*2) -  22 * gammaln(2)

607.6654075771932

In [2]:
gen = doc_generator(
    M = 3,
    L = 20, 
    topic_prior = tr.tensor([1,1,1,1,1], dtype=tr.double)
)

docs = gen.generate_doc()

Document: 0 | word: 0 -> topic: health -> word: injection
Document: 0 | word: 1 -> topic: science -> word: astrophysics
Document: 0 | word: 2 -> topic: health -> word: decongestant
Document: 0 | word: 3 -> topic: art -> word: Craftsmanship
Document: 0 | word: 4 -> topic: health -> word: exercise
Document: 0 | word: 5 -> topic: law -> word: attorney
Document: 0 | word: 6 -> topic: health -> word: decongestant
Document: 0 | word: 7 -> topic: health -> word: fever
Document: 0 | word: 8 -> topic: health -> word: Symmetrical
Document: 0 | word: 9 -> topic: sport -> word: evidence
Document 0: injection astrophysics decongestant Craftsmanship exercise attorney decongestant fever Symmetrical evidence

Document: 1 | word: 0 -> topic: science -> word: immunology
Document: 1 | word: 1 -> topic: sport -> word: electricity
Document: 1 | word: 2 -> topic: art -> word: energy
Document: 1 | word: 3 -> topic: sport -> word: physical
Document: 1 | word: 4 -> topic: sport -> word: athletics
Document: 1 |

In [3]:
result = process_documents(docs, sample=True) 

There are 3 documents in the dataset after processing
On average estimated document length is 10.0 words per document after processing
There are 22 unique vocab in the corpus after processing


In [4]:
import warnings
def init_lda(docs, vocab, n_topic, gibbs=False, random_state=0):
    if gibbs:
        global V, k, N, M, alpha, eta, n_iw, n_di
    else:
        global V, k, N, M, alpha, beta, gamma, phi
        
    np.random.seed(random_state)

    V = len(vocab)
    k = n_topic  # number of topics
    N = np.array([doc.shape[0] for doc in docs])
    M = len(docs)

    print(f"V: {V}\nk: {k}\nN: {N[:10]}...\nM: {M}")

    # initialize α, β
    np.random.random(42)
    if gibbs:
        alpha = np.random.gamma(shape=100, scale=0.01, size=1)  # one for all k
        eta = np.random.gamma(shape=100, scale=0.01, size=1)  # one for all V
        print(f"α: {alpha}\nη: {eta}")
        
        n_iw = np.zeros((k, V), dtype=int)
        n_di = np.zeros((M, k), dtype=int)
        print(f"n_iw: dim {n_iw.shape}\nn_di: dim {n_di.shape}")
    else:
        alpha = np.random.gamma(shape=100, scale=0.01, size=k) #np.random.rand(k)
        beta = np.random.dirichlet(np.ones(V), k)
        print(f"α: dim {alpha.shape}\nβ: dim {beta.shape}")

        # initialize ϕ, γ
        ## ϕ: (M x max(N) x k) arrays with zero paddings on the right
        gamma = alpha + np.ones((M, k)) * N.reshape(-1, 1) / k

        phi = np.ones((M, max(N), k)) / k
        for m, N_d in enumerate(N):
            phi[m, N_d:, :] = 0  # zero padding for vectorized operations

        print(f"γ: dim {gamma.shape}\nϕ: dim ({len(phi)}, N_d, {phi[0].shape[1]})")

def E_step(docs, phi, gamma, alpha, beta):
    """
    Minorize the joint likelihood function via variational inference.
    This is the E-step of variational EM algorithm for LDA.
    """
    # optimize phi
    for m in range(M):
        #print(N[m], docs[m])
        phi[m, :N[m], :] = (beta[:, docs[m]] * np.exp(
            psi(gamma[m, :]) - psi(gamma[m, :].sum())
        ).reshape(-1, 1)).T

        # Normalize phi
        phi[m, :N[m]] /= phi[m, :N[m]].sum(axis=1).reshape(-1, 1)
        if np.any(np.isnan(phi)):
            raise ValueError("phi nan")

    # optimize gamma
    gamma = alpha + phi.sum(axis=1)

    return phi, gamma


def M_step(docs, phi, gamma, alpha, beta, M):
    """
    maximize the lower bound of the likelihood.
    This is the M-step of variational EM algorithm for (smoothed) LDA.
    
    update of alpha follows from appendix A.2 of Blei et al., 2003.
    """
    # update alpha
    alpha = _update(alpha, gamma, M)
    
    # update beta
    for j in range(V):
        beta[:, j] = np.array([_phi_dot_w(docs, phi, m, j) for m in range(M)]).sum(axis=0)
    beta /= beta.sum(axis=1).reshape(-1, 1)

    return alpha, beta

def _update(var, vi_var, const, max_iter=10000, tol=1e-6):
    """
    From appendix A.2 of Blei et al., 2003.
    For hessian with shape `H = diag(h) + 1z1'`
    
    To update alpha, input var=alpha and vi_var=gamma, const=M.
    To update eta, input var=eta and vi_var=lambda, const=k.
    """
    for _ in range(max_iter):
        # store old value
        var0 = var.copy()
        
        # g: gradient 
        psi_sum = psi(vi_var.sum(axis=1)).reshape(-1, 1)
        g = const * (psi(var.sum()) - psi(var)) \
            + (psi(vi_var) - psi_sum).sum(axis=0)

        # H = diag(h) + 1z1'
        z = const * polygamma(1, var.sum())  # z: Hessian constant component
        h = -const * polygamma(1, var)       # h: Hessian diagonal component
        c = (g / h).sum() / (1./z + (1./h).sum())

        # update var
        var -= (g - c) / h
        print(f"{vi_var.sum()}|{var0} -> {var}")
        
        # check convergence
        err = np.sqrt(np.mean((var - var0) ** 2))
        crit = err < tol
        if crit:
            break
    else:
        warnings.warn(f"max_iter={max_iter} reached: values might not be optimal.")
    
    #print(err)
    return var

def _phi_dot_w(docs, phi, d, j):
    """
    \sum_{n=1}^{N_d} ϕ_{dni} w_{dn}^j
    """
    # doc = np.zeros(docs[m].shape[0] * V, dtype=int)
    # doc[np.arange(0, docs[m].shape[0] * V, V) + docs[m]] = 1
    # doc = doc.reshape(-1, V)
    # lam += phi[m, :N[m], :].T @ doc
    return (docs[d] == j) @ phi[d, :N[d], :]

def dg(gamma, d, i):
    """
    E[log θ_t] where θ_t ~ Dir(gamma)
    """
    return psi(gamma[d, i]) - psi(np.sum(gamma[d, :]))


def dl(lam, i, w_n):
    """
    E[log β_t] where β_t ~ Dir(lam)
    """
    return psi(lam[i, w_n]) - psi(np.sum(lam[i, :]))

def vlb(docs, phi, gamma, alpha, beta, M, N, k):
    """
    Average variational lower bound for joint log likelihood.
    """
    lb = 0
    for d in range(M):
        lb += (
            gammaln(np.sum(alpha))
            - np.sum(gammaln(alpha))
            + np.sum([(alpha[i] - 1) * dg(gamma, d, i) for i in range(k)])
        )

        lb -= (
            gammaln(np.sum(gamma[d, :]))
            - np.sum(gammaln(gamma[d, :]))
            + np.sum([(gamma[d, i] - 1) * dg(gamma, d, i) for i in range(k)])
        )

        for n in range(N[d]):
            w_n = int(docs[d][n])

            lb += np.sum([phi[d][n, i] * dg(gamma, d, i) for i in range(k)])
            lb += np.sum([phi[d][n, i] * np.log(beta[i, w_n]) for i in range(k)])
            lb -= np.sum([phi[d][n, i] * np.log(phi[d][n, i]) for i in range(k)])

    return lb / M

In [5]:
result['vocab_to_idx']

{'injection': 0,
 'astrophysics': 1,
 'decongestant': 2,
 'Craftsmanship': 3,
 'exercise': 4,
 'attorney': 5,
 'fever': 6,
 'Symmetrical': 7,
 'evidence': 8,
 'immunology': 9,
 'electricity': 10,
 'energy': 11,
 'physical': 12,
 'athletics': 13,
 'allergy': 14,
 'bruise': 15,
 'infection': 16,
 'genetics': 17,
 'form': 18,
 'appetite': 19,
 'asymmetrical': 20,
 'Technique': 21}

In [6]:
docs

{0: 'injection astrophysics decongestant Craftsmanship exercise attorney decongestant fever Symmetrical evidence',
 1: 'immunology electricity energy physical athletics allergy bruise infection genetics form',
 2: 'form physical Symmetrical appetite athletics asymmetrical Craftsmanship Symmetrical Technique energy'}

In [7]:
docs_np = []
for doc in result['documents']: 

    doc_idx = []
    for n in range(len(doc)): 

        doc_idx.append(result['vocab_to_idx'][doc[n]])

    
    docs_np.append(np.array(doc_idx))
docs_np

[array([0, 1, 2, 3, 4, 5, 2, 6, 7, 8]),
 array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]),
 array([18, 12,  7, 19, 13, 20,  3,  7, 21, 11])]

In [8]:
docs = docs_np

In [9]:
init_lda(docs, set(result['vocab_to_idx'].keys()), n_topic=5)

V: 22
k: 5
N: [10 10 10]...
M: 3
α: dim (5,)
β: dim (5, 22)
γ: dim (3, 5)
ϕ: dim (3, N_d, 5)


In [10]:
alpha

array([0.9623354 , 1.01235711, 0.95849651, 0.96679042, 0.8358445 ])

In [11]:
gamma

array([[2.9623354 , 3.01235711, 2.95849651, 2.96679042, 2.8358445 ],
       [2.9623354 , 3.01235711, 2.95849651, 2.96679042, 2.8358445 ],
       [2.9623354 , 3.01235711, 2.95849651, 2.96679042, 2.8358445 ]])

In [12]:
phi

array([[[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2]],

       [[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2]],

       [[0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.2, 0.2, 0.2, 0.2],
        [0.2, 0.

In [31]:
print(len(result['vocab_to_idx']))

22


In [34]:
np.random.seed(42)

_alpha_ = np.random.gamma(shape=100, scale=0.01, size=5)
_gamma_ = _alpha_ + np.ones((M, k)) * 21 / k

print(_alpha_)
print(_gamma_)

[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368]
[[5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
 [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
 [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]]


In [35]:
_update(_alpha_, _gamma_, M)

78.4192192526763|[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368] -> [1.79562569 1.71883122 1.70721122 1.70721321 1.92565069]
78.4192192526763|[1.79562569 1.71883122 1.70721122 1.70721321 1.92565069] -> [2.89706016 2.81739023 2.80519713 2.80519922 3.0290319 ]
78.4192192526763|[2.89706016 2.81739023 2.80519713 2.80519922 3.0290319 ] -> [4.15126141 4.07750791 4.06632234 4.06632425 4.27649846]
78.4192192526763|[4.15126141 4.07750791 4.06632234 4.06632425 4.27649846] -> [5.00796502 4.94131996 4.93140917 4.93141086 5.12604242]
78.4192192526763|[5.00796502 4.94131996 4.93140917 4.93141086 5.12604242] -> [5.23567369 5.17138297 5.16190351 5.16190513 5.35149999]
78.4192192526763|[5.23567369 5.17138297 5.16190351 5.16190513 5.35149999] -> [5.24705619 5.18290058 5.17344626 5.17344787 5.36275799]
78.4192192526763|[5.24705619 5.18290058 5.17344626 5.17344787 5.36275799] -> [5.24708219 5.18292693 5.17347267 5.17347428 5.36278368]
78.4192192526763|[5.24708219 5.18292693 5.17347267 5.17347428 

array([5.24708219, 5.18292693, 5.17347267, 5.17347428, 5.36278368])

In [25]:
V = 19 
print(k)
print(V)

5
19


In [26]:
_eta_ = 1
np.random.seed(42)
_lambda_ = np.random.gamma(shape=100, scale=0.01, size=(k, V))

print(_eta_)
print(_lambda_)

1
[[1.04708219 0.98292693 0.97347267 0.97347428 1.16278368 1.07526208
  0.95052839 1.05181933 1.02101821 0.81760009 0.89893338 1.0283693
  1.15026429 0.97429619 0.94330106 1.00778149 0.9378975  0.96782864
  0.9953198 ]
 [0.89475728 1.08105987 0.87968664 0.86986295 1.01644942 1.07222286
  1.01387299 0.98516565 0.96690772 0.95138329 1.10597141 1.02937237
  0.95871485 0.93060329 1.05898791 1.1031785  1.09256024 0.91521059
  0.96611462]
 [0.94959045 0.9782464  1.08000428 1.13828828 1.03320764 0.93363961
  0.99309432 1.16117251 1.00538214 0.96711283 0.97489658 1.03274514
  0.91811143 0.94740917 1.09087585 1.02984852 1.00638968 1.09653088
  0.92820856]
 [0.96431168 1.02652255 1.02295647 0.86193808 0.95525935 0.98065155
  1.03755099 1.02260049 0.98925295 0.81711981 0.99402204 1.00269174
  1.26335968 0.97758574 1.02707522 1.11516717 1.07363499 1.14337395
  0.86316381]
 [0.90101275 0.9411936  0.84973584 1.00352721 0.90766658 1.15954736
  0.92049911 0.96485874 1.0801087  0.87876646 0.84464572 1.

In [27]:
_lambda_[0]

array([1.04708219, 0.98292693, 0.97347267, 0.97347428, 1.16278368,
       1.07526208, 0.95052839, 1.05181933, 1.02101821, 0.81760009,
       0.89893338, 1.0283693 , 1.15026429, 0.97429619, 0.94330106,
       1.00778149, 0.9378975 , 0.96782864, 0.9953198 ])

In [28]:
def _update_eta(var, vi_var, const, max_iter=10000, tol=1e-9):
    """
    From appendix A.2 of Blei et al., 2003.
    For hessian with shape `H = diag(h) + 1z1'`
    
    To update alpha, input var=alpha and vi_var=gamma, const=M.
    To update eta, input var=eta and vi_var=lambda, const=k.
    """
    for _ in range(max_iter):
        # store old value
        var0 = var
        
        # g: gradient 
        psi_sum = psi(vi_var.sum(axis=1)).reshape(-1, 1)
        g = const * (V*psi(V*var) - V*psi(var)) + np.sum(psi(vi_var)) - np.sum(V*(psi_sum))
        #print(g)

        # H = diag(h) + 1z1'
        # z = const * polygamma(1, var.sum())  # z: Hessian constant component
        # h = -const * polygamma(1, var)       # h: Hessian diagonal component
        # c = (g / h).sum() / (1./z + (1./h).sum())

        h = const * (V**2 * polygamma(1, V*var) - V * polygamma(1, var))

        # # update var
        var -= g/h
        print(f"{g}, {h}, {var0} -> {var}")

        if var == np.inf or var == -np.inf: 
            raise ValueError(f"Grad -> {g}, Hessian -> {h}, overflow")
        
        # check convergence
        err = np.sqrt(np.mean((var - var0) ** 2))
        crit = err < tol
        if crit:
            break
    else:
        warnings.warn(f"max_iter={max_iter} reached: values might not be optimal.")
    
    #print(err)
    return var

In [29]:
_update_eta(1, _lambda_, k)

-0.8517960407552891, -58.72490095253988, 1 -> 0.985495147255443
0.013704776840540944, -60.62945883574141, 0.985495147255443 -> 0.9857211888052723
3.433071356084838e-06, -60.599086980475064, 0.9857211888052723 -> 0.9857212454574683
2.2737367544323206e-13, -60.59907937125736, 0.9857212454574683 -> 0.9857212454574721


0.9857212454574721

In [None]:
_update()

In [47]:
%%time
N_EPOCH = 1000
TOL = 0.1

verbose = True
lb = -np.inf

for epoch in range(N_EPOCH): 
    # store old value
    lb_old = lb 
    
    # Variational EM
    phi, gamma = E_step(docs, phi, gamma, alpha, beta)
    alpha, beta = M_step(docs, phi, gamma, alpha, beta, M)
    
    # check anomaly
    if np.any(np.isnan(alpha)):
        print("NaN detected: alpha")
        break
    
    # check convergence
    lb = vlb(docs, phi, gamma, alpha, beta, M, N, k)
    err = abs(lb - lb_old)
    
    # check anomaly
    if np.isnan(lb):
        print("NaN detected: lb")
        break
        
    if verbose:
        print(f"{epoch: 04}:  variational_lb: {lb: .3f},  error: {err: .3f}")
    
    if err < TOL:
        break
else:
    warnings.warn(f"max_iter reached: values might not be optimal.")

print(" ========== TRAINING FINISHED ==========")

45.912398470001705|[1.18335424 1.03715177 1.1949785  0.90225084 0.98639747] -> [1.73826588 1.72348485 1.81439381 1.34476821 1.6441462 ]
45.912398470001705|[1.73826588 1.72348485 1.81439381 1.34476821 1.6441462 ] -> [2.3167313  2.49773291 2.46711123 1.80097078 2.39044765]
45.912398470001705|[2.3167313  2.49773291 2.46711123 1.80097078 2.39044765] -> [2.67856843 3.00225603 2.87683695 2.0835062  2.87898895]
45.912398470001705|[2.67856843 3.00225603 2.87683695 2.0835062  2.87898895] -> [2.76380019 3.12393204 2.97345962 2.14956357 2.99722601]
45.912398470001705|[2.76380019 3.12393204 2.97345962 2.14956357 2.99722601] -> [2.76730548 3.12902289 2.97743433 2.1522629  3.00218893]
45.912398470001705|[2.76730548 3.12902289 2.97743433 2.1522629  3.00218893] -> [2.76731107 3.12903114 2.97744067 2.15226717 3.00219699]
45.912398470001705|[2.76731107 3.12903114 2.97744067 2.15226717 3.00219699] -> [2.76731107 3.12903114 2.97744067 2.15226717 3.00219699]
 000:  variational_lb: -30.853,  error:  inf
72.

In [None]:
np.random.seed(42)

_alpha_ = np.random.gamma(shape=100, scale=0.01, size=5)
_gamma_ = _alpha_ + np.ones((M, 10)) * 10 / 5


In [14]:
alpha 

array([6.35049448, 7.69240324, 7.39146901, 5.22342831, 7.20036156])

In [15]:
beta 

array([[0.07400785, 0.14550193, 0.18075127, 0.01376765, 0.02578853,
        0.00450348, 0.06253672, 0.08156986, 0.10041701, 0.06151805,
        0.03819088, 0.03338672, 0.038713  , 0.00971253, 0.02052171,
        0.00168327, 0.06356642, 0.00347359, 0.02345899, 0.00141224,
        0.01551829],
       [0.05505653, 0.13390507, 0.0012131 , 0.17568639, 0.14754379,
        0.14400777, 0.09256184, 0.05702787, 0.01292405, 0.00270679,
        0.0083801 , 0.00098671, 0.00826088, 0.04601553, 0.00139751,
        0.00044219, 0.00964543, 0.00246672, 0.04286658, 0.00307505,
        0.0538301 ],
       [0.00113359, 0.00435   , 0.00131078, 0.09023193, 0.01676564,
        0.09682216, 0.00105029, 0.00100453, 0.0108035 , 0.00658834,
        0.00136415, 0.00455233, 0.00452761, 0.20450217, 0.00079125,
        0.00763315, 0.01627071, 0.12942339, 0.20310205, 0.12678065,
        0.07099178],
       [0.03791479, 0.00198767, 0.00663336, 0.00697271, 0.14571166,
        0.02789297, 0.00384071, 0.00831163, 0.0056685

In [16]:
gamma

array([[ 9.43061307, 11.83846511,  8.05761531,  5.86301808,  8.85171606],
       [ 7.83473102,  8.80924744,  8.5312483 ,  9.56677131,  9.29942956],
       [ 7.30802872,  9.29018623, 12.50237239,  5.3282612 ,  9.61257909]])

In [17]:
phi

array([[[0.38730778, 0.37146697, 0.00780928, 0.19508738, 0.03832859],
        [0.38073006, 0.45172946, 0.01498353, 0.0051137 , 0.14744324],
        [0.94593173, 0.00818479, 0.00902995, 0.03413139, 0.00272214],
        [0.02613875, 0.42801611, 0.17700026, 0.0125832 , 0.35626167],
        [0.07383752, 0.57645515, 0.05663298, 0.29094579, 0.00212856],
        [0.02613875, 0.42801611, 0.17700026, 0.0125832 , 0.35626167],
        [0.00896085, 0.38328767, 0.185269  , 0.04071927, 0.38176321],
        [0.32727554, 0.62451571, 0.00723545, 0.01976203, 0.02121127],
        [0.38073006, 0.45172946, 0.01498353, 0.0051137 , 0.14744324],
        [0.42688232, 0.38476762, 0.00692021, 0.0427668 , 0.13866306]],

       [[0.27397908, 0.04225919, 0.03023781, 0.01891114, 0.63461278],
        [0.32194448, 0.01826274, 0.04538696, 0.58907069, 0.02533514],
        [0.1998656 , 0.0565406 , 0.00939758, 0.6211255 , 0.11307073],
        [0.17472387, 0.00665735, 0.03136095, 0.51948208, 0.26777575],
        [0.2025980

In [None]:
e