Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataReader any iteration checkpointing #4967

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 3 additions & 11 deletions dali/operators/reader/file_reader_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

namespace dali {

class FileReader : public DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true> {
class FileReader : public DataReader<CPUBackend, ImageLabelWrapper> {
public:
explicit FileReader(const OpSpec& spec)
: DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true>(spec) {
: DataReader<CPUBackend, ImageLabelWrapper>(spec) {
bool shuffle_after_epoch = spec.GetArgument<bool>("shuffle_after_epoch");
loader_ = InitLoader<FileLabelLoader>(spec, shuffle_after_epoch);
}
Expand All @@ -54,16 +54,8 @@ class FileReader : public DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWr
label_output.mutable_data<int>()[0] = image_label.label;
}

void SaveState(OpCheckpoint &cpt, std::optional<cudaStream_t> stream) override {
cpt.MutableCheckpointState() = loader_->PopStateSnapshot();
}

void RestoreState(const OpCheckpoint &cpt) override {
loader_->RestoreStateFromSnapshot(cpt.CheckpointState<LoaderStateSnapshot>());
}

protected:
USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true);
USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper);
};

} // namespace dali
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/reader/loader/coco_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ struct RLEMask : public UniqueHandle<RLE, RLEMask> {

using RLEMaskPtr = std::shared_ptr<RLEMask>;

class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase<false> {
class DLL_PUBLIC CocoLoader : public FileLabelLoader {
public:
explicit inline CocoLoader(const OpSpec &spec)
: FileLabelLoaderBase<false>(spec, spec.GetArgument<bool>("shuffle_after_epoch"))
: FileLabelLoader(spec, spec.GetArgument<bool>("shuffle_after_epoch"))
, spec_(spec) {
has_preprocessed_annotations_ = HasPreprocessedAnnotations(spec);
DALI_ENFORCE(has_preprocessed_annotations_ || spec.HasArgument("annotations_file"),
Expand Down
12 changes: 3 additions & 9 deletions dali/operators/reader/loader/file_label_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ namespace dali {

using filesystem::dir_sep;

template<bool checkpointing_supported>
void FileLabelLoaderBase<checkpointing_supported>::PrepareEmpty(ImageLabelWrapper &image_label) {
void FileLabelLoader::PrepareEmpty(ImageLabelWrapper &image_label) {
PrepareEmptyTensor(image_label.image);
}

template<bool checkpointing_supported>
void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper &image_label) {
void FileLabelLoader::ReadSample(ImageLabelWrapper &image_label) {
auto image_pair = image_label_pairs_[current_index_++];

// handle wrap-around
Expand Down Expand Up @@ -75,12 +73,8 @@ void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper
image_label.image.SetMeta(meta);
}

template<bool checkpointing_supported>
Index FileLabelLoaderBase<checkpointing_supported>::SizeImpl() {
Index FileLabelLoader::SizeImpl() {
return static_cast<Index>(image_label_pairs_.size());
}

template class FileLabelLoaderBase<false>;
template class FileLabelLoaderBase<true>;

} // namespace dali
72 changes: 34 additions & 38 deletions dali/operators/reader/loader/file_label_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ struct ImageLabelWrapper {
int label;
};

template<bool supports_checkpointing>
class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapper,
supports_checkpointing> {
struct FileLabelLoaderCheckpoint {
vector<std::pair<string, int>> image_label_pairs_;
Index current_index_;
int current_epoch_;
};

class DLL_PUBLIC FileLabelLoader : public Loader<CPUBackend, ImageLabelWrapper> {
public:
using Base = Loader<CPUBackend, ImageLabelWrapper, supports_checkpointing>;
explicit inline FileLabelLoaderBase(
explicit inline FileLabelLoader(
const OpSpec& spec,
bool shuffle_after_epoch = false)
: Base(spec),
: Loader<CPUBackend, ImageLabelWrapper>(spec),
shuffle_after_epoch_(shuffle_after_epoch),
current_index_(0),
current_epoch_(0) {
Expand Down Expand Up @@ -108,7 +111,7 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapp
* DALI instances will do shuffling after each epoch
*/
DALI_ENFORCE(!(shuffle_after_epoch_ && stick_to_shard_),
"shuffle_after_epoch and stick_to_shard cannot be both true");
"shu/fileffle_after_epoch and stick_to_shard cannot be both true");
DALI_ENFORCE(!(shuffle_after_epoch_ && shuffle_),
"shuffle_after_epoch and random_shuffle cannot be both true");
/*
Expand All @@ -127,6 +130,15 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapp
void PrepareEmpty(ImageLabelWrapper &tensor) override;
void ReadSample(ImageLabelWrapper &tensor) override;

ImageLabelWrapper CopyTarget(const ImageLabelWrapper &target) override {
Tensor<CPUBackend> tensor;
tensor.Copy(target.image);
return ImageLabelWrapper {
std::move(tensor),
target.label
};
}

protected:
Index SizeImpl() override;

Expand Down Expand Up @@ -186,57 +198,43 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapp
std::shuffle(image_label_pairs_.begin(), image_label_pairs_.end(), g);
}

if (checkpointing_ && shuffle_after_epoch_) {
// save initial order
// moving to prevent one copy, as it is restored in the Reset()
backup_image_label_pairs_ = std::move(image_label_pairs_);
}

Reset(true);
}

void Reset(bool wrap_to_shard) override {
if (wrap_to_shard) {
current_index_ = start_index(virtual_shard_id_, num_shards_, SizeImpl());
current_index_ = start_index(shard_id_, num_shards_, SizeImpl());
} else {
current_index_ = 0;
}

current_epoch_++;

if (shuffle_after_epoch_) {
if (checkpointing_) {
// With checkpointing enabled, dataset order must be easy to restore.
// The shuffling is run with different seed every epoch, so this doesn't impact
// the random distribution.
image_label_pairs_ = backup_image_label_pairs_;
}
std::mt19937 g(kDaliDataloaderSeed + current_epoch_);
std::shuffle(image_label_pairs_.begin(), image_label_pairs_.end(), g);
}
}

void RestoreStateImpl(const LoaderStateSnapshot &state) override {
current_epoch_ = state.current_epoch;
std::any SaveStateImpl() override {
return FileLabelLoaderCheckpoint {
image_label_pairs_,
current_index_,
current_epoch_,
};
}

using Base::shard_id_;
using Base::virtual_shard_id_;
using Base::num_shards_;
using Base::stick_to_shard_;
using Base::shuffle_;
using Base::dont_use_mmap_;
using Base::initial_buffer_fill_;
using Base::copy_read_data_;
using Base::read_ahead_;
using Base::checkpointing_;
using Base::PrepareEmptyTensor;
using Base::MoveToNextShard;
using Base::ShouldSkipImage;
void RestoreStateImpl(const std::any &opaque_state) override {
const auto &state = std::any_cast<FileLabelLoaderCheckpoint>(opaque_state);
image_label_pairs_ = state.image_label_pairs_;
current_index_ = state.current_index_;
current_epoch_ = state.current_epoch_;
}

using Loader<CPUBackend, ImageLabelWrapper>::shard_id_;

string file_root_, file_list_;
vector<std::pair<string, int>> image_label_pairs_;
vector<std::pair<string, int>> backup_image_label_pairs_;
vector<string> filters_;

bool has_files_arg_ = false;
Expand All @@ -251,8 +249,6 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader<CPUBackend, ImageLabelWrapp
FileStream::MappingReserver mmap_reserver_;
};

using FileLabelLoader = FileLabelLoaderBase<true>;

} // namespace dali

#endif // DALI_OPERATORS_READER_LOADER_FILE_LABEL_LOADER_H_