forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
filler.h
140 lines (119 loc) · 3.82 KB
/
filler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#ifndef CAFFE2_FILLER_H_
#define CAFFE2_FILLER_H_
#include <sstream>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// TODO: replace filler distribution enum with a better abstraction
enum FillerDistribution { FD_UNIFORM, FD_FIXEDSUM, FD_SYNTHETIC };
class TensorFiller {
public:
template <class Type, class Context>
void Fill(Tensor* tensor, Context* context) const {
CAFFE_ENFORCE(context, "context is null");
CAFFE_ENFORCE(tensor, "tensor is null");
auto min = (min_ < std::numeric_limits<Type>::min())
? std::numeric_limits<Type>::min()
: static_cast<Type>(min_);
auto max = (max_ > std::numeric_limits<Type>::max())
? std::numeric_limits<Type>::max()
: static_cast<Type>(max_);
CAFFE_ENFORCE_LE(min, max);
Tensor temp_tensor(shape_, Context::GetDeviceType());
std::swap(*tensor, temp_tensor);
Type* data = tensor->template mutable_data<Type>();
// select distribution
switch (dist_) {
case FD_UNIFORM: {
math::RandUniform<Type, Context>(
tensor->numel(), min, max, data, context);
break;
}
case FD_FIXEDSUM: {
auto fixed_sum = static_cast<Type>(fixed_sum_);
CAFFE_ENFORCE_LE(min * tensor->numel(), fixed_sum);
CAFFE_ENFORCE_GE(max * tensor->numel(), fixed_sum);
math::RandFixedSum<Type, Context>(
tensor->numel(), min, max, fixed_sum_, data, context);
break;
}
case FD_SYNTHETIC: {
math::RandSyntheticData<Type, Context>(
tensor->numel(), min, max, data, context);
break;
}
}
}
TensorFiller& Dist(FillerDistribution dist) {
dist_ = dist;
return *this;
}
template <class Type>
TensorFiller& Min(Type min) {
min_ = (double)min;
return *this;
}
template <class Type>
TensorFiller& Max(Type max) {
max_ = (double)max;
return *this;
}
template <class Type>
TensorFiller& FixedSum(Type fixed_sum) {
dist_ = FD_FIXEDSUM;
fixed_sum_ = (double)fixed_sum;
return *this;
}
// A helper function to construct the lengths vector for sparse features
// We try to pad least one index per batch unless the total_length is 0
template <class Type>
TensorFiller& SparseLengths(Type total_length) {
return FixedSum(total_length)
.Min(std::min(static_cast<Type>(1), total_length))
.Max(total_length);
}
// a helper function to construct the segments vector for sparse features
template <class Type>
TensorFiller& SparseSegments(Type max_segment) {
CAFFE_ENFORCE(dist_ != FD_FIXEDSUM);
return Min(0).Max(max_segment).Dist(FD_SYNTHETIC);
}
TensorFiller& Shape(const std::vector<int64_t>& shape) {
shape_ = shape;
return *this;
}
template <class Type>
TensorFiller(const std::vector<int64_t>& shape, Type fixed_sum)
: shape_(shape), dist_(FD_FIXEDSUM), fixed_sum_((double)fixed_sum) {}
TensorFiller(const std::vector<int64_t>& shape)
: shape_(shape), dist_(FD_UNIFORM), fixed_sum_(0) {}
TensorFiller() : TensorFiller(std::vector<int64_t>()) {}
std::string DebugString() const {
std::stringstream stream;
stream << "shape = [" << shape_ << "]; min = " << min_
<< "; max = " << max_;
switch (dist_) {
case FD_FIXEDSUM:
stream << "; dist = FD_FIXEDSUM";
break;
case FD_SYNTHETIC:
stream << "; dist = FD_SYNTHETIC";
break;
default:
stream << "; dist = FD_UNIFORM";
break;
}
return stream.str();
}
private:
std::vector<int64_t> shape_;
// TODO: type is unknown until a user starts to fill data;
// cast everything to double for now.
double min_ = 0.0;
double max_ = 1.0;
FillerDistribution dist_;
double fixed_sum_;
};
} // namespace caffe2
#endif // CAFFE2_FILLER_H_