Expose `Solver::Snapshot` to pycaffe #3082

Merged
merged 2 commits into from Oct 31, 2015
Jump to file or symbol
Failed to load files and symbols.
+17 −7
Split
View
@@ -60,6 +60,11 @@ class Solver {
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
+ // The Solver::Snapshot function implements the basic snapshotting utility
+ // that stores the learned net. You should implement the SnapshotSolverState()
+ // function that produces a SolverState protocol buffer that needs to be
+ // written to disk together with the learned net.
+ void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -87,11 +92,6 @@ class Solver {
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
- // The Solver::Snapshot function implements the basic snapshotting utility
- // that stores the learned net. You should implement the SnapshotSolverState()
- // function that produces a SolverState protocol buffer that needs to be
- // written to disk together with the learned net.
- void Snapshot();
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
View
@@ -286,7 +286,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step)
- .def("restore", &Solver<Dtype>::Restore);
+ .def("restore", &Solver<Dtype>::Restore)
+ .def("snapshot", &Solver<Dtype>::Snapshot);
bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
@@ -16,7 +16,8 @@ def setUp(self):
f.write("""net: '""" + net_f + """'
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
- display: 100 max_iter: 100 snapshot_after_train: false""")
+ display: 100 max_iter: 100 snapshot_after_train: false
+ snapshot_prefix: "model" """)
f.close()
self.solver = caffe.SGDSolver(f.name)
# also make sure get_solver runs
@@ -51,3 +52,11 @@ def test_net_memory(self):
total += p.data.sum() + p.diff.sum()
for bl in six.itervalues(net.blobs):
total += bl.data.sum() + bl.diff.sum()
+
+ def test_snapshot(self):
+ self.solver.snapshot()
+ # Check that these files exist and then remove them
+ files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
+ for fn in files:
+ assert os.path.isfile(fn)
+ os.remove(fn)