-
Notifications
You must be signed in to change notification settings - Fork 110
Add DIV_EXACT prim #2626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DIV_EXACT prim #2626
Conversation
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @beverlylytle
|
Is div_exact a good name? I thought floor_div was great to represent |
I am open to changing the name, but since this prim is used for both floor_div and trunc_div I don't think it should be named floor_div. Do you have another suggestion? |
kiya00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @beverlylytle
t-vi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @beverlylytle @kiya00 @kshitij12345
| assert_close(fn(a), jfn(a)) | ||
|
|
||
|
|
||
| def test_div_exact(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test function works even without the change introduced in this PR:
git checkout 80461ba^
```py
In [1]: import torch, thunder
In [2]: def fn(a, b, c):
...: indices = torch.div(a, b, rounding_mode="trunc")
...: # this would throw an error if indices are not ints
...: return c[indices]
...:
In [3]: jfn = thunder.jit(fn)
...: a = torch.randint(1, 5, (5,))
...: b = torch.ones(5, dtype=torch.int32)
...: c = torch.randn(5, 5)
...: fn(a, b, c), jfn(a, b, c)
Out[3]:
(tensor([[ 0.5290, -2.2824, 1.0693, -1.6769, -1.9725],
[-1.0802, -0.4437, -0.7387, -1.1378, 0.6227],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195]]),
tensor([[ 0.5290, -2.2824, 1.0693, -1.6769, -1.9725],
[-1.0802, -0.4437, -0.7387, -1.1378, 0.6227],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195],
[-0.1425, 0.3009, 1.3933, -0.8111, 0.2195]]))
The c tensor must have requires_grad=True to trigger buggy code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, shoot. c lost its require_grad in the various iterations. Good catch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixes #2599
In the model HF ibm-granite/granite-3.1-3b-a800m-instruct a certain set of indices is computed by dividing two tensors of ints. In general,
divis differentiable, but not when both operands are of exact dtype and the rounding mode is "trunc". Because there is no exception for this case, autodiff fetches the augmented forward part coming from div's grad transform, and this returns an inexact result. This PR introduces a new prim DIV_EXACT which is the same asprims.DIVexcept without having a grad transform registered to it. A check on dtypes forwards the call to DIV_EXACT instead of DIV.