diff --git a/dali/operators/reader/file_reader_op.h b/dali/operators/reader/file_reader_op.h index 2491c32710..91a37294b7 100644 --- a/dali/operators/reader/file_reader_op.h +++ b/dali/operators/reader/file_reader_op.h @@ -24,10 +24,10 @@ namespace dali { -class FileReader : public DataReader { +class FileReader : public DataReader { public: explicit FileReader(const OpSpec& spec) - : DataReader(spec) { + : DataReader(spec) { bool shuffle_after_epoch = spec.GetArgument("shuffle_after_epoch"); loader_ = InitLoader(spec, shuffle_after_epoch); } @@ -54,16 +54,8 @@ class FileReader : public DataReader()[0] = image_label.label; } - void SaveState(OpCheckpoint &cpt, std::optional stream) override { - cpt.MutableCheckpointState() = loader_->PopStateSnapshot(); - } - - void RestoreState(const OpCheckpoint &cpt) override { - loader_->RestoreStateFromSnapshot(cpt.CheckpointState()); - } - protected: - USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true); + USE_READER_OPERATOR_MEMBERS(CPUBackend, ImageLabelWrapper); }; } // namespace dali diff --git a/dali/operators/reader/loader/coco_loader.h b/dali/operators/reader/loader/coco_loader.h index b65ee13bce..bcbfe92298 100644 --- a/dali/operators/reader/loader/coco_loader.h +++ b/dali/operators/reader/loader/coco_loader.h @@ -97,10 +97,10 @@ struct RLEMask : public UniqueHandle { using RLEMaskPtr = std::shared_ptr; -class DLL_PUBLIC CocoLoader : public FileLabelLoaderBase { +class DLL_PUBLIC CocoLoader : public FileLabelLoader { public: explicit inline CocoLoader(const OpSpec &spec) - : FileLabelLoaderBase(spec, spec.GetArgument("shuffle_after_epoch")) + : FileLabelLoader(spec, spec.GetArgument("shuffle_after_epoch")) , spec_(spec) { has_preprocessed_annotations_ = HasPreprocessedAnnotations(spec); DALI_ENFORCE(has_preprocessed_annotations_ || spec.HasArgument("annotations_file"), diff --git a/dali/operators/reader/loader/file_label_loader.cc b/dali/operators/reader/loader/file_label_loader.cc index 46c67c4ada..8356d5582b 100644 --- a/dali/operators/reader/loader/file_label_loader.cc +++ b/dali/operators/reader/loader/file_label_loader.cc @@ -23,13 +23,11 @@ namespace dali { using filesystem::dir_sep; -template -void FileLabelLoaderBase::PrepareEmpty(ImageLabelWrapper &image_label) { +void FileLabelLoader::PrepareEmpty(ImageLabelWrapper &image_label) { PrepareEmptyTensor(image_label.image); } -template -void FileLabelLoaderBase::ReadSample(ImageLabelWrapper &image_label) { +void FileLabelLoader::ReadSample(ImageLabelWrapper &image_label) { auto image_pair = image_label_pairs_[current_index_++]; // handle wrap-around @@ -75,12 +73,8 @@ void FileLabelLoaderBase::ReadSample(ImageLabelWrapper image_label.image.SetMeta(meta); } -template -Index FileLabelLoaderBase::SizeImpl() { +Index FileLabelLoader::SizeImpl() { return static_cast(image_label_pairs_.size()); } -template class FileLabelLoaderBase; -template class FileLabelLoaderBase; - } // namespace dali diff --git a/dali/operators/reader/loader/file_label_loader.h b/dali/operators/reader/loader/file_label_loader.h index 51a90ea369..a53fcc5e58 100755 --- a/dali/operators/reader/loader/file_label_loader.h +++ b/dali/operators/reader/loader/file_label_loader.h @@ -38,15 +38,18 @@ struct ImageLabelWrapper { int label; }; -template -class DLL_PUBLIC FileLabelLoaderBase : public Loader { +struct FileLabelLoaderCheckpoint { + vector> image_label_pairs_; + Index current_index_; + int current_epoch_; +}; + +class DLL_PUBLIC FileLabelLoader : public Loader { public: - using Base = Loader; - explicit inline FileLabelLoaderBase( + explicit inline FileLabelLoader( const OpSpec& spec, bool shuffle_after_epoch = false) - : Base(spec), + : Loader(spec), shuffle_after_epoch_(shuffle_after_epoch), current_index_(0), current_epoch_(0) { @@ -108,7 +111,7 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader tensor; + tensor.Copy(target.image); + return ImageLabelWrapper { + std::move(tensor), + target.label + }; + } + protected: Index SizeImpl() override; @@ -186,18 +198,12 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader(opaque_state); + image_label_pairs_ = state.image_label_pairs_; + current_index_ = state.current_index_; + current_epoch_ = state.current_epoch_; + } + + using Loader::shard_id_; string file_root_, file_list_; vector> image_label_pairs_; - vector> backup_image_label_pairs_; vector filters_; bool has_files_arg_ = false; @@ -251,8 +249,6 @@ class DLL_PUBLIC FileLabelLoaderBase : public Loader; - } // namespace dali #endif // DALI_OPERATORS_READER_LOADER_FILE_LABEL_LOADER_H_ diff --git a/dali/operators/reader/loader/loader.h b/dali/operators/reader/loader/loader.h index 868097a301..957a1aa8ba 100644 --- a/dali/operators/reader/loader/loader.h +++ b/dali/operators/reader/loader/loader.h @@ -26,6 +26,7 @@ #include #include #include +#include #include "dali/core/nvtx.h" #include "dali/core/common.h" @@ -43,12 +44,25 @@ DLL_PUBLIC size_t start_index(const size_t shard_id, DLL_PUBLIC Index num_samples(const size_t shard_num, const size_t size); -/** - * @brief Structure describing Loader base state, at the begining of an epoch. -*/ -struct LoaderStateSnapshot { - std::default_random_engine rng; - int current_epoch; +struct ShardBoundaries { + Index start; + Index end; +}; + +template +struct LoaderCheckpoint { + using LoadTargetSharedPtr = std::shared_ptr; + + std::default_random_engine e_; + std::vector sample_buffer_; + bool initial_buffer_filled_ = false; + Index read_sample_counter_; + Index returned_sample_counter_; + int virtual_shard_id_; + LoadTargetSharedPtr last_sample_ptr_tmp; + std::deque shards_; + + std::any subclass_state_; }; /** @@ -59,8 +73,7 @@ struct LoaderStateSnapshot { * @tparam LoadTarget Type into which samples are loaded. * @tparam supports_checkpointing A marker for checkpointing support. */ -template +template class Loader { public: using LoadTargetUniquePtr = std::unique_ptr; @@ -71,9 +84,6 @@ class Loader { initial_empty_size_(2 * options.GetArgument("prefetch_queue_depth") * options.GetArgument("max_batch_size")), tensor_init_bytes_(options.GetArgument("tensor_init_bytes")), - state_queue_front_(0), - state_queue_back_(0), - checkpoint_epoch_(0), seed_(options.GetArgument("seed")), shard_id_(options.GetArgument("shard_id")), num_shards_(options.GetArgument("num_shards")), @@ -95,31 +105,6 @@ class Loader { std::seed_seq seq({seed_}); e_ = std::default_random_engine(seq); virtual_shard_id_ = shard_id_; - - // TODO(mstaniewski): add a proper internal argument in schema - if (!options.TryGetArgument(checkpointing_, "checkpointing")) { - checkpointing_ = false; - } - - if (checkpointing_) { - DALI_ENFORCE(supports_checkpointing, "Checkpointing is disabled for this loader. "); - - // TODO(mstaniewski): support pad_last_batch=false - DALI_ENFORCE(pad_last_batch_, - "Currently, checkpointing is only supported with pad_last_batch=true"); - - /* - * A checkpoint is created every time the prefetching thread starts working - * on a new epoch. Therefore, we are guaranteed, there will be at most - * prefetch_queue_depth checkpoints waiting in the queue at a time. - * - * The +1 is added, because there could be a situation, where a batch is - * collected from the prefetch_queue (possibly leading to creation of another checkpoint), - * but the checkpoint from the corresponding epoch is not yet collected. - */ - state_queue_.resize(options.GetArgument("prefetch_queue_depth") + 1); - PushStateSnapshot(); - } } virtual ~Loader() { @@ -155,47 +140,6 @@ class Loader { DALI_ERROR("Please overload PrepareEmpty for custom LoadTarget type other than Tensor"); } - /** - * @brief Called when loader is moving to the next shard, - * to create a new snapshot and store it in the inner queue. - */ - void PushStateSnapshot() { - std::lock_guard lock(state_queue_mutex_); - state_queue_[state_queue_back_].rng = e_; - state_queue_[state_queue_back_].current_epoch = checkpoint_epoch_++; - state_queue_back_ = (state_queue_back_ + 1) % state_queue_.size(); - } - - /** - * @brief Collects a state snapshot from the inner queue. - */ - LoaderStateSnapshot PopStateSnapshot() { - DALI_ENFORCE(checkpointing_, "PopStateSnapshot called, but checkpointing is not enabled. "); - std::lock_guard lock(state_queue_mutex_); - auto result = state_queue_[state_queue_front_]; - state_queue_front_ = (state_queue_front_ + 1) % state_queue_.size(); - return result; - } - - /** - * @brief Restores the loader's state from a snapshot. - */ - void RestoreStateFromSnapshot(const LoaderStateSnapshot &state) { - e_ = state.rng; - checkpoint_epoch_ = state.current_epoch; - if (!stick_to_shard_) - virtual_shard_id_ = (shard_id_ + state.current_epoch) % num_shards_; - - RestoreStateImpl(state); - - // Re-run reset - Reset(true); - - // Reset checkpointing - state_queue_front_ = state_queue_back_; - PushStateSnapshot(); - } - // Get a random read sample LoadTargetSharedPtr ReadOne(bool is_new_batch) { PrepareMetadata(); @@ -244,11 +188,6 @@ class Loader { // remove shard that was fully consumed shards_.pop_front(); returned_sample_counter_ = 0; - - if (checkpointing_) { - // Create a checkpoint before processing the new shard - PushStateSnapshot(); - } } // choose the random index @@ -336,6 +275,57 @@ class Loader { return stick_to_shard_; } + using Checkpoint = LoaderCheckpoint; + + Checkpoint SaveState() { + std::vector copied_buffer_; + for (const auto &sample : sample_buffer_) + copied_buffer_.push_back(std::make_shared(CopyTarget(*sample))); + + auto copied_last_sample_= std::make_shared(CopyTarget(*last_sample_ptr_tmp)); + + return Checkpoint { + e_, + std::move(copied_buffer_), + initial_buffer_filled_, + read_sample_counter_, + returned_sample_counter_, + virtual_shard_id_, + copied_last_sample_, + shards_, + SaveStateImpl(), + }; + } + + void RestoreState(const Checkpoint &state) { + e_ = state.e_; + // sample_buffer_ = std::move(state.sample_buffer_); + initial_buffer_filled_ = state.initial_buffer_filled_; + read_sample_counter_ = state.read_sample_counter_; + returned_sample_counter_ = state.returned_sample_counter_; + virtual_shard_id_ = state.virtual_shard_id_; + last_sample_ptr_tmp = state.last_sample_ptr_tmp; + shards_ = state.shards_; + + for (const auto &sample : state.sample_buffer_) + sample_buffer_.push_back(std::make_unique(CopyTarget(*sample))); + + if (initial_buffer_filled_) { + std::lock_guard lock(empty_tensors_mutex_); + for (int i = 0; i < initial_empty_size_; ++i) { + auto tensor_ptr = LoadTargetUniquePtr(new LoadTarget()); + PrepareEmpty(*tensor_ptr); + empty_tensors_.push_back(std::move(tensor_ptr)); + } + } + + RestoreStateImpl(state.subclass_state_); + } + + virtual LoadTarget CopyTarget(const LoadTarget &target) { + DALI_FAIL("This loader does not support checkpointing. "); + } + protected: virtual Index SizeImpl() = 0; @@ -349,9 +339,6 @@ class Loader { // Reset reader to the first sample virtual void Reset(bool wrap_to_shard) = 0; - // Overloadable method to handle restoring state in subclasses - virtual void RestoreStateImpl(const LoaderStateSnapshot &state) {} - // Check if given reader moved to the next shard virtual inline bool IsNextShard(Index current_index) { return current_index >= Size() || @@ -398,6 +385,9 @@ class Loader { return cache_ && cache_->IsCached(key); } + virtual std::any SaveStateImpl() { return {}; } + virtual void RestoreStateImpl(const std::any &opaque_state) {} + std::vector sample_buffer_; std::vector empty_tensors_; @@ -410,14 +400,6 @@ class Loader { const int tensor_init_bytes_; bool initial_buffer_filled_ = false; - // when enabled, the loader creates a checkpoint at the start of every epoch. - bool checkpointing_; - std::vector state_queue_; - Index state_queue_front_; - Index state_queue_back_; - int checkpoint_epoch_; - std::mutex state_queue_mutex_; - // rng std::default_random_engine e_; Index seed_; @@ -469,11 +451,6 @@ class Loader { // Keeps pointer to the last returned sample just in case it needs to be cloned LoadTargetSharedPtr last_sample_ptr_tmp; - struct ShardBoundaries { - Index start; - Index end; - }; - std::deque shards_; }; diff --git a/dali/operators/reader/reader_op.h b/dali/operators/reader/reader_op.h index 2fc905c3c2..db154e6a7b 100644 --- a/dali/operators/reader/reader_op.h +++ b/dali/operators/reader/reader_op.h @@ -31,6 +31,36 @@ namespace dali { +template +struct DataReaderCheckpoint { + using LoadTargetPtr = std::shared_ptr; + using BatchQueueElement = std::vector; + + /* DataReader state */ + std::vector prefetched_batch_queue_; + int curr_batch_consumer_; + int curr_batch_producer_; + bool consumer_cycle_; + bool producer_cycle_; + + LoaderCheckpoint loader_state_; + + DataReaderCheckpoint( + std::vector &&batch_queue_copy, + int curr_batch_consumer, + int curr_batch_producer, + bool consumer_cycle, + bool producer_cycle, + LoaderCheckpoint &&loader_state) + : prefetched_batch_queue_(std::move(batch_queue_copy)) + , curr_batch_consumer_(curr_batch_consumer) + , curr_batch_producer_(curr_batch_producer) + , consumer_cycle_(consumer_cycle) + , producer_cycle_(producer_cycle) + , loader_state_(std::move(loader_state)) {} +}; + /** * @brief BaseClass for operators that perform prefetching work * @@ -49,7 +79,7 @@ namespace dali { * @tparam supports_checkpointing A marker for checkpointing support. */ template + typename ParseTarget = LoadTarget> class DataReader : public Operator { public: using LoadTargetPtr = std::shared_ptr; @@ -64,8 +94,7 @@ class DataReader : public Operator { curr_batch_producer_(0), consumer_cycle_(false), producer_cycle_(false), - device_id_(-1), - samples_processed_(0) { + device_id_(-1) { if (std::is_same::value) { device_id_ = spec.GetArgument("device_id"); } @@ -224,6 +253,49 @@ class DataReader : public Operator { prefetched_batch_queue_[curr_batch_consumer_].clear(); } + using Checkpoint = DataReaderCheckpoint; + + void SaveState(OpCheckpoint &cpt, std::optional stream) override { + // Make sure the prefetch thread is running + StartPrefetchThread(); + + { + // Wait until the prefetch queue is full. + std::unique_lock prefetch_lock(prefetch_access_mutex_); + consumer_.wait(prefetch_lock, [this]() { return finished_ || IsPrefetchQueueFull(); }); + if (prefetch_error_) std::rethrow_exception(prefetch_error_); + } + + std::vector batch_queue_copy; + for (const auto &batch : prefetched_batch_queue_) { + batch_queue_copy.emplace_back(); + for (const auto &sample : batch) + batch_queue_copy.back().push_back( + std::make_shared(loader_->CopyTarget(*sample))); + } + + // Now, the prefetching thread is idle, because there is no space in queue. + // We can safely access prefetching data. + cpt.MutableCheckpointState() = (Checkpoint { + std::move(batch_queue_copy), + curr_batch_consumer_, + curr_batch_producer_, + consumer_cycle_, + producer_cycle_, + loader_->SaveState(), + }); + } + + void RestoreState(const OpCheckpoint &cpt) override { + const auto &state = cpt.CheckpointState(); + prefetched_batch_queue_ = state.prefetched_batch_queue_; + curr_batch_consumer_ = state.curr_batch_consumer_; + curr_batch_producer_ = state.curr_batch_producer_; + consumer_cycle_ = state.consumer_cycle_; + producer_cycle_ = state.producer_cycle_; + loader_->RestoreState(state.loader_state_); + } + protected: void ParseIfNeeded(const Tensor& tensor, SampleWorkspace* ws) { using OutputCache = std::unordered_map>>; @@ -352,14 +424,11 @@ class DataReader : public Operator { bool producer_cycle_; int device_id_; - // keep track of how many samples have been processed over all threads. - std::atomic samples_processed_; - // stores any catched exceptions in the prefetch worker std::exception_ptr prefetch_error_; // Loader - std::unique_ptr> loader_; + std::unique_ptr> loader_; // Parser std::unique_ptr> parser_; @@ -375,19 +444,10 @@ class DataReader : public Operator { using DataReader::parser_; \ using DataReader::prefetched_batch_queue_; -#define USE_READER_OPERATOR_MEMBERS_3(Backend, LoadTarget, ParseTarget, supports_checkpointing) \ - using DataReader::loader_; \ - using DataReader::parser_; \ - using DataReader::prefetched_batch_queue_; - -#define GET_MACRO3(_1, _2, _3, NAME, ...) NAME - #define USE_READER_OPERATOR_MEMBERS(Backend, ...) \ - GET_MACRO3(__VA_ARGS__, \ - USE_READER_OPERATOR_MEMBERS_3, \ - USE_READER_OPERATOR_MEMBERS_2, \ - USE_READER_OPERATOR_MEMBERS_1)(Backend, __VA_ARGS__) + GET_MACRO(__VA_ARGS__, \ + USE_READER_OPERATOR_MEMBERS_2, \ + USE_READER_OPERATOR_MEMBERS_1)(Backend, __VA_ARGS__) }; // namespace dali diff --git a/dali/operators/reader/reader_op_test.cc b/dali/operators/reader/reader_op_test.cc index 8067687f3e..a5bb2a2ab0 100644 --- a/dali/operators/reader/reader_op_test.cc +++ b/dali/operators/reader/reader_op_test.cc @@ -512,31 +512,24 @@ class FileReaderTest : public DALITest { pipe.Build(outputs); } - std::vector RunEpoch(Pipeline &pipe, int batch_size, - int num_shards = 1, bool stick_to_shard = false) { + std::vector RunIteration(Pipeline &pipe) { std::vector result; - int samples_per_shard = (filepaths_.size() + num_shards - 1) / num_shards; - int batches_per_shard = (samples_per_shard + batch_size - 1) / batch_size; - for (int it = 0; it < batches_per_shard; it++) { - pipe.RunCPU(); - pipe.RunGPU(); - pipe.Outputs(&ws_); - - auto shape = ws_.Output(0).AsTensor().shape(); - for (int nr = 0; nr < shape[0]; nr++) - result.push_back(ws_.Output(0).tensor(0)[nr]); - } + pipe.RunCPU(); + pipe.RunGPU(); + pipe.Outputs(&ws_); + + auto shape = ws_.Output(0).AsTensor().shape(); + for (int nr = 0; nr < shape[0]; nr++) + result.push_back(ws_.Output(0).tensor(0)[nr]); + return result; } - std::pair, OpCheckpoint> CheckpointEpoch(Pipeline &pipe, int batch_size, - int epoch_nr, int num_shards = 1, - bool stick_to_shard = false) { + OpCheckpoint MakeCheckpoint(Pipeline &pipe) { auto node = pipe.GetOperatorNode("file_reader"); OpCheckpoint cpt(node->spec); node->op->SaveState(cpt, std::nullopt); - EXPECT_EQ(cpt.CheckpointState().current_epoch, epoch_nr); - return {RunEpoch(pipe, batch_size, num_shards, stick_to_shard), cpt}; + return cpt; } void RestoreCheckpointedState(Pipeline &pipe, const OpCheckpoint &cpt) { @@ -550,98 +543,90 @@ class FileReaderTest : public DALITest { TEST_F(FileReaderTest, SimpleCheckpointing) { constexpr int batch_size = 3; - constexpr int epochs = 4; + constexpr int iterations = 25; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() - .AddArg("prefetch_queue_depth", 20) + .AddArg("prefetch_queue_depth", 5) .AddArg("initial_fill", 3), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } TEST_F(FileReaderTest, CheckpointingRandomShuffle) { constexpr int batch_size = 7; - constexpr int epochs = 8; + constexpr int iterations = 15; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() .AddArg("random_shuffle", true) + .AddArg("prefetch_queue_depth", 1) .AddArg("initial_fill", 3), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } TEST_F(FileReaderTest, CheckpointingShuffleAfterEpoch) { constexpr int batch_size = 5; - constexpr int epochs = 8; + constexpr int iterations = 50; auto prepare_pipeline = [this](Pipeline &pipe) { pipe.AddOperator( MakeOpSpec() .AddArg("shuffle_after_epoch", true) - .AddArg("prefetch_queue_depth", 9) + .AddArg("prefetch_queue_depth", 3) .AddArg("initial_fill", 10), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } TEST_F(FileReaderTest, CheckpointingMultipleShards) { constexpr int batch_size = 5; - constexpr int epochs = 8; + constexpr int iterations = 50; constexpr int num_shards = 5; auto prepare_pipeline = [this, num_shards](Pipeline &pipe) { @@ -654,27 +639,24 @@ TEST_F(FileReaderTest, CheckpointingMultipleShards) { BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i, num_shards); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size, num_shards), results[i]); - } + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); + + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); + + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } -TEST_F(FileReaderTest, CheckpointingStickStoShard) { +TEST_F(FileReaderTest, CheckpointingSticktoShard) { constexpr int batch_size = 3; - constexpr int epochs = 8; + constexpr int iterations = 50; constexpr int num_shards = 3; auto prepare_pipeline = [this, num_shards](Pipeline &pipe) { @@ -683,71 +665,24 @@ TEST_F(FileReaderTest, CheckpointingStickStoShard) { .AddArg("stick_to_shard", true) .AddArg("shard_id", 1) .AddArg("num_shards", num_shards) - .AddArg("initial_fill", 5), "file_reader"); - BuildPipeline(pipe); - }; - - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i, num_shards, true); - results.push_back(res); - checkpoints.push_back(cpt); - } - - for (int i = 0; i < epochs; i++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, checkpoints[i]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size, num_shards, true), results[i]); - } -} - - -TEST_F(FileReaderTest, CheckpointingResumeThenSave) { - constexpr int batch_size = 3; - constexpr int epochs = 4; - - auto prepare_pipeline = [this](Pipeline &pipe) { - pipe.AddOperator( - MakeOpSpec() - .AddArg("shuffle_after_epoch", true) .AddArg("prefetch_queue_depth", 2) - .AddArg("initial_fill", 3), "file_reader"); + .AddArg("initial_fill", 5), "file_reader"); BuildPipeline(pipe); }; - Pipeline pipe(batch_size, 1, 0); - prepare_pipeline(pipe); - std::vector> results; - std::vector checkpoints; - for (int i = 0; i < epochs; i++) { - auto [res, cpt] = CheckpointEpoch(pipe, batch_size, i); - results.push_back(res); - checkpoints.push_back(cpt); - } + Pipeline original_pipe(batch_size, 1, 0); + Pipeline restored_pipe(batch_size, 1, 0); + prepare_pipeline(original_pipe); + prepare_pipeline(restored_pipe); - for (int i = 0; i < epochs; i++) { - Pipeline intermediate_pipe(batch_size, 1, 0); - prepare_pipeline(intermediate_pipe); - RestoreCheckpointedState(intermediate_pipe, checkpoints[i]); + for (int i = 0; i < iterations; i++) + RunIteration(original_pipe); - std::vector cpts_after_resume; - for (int j = 0; i + j < epochs; j++) { - auto [res, cpt] = CheckpointEpoch(intermediate_pipe, batch_size, i + j); - EXPECT_EQ(res, results[i + j]); - cpts_after_resume.push_back(cpt); - } + auto cpt = MakeCheckpoint(original_pipe); + RestoreCheckpointedState(restored_pipe, cpt); - for (int j = 0; i + j < epochs; j++) { - Pipeline fresh_pipe(batch_size, 1, 0); - prepare_pipeline(fresh_pipe); - RestoreCheckpointedState(fresh_pipe, cpts_after_resume[j]); - EXPECT_EQ(RunEpoch(fresh_pipe, batch_size), results[i + j]); - } - } + for (int i = 0; i < iterations; i++) + EXPECT_EQ(RunIteration(original_pipe), RunIteration(restored_pipe)); } }; // namespace dali