Skip to content

Commit ce8efcd

Browse files
committed
wider kv transactions using vectorized loads
1 parent 19a2191 commit ce8efcd

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ ifneq (,$(wildcard /usr/local/cuda))
4040
LDFLAGS+=-L/usr/local/cuda/lib64
4141
endif
4242

43-
CUFLAGS+=-g -O2 -lineinfo -Ivendor
43+
CUFLAGS+=-O2 -lineinfo -Ivendor
4444
CUFLAGS+=-allow-unsupported-compiler # for recent CUDA versions
4545

4646
ifeq ($(CUARCH),)

src/infer.cu

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -411,23 +411,34 @@ void att_mix(
411411
int warp_id = threadIdx.y;
412412
int t_stride = blockDim.y;
413413

414-
// Capacity 32 since there can be at most 32 warps in a block.
415-
__shared__ float shared[32];
414+
// Each lane of the warp accumulates across 2 head elements at a time.
415+
// NOTE: Assumes warpSize is 32
416+
__shared__ float shared0[32]; // shared0[i] == chunk[2*i]
417+
__shared__ float shared1[32]; // shared1[i] == chunk[2*i+1]
416418

417-
for (int i = threadIdx.x; i < head_dim; i += warpSize) {
419+
for (int i = 2*threadIdx.x; i < head_dim; i += 2*warpSize) {
418420
if (warp_id == 0) {
419-
shared[threadIdx.x] = 0;
421+
shared0[threadIdx.x] = 0;
422+
shared1[threadIdx.x] = 0;
420423
}
421424
__syncthreads();
422-
float sum = 0.0;
425+
float2 sum01 = make_float2(0.0, 0.0);
423426
for (int t = warp_id; t < seq_len; t += t_stride) {
424-
sum += __half2float(vh[kv_stride * t + i]) * atth[t];
427+
float2 v01 = __half22float2(*((half2*)&vh[kv_stride * t + i]));
428+
float att_t = atth[t];
429+
// Sadly CUDA does not have float2 SIMD ops
430+
sum01.x += v01.x * att_t;
431+
sum01.y += v01.y * att_t;
425432
}
426-
atomicAdd(&shared[threadIdx.x], sum);
433+
atomicAdd(&shared0[threadIdx.x], sum01.x);
434+
atomicAdd(&shared1[threadIdx.x], sum01.y);
427435
__syncthreads();
428436
if (warp_id == 0) {
429-
outh[i] = shared[threadIdx.x];
430-
shared[threadIdx.x] = 0;
437+
float even = shared0[threadIdx.x];
438+
float odd = shared1[threadIdx.x];
439+
*((float2*)&outh[i]) = make_float2(even, odd);
440+
shared0[threadIdx.x] = 0;
441+
shared1[threadIdx.x] = 0;
431442
}
432443
}
433444
}
@@ -459,10 +470,12 @@ inline void rope(
459470
float fcr = cosf(val);
460471
float fci = sinf(val);
461472

462-
float v0 = x[pair_idx];
463-
float v1 = x[pair_idx + 1];
464-
out[pair_idx] = v0 * fcr - v1 * fci;
465-
out[pair_idx + 1] = v0 * fci + v1 * fcr;
473+
float2 v01 = *((float2*)&x[pair_idx]);
474+
float2 result = make_float2(
475+
v01.x * fcr - v01.y * fci,
476+
v01.x * fci + v01.y * fcr
477+
);
478+
*((float2*)&out[pair_idx]) = result;
466479
}
467480
}
468481

@@ -477,10 +490,12 @@ inline void rope(
477490
float fcr = cosf(val);
478491
float fci = sinf(val);
479492

480-
float v0 = x[pair_idx];
481-
float v1 = x[pair_idx + 1];
482-
out[pair_idx] = __float2half(v0 * fcr - v1 * fci);
483-
out[pair_idx + 1] = __float2half(v0 * fci + v1 * fcr);
493+
float2 v01 = *((float2*)&x[pair_idx]);
494+
half2 result = __floats2half2_rn(
495+
v01.x * fcr - v01.y * fci,
496+
v01.x * fci + v01.y * fcr
497+
);
498+
*((half2*)&out[pair_idx]) = result;
484499
}
485500
}
486501

@@ -495,10 +510,12 @@ inline void rope(
495510
float fcr = cosf(val);
496511
float fci = sinf(val);
497512

498-
float v0 = __half2float(x[pair_idx]);
499-
float v1 = __half2float(x[pair_idx + 1]);
500-
out[pair_idx] = __float2half(v0 * fcr - v1 * fci);
501-
out[pair_idx + 1] = __float2half(v0 * fci + v1 * fcr);
513+
float2 v01 = __half22float2(*((half2*)&x[pair_idx]));
514+
half2 result = __floats2half2_rn(
515+
v01.x * fcr - v01.y * fci,
516+
v01.x * fci + v01.y * fcr
517+
);
518+
*((half2*)&out[pair_idx]) = result;
502519
}
503520
}
504521

0 commit comments

Comments
 (0)