In [46]:
import numpy as np
import gensim

def lda(vocabulary, beta, alpha, xi, printFlag = True):
    num_words = np.random.poisson(xi)
    dirichlet = np.random.dirichlet(alpha)
    num_topics = len(alpha)
    if (printFlag):
        print("Num of Words: {}".format(num_words))
        print("Topic Distribution: {}".format(dirichlet))
    document = []
    for _ in range(num_words):
        topic = np.random.choice(num_topics, p=dirichlet)
        word_dist = beta[topic]
        word = np.random.choice(vocabulary, p=word_dist)
        document.append(word)
    return document

In [47]:
vocabulary = ['bass', 'pike', 'deep', 'tuba', 'horn', 'catapult']
beta = np.array([
[0.4, 0.4, 0.2, 0.0, 0.0, 0.0],
[0.0, 0.3, 0.1, 0.0, 0.3, 0.3],
[0.3, 0.0, 0.2, 0.3, 0.2, 0.0]
])
alpha = np.array([1, 3, 8])
xi = 50
document = lda(vocabulary, beta, alpha, xi)
print("Document for Q1: " + (" ".join(document)).strip())

Num of Words: 52
Topic Distribution: [0.0676924  0.30212408 0.63018352]
Document for Q1: bass tuba pike deep horn catapult tuba deep bass deep tuba pike horn tuba horn horn pike bass deep horn bass pike catapult bass deep bass bass catapult horn pike horn deep tuba bass pike deep catapult bass horn horn bass catapult pike pike bass horn pike deep deep pike tuba deep


In [48]:
num_documents = 500
documents = []
for _ in range(num_documents):
    documents.append(lda(vocabulary, beta, alpha, xi, False))

dict = gensim.corpora.Dictionary(documents)
corpus = [dict.doc2bow(doc) for doc in documents]
ldamodel = gensim.models.LdaModel(corpus, alpha='auto', eta='auto', num_topics=3, id2word = dict, passes=50, iterations=400)
print("Learnt Beta:")
print(ldamodel.show_topic(0))
print(ldamodel.show_topic(1))
print(ldamodel.show_topic(2))

Learnt Beta:
[('pike', 0.3918084), ('catapult', 0.19643642), ('bass', 0.1914098), ('horn', 0.11623159), ('deep', 0.08610583), ('tuba', 0.018008059)]
[('bass', 0.25454065), ('tuba', 0.22923991), ('horn', 0.20607078), ('deep', 0.19505307), ('pike', 0.07129663), ('catapult', 0.043799013)]
[('horn', 0.29902488), ('catapult', 0.26223847), ('pike', 0.1896851), ('deep', 0.13503475), ('tuba', 0.078078814), ('bass', 0.03593799)]
