Skip to content

Commit

Permalink
Weight parameter in solver is used in caffe.exe
Browse files Browse the repository at this point in the history
Loading weights is moved from caffe.exe to solver class, so new "weights" solver parameter is used not only from command line but when caffe is used as library (including python)

corrected formatting

fixed line length

more formatting corrected
  • Loading branch information
IlyaOvodov committed Feb 10, 2018
1 parent 7e97067 commit c326294
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
12 changes: 11 additions & 1 deletion src/caffe/proto/caffe.proto
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -241,6 +241,16 @@ message SolverParameter {

// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];

// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
// If command line --weights parameter if specified, it has higher priority
// and owerwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;
}

// A message that stores the solver snapshots
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/solver.cpp
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <vector>

#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
Expand Down Expand Up @@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
current_step_ = 0;
}

// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net,
const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(","));
for (int i = 0; i < model_names.size(); ++i) {
boost::trim(model_names[i]);
LOG(INFO) << "Finetuning from " << model_names[i];
net->CopyTrainedLayersFrom(model_names[i]);
}
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
Expand Down Expand Up @@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(net_, param_.weights(w_idx));
}
}

template <typename Dtype>
Expand Down Expand Up @@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(test_nets_[i], param_.weights(w_idx));
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/caffe/test/test_upgrade_proto.cpp
Expand Up @@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand All @@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand Down
23 changes: 7 additions & 16 deletions tools/caffe.cpp
Expand Up @@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);

// Load the weights from the specified caffemodel(s) into the train and
// test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",") );
for (int i = 0; i < model_names.size(); ++i) {
LOG(INFO) << "Finetuning from " << model_names[i];
solver->net()->CopyTrainedLayersFrom(model_names[i]);
for (int j = 0; j < solver->test_nets().size(); ++j) {
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
}
}
}

// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
Expand Down Expand Up @@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

if (FLAGS_snapshot.size()) {
solver_param.clear_weights();
} else if (FLAGS_weights.size()) {
solver_param.clear_weights();
solver_param.add_weights(FLAGS_weights);
}

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

Expand All @@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}

LOG(INFO) << "Starting Optimization";
Expand Down

1 comment on commit c326294

@hliang
Copy link

@hliang hliang commented on c326294 May 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiling would give error like:
src/caffe/solver.cpp:116:38: error: ‘class caffe::SolverParameter’ has no member named ‘weights_size’
src/caffe/solver.cpp:117:33: error: ‘class caffe::SolverParameter’ has no member named ‘weights’
using: gpu k80, cuda 8.0, cudnn 6.0, g++ 4.8.5

Please sign in to comment.