Snapshot model weights/solver state to HDF5 files #2836

Merged
merged 5 commits into from Aug 7, 2015
View
@@ -61,7 +61,9 @@ Makefile.config
data/*
models/*
*.caffemodel
+*.caffemodel.h5
*.solverstate
+*.solverstate.h5
*.binaryproto
*leveldb
*lmdb
@@ -8,9 +8,9 @@ $TOOLS/caffe train \
# reduce learning rate by factor of 10
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_full_solver_lr1.prototxt \
- --snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate
+ --snapshot=examples/cifar10/cifar10_full_iter_60000.solverstate.h5
# reduce learning rate by factor of 10
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_full_solver_lr2.prototxt \
- --snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate
+ --snapshot=examples/cifar10/cifar10_full_iter_65000.solverstate.h5
@@ -8,4 +8,4 @@ $TOOLS/caffe train \
# reduce learning rate by factor of 10 after 8 epochs
$TOOLS/caffe train \
--solver=examples/cifar10/cifar10_quick_solver_lr1.prototxt \
- --snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate
+ --snapshot=examples/cifar10/cifar10_quick_iter_4000.solverstate.h5
@@ -2,4 +2,4 @@
./build/tools/caffe train \
--solver=models/bvlc_reference_caffenet/solver.prototxt \
- --snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate
+ --snapshot=models/bvlc_reference_caffenet/caffenet_train_10000.solverstate.h5
View
@@ -10,7 +10,7 @@
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
-const int kMaxBlobAxes = INT_MAX;
+const int kMaxBlobAxes = 32;
namespace caffe {
View
@@ -98,8 +98,12 @@ class Net {
*/
void CopyTrainedLayersFrom(const NetParameter& param);
void CopyTrainedLayersFrom(const string trained_filename);
+ void CopyTrainedLayersFromBinaryProto(const string trained_filename);
+ void CopyTrainedLayersFromHDF5(const string trained_filename);
/// @brief Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false) const;
+ /// @brief Writes the net to an HDF5 file.
+ void ToHDF5(const string& filename, bool write_diff = false) const;
/// @brief returns the network name.
inline const string& name() const { return name_; }
View
@@ -27,9 +27,9 @@ class Solver {
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
- // The Restore function implements how one should restore the solver to a
- // previously snapshotted state. You should implement the RestoreSolverState()
- // function that restores the state from a SolverState protocol buffer.
+ // The Restore method simply dispatches to one of the
+ // RestoreSolverStateFrom___ protected methods. You should implement these
+ // methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -46,11 +46,15 @@ class Solver {
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
+ string SnapshotFilename(const string extension);
+ string SnapshotToBinaryProto();
+ string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
- virtual void SnapshotSolverState(SolverState* state) = 0;
- virtual void RestoreSolverState(const SolverState& state) = 0;
+ virtual void SnapshotSolverState(const string& model_filename) = 0;
+ virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
+ virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
SolverParameter param_;
@@ -85,8 +89,11 @@ class SGDSolver : public Solver<Dtype> {
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
- virtual void SnapshotSolverState(SolverState * state);
- virtual void RestoreSolverState(const SolverState& state);
+ virtual void SnapshotSolverState(const string& model_filename);
+ virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
+ virtual void SnapshotSolverStateToHDF5(const string& model_filename);
+ virtual void RestoreSolverStateFromHDF5(const string& state_file);
+ virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
@@ -0,0 +1,39 @@
+#ifndef CAFFE_UTIL_HDF5_H_
+#define CAFFE_UTIL_HDF5_H_
+
+#include <string>
+
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
+#include "caffe/blob.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void hdf5_load_nd_dataset_helper(
+ hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+ Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_load_nd_dataset(
+ hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+ Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_save_nd_dataset(
+ const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,
+ bool write_diff = false);
+
+int hdf5_load_int(hid_t loc_id, const string& dataset_name);
+void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);
+string hdf5_load_string(hid_t loc_id, const string& dataset_name);
+void hdf5_save_string(hid_t loc_id, const string& dataset_name,
+ const string& s);
+
+int hdf5_get_num_links(hid_t loc_id);
+string hdf5_get_name_by_idx(hid_t loc_id, int idx);
+
+} // namespace caffe
+
+#endif // CAFFE_UTIL_HDF5_H_
View
@@ -5,15 +5,11 @@
#include <string>
#include "google/protobuf/message.h"
-#include "hdf5.h"
-#include "hdf5_hl.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
-#define HDF5_NUM_DIMS 4
-
namespace caffe {
using ::google::protobuf::Message;
@@ -140,20 +136,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
-template <typename Dtype>
-void hdf5_load_nd_dataset_helper(
- hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_load_nd_dataset(
- hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_save_nd_dataset(
- const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob);
-
} // namespace caffe
#endif // CAFFE_UTIL_IO_H_
View
@@ -456,31 +456,66 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
}
// copy data
Dtype* data_vec = mutable_cpu_data();
- for (int i = 0; i < count_; ++i) {
- data_vec[i] = proto.data(i);
+ if (proto.double_data_size() > 0) {
+ CHECK_EQ(count_, proto.double_data_size());
+ for (int i = 0; i < count_; ++i) {
+ data_vec[i] = proto.double_data(i);
+ }
+ } else {
+ CHECK_EQ(count_, proto.data_size());
+ for (int i = 0; i < count_; ++i) {
+ data_vec[i] = proto.data(i);
+ }
}
- if (proto.diff_size() > 0) {
+ if (proto.double_diff_size() > 0) {
+ CHECK_EQ(count_, proto.double_diff_size());
+ Dtype* diff_vec = mutable_cpu_diff();
+ for (int i = 0; i < count_; ++i) {
+ diff_vec[i] = proto.double_diff(i);
+ }
+ } else if (proto.diff_size() > 0) {
+ CHECK_EQ(count_, proto.diff_size());
Dtype* diff_vec = mutable_cpu_diff();
for (int i = 0; i < count_; ++i) {
diff_vec[i] = proto.diff(i);
}
}
}
-template <typename Dtype>
-void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
+template <>
+void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
+ proto->clear_shape();
+ for (int i = 0; i < shape_.size(); ++i) {
+ proto->mutable_shape()->add_dim(shape_[i]);
+ }
+ proto->clear_double_data();
+ proto->clear_double_diff();
+ const double* data_vec = cpu_data();
+ for (int i = 0; i < count_; ++i) {
+ proto->add_double_data(data_vec[i]);
+ }
+ if (write_diff) {
+ const double* diff_vec = cpu_diff();
+ for (int i = 0; i < count_; ++i) {
+ proto->add_double_diff(diff_vec[i]);
+ }
+ }
+}
+
+template <>
+void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
proto->clear_shape();
for (int i = 0; i < shape_.size(); ++i) {
proto->mutable_shape()->add_dim(shape_[i]);
}
proto->clear_data();
proto->clear_diff();
- const Dtype* data_vec = cpu_data();
+ const float* data_vec = cpu_data();
for (int i = 0; i < count_; ++i) {
proto->add_data(data_vec[i]);
}
if (write_diff) {
- const Dtype* diff_vec = cpu_diff();
+ const float* diff_vec = cpu_diff();
for (int i = 0; i < count_; ++i) {
proto->add_diff(diff_vec[i]);
}
@@ -16,7 +16,7 @@
#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
namespace caffe {
@@ -6,7 +6,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
@@ -6,7 +6,6 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
Oops, something went wrong.