In [7]:
import numpy as np
from scipy import sparse
import artm
from base_regularizer import BaseRegularizer
from smoothing_regularizer import SmoothingRegularizer
from combined_smoothing_sparsing_regularizer import CombinedSmoothingSparsingRegularizer

In [8]:
def generate_word_in_doc_freqs(words_count, docs_count):

    density = 0.01
    max_freq = 5

    word_in_doc_freqs = sparse.dok_matrix((words_count, docs_count), dtype=int)

    for i in range(int(density*words_count*docs_count)):

        word_index = np.random.choice(words_count)
        doc_index = np.random.choice(docs_count)

        word_in_doc_freqs[word_index, doc_index] = np.random.choice(max_freq) + 1

    return word_in_doc_freqs

In [9]:
class ZeroRegularizer(BaseRegularizer):

    def __init__(self, words_count, docs_count, topics_count):

        self._word_in_topics_probs_grad = np.zeros((words_count, topics_count))
        self._topic_in_doc_probs_grad = np.zeros((topics_count, docs_count))

    def get_value(self, word_in_topics_probs, topic_in_doc_probs):

        return 0.0

    def get_gradient(self, word_in_topics_probs, topic_in_doc_probs):

        return self._word_in_topics_probs_grad, self._topic_in_doc_probs_grad

In [10]:
np.random.seed(seed=0)

words_count = 10000
docs_count = 100
topics_count = 10

word_in_doc_freqs = generate_word_in_doc_freqs(words_count, docs_count)
words_list = np.array([str(i) for i in range(words_count)])

zero_regularizer = ZeroRegularizer(words_count, docs_count, topics_count)

artm_model = artm.ARTM(topics_count, [zero_regularizer], [1])

In [11]:
train_result = artm_model.train(word_in_doc_freqs, words_list, iterations_count=10, verbose=True)

iter#1: loglike=-253375.08594793238
iter#2: loglike=-248155.68445356327
iter#3: loglike=-240993.1539559364
iter#4: loglike=-232894.0453807986
iter#5: loglike=-225252.51292302719
iter#6: loglike=-218928.3537355944
iter#7: loglike=-214095.5894093739
iter#8: loglike=-210671.68357721536
iter#9: loglike=-208376.4824507103
iter#10: loglike=-206869.89796774133


In [12]:
train_result.get_top_words_in_topics(10)

array([['7326', '5398', '9117', '2136', '3057', '5825', '1747', '3845',
        '1692', '7526'],
       ['2824', '6332', '7097', '2735', '2789', '9496', '4471', '5995',
        '902', '3957'],
       ['9537', '7037', '2874', '9457', '1695', '2953', '3248', '367',
        '3392', '6169'],
       ['1632', '7464', '3708', '2460', '8019', '2522', '682', '8904',
        '1609', '6327'],
       ['3419', '1221', '6538', '5146', '8160', '4479', '621', '6762',
        '9477', '2168'],
       ['1100', '1273', '6923', '9966', '7853', '999', '4459', '3713',
        '6223', '5479'],
       ['4740', '3764', '1372', '5311', '5607', '4980', '7415', '1671',
        '6583', '692'],
       ['9098', '2755', '3996', '9536', '2010', '1564', '1265', '5865',
        '4267', '3825'],
       ['7932', '4186', '1004', '8174', '3555', '0', '8986', '6514',
        '1983', '2003'],
       ['2317', '1565', '9275', '8777', '5807', '6572', '2547', '5213',
        '7543', '4420']],
      dtype='<U4')

### Smoothing reg

In [13]:
smoothing_regularizer = SmoothingRegularizer(beta_0=0.5, alpha_0=0.5, 
                                             beta=np.array([1e-4]*words_count), 
                                             alpha=np.array([1e-4]*topics_count), 
                                             num_topics=topics_count, 
                                             num_words=words_count, 
                                             num_docs=docs_count)

In [14]:
artm_model = artm.ARTM(topics_count, [smoothing_regularizer], [1.])

In [15]:
train_result = artm_model.train(word_in_doc_freqs, words_list, iterations_count=10, verbose=True)

iter#1: loglike=-253219.94032536348
iter#2: loglike=-247916.4319508186
iter#3: loglike=-240702.23933397373
iter#4: loglike=-232505.00464822774
iter#5: loglike=-224862.14859385107
iter#6: loglike=-218885.28773205075
iter#7: loglike=-214763.95457877737
iter#8: loglike=-212025.89114727278
iter#9: loglike=-210116.20784951735
iter#10: loglike=-208688.23683997602


### Combining smooth and sparse

In [17]:
sparse_smooth_reg = CombinedSmoothingSparsingRegularizer(beta_0=0.5, alpha_0=0.5, 
                                                         beta=np.array([1e-4]*words_count), 
                                                         alpha=np.array([1e-4]*topics_count), 
                                                         num_topics=topics_count, 
                                                         num_words=words_count, 
                                                         num_docs=docs_count, 
                                                         domain_specific_topics=np.arange(5), 
                                                         background_topics=np.arange(5, 10))

In [18]:
artm_model = artm.ARTM(topics_count, [sparse_smooth_reg], [1.])

In [19]:
train_result = artm_model.train(word_in_doc_freqs, words_list, iterations_count=10, verbose=True)

iter#1: loglike=-253372.61544487116
iter#2: loglike=-248071.50356896667
iter#3: loglike=-240742.03783599974
iter#4: loglike=-232577.38516974295
iter#5: loglike=-225318.40019621555
iter#6: loglike=-219681.819856372
iter#7: loglike=-215474.55468768373
iter#8: loglike=-212378.0939240878
iter#9: loglike=-210210.30299463565
iter#10: loglike=-208686.06198959038


In [20]:
train_result.get_top_words_in_topics(10)

array([['4728', '4895', '4634', '5339', '1221', '553', '8535', '1709',
        '2126', '2953'],
       ['2756', '8781', '6332', '5388', '1562', '7233', '6700', '2344',
        '1175', '1350'],
       ['2488', '7267', '4479', '8201', '9576', '3057', '1111', '4681',
        '9457', '9181'],
       ['1747', '9089', '6953', '4409', '5119', '8116', '6283', '1498',
        '349', '6385'],
       ['8741', '8315', '3323', '3918', '6337', '8809', '2546', '1115',
        '2426', '8717'],
       ['2977', '3923', '2736', '6572', '6679', '8517', '9191', '6624',
        '4766', '3963'],
       ['4922', '6054', '2208', '7642', '6583', '1206', '2', '3753',
        '5713', '296'],
       ['1208', '6458', '9076', '8897', '9098', '4028', '9718', '2641',
        '6359', '9658'],
       ['3091', '4600', '1291', '4950', '2714', '1609', '5758', '7521',
        '1454', '2490'],
       ['4963', '9908', '40', '8269', '7326', '3644', '420', '2558',
        '2689', '7835']],
      dtype='<U4')