Do the imports

In [64]:
!pip install pyro-ppl



In [65]:
import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints
import functools

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.contrib.autoguide import AutoDiagonalNormal, AutoMultivariateNormal, AutoGuideList, AutoDelta
from pyro.optim import ClippedAdam

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Get the modifed data matrix, where the documents are of equal length.

In [66]:
import requests, io
r = requests.get('https://github.com/MikkelGroenning/MBML_project/blob/master/data/processed/upsampled_data.npy?raw=true')

data = np.load(io.BytesIO(r.content)).astype('int32')
data.shape

(32211, 1500)

The data consists of 32.211 speeches, each with a length of 1500 words. We therefore only look at a subset of the data.

In [67]:
n = 10
data_sub = len(np.unique(data_sub)) # data[:n]
num_words = data_sub.max() + 1
num_topics = 25
num_docs = n
num_words_per_doc = data.shape[1]

AttributeError: 'int' object has no attribute 'max'

With these things defined we can now make a model.

In [None]:
def model(data=None, batch_size=1):
    """ Make a plate of size num_topics with name "topics" and define a variable "topic_words".
          This represents the phi above. Use the equivalent of a uniform distribution for it  """
    with pyro.plate("topics", num_topics):
        topic_words = pyro.sample("topic_words", dist.Dirichlet(torch.ones(num_words) / num_words))

    """ Make two (nested) plates in here. One over documents and one over words
          Documents, called "documents":
          The plate over the documents should hold a variable "doc_topics" representing the theta above.
            Use the equivalent of a uniform distribution for it.
          
          Words, called "words":
          The plate over words, should have a topic assignment for each word (z_{i,j} above) which 
            should be enumerated.
          The second variable should be the words themselves which should be drawn from the "topic_words"
            using the assigned z_{i,j} and the observed data.

     """
    with pyro.plate("documents", num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (num_words_per_doc, num_docs)
            data = data[:, ind]
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(torch.ones(num_topics)/ num_topics))
        with pyro.plate("words", num_words_per_doc):
            # The word_topics variable is marginalized out during inference,
            # achieved by specifying infer={"enumerate": "parallel"} and using
            # TraceEnum_ELBO for inference. Thus we can ignore this variable in
            # the guide.
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics), infer={"enumerate": "parallel"})
            data = pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)

    return topic_words, data

In [None]:
W_torch = torch.tensor(data_sub.T).long()
W_torch.shape

Here we use amortized inference of the local variables. This is acheived by using a multi-layer perceptron.

In [None]:
layer_sizes = np.arange(98,103)
layer_sizes = torch.tensor(layer_sizes)
print(layer_sizes.size())

def make_predictor(num_words, layer_sizes):
    layer_sizes = ([num_words] +
                   [int(s) for s in torch.split(layer_sizes,1)] +
                   [num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid())
    layers.append(nn.Softmax(dim=-1))
    return nn.Sequential(*layers)

And the guide

In [None]:
def parametrized_guide(predictor, data, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(num_topics),
            constraint=constraints.positive)
    topic_words_posterior = pyro.param(
            "topic_words_posterior",
            lambda: torch.ones(num_topics, num_words),
            constraint=constraints.greater_than(0.5))
    with pyro.plate("topics", num_topics):
        pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_words", dist.Dirichlet(topic_words_posterior))

    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", num_docs, batch_size) as ind:
        data = data[:, ind]
        # The neural network will operate on histograms rather than word
        # index vectors, so we'll convert the raw data to a histogram.
        counts = (torch.zeros(num_words, ind.size(0)).scatter_add(0, data, torch.ones(data.shape)))
        doc_topics = predictor(counts.transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))



Now we train using SVI

In [None]:
print(num_topics)
print(num_words)
print(num_docs)

In [62]:
pyro.clear_param_store()

learning_rate = 0.05

predictor = make_predictor(num_words, layer_sizes)
guide = functools.partial(parametrized_guide, predictor)
# Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = TraceEnum_ELBO(max_plate_nesting=3)
optim = ClippedAdam({'lr': learning_rate})
svi = SVI(model, guide, optim, loss=elbo)

# Define the number of optimization steps
n_steps = 5

# do gradient steps
for step in range(n_steps):
    elbo = svi.step(W_torch, batch_size=2)
    if step % 1 == 0:
        #print('.', end='')
        print("[%d] ELBO: %.1f" % (step, elbo))

RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 6960600000 bytes. Buy new RAM!
(no backtrace available)