Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added sklearn wrapper for LDASeq model #1405

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
73cd770
added new file for LDASeq model's sklearn wrapper
chinmayapancholi13 Jun 9, 2017
4744c7b
PEP8 changes
chinmayapancholi13 Jun 9, 2017
d79f125
added 'transform' and 'partial_fit' methods
chinmayapancholi13 Jun 12, 2017
07efa33
added unit_tests for ldaseq model
chinmayapancholi13 Jun 12, 2017
d73838e
PEP8 changes
chinmayapancholi13 Jun 12, 2017
6e57c5f
PEP8 changes
chinmayapancholi13 Jun 12, 2017
c969c8b
refactored code acc. to composite design pattern
chinmayapancholi13 Jun 13, 2017
8b0cced
refactored wrapper and tests
chinmayapancholi13 Jun 14, 2017
ea9922e
removed 'self.corpus' attribute
chinmayapancholi13 Jun 14, 2017
8f88a10
updated 'self.__model' to 'self.gensim_model'
chinmayapancholi13 Jun 15, 2017
4f33248
updated 'fit' and 'transform' functions
chinmayapancholi13 Jun 15, 2017
8aa6898
updated 'testTransform' test
chinmayapancholi13 Jun 15, 2017
77a8672
updated 'testTransform' test
chinmayapancholi13 Jun 15, 2017
ad895a2
added 'NotFittedError' in 'transform' function
chinmayapancholi13 Jun 16, 2017
6f9929a
added 'testPersistence' and 'testModelNotFitted' tests
chinmayapancholi13 Jun 16, 2017
05b63e3
added description for 'docs' in docstring of 'transform'
chinmayapancholi13 Jun 16, 2017
3452e80
added 'testPipeline' test
chinmayapancholi13 Jun 18, 2017
492fbc6
PEP8 change
chinmayapancholi13 Jun 18, 2017
dec60e1
replaced 'text_lda' variable with 'text_ldaseq'
chinmayapancholi13 Jun 18, 2017
fd5fc90
updated 'testPersistence' test
chinmayapancholi13 Jun 19, 2017
e041431
set fixed seed in 'testPipeline' test
chinmayapancholi13 Jun 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions gensim/models/ldaseqmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class LdaSeqModel(utils.SaveLoad):
"""

def __init__(self, corpus=None, time_slice=None, id2word=None, alphas=0.01, num_topics=10,
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10,
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10,
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100):
"""
`corpus` is any iterable gensim corpus
Expand Down Expand Up @@ -389,7 +389,7 @@ def dtm_vis(self, time, corpus):
for doc_no, doc in enumerate(corpus):
for pair in doc:
term_frequency[pair[0]] += pair[1]

vocab = [self.id2word[i] for i in range(0, len(self.id2word))]
# returns np arrays for doc_topic proportions, topic_term proportions, and document_lengths, term_frequency.
# these should be passed to the `pyLDAvis.prepare` method to visualise one time-slice of DTM topics.
Expand All @@ -398,7 +398,7 @@ def dtm_vis(self, time, corpus):

def dtm_coherence(self, time):
"""
returns all topics of a particular time-slice without probabilitiy values for it to be used
returns all topics of a particular time-slice without probabilitiy values for it to be used
for either "u_mass" or "c_v" coherence.
"""
coherence_topics = []
Expand Down Expand Up @@ -487,9 +487,9 @@ def compute_post_variance(self, word, chain_variance):

Fwd_Variance(t) ≡ E((beta_{t,w} − mean_{t,w})^2 |beta_{t} for 1:t)
= (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * (fwd_variance[t - 1] + obs_variance)

Variance(t) ≡ E((beta_{t,w} − mean_cap{t,w})^2 |beta_cap{t} for 1:t)
= fwd_variance[t - 1] + (fwd_variance[t - 1] / fwd_variance[t - 1] + obs_variance)^2 * (variance[t - 1] - (fwd_variance[t-1] + obs_variance))
= fwd_variance[t - 1] + (fwd_variance[t - 1] / fwd_variance[t - 1] + obs_variance)^2 * (variance[t - 1] - (fwd_variance[t-1] + obs_variance))

"""
INIT_VARIANCE_CONST = 1000
Expand All @@ -506,7 +506,7 @@ def compute_post_variance(self, word, chain_variance):
c = 0
fwd_variance[t] = c * (fwd_variance[t - 1] + chain_variance)

# backward pass
# backward pass
variance[T] = fwd_variance[T]
for t in range(T - 1, -1, -1):
if fwd_variance[t] > 0.0:
Expand All @@ -516,7 +516,7 @@ def compute_post_variance(self, word, chain_variance):
variance[t] = (c * (variance[t + 1] - chain_variance)) + ((1 - c) * fwd_variance[t])

return variance, fwd_variance


def compute_post_mean(self, word, chain_variance):
"""
Expand All @@ -526,9 +526,9 @@ def compute_post_mean(self, word, chain_variance):

Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t )
= (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] + (1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta

Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T )
= fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) + (1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t]
= fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) + (1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t]

"""
T = self.num_time_slices
Expand All @@ -537,7 +537,7 @@ def compute_post_mean(self, word, chain_variance):
mean = self.mean[word]
fwd_mean = self.fwd_mean[word]

# forward
# forward
fwd_mean[0] = 0
for t in range(1, T + 1):
c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance)
Expand Down Expand Up @@ -644,7 +644,7 @@ def fit_sslm(self, sstats):

def compute_bound(self, sstats, totals):
"""
Compute log probability bound.
Compute log probability bound.
Forumula is as described in appendix of DTM by Blei. (formula no. 5)
"""
W = self.vocab_len
Expand Down Expand Up @@ -691,12 +691,12 @@ def compute_bound(self, sstats, totals):
val += term_2 + term_3 + ent - term_1

return val


def update_obs(self, sstats, totals):
"""
Function to perform optimization of obs. Parameters are suff_stats set up in the fit_sslm method.

TODO:
This is by far the slowest function in the whole algorithm.
Replacing or improving the performance of this would greatly speed things up.
Expand Down Expand Up @@ -725,7 +725,7 @@ def update_obs(self, sstats, totals):
if counts_norm < OBS_NORM_CUTOFF and norm_cutoff_obs is not None:
obs = self.obs[w]
norm_cutoff_obs = np.copy(obs)
else:
else:
if counts_norm < OBS_NORM_CUTOFF:
w_counts = np.zeros(len(w_counts))

Expand Down Expand Up @@ -753,10 +753,10 @@ def update_obs(self, sstats, totals):
self.obs[w] = obs

self.zeta = self.update_zeta()

return self.obs, self.zeta


def compute_mean_deriv(self, word, time, deriv):
"""
Used in helping find the optimum function.
Expand Down Expand Up @@ -842,7 +842,7 @@ def compute_obs_deriv(self, word, word_counts, totals, mean_deriv_mtx, deriv):
term1 = 0.0

deriv[t] = term1 + term2 + term3 + term4

return deriv
# endclass sslm

Expand Down Expand Up @@ -880,7 +880,7 @@ def update_phi(self, doc_number, time):
Update variational multinomial parameters, based on a document and a time-slice.
This is done based on the original Blei-LDA paper, where:
log_phi := beta * exp(Ψ(gamma)), over every topic for every word.

TODO: incorporate lee-sueng trick used in **Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001**.
"""
num_topics = self.lda.num_topics
Expand All @@ -904,7 +904,7 @@ def update_phi(self, doc_number, time):
v = np.logaddexp(v, log_phi_row[i])

# subtract every element by v
log_phi_row = log_phi_row - v
log_phi_row = log_phi_row - v
phi_row = np.exp(log_phi_row)
self.log_phi[n] = log_phi_row
self.phi[n] = phi_row
Expand Down Expand Up @@ -949,7 +949,7 @@ def compute_lda_lhood(self):

# to be used in DIM
# sigma_l = 0
# sigma_d = 0
# sigma_d = 0

lhood = gammaln(np.sum(self.lda.alpha)) - gammaln(gamma_sum)
self.lhood[num_topics] = lhood
Expand Down Expand Up @@ -979,8 +979,8 @@ def compute_lda_lhood(self):

return lhood

def fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED = 1e-8,
lda_inference_max_iter = 25, g=None, g3_matrix=None, g4_matrix=None, g5_matrix=None):
def fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED=1e-8,
lda_inference_max_iter=25, g=None, g3_matrix=None, g4_matrix=None, g5_matrix=None):
"""
Posterior inference for lda.
g, g3, g4 and g5 are matrices used in Document Influence Model and not used currently.
Expand All @@ -989,7 +989,7 @@ def fit_lda_post(self, doc_number, time, ldaseq, LDA_INFERENCE_CONVERGED = 1e-8,
self.init_lda_post()
# sum of counts in a doc
total = sum(count for word_id, count in self.doc)

model = "DTM"
if model == "DIM":
# if in DIM then we initialise some variables here
Expand Down Expand Up @@ -1067,7 +1067,7 @@ def f_obs(x, *args):
term2 = 0

# term 3 and 4 for DIM
term3 = 0
term3 = 0
term4 = 0

sslm.obs[word] = x
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def f_obs(x, *args):
pass

if sslm.chain_variance > 0.0:

term1 = - (term1 / (2 * sslm.chain_variance))
term1 = term1 - mean[0] * mean[0] / (2 * init_mult * sslm.chain_variance)
else:
Expand All @@ -1122,4 +1122,3 @@ def df_obs(x, *args):
deriv = sslm.compute_obs_deriv_fixed(p.word, p.word_counts, p.totals, p.sslm, p.mean_deriv_mtx, deriv)

return np.negative(deriv)

95 changes: 95 additions & 0 deletions gensim/sklearn_integration/sklearn_wrapper_gensim_ldaseqmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <radimrehurek@seznam.cz>
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Scikit learn interface for gensim for easy use of gensim with scikit-learn
Follows scikit-learn API conventions
"""

import numpy as np
from sklearn.base import TransformerMixin, BaseEstimator

from gensim import models
from gensim.sklearn_integration import base_sklearn_wrapper


class SklLdaSeqModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator):
"""
Base LdaSeq module
"""

def __init__(self, time_slice=None, id2word=None, alphas=0.01, num_topics=10,
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10,
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100):
"""
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel
"""
self.gensim_model = None
self.time_slice = time_slice
self.id2word = id2word
self.alphas = alphas
self.num_topics = num_topics
self.initialize = initialize
self.sstats = sstats
self.lda_model = lda_model
self.obs_variance = obs_variance
self.chain_variance = chain_variance
self.passes = passes
self.random_state = random_state
self.lda_inference_max_iter = lda_inference_max_iter
self.em_min_iter = em_min_iter
self.em_max_iter = em_max_iter
self.chunksize = chunksize

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"time_slice": self.time_slice, "id2word": self.id2word,
"alphas": self.alphas, "num_topics": self.num_topics, "initialize": self.initialize,
"sstats": self.sstats, "lda_model": self.lda_model, "obs_variance": self.obs_variance,
"chain_variance": self.chain_variance, "passes": self.passes, "random_state": self.random_state,
"lda_inference_max_iter": self.lda_inference_max_iter, "em_min_iter": self.em_min_iter,
"em_max_iter": self.em_max_iter, "chunksize": self.chunksize}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklLdaSeqModel, self).set_params(**parameters)

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Calls gensim.models.LdaSeqModel
"""
self.gensim_model = models.LdaSeqModel(corpus=X, time_slice=self.time_slice, id2word=self.id2word,
alphas=self.alphas, num_topics=self.num_topics, initialize=self.initialize, sstats=self.sstats,
lda_model=self.lda_model, obs_variance=self.obs_variance, chain_variance=self.chain_variance,
passes=self.passes, random_state=self.random_state, lda_inference_max_iter=self.lda_inference_max_iter,
em_min_iter=self.em_min_iter, em_max_iter=self.em_max_iter, chunksize=self.chunksize)
return self

def transform(self, docs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chek case, when you create instance and call transform immediately (without fit), you need to raise exception like sklearn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please add an example of docs param in docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menshikh-iv For checking if the model has been fitted, would it be a good idea to check if self.gensim_model is None or not? This approach would clearly give an error when fit hasn't been called before calling transform but this also allows the user to set the value of self.gensim_model through set_params function (or even as wrapper.gensim_model=...) and then call transform function, which makes sense for us to allow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely forgot about set_param, so, I think if you disable gensim_model in set_param, you can check model is None (it does not cover all cases, but covers the most obvious)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate the meaning of "disabling" gensim_model param from the function set_params?
Actually, gensim_model is a public attribute of the model so it can be set like ldaseq_wrapper.gensim_model = some_model, which is almost the same as using set_params function to set this value. So, checking whether self.gensim_model is None should be enough, right?
This would be like :

    def transform(self, docs):
        """
        Return the topic proportions for the documents passed.
        """
        if self.gensim_model is None:
            raise NotFittedError("This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method.")

        # The input as array of array
        check = lambda x: [x] if isinstance(x[0], tuple) else x
        ..........................................................................
        ..........................................................................
        ..........................................................................
        ..........................................................................

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, as a temporary option.

"""
Return the topic proportions for the documents passed.
"""
# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for _ in range(0, len(docs))]

for k, v in enumerate(docs):
transformed_author = self.gensim_model[v]
# Everything should be equal in length
if len(transformed_author) != self.num_topics:
transformed_author.extend([1e-12] * (self.num_topics - len(transformed_author)))
X[k] = transformed_author

return np.reshape(np.array(X), (len(docs), self.num_topics))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for the LDA Seq model")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LDA Seq model -> SklLdaSeqModel

Loading