/
data_simulator.cu
114 lines (98 loc) · 4.43 KB
/
data_simulator.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <data_simulator.hpp>
#include <diagnose.hpp>
#include <random>
#include <utils.cuh>
namespace HugeCTR {
template <>
void UniformGenerator::fill<float>(float* ptr, size_t num_elements, float a, float b,
size_t sm_count, const curandGenerator_t& generator,
const cudaStream_t& stream) {
if (a >= b) {
HCTR_OWN_THROW(Error_t::WrongInput, "a must be smaller than b");
}
HCTR_LIB_THROW(curandGenerateUniform(generator, ptr, num_elements));
auto op = [a, b] __device__(float val) { return val * (b - a) + a; };
transform_array<<<sm_count * 2, 1024, 0, stream>>>(ptr, ptr, num_elements, op);
}
template <typename T>
__global__ void sinusoidal_kernel(T* output, int ev_size, int max_sequence_len) {
int row = blockIdx.x;
int col = threadIdx.x;
int offset = row * ev_size + col;
float log_result = __logf(10000) / (ev_size);
float exp_result = __expf(((col >> 1) << 1) * -1 * log_result);
if (col < ev_size) {
output[offset] = (col % 2) ? (T)__cosf(exp_result * row) : (T)__sinf(exp_result * row);
}
}
template <>
void SinusoidalGenerator::fill<float>(float* ptr, size_t num_elements, int ev_size,
int max_sequence_len, size_t sm_count,
const cudaStream_t& stream) {
sinusoidal_kernel<<<max_sequence_len, max(32, ev_size), 0, stream>>>(ptr, ev_size,
max_sequence_len);
}
template <>
void UniformGenerator::fill<float>(Tensor2<float>& tensor, float a, float b, size_t sm_count,
const curandGenerator_t& generator, const cudaStream_t& stream) {
UniformGenerator::fill<float>(tensor.get_ptr(), tensor.get_num_elements(), a, b, sm_count,
generator, stream);
}
template <>
void HostUniformGenerator::fill<float>(Tensor2<float>& tensor, float a, float b,
const curandGenerator_t& generator) {
if (a >= b) {
HCTR_OWN_THROW(Error_t::WrongInput, "a must be smaller than b");
}
HCTR_LIB_THROW(curandGenerateUniform(generator, tensor.get_ptr(),
tensor.get_num_elements() % 2 != 0
? tensor.get_num_elements() + 1
: tensor.get_num_elements()));
float* p = tensor.get_ptr();
for (size_t i = 0; i < tensor.get_num_elements(); i++) {
p[i] = p[i] * (b - a) + a;
}
}
template <>
void NormalGenerator::fill<float>(Tensor2<float>& tensor, float mean, float stddev, size_t sm_count,
const curandGenerator_t& generator, const cudaStream_t& stream) {
HCTR_LIB_THROW(
curandGenerateNormal(generator, tensor.get_ptr(), tensor.get_num_elements(), mean, stddev));
}
template <>
void HostNormalGenerator::fill<float>(Tensor2<float>& tensor, float mean, float stddev,
const curandGenerator_t& gen) {
HCTR_LIB_THROW(curandGenerateNormal(gen, tensor.get_ptr(),
tensor.get_num_elements() % 2 != 0
? tensor.get_num_elements() + 1
: tensor.get_num_elements(),
mean, stddev));
}
void ConstantDataSimulator::fill(Tensor2<float>& tensor, const curandGenerator_t& gen) {
float* p = tensor.get_ptr();
for (size_t i = 0; i < tensor.get_num_elements(); i++) {
p[i] = value_;
}
}
void UniformDataSimulator::fill(Tensor2<float>& tensor, const curandGenerator_t& gen) {
HostUniformGenerator::fill(tensor, min_, max_, gen);
}
void GaussianDataSimulator::fill(Tensor2<float>& tensor, const curandGenerator_t& gen) {
HostNormalGenerator::fill(tensor, mu_, sigma_, gen);
}
} // namespace HugeCTR