Skip to content

Commit

Permalink
Update deprecations and set filterwarnings to error during testing
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Aug 22, 2022
1 parent 1bd2f56 commit dde2d83
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
6 changes: 3 additions & 3 deletions aehmc/metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Tuple

import aesara.tensor as at
import aesara.tensor.slinalg as slinalg
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import shape_tuple
from aesara.tensor.slinalg import cholesky, solve_triangular
from aesara.tensor.var import TensorVariable


Expand Down Expand Up @@ -51,9 +51,9 @@ def gaussian_metric(
dot, matmul = at.dot, lambda x, y: x * y
elif inverse_mass_matrix.ndim == 2:
shape = (shape_tuple(inverse_mass_matrix)[0],)
tril_inv = slinalg.cholesky(inverse_mass_matrix)
tril_inv = cholesky(inverse_mass_matrix)
identity = at.eye(*shape)
mass_matrix_sqrt = slinalg.solve_lower_triangular(tril_inv, identity)
mass_matrix_sqrt = solve_triangular(tril_inv, identity, lower=True)
dot, matmul = at.dot, at.dot
else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion aehmc/proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def update(initial_energy, state):

delta_energy = initial_energy - new_energy
delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy)
is_transition_divergent = at.abs_(delta_energy) > divergence_threshold
is_transition_divergent = at.abs(delta_energy) > divergence_threshold

weight = delta_energy
log_p_accept = at.where(
Expand Down
2 changes: 1 addition & 1 deletion aehmc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aesara.graph.basic import Variable, ancestors
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.utils import rewrite_graph
from aesara.tensor.rewriting.shape import ShapeFeature
from aesara.tensor.rewriting.basic import ShapeFeature
from aesara.tensor.var import TensorVariable


Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ convention = numpy
[tool:pytest]
python_files=test*.py
testpaths=tests
filterwarnings=
error:::aesara
error:::aeppl
error:::aemcmc
ignore:::xarray

[coverage:run]
omit =
Expand Down

0 comments on commit dde2d83

Please sign in to comment.