-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
padding_layer.cu
139 lines (125 loc) · 4.76 KB
/
padding_layer.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Copyright 2013 Yangqing Jia
#include "caffe/layer.hpp"
#include "caffe/vision_layers.hpp"
#include <iostream>
namespace caffe {
template <typename Dtype>
void PaddingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
PAD_ = this->layer_param_.pad();
CHECK_EQ(bottom.size(), 1) << "Padding Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Padding Layer takes a single blob as output.";
NUM_ = bottom[0]->num();
CHANNEL_ = bottom[0]->channels();
HEIGHT_IN_ = bottom[0]->height();
WIDTH_IN_ = bottom[0]->width();
HEIGHT_OUT_ = HEIGHT_IN_ + PAD_ * 2;
WIDTH_OUT_ = WIDTH_IN_ + PAD_ * 2;
(*top)[0]->Reshape(NUM_, CHANNEL_, HEIGHT_OUT_, WIDTH_OUT_);
};
template <typename Dtype>
void PaddingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const Dtype* bottom_data = bottom[0]->cpu_data();
memset(top_data, 0, sizeof(Dtype) * (*top)[0]->count());
// In short, top[n, c, h, w] = bottom[n, c, h-pad, w-pad] if in range
for (int n = 0; n < NUM_; ++n) {
for (int c = 0; c < CHANNEL_; ++c) {
for (int h = 0; h < HEIGHT_IN_; ++h) {
// copy the width part
memcpy(
top_data + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
* WIDTH_OUT_ + PAD_,
bottom_data + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
sizeof(Dtype) * WIDTH_IN_);
}
}
}
}
template <typename Dtype>
Dtype PaddingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
//memset(bottom_data, 0, sizeof(Dtype) * (*bottom)[0]->count());
for (int n = 0; n < NUM_; ++n) {
for (int c = 0; c < CHANNEL_; ++c) {
for (int h = 0; h < HEIGHT_IN_; ++h) {
// copy the width part
memcpy(
bottom_diff + ((n * CHANNEL_ + c) * HEIGHT_IN_ + h) * WIDTH_IN_,
top_diff + ((n * CHANNEL_ + c) * HEIGHT_OUT_ + h + PAD_)
* WIDTH_OUT_ + PAD_,
sizeof(Dtype) * WIDTH_IN_);
}
}
}
return Dtype(0.);
}
template <typename Dtype>
__global__ void PaddingForward(const int count, const Dtype* in, Dtype* out,
const int num, const int channel, const int height_in, const int width_in,
const int pad) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < count) {
int height_out = height_in + pad + pad;
int width_out = width_in + pad + pad;
int w = index % width_in;
index /= width_in;
int h = index % height_in;
index /= height_in;
int c = index % channel;
index /= channel;
out[((index * channel + c) * height_out + h + pad) * width_out + pad + w] =
in[((index * channel + c) * height_in + h) * width_in + w];
}
}
template <typename Dtype>
void PaddingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
// First, set all data to be zero for the boundary pixels
CUDA_CHECK(cudaMemset(top_data, 0, sizeof(Dtype) * (*top)[0]->count()));
PaddingForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_);
CUDA_POST_KERNEL_CHECK;
}
template <typename Dtype>
__global__ void PaddingBackward(const int count, const Dtype* in, Dtype* out,
const int num, const int channel, const int height_in, const int width_in,
const int pad) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < count) {
int height_out = height_in + pad + pad;
int width_out = width_in + pad + pad;
int w = index % width_in;
index /= width_in;
int h = index % height_in;
index /= height_in;
int c = index % channel;
index /= channel;
out[((index * channel + c) * height_in + h) * width_in + w] =
in[((index * channel + c) * height_out + h + pad) * width_out + pad + w];
}
}
template <typename Dtype>
Dtype PaddingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
PaddingBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_diff, NUM_, CHANNEL_, HEIGHT_IN_, WIDTH_IN_,
PAD_);
CUDA_POST_KERNEL_CHECK;
}
return Dtype(0);
}
INSTANTIATE_CLASS(PaddingLayer);
} // namespace caffe