-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
embed_layer.cpp
119 lines (110 loc) · 4.15 KB
/
embed_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layers/embed_layer.hpp"
#include "caffe/util/math_functions.hpp"
namespace caffe {
template <typename Dtype>
void EmbedLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
N_ = this->layer_param_.embed_param().num_output();
CHECK_GT(N_, 0) << "EmbedLayer num_output must be positive.";
K_ = this->layer_param_.embed_param().input_dim();
CHECK_GT(K_, 0) << "EmbedLayer input_dim must be positive.";
bias_term_ = this->layer_param_.embed_param().bias_term();
// Check if we need to set up the weights
if (this->blobs_.size() > 0) {
LOG(INFO) << "Skipping parameter initialization";
} else {
if (bias_term_) {
this->blobs_.resize(2);
} else {
this->blobs_.resize(1);
}
// Initialize the weights --
// transposed from InnerProductLayer for spatial locality.
vector<int> weight_shape(2);
weight_shape[0] = K_;
weight_shape[1] = N_;
this->blobs_[0].reset(new Blob<Dtype>(weight_shape));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
// If necessary, initialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.embed_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}
template <typename Dtype>
void EmbedLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Figure out the dimensions
M_ = bottom[0]->count();
vector<int> top_shape = bottom[0]->shape();
top_shape.push_back(N_);
top[0]->Reshape(top_shape);
// Set up the bias multiplier
if (bias_term_) {
vector<int> bias_shape(1, M_);
bias_multiplier_.Reshape(bias_shape);
caffe_set(M_, Dtype(1), bias_multiplier_.mutable_cpu_data());
}
}
template <typename Dtype>
void EmbedLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n]) << "non-integer input";
caffe_copy(N_, weight + index * N_, top_data + n * N_);
}
if (bias_term_) {
const Dtype* bias = this->blobs_[1]->cpu_data();
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, 1, Dtype(1),
bias_multiplier_.cpu_data(), bias, Dtype(1), top_data);
}
}
template <typename Dtype>
void EmbedLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
CHECK(!propagate_down[0]) << "Can't backpropagate to EmbedLayer input.";
if (this->param_propagate_down_[0]) {
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* bottom_data = bottom[0]->cpu_data();
// Gradient with respect to weight
Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
int index;
for (int n = 0; n < M_; ++n) {
index = static_cast<int>(bottom_data[n]);
DCHECK_GE(index, 0);
DCHECK_LT(index, K_);
DCHECK_EQ(static_cast<Dtype>(index), bottom_data[n])
<< "non-integer input";
caffe_axpy(N_, Dtype(1), top_diff + n * N_, weight_diff + index * N_);
}
}
if (bias_term_ && this->param_propagate_down_[1]) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bias_diff = this->blobs_[1]->mutable_cpu_diff();
caffe_cpu_gemv<Dtype>(CblasTrans, M_, N_, Dtype(1), top_diff,
bias_multiplier_.cpu_data(), Dtype(1), bias_diff);
}
}
#ifdef CPU_ONLY
STUB_GPU(EmbedLayer);
#endif
INSTANTIATE_CLASS(EmbedLayer);
REGISTER_LAYER_CLASS(Embed);
} // namespace caffe