-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
inner_product_layer.cpp
144 lines (132 loc) · 5.22 KB
/
inner_product_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright 2013 Yangqing Jia
#include <mkl.h>
#include <cublas_v2.h>
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "IP Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "IP Layer takes a single blob as output.";
const int num_output = this->layer_param_.num_output();
biasterm_ = this->layer_param_.biasterm();
// Figure out the dimensions
M_ = bottom[0]->num();
K_ = bottom[0]->count() / bottom[0]->num();
N_ = num_output;
(*top)[0]->Reshape(bottom[0]->num(), num_output, 1, 1);
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (biasterm_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Intialize the weight
this->blobs_[0].reset(new Blob<Dtype>(1, 1, N_, K_));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(
GetFiller<Dtype>(this->layer_param_.weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, intiialize and fill the bias term
if (biasterm_) {
this->blobs_[1].reset(new Blob<Dtype>(1, 1, 1, N_));
shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
// Setting up the bias multiplier
if (biasterm_) {
bias_multiplier_.reset(new SyncedMemory(M_ * sizeof(Dtype)));
Dtype* bias_multiplier_data =
reinterpret_cast<Dtype*>(bias_multiplier_->mutable_cpu_data());
for (int i = 0; i < M_; ++i) {
bias_multiplier_data[i] = 1.;
}
}
};
template <typename Dtype>
void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()),
this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
}
}
template <typename Dtype>
Dtype InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
// Gradient with respect to weight
caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
if (biasterm_) {
// Gradient with respect to bias
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
this->blobs_[1]->mutable_cpu_diff());
}
if (propagate_down) {
// Gradient with respect to bottom data
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_cpu_diff());
}
return Dtype(0);
}
template <typename Dtype>
void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
if (biasterm_) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, (Dtype)1.,
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
this->blobs_[1]->gpu_data(), (Dtype)1., top_data);
}
}
template <typename Dtype>
Dtype InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
// Gradient with respect to weight
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
if (biasterm_) {
// Gradient with respect to bias
caffe_gpu_gemv<Dtype>(CblasTrans, M_, N_, (Dtype)1., top_diff,
reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
(Dtype)0., this->blobs_[1]->mutable_gpu_diff());
}
if (propagate_down) {
// Gradient with respect to bottom data
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
(*bottom)[0]->mutable_gpu_diff());
}
return Dtype(0);
}
INSTANTIATE_CLASS(InnerProductLayer);
} // namespace caffe