Set propagate_down=true to force backprop to a particular bottom #3942

Merged
merged 2 commits into from Apr 5, 2016
Jump to file or symbol
Failed to load files and symbols.
+112 −6
Split
View
@@ -427,12 +427,11 @@ int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
- bool propagate_down = true;
+ bool need_backward = blob_need_backward_[blob_id];
// Check if the backpropagation on bottom_id should be skipped
- if (layer_param.propagate_down_size() > 0)
- propagate_down = layer_param.propagate_down(bottom_id);
- const bool need_backward = blob_need_backward_[blob_id] &&
- propagate_down;
+ if (layer_param.propagate_down_size() > 0) {
+ need_backward = layer_param.propagate_down(bottom_id);
+ }
bottom_need_backward_[layer_id].push_back(need_backward);
return blob_id;
}
@@ -328,7 +328,12 @@ message LayerParameter {
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;
- // Specifies on which bottoms the backpropagation should be skipped.
+ // Specifies whether to backpropagate to each bottom. If unspecified,
+ // Caffe will automatically infer whether each input needs backpropagation
+ // to compute parameter gradients. If set to true for some inputs,
+ // backpropagation to those inputs is forced; if set false for some inputs,
+ // backpropagation to those inputs is skipped.
+ //
// The size must be either 0 or equal to the number of bottoms.
repeated bool propagate_down = 11;
View
@@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}
+ virtual void InitForcePropNet(bool test_force_true) {
+ string proto =
+ "name: 'ForcePropTestNetwork' "
+ "layer { "
+ " name: 'data' "
+ " type: 'DummyData' "
+ " dummy_data_param { "
+ " shape { "
+ " dim: 5 "
+ " dim: 2 "
+ " dim: 3 "
+ " dim: 4 "
+ " } "
+ " data_filler { "
+ " type: 'gaussian' "
+ " std: 0.01 "
+ " } "
+ " shape { "
+ " dim: 5 "
+ " } "
+ " data_filler { "
+ " type: 'constant' "
+ " value: 0 "
+ " } "
+ " } "
+ " top: 'data' "
+ " top: 'label' "
+ "} "
+ "layer { "
+ " name: 'innerproduct' "
+ " type: 'InnerProduct' "
+ " inner_product_param { "
+ " num_output: 1 "
+ " weight_filler { "
+ " type: 'gaussian' "
+ " std: 0.01 "
+ " } "
+ " } "
+ " bottom: 'data' "
+ " top: 'innerproduct' ";
+ if (test_force_true) {
+ proto += " propagate_down: true ";
+ }
+ proto +=
+ "} "
+ "layer { "
+ " name: 'loss' "
+ " bottom: 'innerproduct' "
+ " bottom: 'label' "
+ " top: 'cross_entropy_loss' "
+ " type: 'SigmoidCrossEntropyLoss' "
+ "} ";
+ InitNetFromProtoString(proto);
+ }
+
int seed_;
shared_ptr<Net<Dtype> > net_;
};
@@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
}
}
+TYPED_TEST(NetTest, TestForcePropagateDown) {
+ this->InitForcePropNet(false);
+ vector<bool> layer_need_backward = this->net_->layer_need_backward();
+ for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+ const string& layer_name = this->net_->layer_names()[layer_id];
+ const vector<bool> need_backward =
+ this->net_->bottom_need_backward()[layer_id];
+ if (layer_name == "data") {
+ ASSERT_EQ(need_backward.size(), 0);
+ EXPECT_FALSE(layer_need_backward[layer_id]);
+ } else if (layer_name == "innerproduct") {
+ ASSERT_EQ(need_backward.size(), 1);
+ EXPECT_FALSE(need_backward[0]); // data
+ EXPECT_TRUE(layer_need_backward[layer_id]);
+ } else if (layer_name == "loss") {
+ ASSERT_EQ(need_backward.size(), 2);
+ EXPECT_TRUE(need_backward[0]); // innerproduct
+ EXPECT_FALSE(need_backward[1]); // label
+ EXPECT_TRUE(layer_need_backward[layer_id]);
+ } else {
+ LOG(FATAL) << "Unknown layer: " << layer_name;
+ }
+ }
+ this->InitForcePropNet(true);
+ layer_need_backward = this->net_->layer_need_backward();
+ for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+ const string& layer_name = this->net_->layer_names()[layer_id];
+ const vector<bool> need_backward =
+ this->net_->bottom_need_backward()[layer_id];
+ if (layer_name == "data") {
+ ASSERT_EQ(need_backward.size(), 0);
+ EXPECT_FALSE(layer_need_backward[layer_id]);
+ } else if (layer_name == "innerproduct") {
+ ASSERT_EQ(need_backward.size(), 1);
+ EXPECT_TRUE(need_backward[0]); // data
+ EXPECT_TRUE(layer_need_backward[layer_id]);
+ } else if (layer_name == "loss") {
+ ASSERT_EQ(need_backward.size(), 2);
+ EXPECT_TRUE(need_backward[0]); // innerproduct
+ EXPECT_FALSE(need_backward[1]); // label
+ EXPECT_TRUE(layer_need_backward[layer_id]);
+ } else {
+ LOG(FATAL) << "Unknown layer: " << layer_name;
+ }
+ }
+}
+
} // namespace caffe