Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/num workers 0 bug #1112

Merged
merged 10 commits into from
Aug 11, 2021
2 changes: 1 addition & 1 deletion hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"ingest_kaggle",
]

__version__ = "2.0.5"
__version__ = "2.0.6"
__encoded_version__ = np.array(__version__)

hub_reporter.tags.append(f"version:{__version__}")
Expand Down
8 changes: 5 additions & 3 deletions hub/core/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from click.testing import CliRunner
from hub.core.storage.memory import MemoryProvider
from hub.util.remove_cache import remove_memory_cache
from hub.tests.common import parametrize_num_workers
from hub.tests.dataset_fixtures import enabled_datasets
from hub.util.exceptions import InvalidOutputDatasetError

Expand Down Expand Up @@ -71,7 +72,8 @@ def test_single_transform_hub_dataset(ds):


@enabled_datasets
def test_single_transform_hub_dataset_htypes(ds):
@parametrize_num_workers
def test_single_transform_hub_dataset_htypes(ds, num_workers):
with CliRunner().isolated_filesystem():
with hub.dataset("./test/transform_hub_in_htypes") as data_in:
data_in.create_tensor("image", htype="image", sample_compression="png")
Expand All @@ -83,7 +85,7 @@ def test_single_transform_hub_dataset_htypes(ds):
ds_out = ds
ds_out.create_tensor("image")
ds_out.create_tensor("label")
fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=5)
fn2(copy=1, mul=2).eval(data_in, ds_out, num_workers=num_workers)
assert len(ds_out) == 99
for index in range(1, 100):
np.testing.assert_array_equal(
Expand All @@ -104,7 +106,7 @@ def test_chain_transform_list_small(ds):
ds_out.create_tensor("image")
ds_out.create_tensor("label")
pipeline = hub.compose([fn1(mul=5, copy=2), fn2(mul=3, copy=3)])
pipeline.eval(ls, ds_out, num_workers=5)
pipeline.eval(ls, ds_out, num_workers=3)
assert len(ds_out) == 600
for i in range(100):
for index in range(6 * i, 6 * i + 6):
Expand Down
1 change: 1 addition & 0 deletions hub/core/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run(
"""Runs the pipeline on the input data to produce output samples and stores in the dataset.
This receives arguments processed and sanitized by the Pipeline.eval method.
"""
num_workers = max(num_workers, 1)
size = math.ceil(len(data_in) / num_workers)
slices = [data_in[i * size : (i + 1) * size] for i in range(num_workers)]

Expand Down
2 changes: 1 addition & 1 deletion hub/requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pathos
humbug>=0.2.6
types-requests
types-click
tqdm
tqdm
2 changes: 1 addition & 1 deletion hub/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ numcodecs~=0.7.3
Pillow~=8.2.0
lz4~=3.1.3
zstd~=1.4.5
requests~=2.25.1
requests~=2.25.1
3 changes: 3 additions & 0 deletions hub/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
NUM_BATCHES_PARAM = "num_batches"
DTYPE_PARAM = "dtype"
CHUNK_SIZE_PARAM = "chunk_size"
NUM_WORKERS_PARAM = "num_workers"

NUM_BATCHES = (1, 5)
NUM_WORKERS = (0, 1, 2, 4)

CHUNK_SIZES = (
1 * KB,
Expand All @@ -39,6 +41,7 @@
parametrize_chunk_sizes = pytest.mark.parametrize(CHUNK_SIZE_PARAM, CHUNK_SIZES)
parametrize_dtypes = pytest.mark.parametrize(DTYPE_PARAM, DTYPES)
parametrize_num_batches = pytest.mark.parametrize(NUM_BATCHES_PARAM, NUM_BATCHES)
parametrize_num_workers = pytest.mark.parametrize(NUM_WORKERS_PARAM, NUM_WORKERS)


def current_test_name() -> str:
Expand Down
12 changes: 9 additions & 3 deletions hub/util/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,11 @@ def check_transform_data_in(data_in, scheduler: str) -> None:
f"The data_in to transform is invalid. It should support __len__ operation."
)
if isinstance(data_in, hub.core.dataset.Dataset):
base_storage = get_base_storage(data_in.storage)
if isinstance(base_storage, MemoryProvider) and scheduler != "threaded":
input_base_storage = get_base_storage(data_in.storage)
if isinstance(input_base_storage, MemoryProvider) and scheduler not in [
"serial",
"threaded",
]:
raise InvalidOutputDatasetError(
f"Transforms with data_in as a Dataset having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}."
)
Expand All @@ -191,7 +194,10 @@ def check_transform_ds_out(ds_out: hub.core.dataset.Dataset, scheduler: str) ->
)

output_base_storage = get_base_storage(ds_out.storage)
if isinstance(output_base_storage, MemoryProvider) and scheduler != "threaded":
if isinstance(output_base_storage, MemoryProvider) and scheduler not in [
"serial",
"threaded",
]:
raise InvalidOutputDatasetError(
f"Transforms with ds_out having base storage as MemoryProvider are only supported in threaded and serial mode. Current mode is {scheduler}."
)