In [33]:
import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

In [69]:
newsgroups_train = fetch_20newsgroups(subset = 'train', remove = ('headers', 'footers', 'quotes'))

vectorizer = CountVectorizer(lowercase = True, stop_words = "english",
                             analyzer = 'word', min_df = 5, max_df = 0.5, max_features = 2000, binary = True)
vectorizer.fit(newsgroups_train.data[:2000])

X_train = vectorizer.fit_transform(newsgroups_train.data[:2000])

In [70]:
def Sample(n, p):
    q = p / p.sum()
    u = np.random.rand()
    for i in range(n):
        if u < q[i]:
            return i
        else:
            u -= q[i]

def LDA_alg(K, a, N, b, W, w, D, d, iters):
    b_sum = b.sum()
    
    nkw = np.zeros(K * N).reshape(K, N)
    ndk = np.zeros(K * D).reshape(D, K)
    nk  = np.zeros(K)
    
    t = np.random.dirichlet(a)
    z = [Sample(K, t) for _ in range(W)]
    
    for i in range(W):
        nkw[z[i], w[i]] += 1
        ndk[d[i], z[i]] += 1
        nk[z[i]]        += 1

    for it in range(iters):
        print('iteration: ' + str(it + 1))
        
        for i in range(W):
            nkw[z[i], w[i]] -= 1
            ndk[d[i], z[i]] -= 1
            nk[z[i]]        -= 1

            p = []
            for k in range(K):
                p.append((ndk[d[i], k] + a[k]) * (nkw[k, w[i]] + b[w[i]]) / (nk[k] + b_sum))

            z[i] = Sample(len(p), np.array(p))
            
            nkw[z[i], w[i]] += 1
            ndk[d[i], z[i]] += 1
            nk[z[i]]        += 1

    return z

K = 20
a = np.ones(K)

N = len(vectorizer.vocabulary_)
b = np.ones(N)

w = X_train.nonzero()[1]
W = len(w)

d = X_train.nonzero()[0]
D = X_train.shape[0]

iters = 100

z = LDA_alg(K, a, N, b, W, w, D, d, iters)

nkw = np.zeros(K * N).reshape(K, N)

for i in range(W):
    nkw[z[i], w[i]] += 1

print('\nTop 10 words in each theme:')

words = vectorizer.get_feature_names_out()
for k in range(K):
    arr = np.argpartition(nkw[k], -10)[-10:]
    print(words[arr])

iteration: 1
iteration: 2
iteration: 3
iteration: 4
iteration: 5
iteration: 6
iteration: 7
iteration: 8
iteration: 9
iteration: 10
iteration: 11
iteration: 12
iteration: 13
iteration: 14
iteration: 15
iteration: 16
iteration: 17
iteration: 18
iteration: 19
iteration: 20
iteration: 21
iteration: 22
iteration: 23
iteration: 24
iteration: 25
iteration: 26
iteration: 27
iteration: 28
iteration: 29
iteration: 30
iteration: 31
iteration: 32
iteration: 33
iteration: 34
iteration: 35
iteration: 36
iteration: 37
iteration: 38
iteration: 39
iteration: 40
iteration: 41
iteration: 42
iteration: 43
iteration: 44
iteration: 45
iteration: 46
iteration: 47
iteration: 48
iteration: 49
iteration: 50
iteration: 51
iteration: 52
iteration: 53
iteration: 54
iteration: 55
iteration: 56
iteration: 57
iteration: 58
iteration: 59
iteration: 60
iteration: 61
iteration: 62
iteration: 63
iteration: 64
iteration: 65
iteration: 66
iteration: 67
iteration: 68
iteration: 69
iteration: 70
iteration: 71
iteration: 72
i