In [359]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [374]:
import pyro
import numpy as np
import pandas as pd
import torch
import pyro.distributions as dist
import math, time
import torch.distributions.constraints as constraints
import matplotlib.pyplot as plt
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from IPython.display import display, clear_output

<IPython.core.display.Javascript object>

In [378]:
ndocs = 2
ntopics = 3
nwords = 13
vocab_size = 6


def lda(data):
    alpha_prior = torch.ones(ntopics)
    nu_prior = torch.ones(vocab_size)
    with pyro.plate("topic_loop", ntopics):
        topics = pyro.sample(
            "topics", dist.Dirichlet(nu_prior.repeat(ntopics, 1))
        )  # Sample topics
    with pyro.plate("doc_loop", ndocs) as ind:
        data = data[:, ind]
        theta = pyro.sample("theta", dist.Dirichlet(alpha_prior))  # .repeat(ndocs, 1)))
        with pyro.plate("word_loop", nwords):
            z = pyro.sample("z", dist.Categorical(theta))
            w = pyro.sample("w", dist.Categorical(topics[z]), obs=data)


def guide(data):
    lambda_q = pyro.param(
        "lambda_q", torch.ones(ntopics, vocab_size), constraint=constraints.positive
    )
    gamma_q = pyro.param(
        "gamma_q", torch.ones(ndocs, ntopics), constraint=constraints.positive
    )
    phi_q = pyro.param(
        "phi_q", torch.ones(nwords, ndocs, ntopics), constraint=constraints.positive
    )

    with pyro.plate("topic_loop", ntopics):
        topics = pyro.sample("topics", dist.Dirichlet(lambda_q))  # Sample topics
    with pyro.plate("doc_loop", ndocs):
        theta = pyro.sample("theta", dist.Dirichlet(gamma_q))  # .repeat(ndocs, 1)))
        with pyro.plate("word_loop", nwords):
            z = pyro.sample("z", dist.Categorical(phi_q))


def generate():
    alpha_prior = torch.ones(ntopics)
    nu_prior = torch.ones(vocab_size)
    data = torch.zeros([nwords, ndocs])
    topics = pyro.sample("topics", dist.Dirichlet(nu_prior.repeat(ntopics, 1)))
    for d in pyro.plate("doc_loop", ndocs):
        theta = pyro.sample(f"theta_{d}", dist.Dirichlet(alpha_prior))
        for w in pyro.iarange("word_loop", nwords):
            z = pyro.sample(f"z_{d}_{w}", dist.Categorical(theta))
            word = pyro.sample(f"w_{d}_{w}", dist.Categorical(topics[z.item()]))
            data[w, d] = word
    return data, topics


vocab = ["banana", "carrot", "cake", "milk", "diapers", "beer"]
data, topics = generate()
data.shape

torch.Size([13, 2])

<IPython.core.display.Javascript object>

In [None]:
pyro.clear_param_store()

# set up the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(lda, guide, optimizer, loss=Trace_ELBO())

loss = []
n_steps = 50000
for step in range(n_steps):
    loss.append(svi.step(data))
    if step % (n_steps//100) == 0:
        print(".", end="")


..............................................................................................

In [None]:
plt.plot(pd.Series(loss), alpha=0.3, label="Actual")
plt.plot(pd.Series(loss).rolling(100).mean(), alpha=1, label="100 rolling mean")
plt.legend()
plt.title("ELBO")