In [1]:
import numpy as np

from simulator import generate_companies
from theano import shared
import theano.tensor as tt

import pymc3 as pm
from pymc3 import math as pmmath
from pymc3 import Beta, Dirichlet
from pymc3.distributions.transforms import t_stick_breaking

In [2]:
import theano
theano.config.compute_test_value = "ignore"

In [3]:
words_p_topic = 5
industries = ["pbc", "rubber", "media"]
X, Z, company_industry, phi, phi_bg, word2id, id2word = generate_companies(industries, 
                                                                           words_p_topic=words_p_topic
                                                                          )
n_topics, n_words = phi.shape
n_docs = X.shape[0]

In [4]:
ind = 0
print(Z[ind])
print(company_industry[ind])
print(X.toarray()[ind,:])
print(X.toarray()[ind,(words_p_topic*company_industry[ind]):(words_p_topic*(company_industry[ind]+1))])

52
0
[8. 2. 4. 4. 4. 0. 0. 2. 1. 0. 0. 1. 0. 1. 1. 1. 2. 0. 1. 1. 1. 0. 1. 1.
 0. 0. 0. 0. 2. 1. 3. 0. 0. 3. 0. 0. 0. 1. 1. 1. 0. 0. 0. 3. 0. 1. 0. 1.
 0. 0. 0. 1. 1. 1. 0. 2. 0. 0. 1. 1. 2. 1. 0. 2. 2. 2. 1. 1. 1. 1. 0. 0.
 2. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 2. 0. 1. 1. 1. 1.
 1. 1. 3. 0. 1. 1. 1. 0. 1. 0. 0. 0. 0. 1. 1. 1. 1. 0. 1.]
[8. 2. 4. 4. 4.]


In [5]:
tst = np.array(company_industry)

In [6]:
print(tst.shape)

(1000,)


In [7]:
def logp_generator(phi, theta):
    def logp_docs(docs_industry):
        docs = docs_industry[:,:n_words]
        industry = docs_industry[:,n_words:]
        ll_docs  = 0
        for ind in range(len(industries)):
            industry_ind = tt.eq(industry, ind).ravel()
            docs_ind = docs[industry_ind,:]
            d,v = docs_ind.nonzero()
            w = docs_ind[d,v]
            ll_docs += tt.sum(w*pmmath.logsumexp(tt.log(phi[[ind,-1],:].T[v]) + tt.log(theta[industry_ind][d]), axis=1).ravel())
        return ll_docs
    
    return logp_docs

In [8]:
doc_industry_t = tt.concatenate([X.toarray(), np.array(company_industry)[:, np.newaxis]], axis=1)

In [9]:
with pm.Model() as model:
    phi = Dirichlet('phi', a=pm.floatX((1.0 / (n_topics+1)) * np.ones((n_topics+1, n_words))),
                     shape=(n_topics+1, n_words) # last topic is by default the background topic
                    )
    theta = Dirichlet('theta', a=pm.floatX((1.0 / 2) * np.ones((n_docs, 2))),
                      shape=(n_docs, 2)
                      )
    doc = pm.DensityDist('doc', logp_generator(phi, theta), observed=doc_industry_t)

  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])


In [10]:
with model:
    trace = pm.sample(750, chains=2)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
Multiprocess sampling (2 chains in 3 jobs)
NUTS: [theta, phi]
  rval = inputs[0].__getitem__(inputs[1:])
  rval = inputs[0].__getitem__(inputs[1:])
Sampling 2 chains: 100%|██████████| 2500/2500 [20:21<00:00,  2.05draws/s] 
There were 49 divergences after tuning. Increase `target_accept` or reparameterize.
There were 21 divergences after tuning. Increase `target_accept` or reparameterize.
The estimated number of effective samples is smaller than 200 for some parameters.


In [11]:
def plot_samples_topics(trace):
    tr = trace.get_values("phi", combine=False)[0]
    phi = np.mean(tr, axis=0)
    for topic in phi:
        for prob,i in sorted([(prob,i) for i,prob in enumerate(topic)], reverse=True):
            print(f"{id2word[i]}:{prob} ", end="")
        print()
        print("#"*80)

In [12]:
plot_samples_topics(trace)

pbc_word_0:0.10055565894279138 pbc_word_3:0.07638881477352504 pbc_word_4:0.07034518968766654 pbc_word_1:0.06495384769644433 pbc_word_2:0.05994793142213591 background_word_32:0.043739503330072635 background_word_91:0.030479783253153214 background_word_56:0.024330591842324827 background_word_81:0.02260019982842373 background_word_50:0.02060684613386145 background_word_89:0.019195158290517975 rubber_word_1:0.018529904362692275 background_word_51:0.014143636889154411 rubber_word_2:0.013140653788888691 background_word_39:0.012481513419963258 background_word_90:0.012323029858903546 background_word_63:0.011357176790802052 background_word_84:0.011181283073573927 background_word_34:0.01109525797252815 background_word_10:0.011092636688053242 media_word_3:0.011005945774161058 background_word_38:0.010490088353044057 background_word_35:0.010202840128425886 background_word_0:0.010041241410522648 background_word_72:0.009926985146410912 background_word_5:0.009593885238840366 background_word_4:0.009312