diff --git a/hub/__init__.py b/hub/__init__.py index 112f27218c..d834f4f18e 100644 --- a/hub/__init__.py +++ b/hub/__init__.py @@ -35,7 +35,7 @@ "ingest_kaggle", ] -__version__ = "2.0.5" +__version__ = "2.0.6" __encoded_version__ = np.array(__version__) hub_reporter.tags.append(f"version:{__version__}") diff --git a/hub/core/transform/test_transform.py b/hub/core/transform/test_transform.py index cc5a5dd225..a1833e4ced 100644 --- a/hub/core/transform/test_transform.py +++ b/hub/core/transform/test_transform.py @@ -4,6 +4,7 @@ from click.testing import CliRunner from hub.core.storage.memory import MemoryProvider from hub.util.remove_cache import remove_memory_cache +from hub.tests.common import parametrize_num_workers from hub.tests.dataset_fixtures import enabled_datasets from hub.util.exceptions import InvalidOutputDatasetError @@ -71,7 +72,8 @@ def test_single_transform_hub_dataset(ds): @enabled_datasets -def test_single_transform_hub_dataset_htypes(ds): +@parametrize_num_workers +def test_single_transform_hub_dataset_htypes(ds, num_workers): with CliRunner().isolated_filesystem(): with hub.dataset("./test/transform_hub_in_htypes") as data_in: data_in.create_tensor("image", htype="image", sample_compression="png") @@ -83,7 +85,7 @@ def test_single_transform_hub_dataset_htypes(ds): ds_out = ds ds_out.create_tensor("image") ds_out.create_tensor("label") - fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=5) + fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=num_workers) assert len(ds_out) == 99 for index in range(1, 100): np.testing.assert_array_equal( @@ -104,7 +106,7 @@ def test_chain_transform_list_small(ds): ds_out.create_tensor("image") ds_out.create_tensor("label") pipeline = hub.compose([fn1(mul=5, copy=2), fn2(mul=3, copy=3)]) - pipeline.eval(ls, ds_out, num_workers=5) + pipeline.eval(ls, ds_out, num_workers=3) assert len(ds_out) == 600 for i in range(100): for index in range(6 * i, 6 * i + 6): diff --git a/hub/core/transform/transform.py b/hub/core/transform/transform.py index 4d6770c207..004de93661 100644 --- a/hub/core/transform/transform.py +++ b/hub/core/transform/transform.py @@ -122,6 +122,7 @@ def run( """Runs the pipeline on the input data to produce output samples and stores in the dataset. This receives arguments processed and sanitized by the Pipeline.eval method. """ + num_workers = max(num_workers, 1) size = math.ceil(len(data_in) / num_workers) slices = [data_in[i * size : (i + 1) * size] for i in range(num_workers)] diff --git a/hub/requirements/common.txt b/hub/requirements/common.txt index 47824d4137..e8f7dd16fb 100644 --- a/hub/requirements/common.txt +++ b/hub/requirements/common.txt @@ -6,4 +6,4 @@ pathos humbug>=0.2.6 types-requests types-click -tqdm +tqdm \ No newline at end of file diff --git a/hub/requirements/requirements.txt b/hub/requirements/requirements.txt index 0c39d0a344..3acb88a2c9 100644 --- a/hub/requirements/requirements.txt +++ b/hub/requirements/requirements.txt @@ -7,4 +7,4 @@ numcodecs~=0.7.3 Pillow~=8.2.0 lz4~=3.1.3 zstd~=1.4.5 -requests~=2.25.1 +requests~=2.25.1 \ No newline at end of file diff --git a/hub/tests/common.py b/hub/tests/common.py index 67fe130261..d24032e857 100644 --- a/hub/tests/common.py +++ b/hub/tests/common.py @@ -20,8 +20,10 @@ NUM_BATCHES_PARAM = "num_batches" DTYPE_PARAM = "dtype" CHUNK_SIZE_PARAM = "chunk_size" +NUM_WORKERS_PARAM = "num_workers" NUM_BATCHES = (1, 5) +NUM_WORKERS = (0, 1, 2, 4) CHUNK_SIZES = ( 1 * KB, @@ -39,6 +41,7 @@ parametrize_chunk_sizes = pytest.mark.parametrize(CHUNK_SIZE_PARAM, CHUNK_SIZES) parametrize_dtypes = pytest.mark.parametrize(DTYPE_PARAM, DTYPES) parametrize_num_batches = pytest.mark.parametrize(NUM_BATCHES_PARAM, NUM_BATCHES) +parametrize_num_workers = pytest.mark.parametrize(NUM_WORKERS_PARAM, NUM_WORKERS) def current_test_name() -> str: diff --git a/hub/util/transform.py b/hub/util/transform.py index b1d311cc84..1da744cb72 100644 --- a/hub/util/transform.py +++ b/hub/util/transform.py @@ -172,8 +172,11 @@ def check_transform_data_in(data_in, scheduler: str) -> None: f"The data_in to transform is invalid. It should support __len__ operation." ) if isinstance(data_in, hub.core.dataset.Dataset): - base_storage = get_base_storage(data_in.storage) - if isinstance(base_storage, MemoryProvider) and scheduler != "threaded": + input_base_storage = get_base_storage(data_in.storage) + if isinstance(input_base_storage, MemoryProvider) and scheduler not in [ + "serial", + "threaded", + ]: raise InvalidOutputDatasetError( f"Transforms with data_in as a Dataset having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}." ) @@ -191,7 +194,10 @@ def check_transform_ds_out(ds_out: hub.core.dataset.Dataset, scheduler: str) -> ) output_base_storage = get_base_storage(ds_out.storage) - if isinstance(output_base_storage, MemoryProvider) and scheduler != "threaded": + if isinstance(output_base_storage, MemoryProvider) and scheduler not in [ + "serial", + "threaded", + ]: raise InvalidOutputDatasetError( f"Transforms with ds_out having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}." )