Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opt-in support for case insensitive webdataset #5016

Merged
merged 3 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DALI_EXTRA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a6765e17a8297ee2ddc98e2174e99d531a44390b
FIXME
18 changes: 16 additions & 2 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,19 @@ inline std::string SupportedTypesListGen() {
return out_str.substr(0, out_str.size() - 2 * (detail::wds::kSupportedTypes.size() > 0));
}

std::string str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c){ return std::tolower(c); });
return s;
}

WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
: Loader(spec),
paths_(spec.GetRepeatedArgument<std::string>("paths")),
index_paths_(spec.GetRepeatedArgument<std::string>("index_paths")),
missing_component_behavior_(detail::wds::ParseMissingExtBehavior(
spec.GetArgument<std::string>("missing_component_behavior"))) {
spec.GetArgument<std::string>("missing_component_behavior"))),
case_insensitive_extensions_(spec.GetArgument<bool>("case_insensitive_extensions")) {
DALI_ENFORCE(paths_.size() == index_paths_.size() || index_paths_.size() == 0,
make_string("The number of index files, if any, must match the number of archives ",
"in the dataset"));
Expand All @@ -235,6 +242,9 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
std::string ext;
ext_.emplace_back();
while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) {
if (case_insensitive_extensions_) {
ext = str_tolower(ext);
}
if (!ext_.back().count(ext)) {
ext_.back().insert(ext);
}
Expand Down Expand Up @@ -411,7 +421,11 @@ void WebdatasetLoader::PrepareMetadataImpl() {
for (auto& component : sample.components) {
component.outputs =
detail::wds::VectorRange<size_t>(output_indicies_, output_indicies_.size());
for (auto& output : ext_map[component.ext]) {
auto ext = component.ext;
if (case_insensitive_extensions_) {
ext = str_tolower(ext);
}
for (auto& output : ext_map[ext]) {
if (!was_output_set[output]) {
DALI_ENFORCE(component.size % dtype_sizes_[output] == 0,
make_string("Error in index file at ", GetSampleSource(new_sample),
Expand Down
1 change: 1 addition & 0 deletions dali/operators/reader/loader/webdataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class DLL_PUBLIC WebdatasetLoader : public Loader<CPUBackend, vector<Tensor<CPUB

bool generate_index_ = true;
std::string GetSampleSource(const detail::wds::SampleDesc& sample);
bool case_insensitive_extensions_ = false;
};

} // namespace dali
Expand Down
8 changes: 8 additions & 0 deletions dali/operators/reader/webdataset_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ with a semicolon (';') and may contain dots.

Example: "left.png;right.jpg")code",
DALI_STRING_VEC)
.AddOptionalArg("case_insensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case insensitive.

Allows mixing case sizes in the `ext` argument as well as in the webdataset container. For example
when turned on: jpg, JPG, jPG should work.

If the extension characters cannot be represented as ASCI the result of this option is undefined.
stiepan marked this conversation as resolved.
Show resolved Hide resolved
)code", false)
stiepan marked this conversation as resolved.
Show resolved Hide resolved
.AddOptionalArg("index_paths",
R"code(The list of the index files corresponding to the respective webdataset archives.

Expand Down
144 changes: 144 additions & 0 deletions dali/test/python/reader/test_webdataset_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,150 @@ def test_pax_format():
)


def test_case_sensitive_container_format():
global test_batch_size
stiepan marked this conversation as resolved.
Show resolved Hide resolved
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
stiepan marked this conversation as resolved.
Show resolved Hide resolved
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_sensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(AssertionError):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_container_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
case_insensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_arg_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
case_insensitive_extensions=True,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_index_generation():
global test_batch_size
num_samples = 3000
Expand Down
2 changes: 2 additions & 0 deletions dali/test/python/webdataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def webdataset_raw_pipeline(
paths,
index_paths,
ext,
case_insensitive_extensions=False,
missing_component_behavior="empty",
dtypes=None,
dont_use_mmap=False,
Expand All @@ -44,6 +45,7 @@ def webdataset_raw_pipeline(
paths=paths,
index_paths=index_paths,
ext=ext,
case_insensitive_extensions=case_insensitive_extensions,
missing_component_behavior=missing_component_behavior,
dtypes=dtypes,
dont_use_mmap=dont_use_mmap,
Expand Down