Skip to content

Commit

Permalink
Fix unpacking bug in gaussian_metric
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard authored and rlouf committed Oct 14, 2021
1 parent 4155344 commit f413056
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aehmc/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def gaussian_metric(
"""

if inverse_mass_matrix.ndim == 0:
shape = ()
shape: Tuple = ()
mass_matrix_sqrt = aet.sqrt(aet.reciprocal(inverse_mass_matrix))
dot, matmul = lambda x, y: x * y, lambda x, y: x * y
elif inverse_mass_matrix.ndim == 1:
shape = shape_tuple(inverse_mass_matrix)[0]
shape = (shape_tuple(inverse_mass_matrix)[0],)
mass_matrix_sqrt = aet.sqrt(aet.reciprocal(inverse_mass_matrix))
dot, matmul = lambda x, y: x * y, lambda x, y: x * y
elif inverse_mass_matrix.ndim == 2:
shape = shape_tuple(inverse_mass_matrix)[0]
shape = (shape_tuple(inverse_mass_matrix)[0],)
tril_inv = slinalg.cholesky(inverse_mass_matrix)
identity = aet.eye(*shape)
mass_matrix_sqrt = slinalg.solve_lower_triangular(tril_inv, identity)
Expand Down

0 comments on commit f413056

Please sign in to comment.