Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opt-in support for case insensitive webdataset #5016

Merged
merged 3 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DALI_EXTRA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a6765e17a8297ee2ddc98e2174e99d531a44390b
2aa5813f4e9bdb8301312c29f25627b601353f1e
18 changes: 16 additions & 2 deletions dali/operators/reader/loader/webdataset_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,19 @@ inline std::string SupportedTypesListGen() {
return out_str.substr(0, out_str.size() - 2 * (detail::wds::kSupportedTypes.size() > 0));
}

std::string str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c){ return std::tolower(c); });
return s;
}

WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
: Loader(spec),
paths_(spec.GetRepeatedArgument<std::string>("paths")),
index_paths_(spec.GetRepeatedArgument<std::string>("index_paths")),
missing_component_behavior_(detail::wds::ParseMissingExtBehavior(
spec.GetArgument<std::string>("missing_component_behavior"))) {
spec.GetArgument<std::string>("missing_component_behavior"))),
case_sensitive_extensions_(spec.GetArgument<bool>("case_sensitive_extensions")) {
DALI_ENFORCE(paths_.size() == index_paths_.size() || index_paths_.size() == 0,
make_string("The number of index files, if any, must match the number of archives ",
"in the dataset"));
Expand All @@ -235,6 +242,9 @@ WebdatasetLoader::WebdatasetLoader(const OpSpec& spec)
std::string ext;
ext_.emplace_back();
while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) {
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
if (!ext_.back().count(ext)) {
ext_.back().insert(ext);
}
Expand Down Expand Up @@ -411,7 +421,11 @@ void WebdatasetLoader::PrepareMetadataImpl() {
for (auto& component : sample.components) {
component.outputs =
detail::wds::VectorRange<size_t>(output_indicies_, output_indicies_.size());
for (auto& output : ext_map[component.ext]) {
auto ext = component.ext;
if (!case_sensitive_extensions_) {
ext = str_tolower(ext);
}
for (auto& output : ext_map[ext]) {
if (!was_output_set[output]) {
DALI_ENFORCE(component.size % dtype_sizes_[output] == 0,
make_string("Error in index file at ", GetSampleSource(new_sample),
Expand Down
1 change: 1 addition & 0 deletions dali/operators/reader/loader/webdataset_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class DLL_PUBLIC WebdatasetLoader : public Loader<CPUBackend, vector<Tensor<CPUB

bool generate_index_ = true;
std::string GetSampleSource(const detail::wds::SampleDesc& sample);
bool case_sensitive_extensions_ = true;
};

} // namespace dali
Expand Down
9 changes: 9 additions & 0 deletions dali/operators/reader/webdataset_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ with a semicolon (';') and may contain dots.

Example: "left.png;right.jpg")code",
DALI_STRING_VEC)
.AddOptionalArg("case_sensitive_extensions",
R"code(Determines whether the extensions provided via the `ext` should be case sensitive.

Allows mixing case sizes in the `ext` argument as well as in the webdataset container. For example
when turned off: jpg, JPG, jPG should work.

If the extension characters cannot be represented as ASCI the result of turing this option off
is undefined.
)code", true)
.AddOptionalArg("index_paths",
R"code(The list of the index files corresponding to the respective webdataset archives.

Expand Down
151 changes: 142 additions & 9 deletions dali/test/python/reader/test_webdataset_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


def test_return_empty():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -54,7 +53,6 @@ def test_return_empty():


def test_skip_sample():
global test_batch_size
num_samples = 500
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -96,7 +94,6 @@ def test_skip_sample():


def test_raise_error_on_missing():
global test_batch_size
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar")
index_file = generate_temp_index_file(tar_file_path)
wds_pipeline = webdataset_raw_pipeline(
Expand All @@ -112,7 +109,6 @@ def test_raise_error_on_missing():


def test_different_components():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/scrambled.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand All @@ -139,7 +135,6 @@ def test_different_components():


def test_dtypes():
global test_batch_size
num_samples = 100
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/dtypes.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand All @@ -163,7 +158,6 @@ def test_dtypes():


def test_wds_sharding():
global test_batch_size
num_samples = 3000
tar_file_paths = [
os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
Expand Down Expand Up @@ -203,7 +197,6 @@ def test_wds_sharding():


def test_sharding():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)
Expand Down Expand Up @@ -240,7 +233,6 @@ def test_sharding():


def test_pax_format():
global test_batch_size
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
pax_tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/pax/devel-0.tar")
Expand Down Expand Up @@ -274,8 +266,149 @@ def test_pax_format():
)


def test_case_sensitive_container_format():
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(RuntimeError, glob="Underful sample detected at"):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
missing_component_behavior="error",
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_sensitive_arg_format():
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
with assert_raises(RuntimeError, glob="Underful sample detected at"):
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
missing_component_behavior="error",
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_container_format():
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
case_insensitive_tar_file_path = os.path.join(get_dali_extra_path(),
"db/webdataset/case_insensitive/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
case_insensitive_tar_file_path,
None,
ext=["jpg", "cls"],
case_sensitive_extensions=False,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_case_insensitive_arg_format():
num_samples = 1000
tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar")
index_file = generate_temp_index_file(tar_file_path)

num_shards = 100
for shard_id in range(num_shards):
compare_pipelines(
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
["jpg", "cls"],
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
webdataset_raw_pipeline(
tar_file_path,
index_file.name,
ext=["Jpg", "cls"],
case_sensitive_extensions=False,
num_shards=num_shards,
shard_id=shard_id,
batch_size=test_batch_size,
device_id=0,
num_threads=1,
),
test_batch_size,
math.ceil(num_samples / num_shards / test_batch_size) * 2,
)


def test_index_generation():
global test_batch_size
num_samples = 3000
tar_file_paths = [
os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"),
Expand Down
2 changes: 2 additions & 0 deletions dali/test/python/webdataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def webdataset_raw_pipeline(
paths,
index_paths,
ext,
case_sensitive_extensions=True,
missing_component_behavior="empty",
dtypes=None,
dont_use_mmap=False,
Expand All @@ -44,6 +45,7 @@ def webdataset_raw_pipeline(
paths=paths,
index_paths=index_paths,
ext=ext,
case_sensitive_extensions=case_sensitive_extensions,
missing_component_behavior=missing_component_behavior,
dtypes=dtypes,
dont_use_mmap=dont_use_mmap,
Expand Down