Tied weights with transpose flag for InnerProduct layer #3612

Merged
merged 1 commit into from Feb 25, 2016
Jump to file or symbol
Failed to load files and symbols.
+304 −15
Split
@@ -44,6 +44,7 @@ class InnerProductLayer : public Layer<Dtype> {
int N_;
bool bias_term_;
Blob<Dtype> bias_multiplier_;
+ bool transpose_; ///< if true, assume transposed weights
};
} // namespace caffe
@@ -11,6 +11,7 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
+ transpose_ = this->layer_param_.inner_product_param().transpose();
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.inner_product_param().axis());
@@ -27,10 +28,15 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
this->blobs_.resize(1);
}
- // Intialize the weight
+ // Initialize the weights
vector<int> weight_shape(2);
- weight_shape[0] = N_;
- weight_shape[1] = K_;
+ if (transpose_) {
+ weight_shape[0] = K_;
+ weight_shape[1] = N_;
+ } else {
+ weight_shape[0] = N_;
+ weight_shape[1] = K_;
+ }
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
@@ -80,7 +86,8 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans,
+ M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -97,8 +104,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
- caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
- top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ if (transpose_) {
+ caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ K_, N_, M_,
+ (Dtype)1., bottom_data, top_diff,
+ (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ } else {
+ caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ N_, K_, M_,
+ (Dtype)1., top_diff, bottom_data,
+ (Dtype)1., this->blobs_[0]->mutable_cpu_diff());
+ }
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
@@ -110,9 +126,17 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
// Gradient with respect to bottom data
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
- bottom[0]->mutable_cpu_diff());
+ if (transpose_) {
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
+ (Dtype)0., bottom[0]->mutable_cpu_diff());
+ } else {
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->cpu_data(),
+ (Dtype)0., bottom[0]->mutable_cpu_diff());
+ }
}
}
@@ -19,7 +19,9 @@ void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
caffe_gpu_axpy<Dtype>(N_, bias_multiplier_.cpu_data()[0],
this->blobs_[1]->gpu_data(), top_data);
} else {
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
+ caffe_gpu_gemm<Dtype>(CblasNoTrans,
+ transpose_ ? CblasNoTrans : CblasTrans,
+ M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (bias_term_)
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
@@ -36,8 +38,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
// Gradient with respect to weight
- caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
- top_diff, bottom_data, (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ if (transpose_) {
+ caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ K_, N_, M_,
+ (Dtype)1., bottom_data, top_diff,
+ (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ } else {
+ caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans,
+ N_, K_, M_,
+ (Dtype)1., top_diff, bottom_data,
+ (Dtype)1., this->blobs_[0]->mutable_gpu_diff());
+ }
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->gpu_diff();
@@ -49,9 +60,17 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* top_diff = top[0]->gpu_diff();
// Gradient with respect to bottom data
- caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
- top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
- bottom[0]->mutable_gpu_diff());
+ if (transpose_) {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
+ (Dtype)0., bottom[0]->mutable_gpu_diff());
+ } else {
+ caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans,
+ M_, K_, N_,
+ (Dtype)1., top_diff, this->blobs_[0]->gpu_data(),
+ (Dtype)0., bottom[0]->mutable_gpu_diff());
+ }
}
}
@@ -786,6 +786,11 @@ message InnerProductParameter {
// all preceding axes are retained in the output.
// May be negative to index from the end (e.g., -1 for the last axis).
optional int32 axis = 5 [default = 1];
+ // Specify whether to transpose the weight matrix or not.
+ // If transpose == true, any operations will be performed on the transpose
+ // of the weight matrix. The weight matrix itself is not going to be transposed
+ // but rather the transfer flag of operations will be toggled accordingly.
+ optional bool transpose = 6 [default = false];
}
// Message that stores parameters used by LogLayer
Oops, something went wrong.