forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tt_linear_op.h
195 lines (169 loc) · 6.3 KB
/
tt_linear_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#ifndef CAFFE2_OPERATORS_TT_LINEAR_OP_H_
#define CAFFE2_OPERATORS_TT_LINEAR_OP_H_
#ifdef CAFFE2_USE_MKL
#include <mkl.h>
#endif // CAFFE2_USE_MKL
#include "Eigen/Core"
#include "Eigen/Dense"
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <typename T, class Context, class Engine = DefaultEngine>
class TTLinearOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit TTLinearOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
inp_sizes_(this->template GetRepeatedArgument<int>("inp_sizes")),
out_sizes_(this->template GetRepeatedArgument<int>("out_sizes")),
tt_ranks_(this->template GetRepeatedArgument<int>("tt_ranks")),
Y_temp_(unique_ptr<Blob>(new Blob())) {}
~TTLinearOp() {}
bool RunOnDevice() override {
const auto& X = Input(0); // Input array
const auto& b = Input(1); // Bias array
const auto& cores = Input(2); // 1D array containing the TT-cores
CAFFE_ENFORCE(X.dim() > 1, "Number of dimensions in X: ", X.dim());
CAFFE_ENFORCE(b.dim() == 1, "Number of dimensions in b: ", b.dim());
CAFFE_ENFORCE(
inp_sizes_.size() == out_sizes_.size(),
"inp_sizes has size: ",
inp_sizes_.size(),
", out_sizes has size: ",
out_sizes_.size());
CAFFE_ENFORCE(
cores.dim() == 1, "Number of dimensions in cores: ", cores.dim());
// batch size
const int batch_size = X.dim() > 1 ? X.dim32(0) : 1;
// dimension d of tensors
const int d = inp_sizes_.size();
// Keep track of index of current core in multiplication
int cores_idx = 0;
// Temporary buffer to facilitate multiplication of TT-cores with input
auto Y_buf = BlobGetMutableTensor(Y_temp_.get(), Context::GetDeviceType());
Y_buf->ResizeLike(X);
Y_buf->CopyFrom(X);
Tensor* Y;
// The overall forward pass involves multiplication with each core, where
// each core has sizes dictated by inp_sizes_ and out_sizes_. Each core thus
// has size inp_sizes_[i] * tt_ranks_[i] * tt_ranks_[i + 1] * out_sizes_[i].
for (int i = (d - 1); i >= 0; --i) {
int curr_rows = inp_sizes_[i] * tt_ranks_[i + 1];
int curr_cols = tt_ranks_[i] * out_sizes_[i];
// TODO Replace by Reshape(), once wrappers are written
Y_buf->Resize(Y_buf->numel() / curr_rows, curr_rows);
Y = Output(
0, {Y_buf->numel() / curr_rows, curr_cols}, at::dtype<float>());
// Defensive checks
CAFFE_ENFORCE(Y_buf->numel() % curr_rows == 0, Y_buf->numel(), curr_rows);
CAFFE_ENFORCE(
cores_idx + curr_rows * curr_cols <= cores.numel(),
cores_idx + curr_rows * curr_cols,
cores.numel());
// Multiply ith core with the intermediate output
math::Gemm<float, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
Y_buf->numel() / curr_rows,
curr_cols,
curr_rows,
1,
Y_buf->template data<float>(),
cores.template data<float>() + cores_idx,
0,
Y->template mutable_data<float>(),
&context_);
CAFFE_ENFORCE(Y->numel() % out_sizes_[i] == 0, Y->numel(), out_sizes_[i]);
// TODO Add GPU support by writing a generic wrapper.
auto Y_mat = EigenMatrixMap<float>(
Y->template mutable_data<float>(),
Y->numel() / out_sizes_[i],
out_sizes_[i]);
Y_mat = ConstEigenMatrixMap<float>(
Y->template data<float>(),
out_sizes_[i],
Y->numel() / out_sizes_[i])
.transpose()
.eval();
// Resize operation
Y_buf->Resize(Y->dim32(0), Y->dim32(1));
context_.template CopyFromCPU<float>(
Y->numel(),
Y->template data<float>(),
Y_buf->template mutable_data<float>());
cores_idx += curr_rows * curr_cols;
}
// TODO Add GPU support by writing a generic wrapper.
auto Y_mat = EigenMatrixMap<float>(
Y->template mutable_data<float>(), batch_size, Y->numel() / batch_size);
Y_mat = ConstEigenMatrixMap<float>(
Y->template data<float>(), Y->numel() / batch_size, batch_size)
.transpose()
.eval();
// TODO Replace by Reshape(), once wrappers are written
Y = Output(0, {batch_size, Y->numel() / batch_size}, at::dtype<float>());
// Check that output size of Y is the element-wise product of out_sizes
int prod_out_sizes = 1;
for (int i = 0; i < out_sizes_.size(); i++) {
prod_out_sizes *= out_sizes_[i];
}
CAFFE_ENFORCE(
Y->dim32(1) == prod_out_sizes,
"Output dimension of Y: ",
Y->dim32(1),
", product of out_sizes: ",
prod_out_sizes);
// Add bias term
if (bias_multiplier_.numel() != batch_size) {
// If the helper bias multiplier is not M, reshape and fill it with one.
ReinitializeTensor(
&bias_multiplier_,
{batch_size},
at::dtype<T>().device(Context::GetDeviceType()));
math::Set<T, Context>(
batch_size,
static_cast<T>(1),
bias_multiplier_.template mutable_data<T>(),
&context_);
}
math::Gemm<T, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
Y->dim32(0),
Y->dim32(1),
1,
1,
bias_multiplier_.template data<T>(),
b.template data<T>(),
1,
Y->template mutable_data<T>(),
&context_);
return true;
}
protected:
Tensor bias_multiplier_;
std::vector<int> inp_sizes_;
std::vector<int> out_sizes_;
std::vector<int> tt_ranks_;
std::unique_ptr<Blob> Y_temp_;
};
// TODO: Complete after verifying utility of TT-layer's forward pass.
template <typename T, class Context, class Engine = DefaultEngine>
class TTLinearGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit TTLinearGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
~TTLinearGradientOp() {}
bool RunOnDevice() override {
return false;
}
protected:
Tensor bias_multiplier_{Context::GetDeviceType()};
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_TT_LINEAR_OP_H_