In [1]:
import numpy as np
from scipy import sparse
import artm
from base_regularizer import BaseRegularizer

In [2]:
def generate_word_in_doc_freqs(words_count, docs_count):

    density = 0.01
    max_freq = 5

    word_in_doc_freqs = sparse.dok_matrix((words_count, docs_count), dtype=int)

    for i in range(int(density*words_count*docs_count)):

        word_index = np.random.choice(words_count)
        doc_index = np.random.choice(docs_count)

        word_in_doc_freqs[word_index, doc_index] = np.random.choice(max_freq) + 1

    return word_in_doc_freqs

In [3]:
class ZeroRegularizer(BaseRegularizer):

    def __init__(self, words_count, docs_count, topics_count):

        self._word_in_topics_probs_grad = np.zeros((words_count, topics_count))
        self._topic_in_doc_probs_grad = np.zeros((topics_count, docs_count))

    def get_value(self, word_in_topics_probs, topic_in_doc_probs):

        return 0.0

    def get_gradient(self, word_in_topics_probs, topic_in_doc_probs):

        return self._word_in_topics_probs_grad, self._topic_in_doc_probs_grad

In [4]:
np.random.seed(seed=0)

words_count = 10000
docs_count = 100
topics_count = 10

word_in_doc_freqs = generate_word_in_doc_freqs(words_count, docs_count)
words_list = np.array([str(i) for i in range(words_count)])

zero_regularizer = ZeroRegularizer(words_count, docs_count, topics_count)

artm_model = artm.ARTM(topics_count, [zero_regularizer], [1])

In [5]:
train_result = artm_model.train(word_in_doc_freqs, words_list, iterations_count=10, verbose=True)

iter#1: loglike=-253375.08594793288
iter#2: loglike=-248155.68445356327
iter#3: loglike=-240993.1539559371
iter#4: loglike=-232894.04538079945
iter#5: loglike=-225252.51292302832
iter#6: loglike=-218928.35373559463
iter#7: loglike=-214095.58940937347
iter#8: loglike=-210671.68357721512
iter#9: loglike=-208376.48245071087
iter#10: loglike=-206869.89796774142


In [6]:
train_result.get_top_words_in_topics(10)

array([['7326', '5398', '9117', '2136', '3057', '5825', '1747', '3845',
        '1692', '7526'],
       ['2824', '6332', '7097', '2735', '2789', '9496', '4471', '5995',
        '902', '3957'],
       ['9537', '7037', '2874', '9457', '1695', '2953', '3248', '367',
        '3392', '6169'],
       ['1632', '7464', '3708', '2460', '8019', '2522', '682', '8904',
        '1609', '6327'],
       ['3419', '1221', '6538', '5146', '8160', '4479', '621', '6762',
        '9477', '2168'],
       ['1100', '1273', '6923', '9966', '7853', '999', '4459', '3713',
        '6223', '5479'],
       ['4740', '3764', '1372', '5311', '5607', '4980', '7415', '1671',
        '6583', '692'],
       ['9098', '2755', '3996', '9536', '2010', '1564', '1265', '5865',
        '4267', '3825'],
       ['7932', '4186', '1004', '8174', '3555', '0', '8986', '6514',
        '1983', '2003'],
       ['2317', '1565', '9275', '8777', '5807', '6572', '2547', '5213',
        '7543', '4420']],
      dtype='<U4')