-
Notifications
You must be signed in to change notification settings - Fork 110
Propagate backward tags more consistently #2336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I intend to add a test that's a smaller version of the code sample given in the description. |
|
With the latest main (dbf6bad), there's no error. Is there anything missing in the code to trigger the error? |
|
No, this bug is not yet caught by existing tests, but the above (complicated) code snippet does expose the bug. I am working on simplifying the code snippet to add as a test. |
|
Sorry @IvanYashchuk I misunderstood your comment. I'm fairly certain that yes, the example code does error on main and I refer you to comment #2052 (comment). |
riccardofelluga
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Have you checked if the thunder.jit path is ok too?
The biggest difference between the |
|
My smaller repro trying to understand the code from the PR description (but the test added here is much better): import torch, thunder
def forward(self, router_logits):
zeros = torch.zeros((8, 4), device="cuda", dtype=torch.bool)
masked_gates = router_logits.masked_fill(zeros, float("-inf"))
selected_experts = masked_gates.max(dim=-1)[1].unsqueeze(-1); multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
masked_scores = torch.scatter(router_logits, -1, selected_experts, float("-inf"))
eq_1 = masked_scores.masked_fill(zeros, float("-inf")).max(dim=-1, keepdim=True)[1].__eq__(masked_scores.masked_fill(zeros, float("-inf")).max(dim=-1, keepdim=True)[1])
def fwd_body_0(ctx, scores, multiplier, selected_experts, masked_gates): return (multiplier * torch.ones((8, 1), device="cuda", dtype=torch.bool), [masked_gates, selected_experts])
def bwd_body_0(ctx, grad_at_output, masked_gates, selected_experts): masked_gates.scatter_add_(dim=-1, index=selected_experts, src=grad_at_output); return (masked_gates, None, None, None)
multiplier = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, router_logits, multiplier_o, selected_experts, masked_gates, args_tensor_mask=[True, True, True, True], non_differentiable_idx=[])
return (eq_1, multiplier)
router_logits = torch.rand((8, 4), device="cuda", requires_grad=True)
compiled_forward = thunder.jit(forward, fusion_type="dataflow", skip_inplace_alias_updates=False, skip_inplace_functionalization=True)
compiled_forward(None, router_logits) |
t-vi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @beverlylytle @riccardofelluga @IvanYashchuk
The backward tags are not being propagated consistently, leading to certain bound symbols which should be in the backward trace appearing in the forward trace. After the backward trace is split, there are error about undefined tensors.
This is made evident in the following non-trivial example:
The
copy_subsymbol ofmp.backward'sscatter_add_is what is missing the backward tag. The associated update_aliases bsyms were also missing the tag.