Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpointing in FileReader #4954

Merged
merged 15 commits into from
Jul 31, 2023
14 changes: 11 additions & 3 deletions dali/operators/reader/file_reader_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

namespace dali {

class FileReader : public DataReader<CPUBackend, ImageLabelWrapper> {
class FileReader : public DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true> {
public:
explicit FileReader(const OpSpec& spec)
: DataReader<CPUBackend, ImageLabelWrapper>(spec) {
: DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true>(spec) {
bool shuffle_after_epoch = spec.GetArgument<bool>("shuffle_after_epoch");
loader_ = InitLoader<FileLabelLoader>(spec, shuffle_after_epoch);
}
Expand All @@ -54,8 +54,16 @@ class FileReader : public DataReader<CPUBackend, ImageLabelWrapper> {
label_output.mutable_data<int>()[0] = image_label.label;
}

void SaveState(OpCheckpoint &cpt, std::optional<cudaStream_t> stream) override {
cpt.MutableCheckpointState() = loader_->PopStateSnapshot();
}

void RestoreState(const OpCheckpoint &cpt) override {
loader_->RestoreStateFromSnapshot(cpt.CheckpointState<LoaderStateSnapshot>());
}

protected:
USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper);
USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true);
};

} // namespace dali
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/reader/loader/coco_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,11 @@ struct RLEMask : public UniqueHandle<RLE, RLEMask> {

using RLEMaskPtr = std::shared_ptr<RLEMask>;

class DLL_PUBLIC CocoLoader : public FileLabelLoader {
class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase<false> {
public:
explicit inline CocoLoader(const OpSpec &spec)
: FileLabelLoader(spec, spec.GetArgument<bool>("shuffle_after_epoch")), spec_(spec) {
: FileLabelLoaderBase<false>(spec, spec.GetArgument<bool>("shuffle_after_epoch"))
, spec_(spec) {
has_preprocessed_annotations_ = HasPreprocessedAnnotations(spec);
DALI_ENFORCE(has_preprocessed_annotations_ || spec.HasArgument("annotations_file"),
"Either ``annotations_file`` or ``preprocessed_annotations`` must be provided");
Expand Down
13 changes: 10 additions & 3 deletions dali/operators/reader/loader/file_label_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ namespace dali {

using filesystem::dir_sep;

void FileLabelLoader::PrepareEmpty(ImageLabelWrapper &image_label) {
template<bool checkpointing_supported>
void FileLabelLoaderBase<checkpointing_supported>::PrepareEmpty(ImageLabelWrapper &image_label) {
PrepareEmptyTensor(image_label.image);
}

void FileLabelLoader::ReadSample(ImageLabelWrapper &image_label) {
template<bool checkpointing_supported>
void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper &image_label) {
auto image_pair = image_label_pairs_[current_index_++];

// handle wrap-around
Expand Down Expand Up @@ -73,7 +75,12 @@ void FileLabelLoader::ReadSample(ImageLabelWrapper &image_label) {
image_label.image.SetMeta(meta);
}

Index FileLabelLoader::SizeImpl() {
template<bool checkpointing_supported>
Index FileLabelLoaderBase<checkpointing_supported>::SizeImpl() {
return static_cast<Index>(image_label_pairs_.size());
}

template class FileLabelLoaderBase<false>;
template class FileLabelLoaderBase<true>;

} // namespace dali
150 changes: 92 additions & 58 deletions dali/operators/reader/loader/file_label_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,82 +38,85 @@ struct ImageLabelWrapper {
int label;
};

class DLL_PUBLIC FileLabelLoader : public Loader<CPUBackend, ImageLabelWrapper> {
template<bool supports_checkpointing>
class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapper,
supports_checkpointing> {
public:
explicit inline FileLabelLoader(
using Base = Loader<CPUBackend, ImageLabelWrapper, supports_checkpointing>;
explicit inline FileLabelLoaderBase(
const OpSpec& spec,
bool shuffle_after_epoch = false)
: Loader<CPUBackend, ImageLabelWrapper>(spec),
: Base(spec),
shuffle_after_epoch_(shuffle_after_epoch),
current_index_(0),
current_epoch_(0) {

vector<string> files;
vector<int> labels;
vector<string> files;
vector<int> labels;

has_files_arg_ = spec.TryGetRepeatedArgument(files, "files");
has_labels_arg_ = spec.TryGetRepeatedArgument(labels, "labels");
has_file_list_arg_ = spec.TryGetArgument(file_list_, "file_list");
has_file_root_arg_ = spec.TryGetArgument(file_root_, "file_root");
bool has_file_filters_arg = spec.TryGetRepeatedArgument(filters_, "file_filters");
has_files_arg_ = spec.TryGetRepeatedArgument(files, "files");
has_labels_arg_ = spec.TryGetRepeatedArgument(labels, "labels");
has_file_list_arg_ = spec.TryGetArgument(file_list_, "file_list");
has_file_root_arg_ = spec.TryGetArgument(file_root_, "file_root");
bool has_file_filters_arg = spec.TryGetRepeatedArgument(filters_, "file_filters");

// TODO(ksztenderski): CocoLoader inherits after FileLabelLoader and it doesn't work with
// GetArgument.
spec.TryGetArgument(case_sensitive_filter_, "case_sensitive_filter");
// TODO(ksztenderski): CocoLoader inherits after FileLabelLoader and it doesn't work with
// GetArgument.
spec.TryGetArgument(case_sensitive_filter_, "case_sensitive_filter");

DALI_ENFORCE(has_file_root_arg_ || has_files_arg_ || has_file_list_arg_,
"``file_root`` argument is required when not using ``files`` or ``file_list``.");
DALI_ENFORCE(has_file_root_arg_ || has_files_arg_ || has_file_list_arg_,
"``file_root`` argument is required when not using ``files`` or ``file_list``.");

DALI_ENFORCE(has_files_arg_ + has_file_list_arg_ <= 1,
"File paths can be provided through ``files`` or ``file_list`` but not both.");
DALI_ENFORCE(has_files_arg_ + has_file_list_arg_ <= 1,
"File paths can be provided through ``files`` or ``file_list`` but not both.");

DALI_ENFORCE(has_files_arg_ || !has_labels_arg_,
"The argument ``labels`` is valid only when file paths "
"are provided as ``files`` argument.");
DALI_ENFORCE(has_files_arg_ || !has_labels_arg_,
"The argument ``labels`` is valid only when file paths "
"are provided as ``files`` argument.");

DALI_ENFORCE(!has_file_filters_arg || filters_.size() > 0,
"``file_filters`` list cannot be empty.");
DALI_ENFORCE(!has_file_filters_arg || filters_.size() > 0,
"``file_filters`` list cannot be empty.");

if (has_file_list_arg_) {
DALI_ENFORCE(!file_list_.empty(), "``file_list`` argument cannot be empty");
if (!has_file_root_arg_) {
auto idx = file_list_.rfind(filesystem::dir_sep);
if (idx != string::npos) {
file_root_ = file_list_.substr(0, idx);
}
if (has_file_list_arg_) {
DALI_ENFORCE(!file_list_.empty(), "``file_list`` argument cannot be empty");
if (!has_file_root_arg_) {
auto idx = file_list_.rfind(filesystem::dir_sep);
if (idx != string::npos) {
file_root_ = file_list_.substr(0, idx);
}
}
}

if (has_files_arg_) {
DALI_ENFORCE(files.size() > 0, "``files`` specified an empty list.");
if (has_labels_arg_) {
DALI_ENFORCE(files.size() == labels.size(), make_string("Provided ", labels.size(),
" labels for ", files.size(), " files."));
if (has_files_arg_) {
DALI_ENFORCE(files.size() > 0, "``files`` specified an empty list.");
if (has_labels_arg_) {
DALI_ENFORCE(files.size() == labels.size(), make_string("Provided ", labels.size(),
" labels for ", files.size(), " files."));

for (int i = 0, n = files.size(); i < n; i++)
image_label_pairs_.emplace_back(std::move(files[i]), labels[i]);
} else {
for (int i = 0, n = files.size(); i < n; i++)
image_label_pairs_.emplace_back(std::move(files[i]), labels[i]);
} else {
for (int i = 0, n = files.size(); i < n; i++)
image_label_pairs_.emplace_back(std::move(files[i]), i);
}
image_label_pairs_.emplace_back(std::move(files[i]), i);
}
}

/*
* Those options are mutually exclusive as `shuffle_after_epoch` will make every shard looks differently
* after each epoch so coexistence with `stick_to_shard` doesn't make any sense
* Still when `shuffle_after_epoch` we will set `stick_to_shard` internally in the FileLabelLoader so all
* DALI instances will do shuffling after each epoch
*/
DALI_ENFORCE(!(shuffle_after_epoch_ && stick_to_shard_),
"shuffle_after_epoch and stick_to_shard cannot be both true");
DALI_ENFORCE(!(shuffle_after_epoch_ && shuffle_),
"shuffle_after_epoch and random_shuffle cannot be both true");
/*
* Imply `stick_to_shard` from `shuffle_after_epoch`
*/
if (shuffle_after_epoch_) {
stick_to_shard_ = true;
}
/*
* Those options are mutually exclusive as `shuffle_after_epoch` will make every shard looks differently
* after each epoch so coexistence with `stick_to_shard` doesn't make any sense
* Still when `shuffle_after_epoch` we will set `stick_to_shard` internally in the FileLabelLoader so all
* DALI instances will do shuffling after each epoch
*/
DALI_ENFORCE(!(shuffle_after_epoch_ && stick_to_shard_),
"shuffle_after_epoch and stick_to_shard cannot be both true");
DALI_ENFORCE(!(shuffle_after_epoch_ && shuffle_),
"shuffle_after_epoch and random_shuffle cannot be both true");
/*
* Imply `stick_to_shard` from `shuffle_after_epoch`
*/
if (shuffle_after_epoch_) {
stick_to_shard_ = true;
}
if (!dont_use_mmap_) {
mmap_reserver_ = FileStream::MappingReserver(
static_cast<unsigned int>(initial_buffer_fill_));
Expand Down Expand Up @@ -182,29 +185,58 @@ class DLL_PUBLIC FileLabelLoader : public Loader<CPUBackend, ImageLabelWrapper>
std::mt19937 g(kDaliDataloaderSeed);
std::shuffle(image_label_pairs_.begin(), image_label_pairs_.end(), g);
}

if (checkpointing_ && shuffle_after_epoch_) {
// save initial order
// moving to prevent one copy, as it is restored in the Reset()
backup_image_label_pairs_ = std::move(image_label_pairs_);
}

Reset(true);
}

void Reset(bool wrap_to_shard) override {
if (wrap_to_shard) {
current_index_ = start_index(shard_id_, num_shards_, SizeImpl());
current_index_ = start_index(virtual_shard_id_, num_shards_, SizeImpl());
} else {
current_index_ = 0;
}

current_epoch_++;

if (shuffle_after_epoch_) {
if (checkpointing_) {
// With checkpointing enabled, dataset order must be easy to restore.
// The shuffling is run with different seed every epoch, so this doesn't impact
// the random distribution.
image_label_pairs_ = backup_image_label_pairs_;
}
std::mt19937 g(kDaliDataloaderSeed + current_epoch_);
std::shuffle(image_label_pairs_.begin(), image_label_pairs_.end(), g);
}
}

using Loader<CPUBackend, ImageLabelWrapper>::shard_id_;
using Loader<CPUBackend, ImageLabelWrapper>::num_shards_;
void RestoreStateImpl(const LoaderStateSnapshot &state) override {
current_epoch_ = state.current_epoch;
}

using Base::shard_id_;
using Base::virtual_shard_id_;
using Base::num_shards_;
using Base::stick_to_shard_;
using Base::shuffle_;
using Base::dont_use_mmap_;
using Base::initial_buffer_fill_;
using Base::copy_read_data_;
using Base::read_ahead_;
using Base::checkpointing_;
using Base::PrepareEmptyTensor;
using Base::MoveToNextShard;
using Base::ShouldSkipImage;

string file_root_, file_list_;
vector<std::pair<string, int>> image_label_pairs_;
vector<std::pair<string, int>> backup_image_label_pairs_;
vector<string> filters_;

bool has_files_arg_ = false;
Expand All @@ -219,6 +251,8 @@ class DLL_PUBLIC FileLabelLoader : public Loader<CPUBackend, ImageLabelWrapper>
FileStream::MappingReserver mmap_reserver_;
};

using FileLabelLoader = FileLabelLoaderBase<true>;

} // namespace dali

#endif // DALI_OPERATORS_READER_LOADER_FILE_LABEL_LOADER_H_
Loading