Skip to content

Commit

Permalink
fix creation_ops.torch.full for paddle frontend (ivy-llc#28267)
Browse files Browse the repository at this point in the history
Co-authored-by: NripeshN <86844847+NripeshN@users.noreply.github.com>
  • Loading branch information
2 people authored and Kacper-W-Kozdon committed Feb 27, 2024
1 parent 9583755 commit b726519
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
11 changes: 11 additions & 0 deletions ivy/functional/backends/paddle/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ def from_dlpack(x, /, *, out: Optional[paddle.Tensor] = None):
return paddle.utils.dlpack.from_dlpack(capsule)


@with_unsupported_device_and_dtypes(
{
"2.6.0 and below": {
"cpu": (
"complex",
"bool",
)
}
},
backend_version,
)
def full(
shape: Union[ivy.NativeShape, Sequence[int]],
fill_value: Union[int, float, bool],
Expand Down
9 changes: 7 additions & 2 deletions ivy/functional/backends/torch/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def from_dlpack(x, /, *, out: Optional[torch.Tensor] = None):
return torch.from_dlpack(x)


@with_unsupported_dtypes({"2.2.0 and below": ("bfloat16",)}, backend_version)
def full(
shape: Union[ivy.NativeShape, Sequence[int]],
fill_value: Union[int, float, bool],
Expand All @@ -247,9 +248,13 @@ def full(
dtype = ivy.default_dtype(dtype=dtype, item=fill_value, as_native=True)
if isinstance(shape, int):
shape = (shape,)

shape = tuple(int(dim) for dim in shape)
fill_value = torch.tensor(fill_value, dtype=dtype)

return torch.full(
shape,
fill_value,
size=shape,
fill_value=fill_value,
dtype=dtype,
device=device,
out=out,
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/frontends/torch/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def frombuffer(
return ivy.frombuffer(buffer, dtype=dtype, count=count, offset=offset)


@with_unsupported_dtypes({"2.2.0 and below": ("bfloat16",)}, "torch")
@to_ivy_arrays_and_back
def full(
size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _as_tensor_helper(draw):
@st.composite
def _fill_value(draw):
with_array = draw(st.sampled_from([True, False]))
dtype = draw(st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"))[0]
dtype = draw(st.shared(helpers.get_dtypes("valid", full=False), key="dtype"))[0]
with BackendHandler.update_backend(test_globals.CURRENT_BACKEND) as ivy_backend:
if ivy_backend.is_uint_dtype(dtype):
ret = draw(helpers.ints(min_value=0, max_value=5))
Expand Down Expand Up @@ -512,7 +512,7 @@ def test_torch_frombuffer(
max_dim_size=10,
),
fill_value=_fill_value(),
dtype=st.shared(helpers.get_dtypes("numeric", full=False), key="dtype"),
dtype=st.shared(helpers.get_dtypes("valid", full=False), key="dtype"),
)
def test_torch_full(
*,
Expand Down

0 comments on commit b726519

Please sign in to comment.