In [1]:
import numpy as np
from scipy import sparse
import artm
from base_regularizer import BaseRegularizer

In [2]:
def generate_word_in_doc_freqs(words_count, docs_count):

    density = 0.01
    max_freq = 5

    word_in_doc_freqs = sparse.dok_matrix((words_count, docs_count), dtype=int)

    for i in range(int(density*words_count*docs_count)):

        word_index = np.random.choice(words_count)
        doc_index = np.random.choice(docs_count)

        word_in_doc_freqs[word_index, doc_index] = np.random.choice(max_freq) + 1

    return word_in_doc_freqs

In [3]:
class ZeroRegularizer(BaseRegularizer):

    def __init__(self, words_count, docs_count, topics_count):

        self._word_in_topics_probs_grad = np.zeros((words_count, topics_count))
        self._topic_in_doc_probs_grad = np.zeros((topics_count, docs_count))

    def get_value(self, word_in_topics_probs, topic_in_doc_probs):

        return 0.0

    def get_gradient(self, word_in_topics_probs, topic_in_doc_probs):

        return self._word_in_topics_probs_grad, self._topic_in_doc_probs_grad

In [4]:
words_count = 10000
docs_count = 100
topics_count = 10

word_in_doc_freqs = generate_word_in_doc_freqs(words_count, docs_count)
words_list = np.array([str(i) for i in range(words_count)])

zero_regularizer = ZeroRegularizer(words_count, docs_count, topics_count)

artm_model = artm.ARTM(topics_count, [zero_regularizer], [1])

In [5]:
train_result = artm_model.train(word_in_doc_freqs, words_list, iterations_count=10, verbose=True)

iter#1: loglike=-249709.4873482039
iter#2: loglike=-244631.21241853482
iter#3: loglike=-237713.0817303567
iter#4: loglike=-229910.31317012556
iter#5: loglike=-222479.7952583897
iter#6: loglike=-216177.54461735833
iter#7: loglike=-211287.45543143657
iter#8: loglike=-207815.25438644484
iter#9: loglike=-205532.86666285899
iter#10: loglike=-203990.8659191058


In [6]:
train_result.get_top_words_in_topics(10)

array([['443', '4689', '5924', '1295', '676', '7474', '6804', '2371',
        '5695', '3814'],
       ['4970', '4755', '6780', '4466', '5060', '501', '837', '8655',
        '8552', '8402'],
       ['7105', '4697', '6634', '8244', '8366', '4140', '6420', '4094',
        '7929', '3577'],
       ['3898', '7945', '7753', '8269', '4224', '1916', '9959', '1292',
        '5497', '8358'],
       ['7990', '8561', '5329', '1918', '1858', '8805', '7795', '9386',
        '7221', '797'],
       ['8229', '6019', '7219', '5716', '1443', '2189', '417', '1229',
        '6380', '1341'],
       ['8567', '1824', '7602', '6167', '4771', '7248', '8520', '1303',
        '5801', '3179'],
       ['1001', '7821', '2879', '5426', '4686', '5163', '906', '77',
        '3742', '5969'],
       ['9224', '7699', '52', '6909', '4923', '2640', '3567', '1566',
        '8916', '5631'],
       ['2005', '9542', '7009', '7739', '3994', '5365', '6430', '750',
        '8941', '9067']],
      dtype='<U4')