Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL committed Aug 24, 2023
1 parent dadef03 commit 5416660
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 26 deletions.
6 changes: 3 additions & 3 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
index_paths_(spec.GetRepeatedArgument<std::string>("index_paths")),
missing_component_behavior_(detail::wds::ParseMissingExtBehavior(
spec.GetArgument<std::string>("missing_component_behavior"))),
case_insensitive_extensions_(spec.GetArgument<bool>("case_insensitive_extensions")) {
case_sensitive_extensions_(spec.GetArgument<bool>("case_sensitive_extensions")) {
DALI_ENFORCE(paths_.size() == index_paths_.size() || index_paths_.size() == 0,
make_string("The number of index files, if any, must match the number of archives ",
"in the dataset"));
Expand All @@ -242,7 +242,7 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
std::string ext;
ext_.emplace_back();
while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) {
if (case_insensitive_extensions_) {
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
if (!ext_.back().count(ext)) {
Expand Down Expand Up @@ -422,7 +422,7 @@ void WebdatasetLoader::PrepareMetadataImpl() {
component.outputs =
detail::wds::VectorRange<size_t>(output_indicies_, output_indicies_.size());
auto ext = component.ext;
if (case_insensitive_extensions_) {
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
for (auto& output : ext_map[ext]) {
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/webdataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DLL_PUBLIC WebdatasetLoader : public Loader<CPUBackend, vector<Tensor<CPUB

bool generate_index_ = true;
std::string GetSampleSource(const detail::wds::SampleDesc& sample);
bool case_insensitive_extensions_ = false;
bool case_sensitive_extensions_ = true;
};

} // namespace dali
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/reader/webdataset_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ with a semicolon (';') and may contain dots.
Example: "left.png;right.jpg")code",
DALI_STRING_VEC)
.AddOptionalArg("case_insensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case insensitive.
.AddOptionalArg("case_sensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case sensitive.
Allows mixing case sizes in the `ext` argument as well as in the webdataset container. For example
when turned on: jpg, JPG, jPG should work.
If the extension characters cannot be represented as ASCI the result of this option is undefined.
)code", false)
)code", true)
.AddOptionalArg("index_paths",
R"code(The list of the index files corresponding to the respective webdataset archives.
Expand Down
23 changes: 6 additions & 17 deletions dali/test/python/reader/test_webdataset_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


def test_return_empty():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -54,7 +53,6 @@ def test_return_empty():


def test_skip_sample():
global test_batch_size
num_samples = 500
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -96,7 +94,6 @@ def test_skip_sample():


def test_raise_error_on_missing():
global test_batch_size
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
wds_pipeline = webdataset_raw_pipeline(
Expand All @@ -112,7 +109,6 @@ def test_raise_error_on_missing():


def test_different_components():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/scrambled.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand All @@ -139,7 +135,6 @@ def test_different_components():


def test_dtypes():
global test_batch_size
num_samples = 100
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/dtypes.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand All @@ -163,7 +158,6 @@ def test_dtypes():


def test_wds_sharding():
global test_batch_size
num_samples = 3000
tar_file_paths = [
os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
Expand Down Expand Up @@ -203,7 +197,6 @@ def test_wds_sharding():


def test_sharding():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -240,7 +233,6 @@ def test_sharding():


def test_pax_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
pax_tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/pax/devel-0.tar")
Expand Down Expand Up @@ -275,15 +267,14 @@ def test_pax_format():


def test_case_sensitive_container_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
with assert_raises(RuntimeError, glob="Underful sample detected at"):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
Expand All @@ -300,6 +291,7 @@ def test_case_sensitive_container_format():
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
missing_component_behavior="error",
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand All @@ -312,13 +304,12 @@ def test_case_sensitive_container_format():


def test_case_sensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
with assert_raises(RuntimeError, glob="Underful sample detected at"):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
Expand All @@ -335,6 +326,7 @@ def test_case_sensitive_arg_format():
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
missing_component_behavior="error",
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand All @@ -347,7 +339,6 @@ def test_case_sensitive_arg_format():


def test_case_insensitive_container_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
Expand All @@ -371,7 +362,7 @@ def test_case_insensitive_container_format():
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
case_insensitive_extensions=True,
case_sensitive_extensions=False,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand All @@ -384,7 +375,6 @@ def test_case_insensitive_container_format():


def test_case_insensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand All @@ -406,7 +396,7 @@ def test_case_insensitive_arg_format():
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
case_insensitive_extensions=True,
case_sensitive_extensions=False,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand All @@ -419,7 +409,6 @@ def test_case_insensitive_arg_format():


def test_index_generation():
global test_batch_size
num_samples = 3000
tar_file_paths = [
os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/webdataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def webdataset_raw_pipeline(
paths,
index_paths,
ext,
case_insensitive_extensions=False,
case_sensitive_extensions=True,
missing_component_behavior="empty",
dtypes=None,
dont_use_mmap=False,
Expand All @@ -45,7 +45,7 @@ def webdataset_raw_pipeline(
paths=paths,
index_paths=index_paths,
ext=ext,
case_insensitive_extensions=case_insensitive_extensions,
case_sensitive_extensions=case_sensitive_extensions,
missing_component_behavior=missing_component_behavior,
dtypes=dtypes,
dont_use_mmap=dont_use_mmap,
Expand Down

0 comments on commit 5416660

Please sign in to comment.