Skip to content

Conversation

@protonu
Copy link
Collaborator

@protonu protonu commented Jun 3, 2025

What does this PR do?

In #2043 we added custom decompositions for cross-entropy loss fwd and bwd for the nvfuser executor.

We noticed that when computing the bwd, we would end up recomputing the fwd. The final trace of the bwd would look something like:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  # C0: "Collection"
  # None
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t46, = cotangents
  # t46: "cuda:0 f32[]"
  clear_mutable_collection(cotangents)
  del cotangents
  labels, logits, = C0
  # labels: "cuda:0 i64[4096]"
  # logits: "cuda:0 bf16[1, 4096, 131072]"
  clear_mutable_collection(C0)
  del C0
  [grad_for_logits] = nvFusion0(labels, logits, t46)
    # t29 = prims.pad(labels, 0, [(0, 1, 0)])  # t29: "cuda:0 i64[4097]"
    # t36 = prims.convert_element_type(logits, dtypes.float32)  # t36: "cuda:0 f32[1, 4096, 131072]"
    # t33 = prims.slice_prim(t29, [1], [4097], [1])  # t33: "cuda:0 i64[4096]"
    # t39 = prims.squeeze(t36, (0,))  # t39: "cuda:0 f32[4096, 131072]"
    # (_, t45, t43, t44) = thunder.executors.nvfuserex_impl.nv_cross_entropy_fwd(t39, t33, None, None, -100, None, 'mean', 0.0)
    # bw_t47 = thunder.executors.nvfuserex_impl.nv_cross_entropy_bwd(t46, t39, target=t33, a_max=t45, max_log_sum_exp=t43, valid_indices=t44, ignore_index=-100, label_smoothing=0.0)  # bw_t47: "cuda:0 f32[4096, 131072]"
    # bw_t41 = prims.broadcast_in_dim(bw_t47, [1, 4096, 131072], [1, 2])  # bw_t41: "cuda:0 f32[1, 4096, 131072]"
    # grad_for_logits = prims.convert_element_type(bw_t41, dtypes.bfloat16)  # grad_for_logits: "cuda:0 bf16[1, 4096, 131072]"
  del labels, logits, t46
  return (grad_for_logits, None)

In this PR we mark the fwd operator to not be used to recompute in the bwd computation.
Our trace for the bwd will look something like:

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  # C0: "Collection"
  # None
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t30, = cotangents
  # t30: "cuda:0 f32[]"
  clear_mutable_collection(cotangents)
  del cotangents
  labels, logits, t27, t28, t29, = C0
  # labels: "cuda:0 i64[8192]"
  # logits: "cuda:0 f32[8192, 32064]"
  # t27: "cuda:0 f32[8192]"
  # t28: "cuda:0 f32[]"
  # t29: "cuda:0 f32[8192]"
  clear_mutable_collection(C0)
  del C0
  [grad_for_logits] = nvFusion0(t30, logits, t29, t27, labels, t28)
    # grad_for_logits = thunder.executors.nvfuserex_impl.nv_cross_entropy_bwd(t30, logits, target=labels, a_max=t29, max_log_sum_exp=t27, valid_indices=t28, ignore_index=-10, label_smoothing=0.0)  # grad_for_logits: "cuda:0 f32[8192, 32064]"
  del t30, logits, t29, t27, labels, t28
  return (grad_for_logits, None)

Not recomputing the FWD significantly improves performance of BWD.

@t-vi
Copy link
Collaborator

t-vi commented Jun 3, 2025

Lovely. Thank you!

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit fe96f3b into main Jun 3, 2025
49 checks passed
@t-vi t-vi deleted the pbasu_cross_entropy_loss_bwd_edit branch June 3, 2025 16:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants