Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve numerical stability of censored logps #156

Merged
merged 1 commit into from
Jul 22, 2022

Conversation

aseyboldt
Copy link
Contributor

Not using the more stable at.log1mexp instead of at.log(1 - at.exp(x)) can lead to numerical issues in the gradient, because the stabilization optimizations are applied after the gradient. eg something like:

import aesara
import aesara.tensor as at
import numpy as np

x = at.scalar("x")
a = at.scalar("a")
b = at.scalar("b")

logcdf = pm.logcdf(pm.Weibull.dist(a, b), x)
logccdf = at.log(1 - at.exp(logcdf))
logccdf2 = at.log1mexp(logcdf)

logcdf.name = "logcdf"
logccdf.name = "logccdf"

func = aesara.function([x, a, b], [at.grad(logccdf, a), at.grad(logccdf2, a)])

func(50, 1, 1)
# [array(-inf), array(-195.60115027)]

@codecov
Copy link

codecov bot commented Jul 22, 2022

Codecov Report

Merging #156 (9bf8c24) into main (44e441f) will not change coverage.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main     #156   +/-   ##
=======================================
  Coverage   94.88%   94.88%           
=======================================
  Files          12       12           
  Lines        1780     1780           
  Branches      263      263           
=======================================
  Hits         1689     1689           
  Misses         51       51           
  Partials       40       40           
Impacted Files Coverage Δ
aeppl/truncation.py 98.27% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 44e441f...9bf8c24. Read the comment docs.

@ricardoV94 ricardoV94 merged commit 8f65562 into aesara-devs:main Jul 22, 2022
@brandonwillard
Copy link
Member

Why wasn't this covered by a stabilizing rewrite?

@ricardoV94
Copy link
Contributor

ricardoV94 commented Jul 22, 2022

The only stabilizing rewrite we have is to introduce the log1mexp Op, but we don't have anything for the grad of the general graph of log1mexp. The grad of log1mexp was tweaked sometime ago here: aesara-devs/aesara#725

I am not sure we can or want to apply those tweaks to all graphs of that kind, or if they only work when we know they came from a grad of a log1mexp

It would help if grad was not eager and we could run some rewrites before introducing their graphs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants