@@ -411,23 +411,34 @@ void att_mix(
411411 int warp_id = threadIdx .y ;
412412 int t_stride = blockDim .y ;
413413
414- // Capacity 32 since there can be at most 32 warps in a block.
415- __shared__ float shared[32 ];
414+ // Each lane of the warp accumulates across 2 head elements at a time.
415+ // NOTE: Assumes warpSize is 32
416+ __shared__ float shared0[32 ]; // shared0[i] == chunk[2*i]
417+ __shared__ float shared1[32 ]; // shared1[i] == chunk[2*i+1]
416418
417- for (int i = threadIdx .x ; i < head_dim; i += warpSize ) {
419+ for (int i = 2 * threadIdx .x ; i < head_dim; i += 2 * warpSize ) {
418420 if (warp_id == 0 ) {
419- shared[threadIdx .x ] = 0 ;
421+ shared0[threadIdx .x ] = 0 ;
422+ shared1[threadIdx .x ] = 0 ;
420423 }
421424 __syncthreads ();
422- float sum = 0.0 ;
425+ float2 sum01 = make_float2 ( 0.0 , 0.0 ) ;
423426 for (int t = warp_id; t < seq_len; t += t_stride) {
424- sum += __half2float (vh[kv_stride * t + i]) * atth[t];
427+ float2 v01 = __half22float2 (*((half2*)&vh[kv_stride * t + i]));
428+ float att_t = atth[t];
429+ // Sadly CUDA does not have float2 SIMD ops
430+ sum01.x += v01.x * att_t ;
431+ sum01.y += v01.y * att_t ;
425432 }
426- atomicAdd (&shared[threadIdx .x ], sum);
433+ atomicAdd (&shared0[threadIdx .x ], sum01.x );
434+ atomicAdd (&shared1[threadIdx .x ], sum01.y );
427435 __syncthreads ();
428436 if (warp_id == 0 ) {
429- outh[i] = shared[threadIdx .x ];
430- shared[threadIdx .x ] = 0 ;
437+ float even = shared0[threadIdx .x ];
438+ float odd = shared1[threadIdx .x ];
439+ *((float2 *)&outh[i]) = make_float2 (even, odd);
440+ shared0[threadIdx .x ] = 0 ;
441+ shared1[threadIdx .x ] = 0 ;
431442 }
432443 }
433444}
@@ -459,10 +470,12 @@ inline void rope(
459470 float fcr = cosf (val);
460471 float fci = sinf (val);
461472
462- float v0 = x[pair_idx];
463- float v1 = x[pair_idx + 1 ];
464- out[pair_idx] = v0 * fcr - v1 * fci;
465- out[pair_idx + 1 ] = v0 * fci + v1 * fcr;
473+ float2 v01 = *((float2 *)&x[pair_idx]);
474+ float2 result = make_float2 (
475+ v01.x * fcr - v01.y * fci,
476+ v01.x * fci + v01.y * fcr
477+ );
478+ *((float2 *)&out[pair_idx]) = result;
466479 }
467480}
468481
@@ -477,10 +490,12 @@ inline void rope(
477490 float fcr = cosf (val);
478491 float fci = sinf (val);
479492
480- float v0 = x[pair_idx];
481- float v1 = x[pair_idx + 1 ];
482- out[pair_idx] = __float2half (v0 * fcr - v1 * fci);
483- out[pair_idx + 1 ] = __float2half (v0 * fci + v1 * fcr);
493+ float2 v01 = *((float2 *)&x[pair_idx]);
494+ half2 result = __floats2half2_rn (
495+ v01.x * fcr - v01.y * fci,
496+ v01.x * fci + v01.y * fcr
497+ );
498+ *((half2*)&out[pair_idx]) = result;
484499 }
485500}
486501
@@ -495,10 +510,12 @@ inline void rope(
495510 float fcr = cosf (val);
496511 float fci = sinf (val);
497512
498- float v0 = __half2float (x[pair_idx]);
499- float v1 = __half2float (x[pair_idx + 1 ]);
500- out[pair_idx] = __float2half (v0 * fcr - v1 * fci);
501- out[pair_idx + 1 ] = __float2half (v0 * fci + v1 * fcr);
513+ float2 v01 = __half22float2 (*((half2*)&x[pair_idx]));
514+ half2 result = __floats2half2_rn (
515+ v01.x * fcr - v01.y * fci,
516+ v01.x * fci + v01.y * fcr
517+ );
518+ *((half2*)&out[pair_idx]) = result;
502519 }
503520}
504521
0 commit comments