Skip to content

Commit

Permalink
Add opt-in support for case insensitive webdataset
Browse files Browse the repository at this point in the history
- the current implementation of webdataset is case-sensitive when
  it comes to file extensions. This PR adds an option to make
  it case insensitive when it comes to provided by the user extensions
  as well as the file extensions in the container

Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
  • Loading branch information
JanuszL committed Aug 23, 2023
1 parent 301d1a6 commit dadef03
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 3 deletions.
2 changes: 1 addition & 1 deletion DALI_EXTRA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a6765e17a8297ee2ddc98e2174e99d531a44390b
FIXME
18 changes: 16 additions & 2 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,19 @@ inline std::string SupportedTypesListGen() {
return out_str.substr(0, out_str.size() - 2 * (detail::wds::kSupportedTypes.size() > 0));
}

std::string str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c){ return std::tolower(c); });
return s;
}

WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
: Loader(spec),
paths_(spec.GetRepeatedArgument<std::string>("paths")),
index_paths_(spec.GetRepeatedArgument<std::string>("index_paths")),
missing_component_behavior_(detail::wds::ParseMissingExtBehavior(
spec.GetArgument<std::string>("missing_component_behavior"))) {
spec.GetArgument<std::string>("missing_component_behavior"))),
case_insensitive_extensions_(spec.GetArgument<bool>("case_insensitive_extensions")) {
DALI_ENFORCE(paths_.size() == index_paths_.size() || index_paths_.size() == 0,
make_string("The number of index files, if any, must match the number of archives ",
"in the dataset"));
Expand All @@ -235,6 +242,9 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
std::string ext;
ext_.emplace_back();
while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) {
if (case_insensitive_extensions_) {
ext = str_tolower(ext);
}
if (!ext_.back().count(ext)) {
ext_.back().insert(ext);
}
Expand Down Expand Up @@ -411,7 +421,11 @@ void WebdatasetLoader::PrepareMetadataImpl() {
for (auto& component : sample.components) {
component.outputs =
detail::wds::VectorRange<size_t>(output_indicies_, output_indicies_.size());
for (auto& output : ext_map[component.ext]) {
auto ext = component.ext;
if (case_insensitive_extensions_) {
ext = str_tolower(ext);
}
for (auto& output : ext_map[ext]) {
if (!was_output_set[output]) {
DALI_ENFORCE(component.size % dtype_sizes_[output] == 0,
make_string("Error in index file at ", GetSampleSource(new_sample),
Expand Down
1 change: 1 addition & 0 deletions dali/operators/reader/loader/webdataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class DLL_PUBLIC WebdatasetLoader : public Loader<CPUBackend, vector<Tensor<CPUB

bool generate_index_ = true;
std::string GetSampleSource(const detail::wds::SampleDesc& sample);
bool case_insensitive_extensions_ = false;
};

} // namespace dali
Expand Down
8 changes: 8 additions & 0 deletions dali/operators/reader/webdataset_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ with a semicolon (';') and may contain dots.
Example: "left.png;right.jpg")code",
DALI_STRING_VEC)
.AddOptionalArg("case_insensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case insensitive.
Allows mixing case sizes in the `ext` argument as well as in the webdataset container. For example
when turned on: jpg, JPG, jPG should work.
If the extension characters cannot be represented as ASCI the result of this option is undefined.
)code", false)
.AddOptionalArg("index_paths",
R"code(The list of the index files corresponding to the respective webdataset archives.
Expand Down
144 changes: 144 additions & 0 deletions dali/test/python/reader/test_webdataset_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,150 @@ def test_pax_format():
)


def test_case_sensitive_container_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_sensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_container_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
case_insensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
case_insensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_index_generation():
global test_batch_size
num_samples = 3000
Expand Down
2 changes: 2 additions & 0 deletions dali/test/python/webdataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def webdataset_raw_pipeline(
paths,
index_paths,
ext,
case_insensitive_extensions=False,
missing_component_behavior="empty",
dtypes=None,
dont_use_mmap=False,
Expand All @@ -44,6 +45,7 @@ def webdataset_raw_pipeline(
paths=paths,
index_paths=index_paths,
ext=ext,
case_insensitive_extensions=case_insensitive_extensions,
missing_component_behavior=missing_component_behavior,
dtypes=dtypes,
dont_use_mmap=dont_use_mmap,
Expand Down

0 comments on commit dadef03

Please sign in to comment.