-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MADLIB-1351 : Added stopping criteria on perplexity to LDA #432
Conversation
Refer to this link for build results (access rights to CI server needed): |
@hpandeycodeit |
@kaknikhil I will add the tests cases soon to the PR. |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
# the Model and Output Table | ||
if self.evaluate_every > 0: | ||
self.perplexity.append( | ||
get_perplexity('madlib', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The schema should not be hard coded to 'madlib' in all the places that call get_perplexity. Use the schema_madlib
variable instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
prep_string = "" | ||
prep_itr_str = "" | ||
if len(self.perplexity) > 0: | ||
prep_string = ", " + py_list_to_sql_string(self.perplexity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use .format
instead of +
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
END; | ||
$$ LANGUAGE plpgsql; | ||
|
||
select assert(validate_perplexity() = TRUE, 'Perplexity calculation is wrong'); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing new line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
'lda_training', | ||
'lda_model', | ||
'lda_output_data', | ||
20, 5, 2, 10, 0.01, 2, .2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add the column name as a comment after each of these numbers to make it more readable and also add a new line after each argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# JIRA: MADLIB-1351 | ||
# If the Perplexity_diff is less than the perplexity_tol, | ||
# Stop the iteration | ||
if self.perplexity_diff < self.perplexity_tol: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add a test case for this condition. Either unit test or dev check
'lda_output_data', | ||
20, 5, 3, 10, 0.01, 1, .1); | ||
|
||
SELECT assert(cardinality(perplexity) = 3, 'Perplexity calculation is wrong') from lda_model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the cardinality function available in gpdb 4.3. If not then we should replace it by something like array_upper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
---------- TEST CASES FOR PERPLEXITY ---------- | ||
|
||
drop table if exists lda_model, lda_output_data; | ||
SELECT lda_train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add few more test cases. In all these case we need to assert that we calculated the perplexity at the right iteration.
- no_of_iterations % evaluate_every != 0.
- both no_of_iters and evaluate_every = 1
- no_of_iterations % evaluate_every == 0 and no_of_iterations != evaluate_every
- Set evaluate_every to 0 and -1
- When perplexity_tol is reached before finishing all the iterations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added tests for 2 and 4. There are few outstanding tests like 1,3 and 5 for which I need some more clarity. I will discuss with you on that.
# JIRA: MADLIB-1351 | ||
# Calculate Perplexity for evaluate_every Iteration | ||
# Skil the calculation at the first iteration as the model generated | ||
# at first iteration is a random model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be more verbose in this comment. Something like (but definitely not limited to)
For each iteration
1. Model table is updated (for the first iteration, it is the random model. For iteration >1 , the model that is updated is learnt in the previous iteration)
1. __lda_count_topic_agg is called
1. then lda_gibbs_sample is called which learns and updates the model(the updated model is not passed to python. The learnt model is updated in the next iteration)
Because of this workflow we can safely ignore the first perplexity value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
# Calculate Perplexity for evaluate_every Iteration | ||
# Skil the calculation at the first iteration as the model generated | ||
# at first iteration is a random model | ||
if it > self.evaluate_every and self.evaluate_every > 0 and ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- we already assert that evaluate_every >=0 (line 514) , we don't need to repeat this check.
- Unless I am missing something, the whole if check can be simplified by skipping the perplexity calculation when
it == 0
instead of usingit
andit-1
. - We could move this code logic (lines 206 - 216) to it's own function and unit test all the logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is appending values in perplexity_iters ( it - 1) : perplexity_iters[0] = it-1;
and Moved the code to a seperate function
@@ -445,6 +511,12 @@ def lda_train(schema_madlib, train_table, model_table, output_data_table, voc_si | |||
'invalid argument: positive real expected for alpha') | |||
_assert(beta is not None and beta > 0, | |||
'invalid argument: positive real expected for beta') | |||
_assert(evaluate_every is not None and evaluate_every >= 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user docs for evaluate_every mention Set it to 0 or negative number to not evaluate perplexity in training at all
but this check will throw an exception for evaluate_every < 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed this check as we are not calculating the perplexity for 0 or -1.
Few more general comments
|
Refer to this link for build results (access rights to CI server needed): |
prep_string = "" | ||
prep_itr_str = "" | ||
if len(self.perplexity) > 1: | ||
prep_string = ", {0}".format(py_list_to_sql_string(self.perplexity)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we give these 2 variables better names ? What does prep mean (perplexity ??) ??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed the names here.
if it > self.evaluate_every and self.evaluate_every > 0 and ( | ||
it - 1) % self.evaluate_every == 0: | ||
self.gen_output_data_table(work_table_in) | ||
perplexity = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
perplexity = get_perplexity(self.schema_madlib, | ||
self.model_table, | ||
self.output_data_table) | ||
self.perplexity_diff = abs(self.perplexity[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactor self.perplexity[len(self.perplexity) - 1]
as self.perplexity[-1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
@@ -288,3 +288,126 @@ CREATE OR REPLACE FUNCTION validate_lda_output() RETURNS integer AS $$ | |||
$$ LANGUAGE plpgsql; | |||
|
|||
select validate_lda_output(); | |||
|
|||
|
|||
---------- TEST CASES FOR PERPLEXITY ---------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider adding a description at the beginning of each test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One liner headings are already present for every test case. Let me know if you think putting more details is a good idea.
'lda_training', | ||
'lda_model', | ||
'lda_output_data', | ||
20, 5, 2, 10, 0.01, 2, .2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as before
maybe add the column name as a comment after each of these numbers to make it more readable and also add a new line after each argument
'lda_output_data', | ||
20, 5, 2, 10, 0.01, 2, .2); | ||
|
||
SELECT assert(perplexity_iters = '{2}', 'Number of Perplexity iterations are wrong') from lda_model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We can also assert the len of the perplexity values.
- Since we cannot deterministically assert the perplexity value itself, we should at least assert that all the perplexity values > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the test cases for above as discussed.
.1 -- perplexity_tol | ||
); | ||
|
||
SELECT assert(array_upper(perplexity,1) = 3, 'Perplexity calculation is wrong') from lda_model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should assert the value of perplexity_iters
here and also that all perplexity values are > 0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test for this as well.
.1 -- perplexity_tol | ||
); | ||
|
||
select assert(perplexity = '{}', 'Perplexity calculation is wrong') from lda_model; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If evaluate_every=1, why do we expect the perplexity array to be empty ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this one.
Refer to this link for build results (access rights to CI server needed): |
@hpandeycodeit the jenkins build is failing for the latest commit. Can you take a look ? |
Can you also add a test for perplexity_tol ? |
fixed these. |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
select assert(abs(perplexity[2] - perplexity[1]) <10, 'Perplexity tol is less than the perplexity difference') from lda_model ; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we checking for <10
if the tol is 100
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think can add another assert to all the dev-check tests assert that all the perplexity values are unique
. What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean if the length of the calculated perplexity values matches the distinct perplexity values?
fixed other issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I mean adding an assert to check that all the perplexity values are different
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test case for distinct perplexity values as discussed.
<dt>evaluate_every</dt> | ||
<dd>int, optional (default=0). How often to evaluate perplexity. Set it to 0 or negative number to not evaluate perplexity in training at all. Evaluating perplexity can help you check convergence in training process, but it will also increase total training time. Evaluating perplexity in every iteration might increase training time up to two-fold.</dd> | ||
<dt>perplexity_tol</dt> | ||
<dd>float, optional (default=1e-1). Perplexity tolerance to stop iterating. Only used when evaluate_every is greater than 0.</dd> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe @fmcquillan99 can add a more verbose explanation here.
@@ -438,7 +444,9 @@ select assert(array_upper(perplexity_iters,1) <= 5, 'Perplexity iterations are d | |||
select assert(perplexity[1] > 0 , 'Perplexity value should be greate than 0') from lda_model ; | |||
|
|||
|
|||
-- Test to check if the perplexity_toll is greater than the difference between two perplexity iterations -- | |||
-- Test: If the difference between the last two iterations is less than the perplexity_tol, the iterations training will stop -- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of saying last two iterations
we can just say If the perplexity difference between any two iterations is less than the perplexity_tol, we will stop training.
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
(1)
(2)
produces
|
(3)
Train
Predict on input data
I would expect this to be |
(4)
produces
|
@fmcquillan99 I don't see a verbose output when I am running the above query. Are you running it in GPDB or postgres?
|
@hpandeycodeit I was running on GP5 from psql |
So this is not in LDA code. This is the part of GPDB 5. If a table does not have stats, it prints out the messages about the no stats. Once the stats are updated(run analyze on these tables), and run the above sql again, these messages disappear. |
In However, if the same output table(generated by
Generates the following perplexity values with the last perplexity value 179.380131412:
Now running the
which matches the last perplexity value calculated by Thanks! |
This is fixed. |
Refer to this link for build results (access rights to CI server needed): |
(5)
I think |
(6)
Please implement as per
|
Fixed this and num_iterations. |
Refer to this link for build results (access rights to CI server needed): |
Re-test after latest commits (1)
Now looks like:
OK (2)
produces
OK (3)
Train
Perplexity on input data
which matches the last value in the array for the training function. OK (6) still has an issue
This should be the same results as:
which actually does work if you put |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
@fmcquillan99 Fixed the issue with the Null handling on the last param. |
@hpandeycodeit
|
I checked (6) after the last commit and it works now. So LGTM on functionality. |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
Added the test cases for 2 and 3. There was already a test case covering scenario 1. |
Refer to this link for build results (access rights to CI server needed): |
Refer to this link for build results (access rights to CI server needed): |
LGTM |
@@ -474,3 +474,89 @@ select assert(array_upper(perplexity_iters,1) = 2, 'Perplexity iterations are d | |||
select assert(perplexity[1] > 0 , 'Perplexity value should be greate than 0') from lda_model ; | |||
select assert(array_upper(ARRAY(Select distinct unnest(perplexity)),1)= array_upper(perplexity,1) , 'Perplexity values should be unique') from lda_model ; | |||
|
|||
|
|||
-- Test for evaluate_every = 1 and 0 : In this case the iterations should not stop early -- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hpandeycodeit
I can't find the test for evaluate_every = 0. Am i missing something ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when evaluate_every = NULL (it takes the default evaluate_every=0) and in that case, we don't calculate perplexity. We have a test case for covering evaluate_every = NULL.
Prior to this commit, in LDA there are no stopping criteria. It runs for all the provided iterations. This commit calculates the perplexity on each iteration and when the difference between the last two perplexity values is less than the perplexity_tol, it stops the iteration. These are the two new parameters added to the function: ``` evaluate_every INTEGER, perplexity_tol DOUBLE PRECISION ``` Also, there is a change to the model output table. The following new columns are added: 1. perplexity(DOUBLE PRECISION[]): is an array of perplexity values as per the 'evaluate_every' parameter. 2. perplexity_iters(INTEGER[]): is an Array indicating the iterations for which perplexity is calculated
Refer to this link for build results (access rights to CI server needed): |
LDA:
Added stopping criteria on perplexity to LDA.
MADLIB-1351
Currently, in LDA there are no stopping criteria. It runs for all the provided iterations.
This PR calculated the Perplexity on each iteration and when the difference between the last two Perplexity values is less than the perplexity_tol, it stops the iteration.
These are the two new parameters added to the function:
And there is a change to the Model output table as well. It will have these two extra columns
Where
perplexity is an Array of perplexity values as per the 'evaluate_every' parameter.
perplexity_iters is an Array indicating the iterations for which perplexity is calculated