Skip to content

Commit

Permalink
add torch.empty (#353)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
k223kim and pre-commit-ci[bot] committed May 13, 2024
1 parent 6a94fc1 commit eb6565b
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 0 deletions.
7 changes: 7 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ def full_like(
return full(a.shape, fill_value, device=device, dtype=dtype)


@clangop()
def empty(shape: Sequence[int], *, device: DeviceLike, dtype: dtypes.dtype) -> TensorLike:
device = devices.to_device(device)

return prims.empty(tuple(shape), device=device, dtype=dtype)


@clangop()
def uniform(
shape: Sequence[int],
Expand Down
17 changes: 17 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class PrimIDs(Enum):
UNIFORM = auto()
UNIFORM_PHILOX = auto()
RANDN = auto()
EMPTY = auto()
TENSOR_FROM_SEQUENCE = auto()
# Probability distribution-related ops
MULTINOMIAL = auto()
Expand Down Expand Up @@ -2674,6 +2675,22 @@ def _randn_meta(
randn = make_prim(PrimIDs.RANDN, "randn", meta=_randn_meta)


def _empty_meta(
shape: tuple[int, ...],
*,
device: devices.Device,
dtype: dtypes.dtype,
):
utils.check_type(device, devices.Device)
utils.check_type(dtype, dtypes.dtype)
utils.check_type(shape, tuple)
utils.check_valid_shape(shape)
return TensorProxy(shape=shape, device=device, dtype=dtype, requires_grad=False)


empty = make_prim(PrimIDs.EMPTY, "empty", meta=_empty_meta)


# Prim to construct a Tensor from sequence/nested sequence of Numbers.
def _tensor_from_sequence_meta(
seq: Sequence[Number | Sequence], *, dtype: None | dtypes.dtype, device: devices.Device
Expand Down
38 changes: 38 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def no_autocast(fn):
zeros = _register_torch_operation("zeros")
zeros_like = _register_torch_operation("zeros_like")
randn = _register_torch_operation("randn")
empty = _register_torch_operation("empty")
einsum = _register_torch_operation("einsum")


Expand Down Expand Up @@ -417,6 +418,17 @@ def _randn_prims_transform(
return randn(shape, device=torch_device, dtype=torch_dtype)


def _empty_prims_transform(
shape: tuple[int, ...],
*,
device: devices.Device,
dtype: dtypes.dtype,
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return empty(shape, device=torch_device, dtype=torch_dtype)


def _tensor_from_sequence_prims_transform(
seq_or_number, *, device: devices.Device, dtype: None | dtypes.dtype
) -> TensorLike:
Expand All @@ -432,6 +444,7 @@ def _tensor_from_sequence_prims_transform(
prims.uniform_philox, checker=_uniform_philox_prim_checker, execution_transform=_uniform_philox_prim_transform
)
_register_implementation(prims.randn, checker=_always_executable, execution_transform=_randn_prims_transform)
_register_implementation(prims.empty, checker=_always_executable, execution_transform=_empty_prims_transform)
_register_implementation(
prims.tensor_from_sequence, checker=_always_executable, execution_transform=_tensor_from_sequence_prims_transform
)
Expand Down Expand Up @@ -532,6 +545,30 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None)
return squeeze(a, dim)


def _empty_transform(
shape: Sequence[int],
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
out: None | TensorLike = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format,
):
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return empty(
shape,
device=torch_device,
dtype=torch_dtype,
out=out,
layout=layout,
requires_grad=requires_grad,
pin_memory=pin_memory,
memory_format=memory_format,
)


_register_implementation(
prims.broadcast_in_dim, checker=_always_executable, execution_transform=_broadcast_in_dim_prim_transform
)
Expand Down Expand Up @@ -567,6 +604,7 @@ def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None)
_register_implementation(ltorch.unsqueeze, unsqueeze, checker=_always_executable)
_register_implementation(ltorch.view, view, checker=_always_executable)
_register_implementation(ltorch.view_as, view_as, checker=_always_executable)
_register_implementation(ltorch.empty, empty, checker=_always_executable, execution_transform=_empty_transform)

#
# Memory format operations
Expand Down
30 changes: 30 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5394,6 +5394,36 @@ def full_like_sample_generator(op, device, dtype, requires_grad, **kwargs):
tensor_creation_ops.append(full_like_opinfo)


def empty_sample_generator(op, device, dtype, requires_grad, **kwargs):
# shape, fill_value
cases = (
(()),
((4, 4)),
((8, 1, 6)),
((8, 7, 5, 1)),
)

for shape in cases:
yield SampleInput(shape, device=device, dtype=dtype)


def empty_error_generator(op, device, **kwargs):
err_msg = "Can't safely cast fill_value of numbertype <class 'complex'> to dtype float32"
yield (SampleInput((1, 2), 1j, device=device, dtype=torch.float), RuntimeError, err_msg)


# Helper function for `empty` opinfo.
# It always returns zero tensors, so that the consistency tests pass.
def torch_empty_and_zero(*args, **kwargs):
return ltorch.full_like(ltorch.empty(*args, **kwargs), 0)


empty_opinfo = OpInfo(
name="empty", op=torch_empty_and_zero, sample_input_generator=empty_sample_generator, torch_reference=torch.zeros
)
tensor_creation_ops.append(empty_opinfo)


def fixed_value_tensor_creation_op_sample_generator(op, device, dtype, requires_grad, **kwargs):
# shape
cases = (
Expand Down
40 changes: 40 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,46 @@ def zeros_like(a: TensorLike, /, *, device: DeviceLike | None = None, dtype: dty
return full_like(a, 0, device=device, dtype=dtype)


@torchsymbol(torch.empty)
def empty(
*size: int,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
out: None | TensorLike = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLike:
size = utils.extract_shape_from_varargs(size)

utils.check(out is None, lambda: "empty(): out is not None which is currently unsupported", NotImplementedError)
utils.check(layout == torch.strided, lambda: "Only torch.strided layout is supported", NotImplementedError)
utils.check(
not requires_grad, lambda: "requires_grad=True is not yet supported within thunder.compile", NotImplementedError
)
utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.compile", NotImplementedError)
utils.check(
memory_format == torch.contiguous_format,
lambda: "Only torch.contiguous_format is supported",
NotImplementedError,
)

# For now we default to `float32`,
# however, we should add a default dtype or rely on `torch.get_default_dtype`.
if dtype is None:
dtype = torch.float
dtype = to_dtype(dtype)

# For now we default to "cpu",
# however, we should add a default device or rely on `torch.get_default_device`.
if device is None:
device = "cpu"
device = to_device(device)

return clang.empty(size, device=device, dtype=dtype)


#
# Shape operations
#
Expand Down

0 comments on commit eb6565b

Please sign in to comment.