|
|
@@ -331,11 +331,20 @@ void Solver<Dtype>::Test(const int test_net_id) { |
|
|
<< ", Testing net (#" << test_net_id << ")";
|
|
|
CHECK_NOTNULL(test_nets_[test_net_id].get())->
|
|
|
ShareTrainedLayersWith(net_.get());
|
|
|
+
|
|
|
+ for (int i = 0; i < callbacks_.size(); ++i) {
|
|
|
+ callbacks_[i]->on_test_start(test_net_id);
|
|
|
+ }
|
|
|
+
|
|
|
vector<Dtype> test_score;
|
|
|
vector<int> test_score_output_id;
|
|
|
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
|
|
|
Dtype loss = 0;
|
|
|
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
|
|
|
+ for (int i = 0; i < callbacks_.size(); ++i) {
|
|
|
+ callbacks_[i]->on_test_iter_start(test_net_id, i);
|
|
|
+ }
|
|
|
+
|
|
|
SolverAction::Enum request = GetRequestedAction();
|
|
|
// Check to see if stoppage of testing/training has been requested.
|
|
|
while (request != SolverAction::NONE) {
|
|
|
@@ -375,6 +384,11 @@ void Solver<Dtype>::Test(const int test_net_id) { |
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ for (int i = 0; i < callbacks_.size(); ++i) {
|
|
|
+ callbacks_[i]->on_test_end(test_net_id);
|
|
|
+ }
|
|
|
+
|
|
|
if (requested_early_exit_) {
|
|
|
LOG(INFO) << "Test interrupted.";
|
|
|
return;
|
|
|
|