Skip to content

Commit

Permalink
feat: added tensor long (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim committed May 13, 2024
1 parent 83ccc18 commit 6a94fc1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
22 changes: 22 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2919,6 +2919,28 @@ def type_as_sample_generator(op, device, dtype, requires_grad, **kwargs):
data_movement_ops.append(type_as_sample)


def long_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# input shape
shapes = (
(),
(0,),
(2,),
(1, 2),
(1, 2, 3),
)

for shape in shapes:
yield SampleInput(make(shape))


long_opinfo = OpInfo(
ltorch.long,
sample_input_generator=long_sample_generator,
torch_reference=torch.Tensor.long,
)

opinfos.extend(data_movement_ops)

#
Expand Down
5 changes: 5 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ def type_as(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:
return to(a, b.true_dtype)


@torchsymbol(torch.Tensor.long, is_method=True)
def long(a: TensorLike, /, memory_format: torch.memory_format = torch.preserve_format) -> TensorLike:
return to(a, dtype=dtypes.int64, memory_format=memory_format)


#
# Tensor creation operations
#
Expand Down

0 comments on commit 6a94fc1

Please sign in to comment.