Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement censored log-probabilities via the Clip Op #22

Merged
merged 2 commits into from Nov 10, 2021

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Jun 22, 2021

This PR implements logprob for censored (clipped) RVs.

x_rv = at.random.normal(0, 1)
cens_x_rv = at.clip(x_rv, -1, 1)
cens_x = cens_x_rv.type()
logp = ppl.joint_logprob({cens_x_rv: cens_x})

I placed the new methods and tests inside truncation.py, expecting this file wil also contain the methods for truncated RVs in the future.


Some things are still not working well / missing:

  • Issues with broadcasting / sizes (see xfailed tests)
  • Propagate names in logp graph
  • Canonicalize set_subtensors to clip x[x>ub] = ub -> clip(x, x, ub) Will do in another PR
  • Add tests for logcdf methods
  • Compute test values for new nodes Seems to not be necessary, tests pass with compute_test_value="raise"
  • Explore if CensoredRVs should be created even when they don't have a direct value variable (e.g, so that they can work as input to other derivedRVs) Postponed

aeppl/truncation.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Jun 22, 2021

Codecov Report

Merging #22 (3300159) into main (18275e2) will decrease coverage by 0.07%.
The diff coverage is 94.39%.

❗ Current head 3300159 differs from pull request most recent head 4830635. Consider uploading reports for the commit 4830635 to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main      #22      +/-   ##
==========================================
- Coverage   94.92%   94.84%   -0.08%     
==========================================
  Files           9        8       -1     
  Lines        1260     1106     -154     
  Branches      164      133      -31     
==========================================
- Hits         1196     1049     -147     
+ Misses         31       27       -4     
+ Partials       33       30       -3     
Impacted Files Coverage Δ
aeppl/logprob.py 98.54% <85.18%> (-1.06%) ⬇️
aeppl/truncation.py 97.40% <97.40%> (ø)
aeppl/opt.py 93.93% <100.00%> (+1.08%) ⬆️
aeppl/transforms.py 93.87% <0.00%> (-0.54%) ⬇️
aeppl/mixture.py 98.96% <0.00%> (-0.02%) ⬇️
aeppl/abstract.py 100.00% <0.00%> (ø)
aeppl/cumsum.py
aeppl/scan.py
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 18275e2...4830635. Read the comment docs.

@ricardoV94 ricardoV94 changed the title Censor rvs Censored RVs Jun 22, 2021
@brandonwillard brandonwillard added the important This label is used to indicate priority over things not given this label label Jun 22, 2021
aeppl/truncation.py Outdated Show resolved Hide resolved
aeppl/truncation.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 force-pushed the censor_rvs branch 2 times, most recently from c5e7da1 to 46fac13 Compare June 25, 2021 19:51
aeppl/truncation.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member

To make broadcasting work, you should be able to use something like at.broadcast_arrays(value, lower_bound, upper_bound) in the log-probability calculation.

@brandonwillard
Copy link
Member

  • Canonicalize set_subtensors to clip x[x>ub] = ub -> clip(x, x, ub)

That's a very interesting idea! Can it be combined with other indices (e.g. x[idx_1, ..., x > ub, ..., idx_N] = ...)?

Regardless, don't make this PR conditional on such extensions. Let's get at.clip working and merged first.

@brandonwillard
Copy link
Member

brandonwillard commented Jun 29, 2021

Tip: if you assign names to your test Variables, it can be much easier to read the aesara.dprint output when debugging, especially when a large graph is produces that corresponds to something in particular (e.g. the logprob of a specific RandomVariable output). Same with Variables created in the functions you're working on (e.g. in censor_logprob and/or censor_rvs).

Also, don't forget about test values! They will cause errors to arise during graph construction (i.e. where the symbolic objects themselves are defined). That combined with aesara.config.print_test_value = True can make things a lot easier.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 29, 2021

To make broadcasting work, you should be able to use something like at.broadcast_arrays(value, lower_bound, upper_bound) in the log-probability calculation.

Still couldn't fix my test_broadcasted_censoring failing tests.
I seem to be loosing the reference to the lb_rv before I even get to the new censor_rvs opt.

lb_rv = at.random.uniform(0, 1, name="lb_rv")
x_rv = at.random.normal(0, 2, name="x_rv")
cens_x_rv = at.clip(x_rv, lb_rv, [1, 1])
cens_x_rv.name = "cens_x_rv"

lb = lb_rv.type()
lb.name = "lb"
cens_x = cens_x_rv.type()
cens_x.name = "cens_x"

logp = joint_logprob(cens_x_rv, {cens_x_rv: cens_x, lb_rv: lb})
assert_no_rvs(logp)

These are the fgraphs printouts before and after the optimization phase in joint_logprob:

Before:
Elemwise{clip,no_inplace} [id A] 'cens_x_rv'   4
 |InplaceDimShuffle{x} [id B] ''   3
 | |normal_rv.1 [id C] 'x_rv'   2
 |   |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F64081A3440>) [id D]
 |   |TensorConstant{[]} [id E]
 |   |TensorConstant{11} [id F]
 |   |TensorConstant{0} [id G]
 |   |TensorConstant{2} [id H]
 |InplaceDimShuffle{x} [id I] ''   1
 | |uniform_rv.1 [id J] 'lb_rv'   0
 |   |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F642C9B3A40>) [id K]
 |   |TensorConstant{[]} [id L]
 |   |TensorConstant{11} [id M]
 |   |TensorConstant{0} [id N]
 |   |TensorConstant{1} [id O]
 |TensorConstant{(2,) of 1} [id P]

After:
censored_rv.1 [id A] 'cens_x_rv'   5
 |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F64081A3440>) [id B]
 |TensorConstant{[]} [id C]
 |TensorConstant{11} [id D]
 |InplaceDimShuffle{x} [id E] ''   4
 | |TensorConstant{0} [id F]
 |InplaceDimShuffle{x} [id G] ''   3
 | |TensorConstant{2} [id H]
 |uniform_rv.1 [id I] ''   2
 | |RandomStateSharedVariable(<RandomState(MT19937) at 0x7F642C9B3A40>) [id J]
 | |TensorConstant{[]} [id K]
 | |TensorConstant{11} [id L]
 | |InplaceDimShuffle{x} [id M] ''   1
 | | |TensorConstant{0} [id N]
 | |InplaceDimShuffle{x} [id O] ''   0
 |   |TensorConstant{1} [id P]
 |TensorConstant{(2,) of 1} [id Q]

That uniform_rv.1 in the "After graph" is no longer related to the original one.
This is caused by the local_dimshuffle_rv_lift. During interactive debugging:

lower_bound
# uniform_rv.out

lower_bound in rv_map_feature.rv_values
# False

lower_bound.owner.tag
# scratchpad{'imported_by': ['local_dimshuffle_rv_lift']}

Also that opt seems to not propagate the variable name

lower_bound.name
# None

@ricardoV94
Copy link
Contributor Author

Also, don't forget about test values! They will cause errors to arise during graph construction (i.e. where the symbolic objects themselves are defined). That combined with aesara.config.print_test_value = True can make things a lot easier.

If I set tthe flag to "warn", even a simple unform logp raises a lot of "Cannot compute test value..." for every node in the logp. Is this something we need to address?

@aesara.config.change_flags(compute_test_value='warn')
def test_compute_test_value():
    x_rv = at.random.uniform(-1, 1)
    x = x_rv.type()
    logp = joint_logprob(x_rv, {x_rv: x})
=============================== warnings summary ===============================
tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/venv/lib/python3.8/site-packages/aesara/graph/op.py:272: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{second,no_inplace}(<TensorType(float64, scalar)>, Elemwise{neg,no_inplace}.0) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{le,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{1}) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (<TensorType(float64, scalar)>) of Op Elemwise{ge,no_inplace}(<TensorType(float64, scalar)>, TensorConstant{-1}) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{ge,no_inplace}.0) of Op Elemwise{and_,no_inplace}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0) missing default value
    compute_test_value(node)

tests/test_truncation.py::test_compute_test_value
  /home/ricardo/Documents/Projects/aeppl/aeppl/joint_logprob.py:170: UserWarning: Warning, Cannot compute test value: input 0 (Elemwise{and_,no_inplace}.0) of Op Elemwise{switch,no_inplace}(Elemwise{and_,no_inplace}.0, Elemwise{second,no_inplace}.0, TensorConstant{-inf}) missing default value
    compute_test_value(node)

-- Docs: https://docs.pytest.org/en/stable/warnings.html
======================== 1 passed, 10 warnings in 0.97s ========================

Process finished with exit code 0
PASSED                       [100%]

@ricardoV94 ricardoV94 force-pushed the censor_rvs branch 8 times, most recently from 9dd9052 to cfc3fbe Compare June 29, 2021 13:09
@brandonwillard
Copy link
Member

Is this something we need to address?

It looks like you need to set s test value for x, or use x_rv.clone().

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Jun 30, 2021

I added tests for the logcdf methods and checked whether the new opts work with compute_test_value="raise".

The only thing missing are the broadcasting / local_dimshuffle_rv_lift related issues I described in #22 (comment)

@brandonwillard
Copy link
Member

Is the broadcasting still the blocking issue/change here?

@ricardoV94
Copy link
Contributor Author

Is the broadcasting still the blocking issue/change here?

No, I was trying an alternative that did not involve subclassing from RandomVariable. I'll try to get this back on board soon.

@ricardoV94 ricardoV94 force-pushed the censor_rvs branch 2 times, most recently from 38f6d75 to d3fe4c0 Compare October 17, 2021 16:14
@ricardoV94 ricardoV94 marked this pull request as ready for review October 17, 2021 16:17
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the use of tag.ignore_logprob, this looks great.

Let's find a way to avoid using that feature, especially since nothing should be depending on it now and it's slated to be removed entirely.

Otherwise, if you want, submit the logcdf additions as a separate PR and we can push those through sooner.

aeppl/joint_logprob.py Outdated Show resolved Hide resolved
aeppl/truncation.py Outdated Show resolved Hide resolved
Comment on lines 218 to 223
# Filter out missing terms of variables with ignore_logprob
value_rvs = {v: k for k, v in updated_rv_values.items()}
for missing in tuple(missing_value_terms):
rv_of_missing = value_rvs.get(missing, None)
if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False):
missing_value_terms.remove(missing)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If/when wee abandon the ignore_logprob this section can be removed. I had to add it for backwards compatibility with some tests in test_joint_logrob

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base and test_fail_base_and_censored_have_values. If you remove those, nothing will depend on this functionality and the commit can be removed.

This needs to be done before merging, if only because the changes are unrelated to censoring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.

While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you don't want this type of check at the end of factorized_joint_logprob?

It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I said something could be removed later, I meant the specific ignore_logprob flag logic, not the raising a RuntimeError if a variable is missing.

While developing the censored variables it would often fail silently and just return a graph with aesara clips unchanged and/or less terms than requested. This seems like a good way to catch such failures.

From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.

Otherwise, if something is failing silently the first question is "What's failing?". Is it the factorized_joint_logprob loop? If not, the failure should be addressed closer to where its primary logic/code resides, and that doesn't appear to be here.

Is there a reason why you don't want this type of check at the end of factorized_joint_logprob?

It's trivial to manually do the same check in those new tests I added, but this explicit check might be valuable enough to have as a default.

The reason why I don't want these kinds of unrelated changes is that their inclusion makes a PR contingent on additional review work and discussions.

It takes extra time and effort to go through logic like this and determine its relevance, risk, etc. These are things that need to be done within issues and/or at the outset of a PR (e.g. the premise/description of a PR) in order to avoid delaying the inclusion of any agreed upon changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a simple development and design perspective, it sounds like you're addressing a testing-specific issue within a feature implementation, and that's generally not good.

What happened was that the test revealed that such a call to factorized_logprob would return with a missing logp term and zero complaints so I decided to add an explicit check there.

It was not for the sake of the test as that had already been solved and could be tested explicitly inside the test itself.

It was meant for further development when we introduce rewrites for Ops that are otherwise valid in logp graphs. It's also a conceptual obvious check for me: a user requested a dictionary of rv_values and we make sure we are returning a dictionary with a item for each original pair.

I don't mind splitting this into another PR

Copy link
Contributor Author

@ricardoV94 ricardoV94 Nov 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the commit but I am not very happy with the fact that this does not raise an error inside factorized_joint_logprob:

def test_fail_base_and_censored_have_values():
    """Test failure when both base_rv and clipped_rv are given value vars"""
    x_rv = at.random.normal(0, 1)
    cens_x_rv = at.clip(x_rv, x_rv, 1)
    cens_x_rv.name = "cens_x"

    x_vv = x_rv.clone()
    cens_x_vv = cens_x_rv.clone()
    logp_terms = factorized_joint_logprob({cens_x_rv: cens_x_vv, x_rv: x_vv})
    assert cens_x_vv not in logp_terms

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, why should the conditions be checked and the error be raised in factorized_joint_logprob specifically?

The two terms involved are very specific to the censored variable logic, and it looks like the error could've been initiated in find_censored_rvs—i.e. where all the relevant terms are identified and used. This approach could also short-circuit all the unnecessary down-stream logic, no?

You already have a warning there to that effect, so what do we gain by having an exception in factorized_joint_logprob?

Copy link
Contributor Author

@ricardoV94 ricardoV94 Nov 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should be the rewrite responsibility to raise a failure. It may be there is another rewrite (e.g., added by users of the library) that can handle the conversion.

The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:

import aesara.tensor as at
import aeppl

x_rv = at.random.normal(name='x')
y_rv = at.cos(x_rv)

x_vv = x_rv.clone()
y_vv = y_rv.clone()

logprob_dict = aeppl.factorized_joint_logprob({x_rv: x_vv, y_rv: y_vv})
logprob_dict
# {x: x_logprob}

This snippet would be more realistic about what a user may try but is now failing for a different reason #87

logprob_dict = aeppl.factorized_joint_logprob({y_rv: y_vv})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bigger problem is that nothing happens if you ask for a graph that we don't know how to handle. That is not specific to censoredRVs, we just haven't tested it. For instance, this snippet does not complain at all:

If our rewrites don't know how to handle something, that's not necessarily a problem. As a matter of fact, we expect that they won't know how to handle more things than they do.

The assumption underlying your statements and example seems to be that you know what should be done relative to specific rewrites, and this is what makes it reasonable to handle rewrite-relevant errors in the rewrite logic.

In other words, if you know an error/warning should be raised because a value variable specification is redundant, you only really know that because you also know that there's a specific rewrite that determines which value variables are and aren't relevant.

Otherwise, a generic warning for "unused" variable/value mappings is simply an interface choice that might help inform people of issues elsewhere and/or bad assumptions (e.g. that the resulting graph will depend on certain terms), but that's all.

All this relates directly to #85.

Comment on lines 218 to 223
# Filter out missing terms of variables with ignore_logprob
value_rvs = {v: k for k, v in updated_rv_values.items()}
for missing in tuple(missing_value_terms):
rv_of_missing = value_rvs.get(missing, None)
if rv_of_missing and getattr(rv_of_missing.tag, "ignore_logprob", False):
missing_value_terms.remove(missing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you added two tests in a different commit that explicitly require this functionality: test_fail_multiple_censored_single_base and test_fail_base_and_censored_have_values. If you remove those, nothing will depend on this functionality and the commit can be removed.

This needs to be done before merging, if only because the changes are unrelated to censoring.

tests/test_truncation.py Outdated Show resolved Hide resolved
Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the same tests as before, but change the pytest.raises to look for the warnings instead.

@brandonwillard brandonwillard merged commit ec6535f into aesara-devs:main Nov 10, 2021
@brandonwillard brandonwillard added the op-probability Involves the implementation of log-probabilities for Aesara `Op`s label Nov 10, 2021
@brandonwillard brandonwillard changed the title Censored RVs Implement censored log-probabilities via the Clip Op Nov 10, 2021
@ricardoV94 ricardoV94 deleted the censor_rvs branch November 29, 2021 11:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important This label is used to indicate priority over things not given this label op-probability Involves the implementation of log-probabilities for Aesara `Op`s
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants