Added solver callback for testing nets #5710

Open
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+18 −0
Split
View
@@ -80,6 +80,10 @@ class Solver {
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
+ virtual void on_test_start(int test_net_id) {}
+ virtual void on_test_end(int test_net_id) {}
+ virtual void on_test_iter_start(int test_net_id, int iter) {}
+
template <typename T>
friend class Solver;
};
View
@@ -331,11 +331,20 @@ void Solver<Dtype>::Test(const int test_net_id) {
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
+
+ for (int i = 0; i < callbacks_.size(); ++i) {
+ callbacks_[i]->on_test_start(test_net_id);
+ }
+
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
+ for (int i = 0; i < callbacks_.size(); ++i) {
+ callbacks_[i]->on_test_iter_start(test_net_id, i);
+ }
+
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
while (request != SolverAction::NONE) {
@@ -375,6 +384,11 @@ void Solver<Dtype>::Test(const int test_net_id) {
}
}
}
+
+ for (int i = 0; i < callbacks_.size(); ++i) {
+ callbacks_[i]->on_test_end(test_net_id);
+ }
+
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;