In [1]:
import os
from pyspark.sql import SparkSession

In [2]:
import numpy as np

from simulator import generate_companies

from scipy.special import psi
from scipy.sparse import csr_matrix, hstack, lil_matrix, vstack

from gensim.matutils import mean_absolute_difference
from itertools import islice

In [3]:
words_p_topic = 5
industries = ["pbc", "rubber", "media"]
X, Z, company_industry, phi, phi_bg, word2id, id2word = generate_companies(industries, 
                                                                           words_p_topic=words_p_topic,
                                                                           num_companies=100000,
                                                                          )
n_topics, n_words = phi.shape
n_docs = X.shape[0]

In [4]:
ind = 0
print(Z[ind])
print(company_industry[ind])
print(X.toarray()[ind,:])
print(X.toarray()[ind,(words_p_topic*company_industry[ind]):(words_p_topic*(company_industry[ind]+1))])

62
2
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 2. 0. 1. 1. 5. 1. 0. 1. 1. 0. 2. 1. 0. 1.
 0. 1. 0. 2. 2. 2. 0. 1. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 3. 0. 1. 1.
 1. 1. 0. 2. 0. 2. 1. 0. 1. 0. 3. 1. 2. 1. 3. 0. 1. 0. 1. 1. 0. 1. 1. 0.
 3. 3. 2. 1. 0. 1. 1. 0. 1. 0. 1. 5. 1. 0. 3. 0. 0. 2. 0. 0. 1. 1. 2. 2.
 0. 1. 0. 3. 0. 0. 0. 0. 2. 0. 3. 0. 0. 0. 0. 0. 4. 0. 1.]
[2. 0. 1. 1. 5.]


In [5]:
spark = (SparkSession.builder
         .master("local[*]")
         .appName("SparkTest")
         .getOrCreate()
        )

In [6]:
class SimpleLDA:
    def __init__(self, word2id, 
                 industries, 
                 a=1.0, 
                 b=1.0, 
                 e_step_iter=50,
                 print_every=25,
                 estep_threshold=1e-5
                ):
        self.a = a
        self.b = b
        self.estep_threshold = estep_threshold
        
        self.word2id = word2id
        self.industries = industries
        id2word = {}
        for k,v in word2id.items():
            id2word[v] = k
        self.id2word = id2word
        self.e_step_iter = e_step_iter
        self.print_every = print_every
    
    def _e_step(self, X, metadata):
        import numpy
        from scipy.special import psi
        from scipy.sparse import csr_matrix, hstack, lil_matrix, vstack
        
        a,b = self.a, self.b
        D,V = X.shape
        q_theta = np.random.gamma(1.0, size=(D,2))
        q_z = lil_matrix((D,V))
        phi = self.phi

        for industry in range(len(self.industries)):
            ind = metadata==industry
            q_theta_ind = q_theta[ind,:]
            X_ind = X[ind]
            q_z_ind = lil_matrix(X_ind.shape)
            d,v = X_ind.nonzero()
            q_theta_ind_old = q_theta_ind.copy()
            
            for i in range(self.e_step_iter):
                # Compute q(z)
                coef = np.clip(psi(q_theta_ind[d,-1])-psi(np.sum(q_theta_ind[d], axis=1)), a_min=-100.0, a_max=100.0)
                
                bg_w = np.exp(coef)*phi[-1,v]
                ind_w = np.exp(-coef)*phi[industry,v]
                q_z_ind[d,v] = bg_w/(bg_w+ind_w+1e-9)

                # Compute q(theta)
                q_theta_ind[:,0] = a+np.sum(q_z_ind.multiply(X_ind), axis=1).ravel()
                q_theta_ind[:,1] = b+np.sum(X_ind-q_z_ind.multiply(X_ind), axis=1).ravel()
                
                if mean_absolute_difference(q_theta_ind.ravel(), q_theta_ind_old.ravel()) <= self.estep_threshold:
                    break
                q_theta_ind_old = q_theta_ind.copy()

            q_z[ind] = q_z_ind
            q_theta[ind] = q_theta_ind
        
        return q_z, q_theta
    
    def _m_step(self, X, q_z, metadata):
        import numpy
        from scipy.special import psi
        from scipy.sparse import csr_matrix, hstack, lil_matrix, vstack
        
        industries = self.industries
        id2word = self.id2word
        _sstats = np.zeros(shape=(len(industries)+1, len(id2word)))

        for industry in range(len(industries)):
            ind = metadata==industry
            q_z_ind = q_z[ind]
            X_ind = X[ind]
            mlt = q_z_ind.multiply(X_ind)
            q_z_sm_bg = np.sum(mlt, axis=0)
            q_z_sm_ind = np.sum(X_ind - mlt, axis=0)

            # Background sstats
            _sstats[-1,:] = _sstats[-1,:] + q_z_sm_bg

            # Industry sstats
            _sstats[industry,:] = _sstats[industry,:] + q_z_sm_ind
        
        return _sstats
        
    def _update_phi(self):
        self.phi = self._sstats/np.sum(self._sstats, axis=1, keepdims=True)
        self._sstats = np.zeros(shape=(len(self.industries)+1, len(self.id2word)))

    def train(self, X, metadata, n_iter=50):
        industries = self.industries
        id2word = self.id2word
        self._sstats = np.zeros(shape=(len(industries)+1, len(id2word)))
        self.phi = np.random.dirichlet([1.0]*(len(id2word)), size=(len(industries)+1))

        for i in range(n_iter):
            if i%self.print_every == 0:
                print(i)
            q_z,_ = self._e_step(X, metadata)
            self._sstats = self._m_step(X, q_z, metadata)
            self._update_phi()
            
    def train_distributed(self, X_metadata_rdd, n_iter=50):
        industries = self.industries
        id2word = self.id2word
        self._sstats = np.zeros(shape=(len(industries)+1, len(id2word)))
        self.phi = np.random.dirichlet([1.0]*(len(id2word)), size=(len(industries)+1))
        
        for i in range(n_iter):
            print(i)
            self._sstats = (X_metadata_rdd
                             .mapPartitions(lambda u: [list(u)])
                             .map(lambda u: (vstack([el[0] for el in u], format="lil"), 
                                             np.array([el[1] for el in u])
                                             )
                                  )
                             .map(lambda line: (line[0], line[1], self._e_step(line[0], line[1])))
                             .map(lambda line: self._m_step(line[0], line[2][0], line[1]))
                             .reduce(lambda a,b: a+b)
                            )
            self._update_phi()

In [7]:
tst = SimpleLDA(word2id, industries, e_step_iter=50, print_every=1)

In [23]:
tst.train(X, np.array(company_industry), n_iter=2)

0
1


In [8]:
X_metadata_rdd = spark.sparkContext.parallelize(list(zip(X,company_industry)), 10)

In [None]:
tst.train_distributed(X_metadata_rdd)

In [56]:
q_z, q_theta = tst._e_step(X, np.array(company_industry))



In [45]:
tst._m_step(X, q_z, np.array(company_industry))

In [47]:
tst._update_phi()

In [50]:
ind=9
Z[ind]/100

0.16

In [51]:
alpha, beta = q_theta[ind]
(alpha-1)/(alpha+beta-2)

0.16203939683294755

In [11]:
def plot_topics(model):
    phi = model.phi
    id2word = model.id2word

    for i,topic in enumerate(phi):
        topic_name = model.industries[i] if i < len(model.industries) else "background"
        print("#"*20, topic_name, "#"*20)
        for prob,i in islice(sorted([(prob,i) for i,prob in enumerate(topic)], reverse=True), 20):
            print(f"{id2word[i]}:{prob} ", end="")
        print()
        print("#"*80)

In [12]:
plot_topics(tst)

#################### pbc ####################
pbc_word_4:0.07534915206131437 pbc_word_0:0.058010320558367745 pbc_word_1:0.03995175223929919 pbc_word_3:0.033103261374531845 pbc_word_2:0.02876636046517686 background_word_66:0.022131855208278744 background_word_9:0.02167824945386023 background_word_4:0.021507901615137017 background_word_60:0.020116755735080708 background_word_74:0.01854155487978196 background_word_97:0.01586223407996577 background_word_99:0.014991985130582744 background_word_80:0.014073735303582245 background_word_89:0.012785135468585946 background_word_84:0.012635328702914641 background_word_47:0.012228231119423946 background_word_29:0.012099746867287525 background_word_82:0.011961562089294586 background_word_22:0.01170929831443301 background_word_67:0.011591268913134079 
################################################################################
#################### rubber ####################
rubber_word_0:0.05292214429586776 rubber_word_3:0.05140293818137645 rubb

In [30]:
tst.phi[2]

array([4.77360101e-04, 5.08047892e-09, 4.95699909e-09, 5.95946442e-09,
       4.45524698e-09, 1.06981330e-08, 1.23745819e-08, 3.50080409e-03,
       4.46296889e-09, 1.95095651e-03, 5.05757351e-02, 4.23133669e-02,
       5.21601942e-02, 6.52181788e-02, 2.88467847e-02, 1.39022941e-02,
       4.11016556e-03, 7.69024571e-06, 4.85166323e-03, 1.43319551e-02,
       2.85845865e-03, 1.14937039e-02, 5.81033200e-03, 4.51968913e-04,
       9.11293664e-03, 4.82247234e-03, 4.15702485e-03, 1.19372560e-02,
       1.20693406e-02, 1.45752363e-02, 3.54375647e-03, 1.22369497e-02,
       1.39254050e-03, 1.20255880e-02, 8.62462024e-03, 7.72820938e-03,
       6.64722434e-03, 5.29770320e-03, 6.86598470e-03, 7.68281416e-04,
       8.96557302e-04, 8.97453427e-03, 1.34233474e-02, 1.43754160e-02,
       2.05959709e-02, 7.59492367e-03, 7.88963925e-03, 2.29643447e-03,
       1.12868267e-04, 3.63032997e-03, 1.05711903e-02, 8.55567312e-03,
       8.01980738e-03, 5.58202900e-03, 1.20722750e-02, 1.38479580e-02,
      

In [28]:
tst.phi[-1]

array([0.00255062, 0.00374005, 0.00143705, 0.01602743, 0.00730041,
       0.00575224, 0.01068066, 0.00208225, 0.00461285, 0.00737695,
       0.00445051, 0.01031931, 0.00862009, 0.00426616, 0.00575331,
       0.00270197, 0.01428479, 0.01895589, 0.0069042 , 0.00038038,
       0.00566887, 0.00488394, 0.00657537, 0.0221984 , 0.0174086 ,
       0.00028983, 0.01210961, 0.01028605, 0.00820914, 0.01900917,
       0.01156574, 0.00867146, 0.01760637, 0.01293167, 0.00373843,
       0.00449055, 0.01899139, 0.01627245, 0.00764862, 0.01483029,
       0.02352811, 0.01402309, 0.00346546, 0.00026798, 0.00638356,
       0.00831331, 0.00132077, 0.00892662, 0.01900948, 0.01524567,
       0.00753582, 0.00454427, 0.00617305, 0.00258613, 0.00291366,
       0.00991292, 0.00859201, 0.00969812, 0.00317606, 0.0101634 ,
       0.01005122, 0.01480652, 0.01329156, 0.00498139, 0.00473017,
       0.00881681, 0.00245789, 0.01447389, 0.01531067, 0.00529541,
       0.01432582, 0.0075923 , 0.00063904, 0.01877674, 0.00335

In [277]:
tst.phi[-1, 33]

0.009806433896111504

In [273]:
tst.id2word

{0: 'pbc_word_0',
 1: 'pbc_word_1',
 2: 'pbc_word_2',
 3: 'pbc_word_3',
 4: 'pbc_word_4',
 5: 'rubber_word_0',
 6: 'rubber_word_1',
 7: 'rubber_word_2',
 8: 'rubber_word_3',
 9: 'rubber_word_4',
 10: 'media_word_0',
 11: 'media_word_1',
 12: 'media_word_2',
 13: 'media_word_3',
 14: 'media_word_4',
 15: 'background_word_0',
 16: 'background_word_1',
 17: 'background_word_2',
 18: 'background_word_3',
 19: 'background_word_4',
 20: 'background_word_5',
 21: 'background_word_6',
 22: 'background_word_7',
 23: 'background_word_8',
 24: 'background_word_9',
 25: 'background_word_10',
 26: 'background_word_11',
 27: 'background_word_12',
 28: 'background_word_13',
 29: 'background_word_14',
 30: 'background_word_15',
 31: 'background_word_16',
 32: 'background_word_17',
 33: 'background_word_18',
 34: 'background_word_19',
 35: 'background_word_20',
 36: 'background_word_21',
 37: 'background_word_22',
 38: 'background_word_23',
 39: 'background_word_24',
 40: 'background_word_25',
 41: 'ba