|
|
@@ -228,6 +228,27 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) { |
|
|
return bp::object();
|
|
|
}
|
|
|
|
|
|
+template<typename Dtype>
|
|
|
+class PythonCallback: public Solver<Dtype>::Callback {
|
|
|
+ protected:
|
|
|
+ bp::object on_start_, on_gradients_ready_;
|
|
|
+
|
|
|
+ public:
|
|
|
+ PythonCallback(bp::object on_start, bp::object on_gradients_ready)
|
|
|
+ : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { }
|
|
|
+ virtual void on_gradients_ready() {
|
|
|
+ on_gradients_ready_();
|
|
|
+ }
|
|
|
+ virtual void on_start() {
|
|
|
+ on_start_();
|
|
|
+ }
|
|
|
+};
|
|
|
+template<typename Dtype>
|
|
|
+void Solver_add_callback(Solver<Dtype> * solver, bp::object on_start,
|
|
|
+ bp::object on_gradients_ready) {
|
|
|
+ solver->add_callback(new PythonCallback<Dtype>(on_start, on_gradients_ready));
|
|
|
+}
|
|
|
+
|
|
|
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
|
|
|
|
|
|
BOOST_PYTHON_MODULE(_caffe) {
|
|
|
@@ -317,6 +338,7 @@ BOOST_PYTHON_MODULE(_caffe) { |
|
|
.add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
|
|
|
bp::return_internal_reference<>()))
|
|
|
.add_property("iter", &Solver<Dtype>::iter)
|
|
|
+ .def("add_callback", &Solver_add_callback<Dtype>)
|
|
|
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
|
|
|
&Solver<Dtype>::Solve), SolveOverloads())
|
|
|
.def("step", &Solver<Dtype>::Step)
|
|
|
|