Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL committed Aug 24, 2023
1 parent dadef03 commit 6f8678a
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
index_paths_(spec.GetRepeatedArgument<std::string>("index_paths")),
missing_component_behavior_(detail::wds::ParseMissingExtBehavior(
spec.GetArgument<std::string>("missing_component_behavior"))),
case_insensitive_extensions_(spec.GetArgument<bool>("case_insensitive_extensions")) {
case_sensitive_extensions_(spec.GetArgument<bool>("case_sensitive_extensions")) {
DALI_ENFORCE(paths_.size() == index_paths_.size() || index_paths_.size() == 0,
make_string("The number of index files, if any, must match the number of archives ",
"in the dataset"));
Expand All @@ -242,7 +242,7 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
std::string ext;
ext_.emplace_back();
while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) {
if (case_insensitive_extensions_) {
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
if (!ext_.back().count(ext)) {
Expand Down Expand Up @@ -422,7 +422,7 @@ void WebdatasetLoader::PrepareMetadataImpl() {
component.outputs =
detail::wds::VectorRange<size_t>(output_indicies_, output_indicies_.size());
auto ext = component.ext;
if (case_insensitive_extensions_) {
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
for (auto& output : ext_map[ext]) {
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/webdataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class DLL_PUBLIC WebdatasetLoader : public Loader<CPUBackend, vector<Tensor<CPUB

bool generate_index_ = true;
std::string GetSampleSource(const detail::wds::SampleDesc& sample);
bool case_insensitive_extensions_ = false;
bool case_sensitive_extensions_ = true;
};

} // namespace dali
Expand Down
6 changes: 3 additions & 3 deletions dali/operators/reader/webdataset_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ with a semicolon (';') and may contain dots.
Example: "left.png;right.jpg")code",
DALI_STRING_VEC)
.AddOptionalArg("case_insensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case insensitive.
.AddOptionalArg("case_sensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case sensitive.
Allows mixing case sizes in the `ext` argument as well as in the webdataset container. For example
when turned on: jpg, JPG, jPG should work.
If the extension characters cannot be represented as ASCI the result of this option is undefined.
)code", false)
)code", true)
.AddOptionalArg("index_paths",
R"code(The list of the index files corresponding to the respective webdataset archives.
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/reader/test_webdataset_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def test_case_insensitive_container_format():
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
case_insensitive_extensions=True,
case_sensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_case_insensitive_arg_format():
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
case_insensitive_extensions=True,
case_sensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
Expand Down
4 changes: 2 additions & 2 deletions dali/test/python/webdataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def webdataset_raw_pipeline(
paths,
index_paths,
ext,
case_insensitive_extensions=False,
case_sensitive_extensions=True,
missing_component_behavior="empty",
dtypes=None,
dont_use_mmap=False,
Expand All @@ -45,7 +45,7 @@ def webdataset_raw_pipeline(
paths=paths,
index_paths=index_paths,
ext=ext,
case_insensitive_extensions=case_insensitive_extensions,
case_sensitive_extensions=case_sensitive_extensions,
missing_component_behavior=missing_component_behavior,
dtypes=dtypes,
dont_use_mmap=dont_use_mmap,
Expand Down

0 comments on commit 6f8678a

Please sign in to comment.