Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inductor: Codegen for sympy Trunc is incorrect #126537

Open
ezyang opened this issue May 17, 2024 · 0 comments
Open

Inductor: Codegen for sympy Trunc is incorrect #126537

ezyang opened this issue May 17, 2024 · 0 comments
Assignees
Labels
module: dynamic shapes module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ezyang
Copy link
Contributor

ezyang commented May 17, 2024

馃悰 Describe the bug

I'm working on a general refactor that's going to fix this, but this was a really clean test case that I wanted to file for posterity.

This program:

import torch
import torch._dynamo.config
import math

torch._dynamo.config.capture_scalar_outputs = True

def f(x):
    r = math.sqrt(x.size(0))
    r = r ** 70
    return torch.tensor(math.trunc(r), dtype=torch.float64, device='cuda')

cf = torch.compile(fullgraph=True, dynamic=True)(f)

print(math.frexp(f(torch.randn(4, device='cuda')).item()))
print(math.frexp(cf(torch.randn(4, device='cuda')).item()))

prints

(0.5, 71) 
(0.9999999995343387, 31)

The problem is:

    def _print_Trunc(self, expr):
        assert len(expr.args) == 1
        return (
            f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
        )

trunc on float inputs should not convert to index type

Versions

main

cc @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@ezyang ezyang self-assigned this May 17, 2024
@xmfan xmfan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants