This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
tensor_gpu-inl.h
264 lines (247 loc) · 10.5 KB
/
tensor_gpu-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2014 by Contributors
* \file tensor_gpu-inl.h
* \brief implementation of GPU host code
* \author Bing Xu, Tianqi Chen
*/
#ifndef MSHADOW_TENSOR_GPU_INL_H_
#define MSHADOW_TENSOR_GPU_INL_H_
#include "./base.h"
#include "./tensor.h"
namespace mshadow {
#if MSHADOW_USE_CUDA
template<>
inline void InitTensorEngine<gpu>(int dev_id) {
cudaDeviceProp prop;
int device_id = 0;
int device_count = 0;
cudaGetDeviceCount(&device_count);
CHECK_GT(device_count, 0) << "Cannot find CUDA device. Please check CUDA-Configuration";
if (dev_id < 0) {
device_id = 0;
} else {
device_id = dev_id;
}
CHECK_LT(device_id, device_count) << "Incorrect Device ID";
MSHADOW_CUDA_CALL(cudaSetDevice(device_id));
MSHADOW_CUDA_CALL(cudaGetDeviceProperties(&prop, device_id));
}
template<>
inline void ShutdownTensorEngine<gpu>(void) {
}
template<>
inline void SetDevice<gpu>(int devid) {
MSHADOW_CUDA_CALL(cudaSetDevice(devid));
}
template<int dim, typename DType>
inline void AllocSpace(Tensor<gpu, dim, DType> *obj, bool pad) {
size_t pitch;
// common choice for cuda mem align unit is 32
if (pad && obj->size(dim - 1) >= MSHADOW_MIN_PAD_RATIO * 32) {
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
obj->size(dim - 1) * sizeof(DType),
obj->shape_.FlatTo2D()[0]));
obj->stride_ = static_cast<index_t>(pitch / sizeof(DType));
} else {
obj->stride_ = obj->size(dim - 1);
MSHADOW_CUDA_CALL(cudaMallocPitch(reinterpret_cast<void**>(&(obj->dptr_)), &pitch,
obj->shape_.Size() * sizeof(DType), 1));
}
}
template<int dim, typename DType>
inline void FreeSpace(Tensor<gpu, dim, DType> *obj) {
MSHADOW_CUDA_CALL(cudaFree(obj->dptr_));
obj->dptr_ = NULL;
}
template<typename A, typename B, int dim, typename DType>
inline void Copy(Tensor<A, dim, DType> _dst,
Tensor<B, dim, DType> _src,
cudaMemcpyKind kind,
Stream<gpu> *stream) {
CHECK_EQ(_dst.shape_, _src.shape_) << "Copy:shape mismatch";
Tensor<A, 2, DType> dst = _dst.FlatTo2D();
Tensor<B, 2, DType> src = _src.FlatTo2D();
MSHADOW_CUDA_CALL(cudaMemcpy2DAsync(dst.dptr_, dst.stride_ * sizeof(DType),
src.dptr_, src.stride_ * sizeof(DType),
dst.size(1) * sizeof(DType),
dst.size(0), kind,
Stream<gpu>::GetStream(stream)));
// use synchronize call behavior for zero stream
if (stream == NULL) {
MSHADOW_CUDA_CALL(cudaStreamSynchronize(0));
}
}
template<int dim, typename DType>
inline void Copy(Tensor<cpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyDeviceToHost, stream);
}
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<gpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyDeviceToDevice, stream);
}
template<int dim, typename DType>
inline void Copy(Tensor<gpu, dim, DType> dst,
const Tensor<cpu, dim, DType> &src,
Stream<gpu> *stream) {
Copy(dst, src, cudaMemcpyHostToDevice, stream);
}
#endif // MSHADOW_USE_CUDA
} // namespace mshadow
// the following part is included only if compiler is nvcc
#ifdef __CUDACC__
#include "./cuda/tensor_gpu-inl.cuh"
namespace mshadow {
template<typename Saver, typename R, int dim,
typename DType, typename E, int etype>
inline void MapExp(TRValue<R, gpu, dim, DType> *dst,
const expr::Exp<E, DType, etype> &exp) {
expr::TypeCheckPass<expr::TypeCheck<gpu, dim, DType, E>::kMapPass>
::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
Shape<dim> eshape = expr::ShapeCheck<dim, E>::Check(exp.self());
Shape<dim> dshape = expr::ShapeCheck<dim, R>::Check(dst->self());
CHECK(eshape[0] == 0 || eshape == dshape)
<< "Assignment: Shape of Tensors are not consistent with target, "
<< "eshape: " << eshape << " dshape:" << dshape;
cuda::MapPlan<Saver>(MakePlan(dst->self()),
MakePlan(exp.self()),
dshape.FlatTo2D(),
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename Saver, typename Reducer,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepLowest(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<gpu, 1, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
Shape<2> eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self()).FlatTo2D();
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[1], dshape[0]) << "MapReduceKeepLowest::reduction dimension do not match";
CHECK_NE(eshape[0], 0U) << "can not reduce over empty tensor";
cuda::MapReduceKeepLowest<Saver, Reducer>
(MakePlan(dst->self()), MakePlan(exp.self()), scale, eshape,
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename Saver, typename Reducer, int dimkeep,
typename R, typename DType, typename E, int etype>
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale) {
expr::TypeCheckPass<expr::TypeCheck<gpu, dimkeep, DType, E>::kRedPass>
::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
typedef Shape<expr::ExpInfo<E>::kDim> EShape;
EShape eshape = expr::ShapeCheck<expr::ExpInfo<E>::kDim, E>
::Check(exp.self());
Shape<1> dshape = expr::ShapeCheck<1, R>::Check(dst->self());
CHECK_EQ(eshape[dimkeep], dshape[0]) << "MapReduceKeepHighDim::reduction dimension do not match";
// use equvalent form
Shape<4> pshape = Shape4(eshape.ProdShape(0, dimkeep),
eshape[dimkeep],
eshape.ProdShape(dimkeep + 1, EShape::kSubdim),
eshape[EShape::kSubdim]);
// call equavalent map red dim 2
cuda::MapReduceKeepDim1<Saver, Reducer>
(MakePlan(dst->self()), MakePlan(exp.self()), scale, pshape,
Stream<gpu>::GetStream(expr::StreamInfo<gpu, R>::Get(dst->self())));
}
template<typename DType>
inline void Softmax(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType>& src) {
cuda::Softmax(dst, src);
}
template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType>& src) {
cuda::Softmax(dst, src);
}
template<typename DType>
inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}
template<typename DType>
inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label,
const float alpha) {
cuda::SmoothSoftmaxGrad(dst, src, label, alpha);
}
template<typename DType>
inline void SoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label,
const DType &ignore_label) {
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}
template<typename DType>
inline void SmoothSoftmaxGrad(const Tensor<gpu, 2, DType> &dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label,
const DType &ignore_label,
const float alpha) {
cuda::SmoothSoftmaxGrad(dst, src, label, ignore_label, alpha);
}
template<typename DType>
inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}
template<typename DType>
inline void SoftmaxGrad(const Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label,
const DType &ignore_label) {
cuda::SoftmaxGrad(dst, src, label, ignore_label);
}
template<bool clip, typename IndexType, typename DType>
inline void AddTakeGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGrad<clip, IndexType, DType>(dst, index, src);
}
template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& sorted,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::AddTakeGradLargeBatch(dst, sorted, index, src);
}
template<typename KDType, typename VDType>
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
bool is_ascend) {
cuda::SortByKey(keys, values, is_ascend);
}
template<typename IndexType, typename DType>
inline void IndexFill(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 1, IndexType>& index,
const Tensor<gpu, 2, DType> &src) {
cuda::IndexFill(dst, index, src);
}
} // namespace mshadow
#endif // __CUDACC__
#endif // MSHADOW_TENSOR_GPU_INL_H_