-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
pooling_layer.cpp
198 lines (187 loc) · 6.7 KB
/
pooling_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
// Copyright 2013 Yangqing Jia
#include <algorithm>
#include <cfloat>
#include <vector>
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/util/math_functions.hpp"
using std::max;
using std::min;
namespace caffe {
template <typename Dtype>
void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "PoolingLayer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "PoolingLayer takes a single blob as output.";
KSIZE_ = this->layer_param_.kernelsize();
STRIDE_ = this->layer_param_.stride();
CHANNELS_ = bottom[0]->channels();
HEIGHT_ = bottom[0]->height();
WIDTH_ = bottom[0]->width();
POOLED_HEIGHT_ = static_cast<int>(
ceil(static_cast<float>(HEIGHT_ - KSIZE_) / STRIDE_)) + 1;
POOLED_WIDTH_ = static_cast<int>(
ceil(static_cast<float>(WIDTH_ - KSIZE_) / STRIDE_)) + 1;
(*top)[0]->Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
POOLED_WIDTH_);
// If stochastic pooling, we will initialize the random index part.
if (this->layer_param_.pool() == LayerParameter_PoolMethod_STOCHASTIC) {
rand_idx_.Reshape(bottom[0]->num(), CHANNELS_, POOLED_HEIGHT_,
POOLED_WIDTH_);
}
};
// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
// case?
template <typename Dtype>
void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
// Different pooling methods. We explicitly do the switch outside the for
// loop to save time, although this results in more codes.
int top_count = (*top)[0]->count();
switch (this->layer_param_.pool()) {
case LayerParameter_PoolMethod_MAX:
// Initialize
for (int i = 0; i < top_count; ++i) {
top_data[i] = -FLT_MAX;
}
// The main loop
for (int n = 0; n < bottom[0]->num(); ++n) {
for (int c = 0; c < CHANNELS_; ++c) {
for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
int hstart = ph * STRIDE_;
int wstart = pw * STRIDE_;
int hend = min(hstart + KSIZE_, HEIGHT_);
int wend = min(wstart + KSIZE_, WIDTH_);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
top_data[ph * POOLED_WIDTH_ + pw] =
max(top_data[ph * POOLED_WIDTH_ + pw],
bottom_data[h * WIDTH_ + w]);
}
}
}
}
// compute offset
bottom_data += bottom[0]->offset(0, 1);
top_data += (*top)[0]->offset(0, 1);
}
}
break;
case LayerParameter_PoolMethod_AVE:
for (int i = 0; i < top_count; ++i) {
top_data[i] = 0;
}
// The main loop
for (int n = 0; n < bottom[0]->num(); ++n) {
for (int c = 0; c < CHANNELS_; ++c) {
for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
int hstart = ph * STRIDE_;
int wstart = pw * STRIDE_;
int hend = min(hstart + KSIZE_, HEIGHT_);
int wend = min(wstart + KSIZE_, WIDTH_);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
top_data[ph * POOLED_WIDTH_ + pw] +=
bottom_data[h * WIDTH_ + w];
}
}
top_data[ph * POOLED_WIDTH_ + pw] /=
(hend - hstart) * (wend - wstart);
}
}
// compute offset
bottom_data += bottom[0]->offset(0, 1);
top_data += (*top)[0]->offset(0, 1);
}
}
break;
case LayerParameter_PoolMethod_STOCHASTIC:
NOT_IMPLEMENTED;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
}
template <typename Dtype>
Dtype PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
if (!propagate_down) {
return Dtype(0.);
}
const Dtype* top_diff = top[0]->cpu_diff();
const Dtype* top_data = top[0]->cpu_data();
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
// Different pooling methods. We explicitly do the switch outside the for
// loop to save time, although this results in more codes.
memset(bottom_diff, 0, (*bottom)[0]->count() * sizeof(Dtype));
switch (this->layer_param_.pool()) {
case LayerParameter_PoolMethod_MAX:
// The main loop
for (int n = 0; n < top[0]->num(); ++n) {
for (int c = 0; c < CHANNELS_; ++c) {
for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
int hstart = ph * STRIDE_;
int wstart = pw * STRIDE_;
int hend = min(hstart + KSIZE_, HEIGHT_);
int wend = min(wstart + KSIZE_, WIDTH_);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
bottom_diff[h * WIDTH_ + w] +=
top_diff[ph * POOLED_WIDTH_ + pw] *
(bottom_data[h * WIDTH_ + w] ==
top_data[ph * POOLED_WIDTH_ + pw]);
}
}
}
}
// offset
bottom_data += (*bottom)[0]->offset(0, 1);
top_data += top[0]->offset(0, 1);
bottom_diff += (*bottom)[0]->offset(0, 1);
top_diff += top[0]->offset(0, 1);
}
}
break;
case LayerParameter_PoolMethod_AVE:
// The main loop
for (int n = 0; n < top[0]->num(); ++n) {
for (int c = 0; c < CHANNELS_; ++c) {
for (int ph = 0; ph < POOLED_HEIGHT_; ++ph) {
for (int pw = 0; pw < POOLED_WIDTH_; ++pw) {
int hstart = ph * STRIDE_;
int wstart = pw * STRIDE_;
int hend = min(hstart + KSIZE_, HEIGHT_);
int wend = min(wstart + KSIZE_, WIDTH_);
int poolsize = (hend - hstart) * (wend - wstart);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
bottom_diff[h * WIDTH_ + w] +=
top_diff[ph * POOLED_WIDTH_ + pw] / poolsize;
}
}
}
}
// offset
bottom_data += (*bottom)[0]->offset(0, 1);
top_data += top[0]->offset(0, 1);
bottom_diff += (*bottom)[0]->offset(0, 1);
top_diff += top[0]->offset(0, 1);
}
}
break;
case LayerParameter_PoolMethod_STOCHASTIC:
NOT_IMPLEMENTED;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
return Dtype(0.);
}
INSTANTIATE_CLASS(PoolingLayer);
} // namespace caffe