Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampling from flow raises deprecation warning #12

Closed
timothygebhard opened this issue Nov 8, 2022 · 0 comments
Closed

Sampling from flow raises deprecation warning #12

timothygebhard opened this issue Nov 8, 2022 · 0 comments

Comments

@timothygebhard
Copy link
Contributor

Running the following minimal example:

import normflows as nf
import torch

torch.manual_seed(42)

flow = nf.NormalizingFlow(
    nf.distributions.DiagGaussian(1, trainable=False),
    [
        nf.flows.AutoregressiveRationalQuadraticSpline(1, 1, 1),
        nf.flows.LULinearPermute(1)
    ]
)

with torch.no_grad():
    samples_flow, _ = flow.sample(4)

print(samples_flow)

raises a UserWarning about an upcoming deprecation:

/Users/timothy/Desktop/normalizing-flows/normflows/flows/mixing.py:437: UserWarning: torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangular and will be removed in a future PyTorch release.
torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.
X = torch.triangular_solve(B, A).solution
should be replaced with
X = torch.linalg.solve_triangular(A, B). (Triggered internally at  /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:2189.)
  outputs, _ = torch.triangular_solve(

I will submit a PR shortly that fixes the issue 🙂

timothygebhard added a commit to timothygebhard/normalizing-flows that referenced this issue Nov 8, 2022
Replace `triangular_solve()` with `linalg.solve_triangular()` to
fix the deprecation warning. See also:

VincentStimper#12
VincentStimper pushed a commit that referenced this issue Nov 8, 2022
Replace `triangular_solve()` with `linalg.solve_triangular()` to
fix the deprecation warning. See also:

#12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant