diff --git a/normflows/flows/mixing.py b/normflows/flows/mixing.py index 830cf98..1ab9243 100644 --- a/normflows/flows/mixing.py +++ b/normflows/flows/mixing.py @@ -434,11 +434,11 @@ def inverse_no_cache(self, inputs): """ lower, upper = self._create_lower_upper() outputs = inputs - self.bias - outputs, _ = torch.triangular_solve( - outputs.t(), lower, upper=False, unitriangular=True + outputs = torch.linalg.solve_triangular( + lower, outputs.t(), upper=False, unitriangular=True ) - outputs, _ = torch.triangular_solve( - outputs, upper, upper=True, unitriangular=False + outputs = torch.linalg.solve_triangular( + upper, outputs, upper=True, unitriangular=False ) outputs = outputs.t()