Skip to content

Conversation

@shino16
Copy link
Collaborator

@shino16 shino16 commented Dec 9, 2025

Fixes #2790. Previously, inp[idx] = value was translated into _copy_(inp, clang.copy_with_setitem(inp, idx, val)), where copy_with_setitem was an opaque primitive that only torchex could execute. This PR makes the entire setitem opaque.

Demo

import torch, thunder

def fn(x, idx, val):
    x[idx] = val

jf = thunder.jit(fn)
x = torch.zeros(3, 4, 5, device="cpu")
x_ref = x.clone()
idx = torch.tensor([1], device="cpu")
val = torch.ones(1, 4, 5, device="cpu")
fn(x_ref, idx, val)
jf(x, idx, val)

print(thunder.last_traces(jf)[-1])
torch.testing.assert_close(x, x_ref)
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, idx, val):
  # x: "cpu f32[3, 4, 5]"
  # idx: "cpu i64[1]"
  # val: "cpu f32[1, 4, 5]"
  (t4,) = update_aliases((x,))
  del x

  # /opt/pytorch/lightning-thunder/tmp/main.py:4:           x[idx] = val
  t5 = setitem(t4, idx, val)  # t5: "cpu f32[3, 4, 5]"
    # t5 = ltorch.setitem_(t4, idx, val)  # t5: "cpu f32[3, 4, 5]"
      # t5 = prims.setitem(t4, idx, val)  # t5: "cpu f32[3, 4, 5]"
  del t4
  return ()

@shino16 shino16 marked this pull request as draft December 10, 2025 06:53
@shino16 shino16 force-pushed the setitem-without-copy branch from aaaaf67 to 3d9be2f Compare December 10, 2025 07:38
@kiya00
Copy link
Collaborator

kiya00 commented Dec 10, 2025

The failures in CI seems relevant

@shino16 shino16 marked this pull request as ready for review December 10, 2025 16:49
@shino16 shino16 force-pushed the setitem-without-copy branch from ad552de to 7011cc2 Compare December 10, 2025 22:16
@shino16 shino16 marked this pull request as draft December 11, 2025 11:21
return VJPDual(primal, residuals)


@register_augmented_forward(prims.PrimIDs.SETITEM)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to define forward and backward separately. If not, when we extract the forward trace from the combined fw&bw function, we hit into prims.setitem(clone(g), index, 0) in backward, which survives through DCE because setitem is tagged as DONT_DCE.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Investigate torch.Tensor.__setitem__ as unsupported in thunderfx

2 participants