#### Implemented by Liushiya Chen (lc3501@columbia.edu) for COMS 6998
reference: http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixed-membership.pdf

In [None]:
%load_ext nb_black

In [71]:
import pyro
import pyro.distributions as dist
import torch


def mixed_membership_model(alpha, nu, num_docs, num_words, vocab):

    # Each topic is a categorical distribution over words. E.g. topics[i] = {word1: 0.2, word2: 0.8}
    topics = [
        pyro.sample("topic_{}".format(i), dist.Dirichlet(nu)) for i in range(len(alpha))
    ]

    # Every document is a categorical distribution over topics. E.g. documents[i] = {science: 0.2, business: 0.8}
    documents = [
        pyro.sample("doc_{}".format(i), dist.Dirichlet(alpha)) for i in range(num_docs)
    ]

    # A corpus is a set of documents
    corpus = []

    for i in range(num_docs):
        document = []
        for j in range(num_words):
            topic_id = pyro.sample(
                "assigned_topic_{}_{}".format(i, j), dist.Categorical(documents[i])
            )
            word_id = pyro.sample(
                "word_id_{}_{}".format(i, j), dist.Categorical(topics[topic_id])
            )
            document.append(vocab[word_id])
        corpus.append(document)
    return corpus

<IPython.core.display.Javascript object>

In [74]:
VOCAB = [
    "budget",
    "leak",
    "designer",
    "raid",
    "colon",
    "grace",
    "defendant",
    "comprehensive",
    "retreat",
    "factor",
    "adjust",
    "drag",
    "complex",
    "retain",
    "relaxation",
    "government",
    "breeze",
    "idea",
    "concrete",
    "chimney",
    "bottom",
    "snow",
    "knot",
    "stick",
    "guilt",
    "exception",
    "sensitive",
    "settle",
    "enlarge",
    "issue",
    "harmful",
    "core",
]

n_topics = 3
vocab_size = len(VOCAB)
alpha = torch.ones(n_topics)
nu = torch.ones(vocab_size)
num_docs = 10
num_words = 20

sampled_corpus = mixed_membership_model(alpha, nu, num_docs, num_words, VOCAB)
for doc in sampled_corpus:
    print(" ".join(doc) + ".")

core government chimney settle adjust drag enlarge retain bottom idea designer budget knot settle snow leak comprehensive factor adjust drag.
knot bottom sensitive chimney defendant adjust drag bottom relaxation comprehensive bottom budget drag knot budget core bottom relaxation exception retreat.
relaxation core knot grace factor leak defendant colon concrete drag concrete bottom guilt grace issue idea guilt designer idea idea.
drag harmful comprehensive budget stick exception harmful defendant bottom leak sensitive adjust complex retreat adjust adjust snow snow stick knot.
guilt idea budget drag retain guilt guilt adjust issue comprehensive government defendant settle snow drag stick bottom leak drag defendant.
concrete budget concrete factor comprehensive retreat factor exception idea budget harmful relaxation settle bottom settle factor knot exception guilt exception.
concrete harmful chimney snow leak factor sensitive concrete bottom core relaxation settle government guilt colon c

<IPython.core.display.Javascript object>