forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prefetch_op.h
142 lines (128 loc) · 4.55 KB
/
prefetch_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_
#define CAFFE2_OPERATORS_PREFETCH_OP_H_
#include <condition_variable>
#include <mutex>
#include <thread> // NOLINT
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// PrefetchOperator is an operator that prefetches the next batch. It should
// almost always be used to read things from disk, so I am setting the input to
// zero blobs.
//
// For any operator that is derived from PrefetchOperator, it should
// explicitly call the Finalize() function in its destructor, so that the
// prefetching thread is properly destructed.
// Note: We inherit from OperatorBase since we control the
// synchronization properties of this operator ourselves (we inform
// the waiting producer after we synchronize). This is a special-case
// - you should generally inherit from Operator<Context> directly.
template <class Context>
class PrefetchOperator : public OperatorBase {
public:
PrefetchOperator(const OperatorDef& operator_def, Workspace* ws)
: OperatorBase(operator_def, ws),
context_(operator_def.device_option()),
prefetched_(false),
prefetch_success_(true),
finalize_(false),
no_prefetch_(GetSingleArgument<bool>("no_prefetch", false)) {
context_.SwitchToDevice();
}
virtual ~PrefetchOperator() noexcept {
CHECK(finalize_ || !prefetch_thread_.get()) <<
"YOU MADE A PROGRAMING ERROR: derived class of PrefetchOperator "
"should call Finalize() in its destructor so the prefetching "
"thread is joined. ";
}
void Finalize() {
if (prefetch_thread_.get()) {
{
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
finalize_ = true;
prefetched_ = false;
}
producer_.notify_one();
prefetch_thread_->join();
prefetch_thread_.reset();
} else {
// If we never initialized the prefetch thread, just set
// finalize anyway.
finalize_ = true;
}
}
bool Run(int /* unused */ /*stream_id*/) override {
if (no_prefetch_) {
context_.SwitchToDevice();
bool result = Prefetch() && CopyPrefetched();
context_.FinishDeviceComputation();
return result;
}
// Note(jiayq): We only start the prefetch_thread at the Run() function
// instead of in the constructor, because the prefetch_thread needs to start
// after all derived classes' constructors finish.
if (!prefetch_thread_) {
prefetch_thread_.reset(
new std::thread([this] { this->PrefetchWorker(); }));
}
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (!prefetched_)
consumer_.wait(lock);
if (!prefetch_success_) {
LOG(ERROR) << "Prefetching failed.";
return false;
}
if (!CopyPrefetched()) {
LOG(ERROR) << "Error when copying prefetched data.";
return false;
}
prefetched_ = false;
context_.FinishDeviceComputation();
producer_.notify_one();
return true;
}
void PrefetchWorker() {
context_.SwitchToDevice();
std::unique_lock<std::mutex> lock(prefetch_access_mutex_);
while (prefetched_)
producer_.wait(lock);
while (!finalize_) {
// We will need to run a FinishDeviceComputation() call because the
// prefetcher thread and the main thread are potentially using different
// streams (like on GPU).
try {
prefetch_success_ = Prefetch();
context_.FinishDeviceComputation();
} catch (const std::exception& e) {
// TODO: propagate exception_ptr to the caller side
LOG(ERROR) << "Prefetching error " << e.what();
prefetch_success_ = false;
}
prefetched_ = true;
consumer_.notify_one();
while (prefetched_)
producer_.wait(lock);
}
}
// You will need to implement this instead of the Run function.
virtual bool Prefetch() = 0;
virtual bool CopyPrefetched() = 0;
protected:
Context context_;
std::mutex prefetch_access_mutex_;
std::condition_variable producer_, consumer_;
// prefetched_ is used to tell the operator that it is done.
std::atomic<bool> prefetched_;
// prefetch_success_ is used to see if prefetching failed or not.
std::atomic<bool> prefetch_success_;
// finalize_ is used to tell the prefetcher to quit.
std::atomic<bool> finalize_;
unique_ptr<std::thread> prefetch_thread_;
// Whether to do prefetching or run this as a normal operator
const bool no_prefetch_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_PREFETCH_OP_H_