diff --git a/hub/api/tests/test_api.py b/hub/api/tests/test_api.py index 2210352dfc..5e7e9399e0 100644 --- a/hub/api/tests/test_api.py +++ b/hub/api/tests/test_api.py @@ -7,7 +7,7 @@ from hub.api.dataset import Dataset from hub.core.tests.common import parametrize_all_dataset_storages from hub.tests.common import assert_array_lists_equal -from hub.util.exceptions import TensorDtypeMismatchError +from hub.util.exceptions import TensorDtypeMismatchError, TensorInvalidSampleShapeError from hub.client.client import HubBackendClient from hub.client.utils import has_hub_testing_creds @@ -239,6 +239,16 @@ def test_scalar_samples(ds: Dataset): assert tensor.numpy(aslist=True) == expected.tolist() + assert tensor.shape == (11,) + + # len(shape) for a scalar is `()`. len(shape) for [1] is `(1,)` + with pytest.raises(TensorInvalidSampleShapeError): + tensor.append([1]) + + # len(shape) for a scalar is `()`. len(shape) for [1, 2] is `(2,)` + with pytest.raises(TensorInvalidSampleShapeError): + tensor.append([1, 2]) + @parametrize_all_dataset_storages def test_sequence_samples(ds: Dataset):