From 6b5ab3fcc75bcfc3d5ef05b8a3ae67044c15a5d6 Mon Sep 17 00:00:00 2001 From: McCrearyD Date: Wed, 7 Jul 2021 15:23:54 -0700 Subject: [PATCH] add test failure case to make sure scalar tensors remain as such --- hub/api/tests/test_api.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/hub/api/tests/test_api.py b/hub/api/tests/test_api.py index 2210352dfc..5e7e9399e0 100644 --- a/hub/api/tests/test_api.py +++ b/hub/api/tests/test_api.py @@ -7,7 +7,7 @@ from hub.api.dataset import Dataset from hub.core.tests.common import parametrize_all_dataset_storages from hub.tests.common import assert_array_lists_equal -from hub.util.exceptions import TensorDtypeMismatchError +from hub.util.exceptions import TensorDtypeMismatchError, TensorInvalidSampleShapeError from hub.client.client import HubBackendClient from hub.client.utils import has_hub_testing_creds @@ -239,6 +239,16 @@ def test_scalar_samples(ds: Dataset): assert tensor.numpy(aslist=True) == expected.tolist() + assert tensor.shape == (11,) + + # len(shape) for a scalar is `()`. len(shape) for [1] is `(1,)` + with pytest.raises(TensorInvalidSampleShapeError): + tensor.append([1]) + + # len(shape) for a scalar is `()`. len(shape) for [1, 2] is `(2,)` + with pytest.raises(TensorInvalidSampleShapeError): + tensor.append([1, 2]) + @parametrize_all_dataset_storages def test_sequence_samples(ds: Dataset):