Skip to content

Conversation

@beverlylytle
Copy link
Collaborator

@beverlylytle beverlylytle commented Oct 10, 2025

Fixes #2599

In the model HF ibm-granite/granite-3.1-3b-a800m-instruct a certain set of indices is computed by dividing two tensors of ints. In general, div is differentiable, but not when both operands are of exact dtype and the rounding mode is "trunc". Because there is no exception for this case, autodiff fetches the augmented forward part coming from div's grad transform, and this returns an inexact result. This PR introduces a new prim DIV_EXACT which is the same as prims.DIV except without having a grad transform registered to it. A check on dtypes forwards the call to DIV_EXACT instead of DIV.

@beverlylytle beverlylytle marked this pull request as ready for review October 10, 2025 12:28
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @beverlylytle

@t-vi
Copy link
Collaborator

t-vi commented Oct 10, 2025

Is div_exact a good name? I thought floor_div was great to represent //

@beverlylytle
Copy link
Collaborator Author

Is div_exact a good name? I thought floor_div was great to represent //

I am open to changing the name, but since this prim is used for both floor_div and trunc_div I don't think it should be named floor_div. Do you have another suggestion?

Copy link
Collaborator

@kiya00 kiya00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @beverlylytle

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit 80461ba into main Oct 13, 2025
48 of 51 checks passed
@t-vi t-vi deleted the add_div_exact branch October 13, 2025 08:21
assert_close(fn(a), jfn(a))


def test_div_exact():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test function works even without the change introduced in this PR:

git checkout 80461ba^
```py
In [1]: import torch, thunder

In [2]:     def fn(a, b, c):
   ...:         indices = torch.div(a, b, rounding_mode="trunc")
   ...:         # this would throw an error if indices are not ints
   ...:         return c[indices]
   ...: 

In [3]:     jfn = thunder.jit(fn)
   ...:     a = torch.randint(1, 5, (5,))
   ...:     b = torch.ones(5, dtype=torch.int32)
   ...:     c = torch.randn(5, 5)
   ...:     fn(a, b, c), jfn(a, b, c)
Out[3]: 
(tensor([[ 0.5290, -2.2824,  1.0693, -1.6769, -1.9725],
         [-1.0802, -0.4437, -0.7387, -1.1378,  0.6227],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195]]),
 tensor([[ 0.5290, -2.2824,  1.0693, -1.6769, -1.9725],
         [-1.0802, -0.4437, -0.7387, -1.1378,  0.6227],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195],
         [-0.1425,  0.3009,  1.3933, -0.8111,  0.2195]]))

The c tensor must have requires_grad=True to trigger buggy code path.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, shoot. c lost its require_grad in the various iterations. Good catch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

HF ibm-granite/granite-3.1-3b-a800m-instruct: index_add_(): Expected dtype int32/int64 for index but got: Float

5 participants