Skip to content

Conversation

@beverlylytle
Copy link
Collaborator

@beverlylytle beverlylytle commented Jul 17, 2025

The backward tags are not being propagated consistently, leading to certain bound symbols which should be in the backward trace appearing in the forward trace. After the backward trace is split, there are error about undefined tensors.
This is made evident in the following non-trivial example:

import torch


class mp(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        scores: torch.Tensor,
        multiplier: torch.Tensor,
        selected_experts: torch.Tensor,
        masked_gates: torch.Tensor,
        mask_for_one: torch.Tensor,
    ):
        ctx.save_for_backward(selected_experts, masked_gates)
        return multiplier * mask_for_one

    @staticmethod
    def backward(ctx, grad_at_output: torch.Tensor):
        selected_experts, masked_gates = ctx.saved_tensors
        masked_gates.scatter_add_(dim=-1, index=selected_experts, src=grad_at_output)
        return masked_gates, None, None, None, None


def sparsemixer(scores, top_k=2, jitter_eps=0.01, training=True):
    mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)

    masked_gates = scores.masked_fill(torch.zeros_like(mask_logits_threshold).to(torch.bool), float("-inf"))
    selected_experts = masked_gates.max(dim=-1)[1].unsqueeze(-1)

    # compute scores for gradients
    multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)

    # compute midpoint mask
    max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
    mask_for_one = torch.logical_or(
        selected_experts == max_ind,
        torch.rand_like(max_scores) > 0.75,
    )

    multiplier = mp.apply(
        scores,
        multiplier_o,
        selected_experts,
        masked_gates,
        mask_for_one,
    )

    # masked out first expert
    masked_scores = torch.scatter(
        scores,
        -1,
        selected_experts,
        float("-inf"),
    )

    # # apply mask
    masked_gates_top2 = masked_scores.masked_fill(torch.zeros_like(mask_logits_threshold).to(torch.bool), float("-inf"))

    # compute midpoint mask
    max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
    mask_for_one_top2 = torch.logical_or(
        max_ind == max_ind,
        torch.rand_like(max_scores).uniform_() > 0.75,
    )

    return (
        multiplier,
        selected_experts,
    )


class GRINMoESparseMoeBlock(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_dim = 4096
        self.num_experts = 16
        # gating
        self.gate = torch.nn.Linear(self.hidden_dim, self.num_experts, bias=False)

    def forward(self, hidden_states: torch.Tensor):
        """ """
        sequence_length, hidden_dim = hidden_states.shape
        hidden_states *= torch.empty_like(hidden_states).uniform_(0.1, 1.0)
        hidden_states = hidden_states.view(-1, hidden_dim)

        router_logits = self.gate(hidden_states)

        routing_weights, selected_experts = sparsemixer(
            router_logits,
            training=True,
        )

        return routing_weights, selected_experts


from thunder.dynamo import ThunderCompiler

a = torch.rand((128, 4096), device="cuda")
model = GRINMoESparseMoeBlock().to("cuda")
back = ThunderCompiler(skip_inplace_alias_updates=False, skip_inplace_functionalization=True, disable_inplace_copy_check=True)
jfoo = torch.compile(model, backend=back)
jfoo(a)

The copy_ subsymbol of mp.backward's scatter_add_ is what is missing the backward tag. The associated update_aliases bsyms were also missing the tag.

@beverlylytle
Copy link
Collaborator Author

I intend to add a test that's a smaller version of the code sample given in the description.

@IvanYashchuk
Copy link
Collaborator

With the latest main (dbf6bad), there's no error. Is there anything missing in the code to trigger the error?

@beverlylytle
Copy link
Collaborator Author

No, this bug is not yet caught by existing tests, but the above (complicated) code snippet does expose the bug. I am working on simplifying the code snippet to add as a test.

@beverlylytle
Copy link
Collaborator Author

Sorry @IvanYashchuk I misunderstood your comment. I'm fairly certain that yes, the example code does error on main and I refer you to comment #2052 (comment).

Copy link
Collaborator

@riccardofelluga riccardofelluga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Have you checked if the thunder.jit path is ok too?

@beverlylytle
Copy link
Collaborator Author

Have you checked if the thunder.jit path is ok too?

The biggest difference between the jit path and the ThunderCompiler option in this situation is the fact that fusion strategy for jit is "consecutive" whereas with ThunderCompiler it is "dataflow". Because get_grad is not fusible, this bug cannot appear with jit and adding these backward tags can do no harm to that path.

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Jul 21, 2025

My smaller repro trying to understand the code from the PR description (but the test added here is much better):

import torch, thunder

def forward(self, router_logits):
    zeros = torch.zeros((8, 4), device="cuda", dtype=torch.bool)
    masked_gates = router_logits.masked_fill(zeros, float("-inf"))
    selected_experts = masked_gates.max(dim=-1)[1].unsqueeze(-1); multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
    masked_scores = torch.scatter(router_logits, -1, selected_experts, float("-inf"))
    eq_1 = masked_scores.masked_fill(zeros, float("-inf")).max(dim=-1, keepdim=True)[1].__eq__(masked_scores.masked_fill(zeros, float("-inf")).max(dim=-1, keepdim=True)[1])
    def fwd_body_0(ctx, scores, multiplier, selected_experts, masked_gates): return (multiplier * torch.ones((8, 1), device="cuda", dtype=torch.bool), [masked_gates, selected_experts])
    def bwd_body_0(ctx, grad_at_output, masked_gates, selected_experts): masked_gates.scatter_add_(dim=-1, index=selected_experts, src=grad_at_output); return (masked_gates, None, None, None)
    multiplier = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, router_logits, multiplier_o, selected_experts, masked_gates, args_tensor_mask=[True, True, True, True], non_differentiable_idx=[])
    return (eq_1, multiplier)

router_logits = torch.rand((8, 4), device="cuda", requires_grad=True)
compiled_forward = thunder.jit(forward, fusion_type="dataflow", skip_inplace_alias_updates=False, skip_inplace_functionalization=True)
compiled_forward(None, router_logits)

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi enabled auto-merge (squash) July 22, 2025 11:34
@t-vi t-vi merged commit 63e90a0 into main Jul 22, 2025
51 checks passed
@t-vi t-vi deleted the propagate_backward_tags branch July 22, 2025 11:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants