From 15573883c43e83bf85faacd3f9bdabfebe43541c Mon Sep 17 00:00:00 2001 From: Timothy Gebhard Date: Tue, 8 Nov 2022 22:11:12 +0100 Subject: [PATCH] Fix deprecation warning (#13) Replace `triangular_solve()` with `linalg.solve_triangular()` to fix the deprecation warning. See also: https://github.com/VincentStimper/normalizing-flows/issues/12 --- normflows/flows/mixing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index 830cf98..1ab9243 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -434,11 +434,11 @@ def inverse_no_cache(self, inputs): """ lower, upper = self._create_lower_upper() outputs = inputs - self.bias - outputs, _ = torch.triangular_solve( - outputs.t(), lower, upper=False, unitriangular=True + outputs = torch.linalg.solve_triangular( + lower, outputs.t(), upper=False, unitriangular=True ) - outputs, _ = torch.triangular_solve( - outputs, upper, upper=True, unitriangular=False + outputs = torch.linalg.solve_triangular( + upper, outputs, upper=True, unitriangular=False ) outputs = outputs.t()