Skip to content

Commit

Permalink
Support checkpointing in Numpy reader (#5198)
Browse files Browse the repository at this point in the history
This PR adds checkpointing support to FileLoader and fn.readers.numpy reader.

Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski authored and stiepan committed Dec 11, 2023
1 parent 9efebf1 commit f8a7cc7
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 42 deletions.
4 changes: 2 additions & 2 deletions dali/operators/reader/fits_reader_gpu_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@ class FitsReaderGPU : public FitsReader<GPUBackend, FitsFileWrapperGPU> {
using Operator<GPUBackend>::RunImpl;

private:
USE_READER_OPERATOR_MEMBERS(GPUBackend, FitsFileWrapperGPU);
USE_READER_OPERATOR_MEMBERS(GPUBackend, FitsFileWrapperGPU, FitsFileWrapperGPU, true);
};

} // namespace dali
Expand Down
25 changes: 17 additions & 8 deletions dali/operators/reader/fits_reader_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,22 +27,31 @@
namespace dali {

template <typename Backend, typename Target>
class FitsReader : public DataReader<Backend, Target> {
class FitsReader : public DataReader<Backend, Target, Target, true> {
public:
explicit FitsReader(const OpSpec& spec) : DataReader<Backend, Target>(spec) {}
explicit FitsReader(const OpSpec& spec) : DataReader<Backend, Target, Target, true>(spec) {}

bool CanInferOutputs() const override {
return true;
}

USE_READER_OPERATOR_MEMBERS(Backend, Target);
using DataReader<Backend, Target>::GetCurrBatchSize;
using DataReader<Backend, Target>::GetSample;
// TODO(skarpinski) Debug fits reader and add checkpointing support
void SaveState(OpCheckpoint &cpt, AccessOrder order) override {
DALI_FAIL("Fits reader does not support checkpointing.");
}

void RestoreState(const OpCheckpoint &cpt) override {
DALI_FAIL("Fits reader does not support checkpointing.");
}

USE_READER_OPERATOR_MEMBERS(Backend, Target, Target, true);
using DataReader<Backend, Target, Target, true>::GetCurrBatchSize;
using DataReader<Backend, Target, Target, true>::GetSample;
using Operator<Backend>::spec_;

bool SetupImpl(std::vector<OutputDesc>& output_desc, const Workspace& ws) override {
// If necessary start prefetching thread and wait for a consumable batch
DataReader<Backend, Target>::SetupImpl(output_desc, ws);
DataReader<Backend, Target, Target, true>::SetupImpl(output_desc, ws);

int num_outputs = ws.NumOutput();
int num_samples = GetCurrBatchSize(); // samples here are synonymous with files
Expand Down Expand Up @@ -103,7 +112,7 @@ class FitsReaderCPU : public FitsReader<CPUBackend, FitsFileWrapper> {
using Operator<CPUBackend>::RunImpl;

private:
USE_READER_OPERATOR_MEMBERS(CPUBackend, FitsFileWrapper);
USE_READER_OPERATOR_MEMBERS(CPUBackend, FitsFileWrapper, FitsFileWrapper, true);
};

} // namespace dali
Expand Down
48 changes: 32 additions & 16 deletions dali/operators/reader/loader/file_loader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2021, 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,10 +41,10 @@ struct FileWrapper {

template <typename Backend = CPUBackend, typename Target = FileWrapper,
typename InputStream = FileStream>
class FileLoader : public Loader<Backend, Target> {
class FileLoader : public Loader<Backend, Target, true> {
public:
FileLoader(const OpSpec &spec, bool shuffle_after_epoch)
: Loader<Backend, Target>(spec),
: Loader<Backend, Target, true>(spec),
file_filter_(spec.GetArgument<string>("file_filter")),
shuffle_after_epoch_(shuffle_after_epoch),
current_index_(0),
Expand Down Expand Up @@ -125,6 +125,9 @@ class FileLoader : public Loader<Backend, Target> {
}
DALI_ENFORCE(SizeImpl() > 0, "No files found.");

if (IsCheckpointingEnabled()) {
backup_files_ = files_;
}
if (shuffle_) {
// seeded with hardcoded value to get
// the same sequence on every shard
Expand All @@ -136,34 +139,47 @@ class FileLoader : public Loader<Backend, Target> {

void Reset(bool wrap_to_shard) override {
if (wrap_to_shard) {
current_index_ = start_index(shard_id_, num_shards_, SizeImpl());
current_index_ = start_index(virtual_shard_id_, num_shards_, SizeImpl());
} else {
current_index_ = 0;
}

current_epoch_++;

if (shuffle_after_epoch_) {
if (IsCheckpointingEnabled()) {
// With checkpointing enabled dataset order must be easy to restore.
// Shuffling is run with different seed every epoch, so this doesn't
// reduce the randomness.
files_ = backup_files_;
}
std::mt19937 g(kDaliDataloaderSeed + current_epoch_);
std::shuffle(files_.begin(), files_.end(), g);
}
}

using Loader<Backend, Target>::shard_id_;
using Loader<Backend, Target>::num_shards_;
using Loader<Backend, Target>::stick_to_shard_;
using Loader<Backend, Target>::shuffle_;
using Loader<Backend, Target>::dont_use_mmap_;
using Loader<Backend, Target>::initial_buffer_fill_;
using Loader<Backend, Target>::copy_read_data_;
using Loader<Backend, Target>::read_ahead_;
using Loader<Backend, Target>::MoveToNextShard;
using Loader<Backend, Target>::ShouldSkipImage;
using Loader<Backend, Target>::Size;
using Loader<Backend, Target>::PrepareEmptyTensor;
void RestoreStateImpl(const LoaderStateSnapshot &state) override {
current_epoch_ = state.current_epoch;
}

using Loader<Backend, Target, true>::shard_id_;
using Loader<Backend, Target, true>::virtual_shard_id_;
using Loader<Backend, Target, true>::num_shards_;
using Loader<Backend, Target, true>::stick_to_shard_;
using Loader<Backend, Target, true>::shuffle_;
using Loader<Backend, Target, true>::dont_use_mmap_;
using Loader<Backend, Target, true>::initial_buffer_fill_;
using Loader<Backend, Target, true>::copy_read_data_;
using Loader<Backend, Target, true>::read_ahead_;
using Loader<Backend, Target, true>::MoveToNextShard;
using Loader<Backend, Target, true>::ShouldSkipImage;
using Loader<Backend, Target, true>::Size;
using Loader<Backend, Target, true>::PrepareEmptyTensor;
using Loader<Backend, Target, true>::IsCheckpointingEnabled;

string file_list_, file_root_, file_filter_;
vector<std::string> files_;
vector<std::string> backup_files_;

bool has_files_arg_ = false;
bool has_file_list_arg_ = false;
Expand Down
6 changes: 5 additions & 1 deletion dali/operators/reader/loader/numpy_loader.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -115,4 +115,8 @@ void NumpyLoader::ReadSample(NumpyFileWrapper& target) {
target.fortran_order = header.fortran_order;
}

void NumpyLoader::Skip() {
MoveToNextShard(++current_index_);
}

} // namespace dali
3 changes: 2 additions & 1 deletion dali/operators/reader/loader/numpy_loader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -100,6 +100,7 @@ class NumpyLoader : public FileLoader<CPUBackend, NumpyFileWrapper> {

// we want to make it possible to override this function as well
void ReadSample(NumpyFileWrapper& target) override;
void Skip() override;

private:
detail::NumpyHeaderCache header_cache_;
Expand Down
5 changes: 3 additions & 2 deletions dali/operators/reader/numpy_reader_gpu_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,6 +40,7 @@ NumpyReaderGPU::NumpyReaderGPU(const OpSpec& spec)
// init loader
bool shuffle_after_epoch = spec.GetArgument<bool>("shuffle_after_epoch");
loader_ = InitLoader<NumpyLoaderGPU>(spec, std::vector<string>(), shuffle_after_epoch);
this->SetInitialSnapshot();

kmgr_transpose_.Resize<TransposeKernel>(1);
}
Expand All @@ -54,7 +55,7 @@ void NumpyReaderGPU::Prefetch() {
// We actually prepare the next batch
DomainTimeRange tr("[DALI][NumpyReaderGPU] Prefetch #" + to_string(curr_batch_producer_),
DomainTimeRange::kRed);
DataReader<GPUBackend, NumpyFileWrapperGPU>::Prefetch();
DataReader<GPUBackend, NumpyFileWrapperGPU, NumpyFileWrapperGPU, true>::Prefetch();
auto &curr_batch = prefetched_batch_queue_[curr_batch_producer_];
auto &curr_tensor_list = prefetched_batch_tensors_[curr_batch_producer_];

Expand Down
4 changes: 2 additions & 2 deletions dali/operators/reader/numpy_reader_gpu_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -70,7 +70,7 @@ class NumpyReaderGPU : gds::GDSLazyInit, public NumpyReader<GPUBackend, NumpyFil
using Operator<GPUBackend>::RunImpl;


USE_READER_OPERATOR_MEMBERS(GPUBackend, NumpyFileWrapperGPU);
USE_READER_OPERATOR_MEMBERS(GPUBackend, NumpyFileWrapperGPU, NumpyFileWrapperGPU, true);

private:
using TransposeKernel = kernels::TransposeGPU;
Expand Down
4 changes: 2 additions & 2 deletions dali/operators/reader/numpy_reader_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -244,7 +244,7 @@ void NumpyReaderCPU::Prefetch() {
// We actually prepare the next batch
DomainTimeRange tr("[DALI][NumpyReaderCPU] Prefetch #" + to_string(curr_batch_producer_),
DomainTimeRange::kRed);
DataReader<CPUBackend, NumpyFileWrapper>::Prefetch();
DataReader<CPUBackend, NumpyFileWrapper, NumpyFileWrapper, true>::Prefetch();

if (!dont_use_mmap_)
return;
Expand Down
17 changes: 9 additions & 8 deletions dali/operators/reader/numpy_reader_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,10 +32,10 @@
namespace dali {

template <typename Backend, typename Target>
class NumpyReader : public DataReader<Backend, Target> {
class NumpyReader : public DataReader<Backend, Target, Target, true> {
public:
explicit NumpyReader(const OpSpec& spec)
: DataReader<Backend, Target>(spec),
: DataReader<Backend, Target, Target, true>(spec),
slice_attr_(spec, "roi_start", "rel_roi_start", "roi_end", "rel_roi_end", "roi_shape",
"rel_roi_shape", "roi_axes", nullptr) {
out_of_bounds_policy_ = GetOutOfBoundsPolicy(spec);
Expand All @@ -48,14 +48,14 @@ class NumpyReader : public DataReader<Backend, Target> {
return true;
}

USE_READER_OPERATOR_MEMBERS(Backend, Target);
using DataReader<Backend, Target>::GetCurrBatchSize;
using DataReader<Backend, Target>::GetSample;
USE_READER_OPERATOR_MEMBERS(Backend, Target, Target, true);
using DataReader<Backend, Target, Target, true>::GetCurrBatchSize;
using DataReader<Backend, Target, Target, true>::GetSample;
using Operator<Backend>::spec_;

bool SetupImpl(std::vector<OutputDesc>& output_desc, const Workspace &ws) override {
// If necessary start prefetching thread and wait for a consumable batch
DataReader<Backend, Target>::SetupImpl(output_desc, ws);
DataReader<Backend, Target, Target, true>::SetupImpl(output_desc, ws);

int batch_size = GetCurrBatchSize();
const auto& file_0 = GetSample(0);
Expand Down Expand Up @@ -167,6 +167,7 @@ class NumpyReaderCPU : public NumpyReader<CPUBackend, NumpyFileWrapper> {
}
loader_ = InitLoader<NumpyLoader>(spec, shuffle_after_epoch, use_o_direct_, o_direct_alignm_,
o_direct_read_len_alignm_);
this->SetInitialSnapshot();
}
~NumpyReaderCPU() override;
void Prefetch() override;
Expand All @@ -176,7 +177,7 @@ class NumpyReaderCPU : public NumpyReader<CPUBackend, NumpyFileWrapper> {
using Operator<CPUBackend>::RunImpl;

private:
USE_READER_OPERATOR_MEMBERS(CPUBackend, NumpyFileWrapper);
USE_READER_OPERATOR_MEMBERS(CPUBackend, NumpyFileWrapper, NumpyFileWrapper, true);

bool dont_use_mmap_ = false;
bool use_o_direct_ = false;
Expand Down
77 changes: 77 additions & 0 deletions dali/test/python/checkpointing/test_dali_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import os
import shutil
import webdataset_base
import numpy as np
from nvidia.dali.pipeline import pipeline_def
Expand All @@ -25,6 +26,7 @@
from nose.plugins.attrib import attr
from dataclasses import dataclass
from nvidia.dali import tfrecord as tfrec
from reader.test_numpy import is_gds_supported

data_root = get_dali_extra_path()
images_dir = os.path.join(data_root, "db", "single", "jpeg")
Expand Down Expand Up @@ -568,6 +570,81 @@ def test_nemo_asr_reader(
manifest.close()


# device,
# num_epochs, batch_size, shard_id, num_shards,
# random_shuffle, shuffle_after_epoch, stick_to_shard, pad_last_batch,
# iters_into_epoch, initial_fill
@params(
("cpu", 0, 1, 0, 1, False, False, False, False, None),
("cpu", 5, 2, 4, 7, False, False, False, True, 1),
("cpu", 4, 4, 0, 2, False, False, True, False, 2),
("cpu", 3, 8, 4, 6, False, False, True, True, 3),
("cpu", 6, 1, 2, 3, False, True, False, False, 4),
("cpu", 5, 2, 2, 5, False, True, False, True, 3),
("cpu", 4, 4, 3, 4, True, False, False, False, 2),
("cpu", 3, 8, 1, 4, True, False, False, True, 1),
("cpu", 2, 1, 1, 2, True, False, True, False, None),
("cpu", 0, 2, 0, 1, True, False, True, True, 2),
*(
[
("gpu", 2, 1, 1, 2, False, False, False, False, None),
("gpu", 5, 2, 0, 5, False, False, False, True, 1),
("gpu", 3, 4, 2, 3, False, False, True, False, 2),
("gpu", 6, 8, 3, 5, False, False, True, True, 3),
("gpu", 7, 1, 1, 4, False, True, False, False, 4),
("gpu", 3, 2, 2, 4, False, True, False, True, 3),
("gpu", 3, 4, 2, 5, True, False, False, False, 2),
("gpu", 4, 8, 0, 2, True, False, False, True, 1),
("gpu", 1, 1, 2, 3, True, False, True, False, None),
("gpu", 0, 2, 0, 2, True, False, True, True, 2),
]
if is_gds_supported()
else []
),
)
def test_numpy_reader(
device,
num_epochs,
batch_size,
shard_id,
num_shards,
random_shuffle,
shuffle_after_epoch,
stick_to_shard,
pad_last_batch,
iters_into_epoch=None,
initial_fill=1024,
):
numpy_dir = os.path.join(data_root, "db", "3D", "MRI", "Knee", "npy_2d_slices", "STU00001")

# GDS doesn't support overlayfs, so we need to use runner's scratch
gds_data_root = "/scratch/"
if not os.path.isdir(gds_data_root):
gds_data_root = os.getcwd() + "/scratch/"
if not os.path.isdir(gds_data_root):
os.mkdir(gds_data_root)
assert os.path.isdir(gds_data_root)

with tempfile.TemporaryDirectory(prefix=gds_data_root) as test_data_root:
shutil.copytree(numpy_dir, os.path.join(test_data_root, "numpy"))

check_reader_checkpointing(
fn.readers.numpy,
num_epochs,
batch_size,
iters_into_epoch,
device=device,
file_root=os.path.join(test_data_root, "numpy"),
pad_last_batch=pad_last_batch,
random_shuffle=random_shuffle,
shuffle_after_epoch=shuffle_after_epoch,
shard_id=shard_id,
num_shards=num_shards,
stick_to_shard=stick_to_shard,
initial_fill=initial_fill,
)


@attr("pytorch")
@params(
(1, 3, 0, 1, True, False, False),
Expand Down

0 comments on commit f8a7cc7

Please sign in to comment.