Skip to content

Commit

Permalink
Fix deprecation warning (#13)
Browse files Browse the repository at this point in the history
Replace `triangular_solve()` with `linalg.solve_triangular()` to
fix the deprecation warning. See also:

#12
  • Loading branch information
timothygebhard committed Nov 8, 2022
1 parent f753a36 commit 1557388
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions normflows/flows/mixing.py
Expand Up @@ -434,11 +434,11 @@ def inverse_no_cache(self, inputs):
"""
lower, upper = self._create_lower_upper()
outputs = inputs - self.bias
outputs, _ = torch.triangular_solve(
outputs.t(), lower, upper=False, unitriangular=True
outputs = torch.linalg.solve_triangular(
lower, outputs.t(), upper=False, unitriangular=True
)
outputs, _ = torch.triangular_solve(
outputs, upper, upper=True, unitriangular=False
outputs = torch.linalg.solve_triangular(
upper, outputs, upper=True, unitriangular=False
)
outputs = outputs.t()

Expand Down

0 comments on commit 1557388

Please sign in to comment.