In [71]:
import pyro
import pyro.distributions as dist
import torch


def mixed_membership_model(alpha, nu, num_docs, num_words, vocab):

    # Each topic is a categorical distribution over words. E.g. topics[i] = {word1: 0.2, word2: 0.8}
    topics = [
        pyro.sample("topic_{}".format(i), dist.Dirichlet(nu)) for i in range(len(alpha))
    ]

    # Every document is a categorical distribution over topics. E.g. documents[i] = {science: 0.2, business: 0.8}
    documents = [
        pyro.sample("doc_{}".format(i), dist.Dirichlet(alpha)) for i in range(num_docs)
    ]

    # A corpus is a set of documents
    corpus = []

    for i in range(num_docs):
        document = []
        for j in range(num_words):
            topic_id = pyro.sample(
                "assigned_topic_{}_{}".format(i, j), dist.Categorical(documents[i])
            )
            word_id = pyro.sample(
                "word_id_{}_{}".format(i, j), dist.Categorical(topics[topic_id])
            )
            document.append(vocab[word_id])
        corpus.append(document)
    return corpus

<IPython.core.display.Javascript object>

In [72]:
VOCAB = [
    "budget",
    "leak",
    "designer",
    "raid",
    "colon",
    "grace",
    "defendant",
    "comprehensive",
    "retreat",
    "factor",
    "adjust",
    "drag",
    "complex",
    "retain",
    "relaxation",
    "government",
    "breeze",
    "idea",
    "concrete",
    "chimney",
    "bottom",
    "snow",
    "knot",
    "stick",
    "guilt",
    "exception",
    "sensitive",
    "settle",
    "enlarge",
    "issue",
    "harmful",
    "core",
]

n_topics = 3
vocab_size = len(VOCAB)
alpha = torch.ones(n_topics)
nu = torch.ones(vocab_size)
num_docs = 10
num_words = 20

sampled_corpus = mixed_membership_model(alpha, nu, num_docs, num_words, VOCAB)
[print(" ".join(doc) + ".") for doc in sampled_corpus]

complex issue designer comprehensive designer bottom exception raid retreat breeze retreat issue government factor designer complex core budget colon relaxation.
factor factor breeze enlarge issue relaxation sensitive settle defendant issue complex defendant budget relaxation breeze breeze snow colon settle breeze.
breeze colon harmful colon colon snow breeze defendant colon harmful raid grace raid adjust colon snow issue harmful budget factor.
colon designer retain issue government issue designer bottom harmful harmful snow colon knot raid factor designer designer issue factor idea.
settle exception complex issue colon chimney designer complex exception harmful exception comprehensive idea chimney issue retain complex designer settle issue.
idea retain settle harmful settle drag designer concrete idea issue concrete breeze government comprehensive retain complex drag designer issue colon.
colon breeze breeze settle issue colon factor idea harmful complex raid exception designer factor

[None, None, None, None, None, None, None, None, None, None]

<IPython.core.display.Javascript object>