diff --git a/deeplake/enterprise/test_pytorch.py b/deeplake/enterprise/test_pytorch.py index a803d1e1c0..f22757d4ae 100644 --- a/deeplake/enterprise/test_pytorch.py +++ b/deeplake/enterprise/test_pytorch.py @@ -6,7 +6,11 @@ from deeplake.util.remove_cache import get_base_storage from deeplake.core.index.index import IndexEntry -from deeplake.tests.common import requires_torch, requires_libdeeplake +from deeplake.tests.common import ( + requires_torch, + requires_libdeeplake, + convert_data_according_to_torch_version, +) from deeplake.core.dataset import Dataset from deeplake.constants import KB @@ -458,7 +462,7 @@ def test_pytorch_decode(hub_cloud_ds, compressed_image_paths, compression): ptds = hub_cloud_ds.dataloader().pytorch(decode_method={"image": "tobytes"}) for i, batch in enumerate(ptds): - image = batch["image"] + image = convert_data_according_to_torch_version(batch["image"]) assert isinstance(image, bytes) if i < 5 and not compression: np.testing.assert_array_equal( diff --git a/deeplake/integrations/pytorch/shuffle_buffer.py b/deeplake/integrations/pytorch/shuffle_buffer.py index 9793a45a1a..08f13f7e8b 100644 --- a/deeplake/integrations/pytorch/shuffle_buffer.py +++ b/deeplake/integrations/pytorch/shuffle_buffer.py @@ -118,6 +118,10 @@ def _num_torch_tensors(self, sample): return 0 if isinstance(sample, TorchTensor): return 1 + elif isinstance(sample, bytes): + return 0 + elif isinstance(sample, str): + return 0 elif isinstance(sample, dict): return sum(self._num_torch_tensors(tensor) for tensor in sample.values()) elif isinstance(sample, Sequence): diff --git a/deeplake/requirements/tests.txt b/deeplake/requirements/tests.txt index 1ccbf435d9..c4de825ac1 100644 --- a/deeplake/requirements/tests.txt +++ b/deeplake/requirements/tests.txt @@ -16,3 +16,5 @@ boto3-stubs[essential] lz4 rich wandb + +pandas; python_version >= '3.11' and sys_platform == 'win32' diff --git a/deeplake/tests/common.py b/deeplake/tests/common.py index eaaa5f6397..fac807d1d2 100644 --- a/deeplake/tests/common.py +++ b/deeplake/tests/common.py @@ -142,9 +142,7 @@ def __exit__(self, *args, **kwargs): def convert_data_according_to_torch_version(batch): - import torch - - if torch.__version__ < "2.0.0": + if isinstance(batch, List): return batch[0] else: return batch