-
Notifications
You must be signed in to change notification settings - Fork 868
/
add_residual_kernels.cu
562 lines (518 loc) · 27.9 KB
/
add_residual_kernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/fastertransformer/kernels/add_residual_kernels.h"
#include "src/fastertransformer/utils/cuda_type_utils.cuh"
namespace fastertransformer {
template<typename T, int RESIDUAL_NUM, typename T2 = T>
__global__ void addBiasResidual(T* output,
const T2* input,
const T* residual1,
const T* residual2,
const T* bias,
const float* scale_inter,
const float* scale_out,
const int m,
const int n)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
T bias_val = (bias == nullptr) ? (T)(0.0f) : bias[col_index];
T in;
if (std::is_same<T, T2>::value) {
in = cuda_cast<T>(input[blockIdx.x * n + col_index]); // cast required for compilation when T != T2
}
else {
in = cuda_cast<float>(input[blockIdx.x * n + col_index]) * (*scale_inter) * (*scale_out);
}
if (RESIDUAL_NUM == 1) {
output[blockIdx.x * n + col_index] = in + residual1[blockIdx.x * n + col_index] + bias_val;
}
else if (RESIDUAL_NUM == 2) {
output[blockIdx.x * n + col_index] =
in + residual1[blockIdx.x * n + col_index] + residual2[blockIdx.x * n + col_index] + bias_val;
}
}
}
template<typename T>
void invokeAddBiasResidual(T* output,
const T* input,
const T* residual1,
const T* residual2,
const T* bias,
const float* scale_inter,
const float* scale_out,
const int m,
const int n,
cudaStream_t stream)
{
FT_CHECK_WITH_INFO(!((scale_inter == nullptr) ^ (scale_out == nullptr)),
"Cannot use `scale_inter` without `scale_out`");
const bool should_scale_input = scale_inter != nullptr;
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
if (residual2 == nullptr) {
if (should_scale_input) {
addBiasResidual<T, 1><<<grid, block, 0, stream>>>(output,
reinterpret_cast<const int32_t*>(input),
residual1,
residual2,
bias,
scale_inter,
scale_out,
m,
n);
}
else {
addBiasResidual<T, 1>
<<<grid, block, 0, stream>>>(output, input, residual1, residual2, bias, nullptr, nullptr, m, n);
}
}
else {
if (should_scale_input) {
addBiasResidual<T, 2><<<grid, block, 0, stream>>>(output,
reinterpret_cast<const int32_t*>(input),
residual1,
residual2,
bias,
scale_inter,
scale_out,
m,
n);
}
else {
addBiasResidual<T, 2>
<<<grid, block, 0, stream>>>(output, input, residual1, residual2, bias, nullptr, nullptr, m, n);
}
}
}
template<typename T>
void invokeAddBiasResidual(
T* output, const T* residual1, const T* residual2, const T* bias, const int m, const int n, cudaStream_t stream)
{
invokeAddBiasResidual(output, output, residual1, residual2, bias, nullptr, nullptr, m, n, stream);
}
template<typename T>
__global__ void addBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* block_input,
const T* bias,
const int m,
const int n,
const int block_input_tp_split)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
block_output[blockIdx.x * n + col_index] =
ffn_output[blockIdx.x * n + col_index] + attn_output[blockIdx.x * n + col_index] + bias[col_index]
+ ((block_input != nullptr) ?
cuda_cast<T>((float)block_input[blockIdx.x * n + col_index] / (float)block_input_tp_split) :
static_cast<T>(0.0f));
}
}
template<typename T>
__global__ void addBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* bias,
const int m,
const int n,
const int block_input_tp_split)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
const int global_index = blockIdx.x * n + col_index;
block_output[global_index] = add(cuda_cast<T>((float)block_output[global_index] / (float)block_input_tp_split),
ffn_output[global_index],
attn_output[global_index],
bias[col_index]);
}
}
template<typename T>
void invokeAddBiasAttentionFfnResidual(T* block_output,
const T* ffn_output,
const T* attn_output,
const T* block_input,
const T* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream)
{
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
if (block_output == block_input) {
addBiasAttentionFfnResidual<<<grid, block, 0, stream>>>(
block_output, ffn_output, attn_output, bias, m, n, block_input_tp_split);
}
else {
addBiasAttentionFfnResidual<<<grid, block, 0, stream>>>(
block_output, ffn_output, attn_output, block_input, bias, m, n, block_input_tp_split);
}
}
#define INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(T) \
template void invokeAddBiasResidual(T* output, \
const T* input, \
const T* residual1, \
const T* residual2, \
const T* bias, \
const float* scale_inter, \
const float* scale_out, \
const int m, \
const int n, \
cudaStream_t stream)
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(float);
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(half);
#ifdef ENABLE_BF16
INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL(__nv_bfloat16);
#endif
#undef INSTANTIATE_INVOKE_ADD_BIAS_RESIDUAL
template void invokeAddBiasResidual(float* output,
const float* residual1,
const float* residual2,
const float* bias,
const int m,
const int n,
cudaStream_t stream);
template void invokeAddBiasResidual(half* output,
const half* residual1,
const half* residual2,
const half* bias,
const int m,
const int n,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeAddBiasResidual(__nv_bfloat16* output,
const __nv_bfloat16* residual1,
const __nv_bfloat16* residual2,
const __nv_bfloat16* bias,
const int m,
const int n,
cudaStream_t stream);
#endif
template void invokeAddBiasAttentionFfnResidual(float* block_output,
const float* ffn_output,
const float* attn_output,
const float* input,
const float* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
template void invokeAddBiasAttentionFfnResidual(half* block_output,
const half* ffn_output,
const half* attn_output,
const half* input,
const half* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeAddBiasAttentionFfnResidual(__nv_bfloat16* block_output,
const __nv_bfloat16* ffn_output,
const __nv_bfloat16* attn_output,
const __nv_bfloat16* input,
const __nv_bfloat16* bias,
const int m,
const int n,
const int block_input_tp_split,
cudaStream_t stream);
#endif
template<typename T>
__global__ void T5addResidual(T* output, const T* input, const int m, const int n)
{
const int col_index = blockIdx.y * blockDim.x + threadIdx.x;
if (col_index < n) {
float out_val = (float)output[blockIdx.x * n + col_index] + (float)input[blockIdx.x * n + col_index];
output[blockIdx.x * n + col_index] =
(T)((std::is_same<T, half>::value && (out_val > 64512 || out_val < -64512)) ?
(out_val > 0 ? 64512 : -64512) :
out_val);
}
}
template<typename T>
void invokeT5AddResidual(T* output, const T* input, const int m, const int n, cudaStream_t stream)
{
int blocks_per_row = ceil(float(n) / 1024);
dim3 grid(m, blocks_per_row);
dim3 block(min(n, 1024));
T5addResidual<<<grid, block, 0, stream>>>(output, input, m, n);
}
template void invokeT5AddResidual(float* output, const float* input, const int m, const int n, cudaStream_t stream);
template void invokeT5AddResidual(half* output, const half* input, const int m, const int n, cudaStream_t stream);
#ifdef ENABLE_BF16
template void
invokeT5AddResidual(__nv_bfloat16* output, const __nv_bfloat16* input, const int m, const int n, cudaStream_t stream);
#endif
template<typename T>
void invokeT5AddBiasResidual(T* output, const T* input, const T* bias, const int m, const int n, cudaStream_t stream)
{
if (bias != nullptr) {
invokeAddBiasResidual(output, input, bias, m, n, stream);
}
else {
invokeT5AddResidual(output, input, m, n, stream);
}
return;
}
template void invokeT5AddBiasResidual(
float* output, const float* input, const float* bias, const int m, const int n, cudaStream_t stream);
template void invokeT5AddBiasResidual(
half* output, const half* input, const half* bias, const int m, const int n, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeT5AddBiasResidual(__nv_bfloat16* output,
const __nv_bfloat16* input,
const __nv_bfloat16* bias,
const int m,
const int n,
cudaStream_t stream);
#endif
/******************* invokeAddBiasResidualCol32 ***********************/
// input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n)
//(grid, block) must be (m, n/4)
// using char4
template<typename T>
__global__ void add_bias_input_COL32_int8I_DataTypeO(
T* output, const int8_t* input1, const T* input2, const T* bias, int m, int n, const float* input1_deQFactor_ptr)
{
const float input1_deQFactor = __ldg(input1_deQFactor_ptr);
int col_start = threadIdx.x << 2;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
char4* input1TmpPtr = (char4*)input1;
char4 input1Tmp = __ldg(input1TmpPtr + outIdx);
int col_start_tmp = col_start;
local_out[0] = static_cast<float>(input2[(outIdx << 2) + 0]) + static_cast<float>(input1Tmp.x) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[1] = static_cast<float>(input2[(outIdx << 2) + 1]) + static_cast<float>(input1Tmp.y) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[2] = static_cast<float>(input2[(outIdx << 2) + 2]) + static_cast<float>(input1Tmp.z) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[3] = static_cast<float>(input2[(outIdx << 2) + 3]) + static_cast<float>(input1Tmp.w) * input1_deQFactor
+ static_cast<float>(__ldg(bias + col_start_tmp));
for (int i = 0; i < 4; i++) {
output[(outIdx << 2) + i] = static_cast<T>(local_out[i]);
}
}
template<>
__global__ void add_bias_input_COL32_int8I_DataTypeO(half4* output,
const int8_t* input1,
const half4* input2,
const half4* bias,
int m,
int n,
const float* input1_deQFactor_ptr)
{
const float input1_deQFactor = __ldg(input1_deQFactor_ptr);
int col_start = (blockIdx.x << 5) + (threadIdx.x << 2);
int row_start = (blockIdx.y << 5) + (threadIdx.y);
if (col_start < n && row_start < m) {
half4 local_out;
int outIdx = ((col_start & 0xffffffe0) * m + (row_start << 5) + (col_start & 31)) >> 2;
char4* input1TmpPtr = (char4*)input1;
char4 input1Tmp = input1TmpPtr[outIdx];
half4 input2Tmp = input2[outIdx];
half4 biasTmp = bias[col_start >> 2];
local_out.x = static_cast<half>((float)input1Tmp.x * input1_deQFactor + (float)biasTmp.x + (float)input2Tmp.x);
local_out.y = static_cast<half>((float)input1Tmp.y * input1_deQFactor + (float)biasTmp.y + (float)input2Tmp.y);
local_out.z = static_cast<half>((float)input1Tmp.z * input1_deQFactor + (float)biasTmp.z + (float)input2Tmp.z);
local_out.w = static_cast<half>((float)input1Tmp.w * input1_deQFactor + (float)biasTmp.w + (float)input2Tmp.w);
output[outIdx] = local_out;
}
}
template<typename T>
void invokeAddBiasResidualCol32(T* output,
const int8_t* input1,
const T* input2,
const T* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr)
{
dim3 grid((n + 31) / 32, (m + 31) / 32);
dim3 block(8, 32);
assert(block.x <= 1024);
if (sizeof(T) == 2) {
add_bias_input_COL32_int8I_DataTypeO<<<grid, block, 0, stream>>>(
(half4*)output, input1, (const half4*)input2, (const half4*)bias, m, n, input1_deQFactor_ptr);
}
else {
add_bias_input_COL32_int8I_DataTypeO<T>
<<<grid, block, 0, stream>>>(output, input1, input2, bias, m, n, input1_deQFactor_ptr);
}
}
template void invokeAddBiasResidualCol32(float* output,
const int8_t* input1,
const float* input2,
const float* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr);
template void invokeAddBiasResidualCol32(half* output,
const int8_t* input1,
const half* input2,
const half* bias,
int m,
int n,
cudaStream_t stream,
const float* input1_deQFactor_ptr);
/******************* invokeAddBiasResidualCol32 ***********************/
// input1/input2/out matrix with layout of cublasLt CUBLASLT_ORDER_COL32 (m*n)
//(grid, block) must be (m, n/4)
// using char4
template<typename T>
__global__ void add_bias_input_COL32_int32I_DataTypeO(T* output,
const int32_t* input1,
const T* input2,
const T* bias,
int m,
int n,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
int col_start = threadIdx.x << 2;
const float4* weight_scale_ptr = (const float4*)weight_amax;
const float4 weight_scale = __ldg(weight_scale_ptr + threadIdx.x * scale_is_vector);
const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
int4* input1TmpPtr = (int4*)input1;
int4 input1Tmp = input1TmpPtr[outIdx];
int col_start_tmp = col_start;
local_out[0] = static_cast<float>(input2[(outIdx << 2) + 0])
+ static_cast<float>(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[1] = static_cast<float>(input2[(outIdx << 2) + 1])
+ static_cast<float>(input1Tmp.y) * input1_deQ * weight_scale.y / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[2] = static_cast<float>(input2[(outIdx << 2) + 2])
+ static_cast<float>(input1Tmp.z) * input1_deQ * weight_scale.z / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
col_start_tmp = col_start_tmp + 1;
local_out[3] = static_cast<float>(input2[(outIdx << 2) + 3])
+ static_cast<float>(input1Tmp.w) * input1_deQ * weight_scale.w / 127.0f
+ static_cast<float>(__ldg(bias + col_start_tmp));
for (int i = 0; i < 4; i++) {
output[(outIdx << 2) + i] = static_cast<T>(local_out[i]);
}
}
template<>
__global__ void add_bias_input_COL32_int32I_DataTypeO(half4* output,
const int32_t* input1,
const half4* input2,
const half4* bias,
int m,
int n,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
int col_start = threadIdx.x << 2;
const float4* weight_scale_ptr = (const float4*)weight_amax;
const float weight_scale_single = __ldg(weight_amax);
const float4 weight_scale =
scale_is_vector == 1 ?
__ldg(weight_scale_ptr + threadIdx.x * scale_is_vector) :
make_float4(weight_scale_single, weight_scale_single, weight_scale_single, weight_scale_single);
const float input1_deQ = __ldg(input1_amax_ptr) / 127.0f;
float local_out[4];
int outIdx = ((col_start & 0xffffffe0) * m + (blockIdx.x << 5) + (col_start & 31)) >> 2;
int4* input1TmpPtr = (int4*)input1;
int4 input1Tmp = input1TmpPtr[outIdx];
half4 input2Tmp = input2[outIdx];
half4 biasTmp = bias[threadIdx.x];
local_out[0] = static_cast<float>(input2Tmp.x)
+ static_cast<float>(input1Tmp.x) * input1_deQ * weight_scale.x / 127.0f
+ static_cast<float>(biasTmp.x);
local_out[1] = static_cast<float>(input2Tmp.y)
+ static_cast<float>(input1Tmp.y) * input1_deQ * weight_scale.y / 127.0f
+ static_cast<float>(biasTmp.y);
local_out[2] = static_cast<float>(input2Tmp.z)
+ static_cast<float>(input1Tmp.z) * input1_deQ * weight_scale.z / 127.0f
+ static_cast<float>(biasTmp.z);
local_out[3] = static_cast<float>(input2Tmp.w)
+ static_cast<float>(input1Tmp.w) * input1_deQ * weight_scale.w / 127.0f
+ static_cast<float>(biasTmp.w);
half4 outTmp;
outTmp.x = static_cast<half>(local_out[0]);
outTmp.y = static_cast<half>(local_out[1]);
outTmp.z = static_cast<half>(local_out[2]);
outTmp.w = static_cast<half>(local_out[3]);
output[outIdx] = outTmp;
}
template<typename T>
void invokeAddBiasResidualCol32(T* output,
const int32_t* input1,
const T* input2,
const T* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector)
{
dim3 grid(m);
dim3 block(n / 4);
assert(block.x <= 1024);
if (sizeof(T) == 2) {
add_bias_input_COL32_int32I_DataTypeO<<<grid, block, 0, stream>>>((half4*)output,
input1,
(const half4*)input2,
(const half4*)bias,
m,
n,
weight_amax,
input1_amax_ptr,
scale_is_vector);
}
else {
add_bias_input_COL32_int32I_DataTypeO<T><<<grid, block, 0, stream>>>(
output, input1, input2, bias, m, n, weight_amax, input1_amax_ptr, scale_is_vector);
}
}
template void invokeAddBiasResidualCol32(float* output,
const int* input1,
const float* input2,
const float* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector);
template void invokeAddBiasResidualCol32(half* output,
const int* input1,
const half* input2,
const half* bias,
int m,
int n,
cudaStream_t stream,
const float* weight_amax,
const float* input1_amax_ptr,
const int scale_is_vector);
} // namespace fastertransformer