forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pow_op.h
136 lines (126 loc) · 4.57 KB
/
pow_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#ifndef CAFFE2_OPERATORS_POW_OP_H_
#define CAFFE2_OPERATORS_POW_OP_H_
#include "caffe2/core/common_omp.h"
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/operators/elementwise_ops.h"
#include "caffe2/operators/elementwise_ops_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <
typename InputTypes,
class Context,
class Functor,
class TypeMap = SameTypeAsInput>
class PowOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit PowOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
OP_SINGLE_ARG(int, "axis", axis_, -1),
OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
OP_SINGLE_ARG(string, "order", order_, "NCHW"),
functor_() {
if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
exponent_ = this->template GetSingleArgument<float>(
"exponent", 0); // based on pow_ops.h
} else if (InputSize() == 2) { // BinaryElementwiseOp
// Figure out the correct axis to use.
if (enable_broadcast_) {
if (axis_ != -1) {
// Get axis from an explicit axis argument.
CAFFE_ENFORCE_EQ(
axis_str_.size(),
0,
"Args axis and axis_str cannot be used simultaneously.");
} else if (axis_str_.size()) {
// Get the axis index semantically.
CAFFE_ENFORCE_EQ(
axis_str_.size(), 1, "Unsupported axis string", axis_str_);
size_t semantic_axis_ = order_.find(axis_str_);
CAFFE_ENFORCE_NE(
semantic_axis_,
string::npos,
"Unrecognizable axis string ",
axis_str_,
" from order string ",
order_);
axis_ = semantic_axis_;
}
} else {
CAFFE_ENFORCE(
axis_ == -1 && axis_str_.size() == 0,
"Do not specify axis or axis_str if broadcast is not enabled.");
}
} else {
CAFFE_THROW(
"Only a tensor with an argument or two input tensors are supported as input to pow operator.");
}
}
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
const auto& A = Input(0);
auto* C =
Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>());
const T* Adata = A.template data<T>();
auto* Cdata =
C->template mutable_data<typename TypeMap::template type<T>>();
functor_.template Run<true, T, float, T>(
A.numel(), Adata, NULL, exponent_, Cdata, &context_);
} else if (InputSize() == 2) { // BinaryElementwiseOp
const auto& A = Input(0);
const auto& B = Input(1);
CAFFE_ENFORCE(
!IsInputOutputAlias(1, 0) || !enable_broadcast_,
"In-place is allowed only with the first tensor when broadcasting");
auto* C =
Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>());
const T* Adata = A.template data<T>();
const T* Bdata = B.template data<T>();
auto* Cdata =
C->template mutable_data<typename TypeMap::template type<T>>();
if (!enable_broadcast_) {
CAFFE_ENFORCE_EQ(
A.sizes(),
B.sizes(),
"Dimension mismatch - did you forget to set broadcast=1?");
functor_.template Run<false, T, T, T>(
A.numel(), Adata, Bdata, 0, Cdata, &context_);
} else if (B.numel() == 1) {
functor_.template Run<true, T, T, T>(
A.numel(), Adata, Bdata, 0, Cdata, &context_);
} else {
size_t pre, n, post;
std::tie(pre, n, post) =
elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
if (post == 1) {
functor_.template RunWithBroadcast<T, T, T>(
Adata, Bdata, Cdata, pre, n, &context_);
} else {
functor_.template RunWithBroadcast2<T, T, T>(
Adata, Bdata, Cdata, pre, n, post, &context_);
}
}
} else {
CAFFE_THROW(
"Only a tensor with an argument or two input tensors are supported as input to pow operator.");
}
return true;
}
private:
bool enable_broadcast_;
int axis_;
string axis_str_;
string order_;
float exponent_;
Functor functor_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_POW_OP_H_