@@ -330,7 +330,7 @@ void fused_qkv_matmul_clip(
330330
331331__global__
332332void attn (
333- const float * kb, // (max_seq_len, n_kv_heads, head_dim)
333+ const half * kb, // (max_seq_len, n_kv_heads, head_dim)
334334 const float * q, // (n_heads, head_dim)
335335 int head_dim,
336336 int kv_len,
@@ -345,10 +345,10 @@ void attn(
345345 if (t >= kv_len || h >= n_heads) return ;
346346
347347 const float * query = q + h * head_dim;
348- const float * key = kb + n_kv_heads * head_dim * t + head_dim * group;
348+ const half * key = kb + n_kv_heads * head_dim * t + head_dim * group;
349349 float score = 0.0 ;
350350 for (int i = 0 ; i < head_dim; i++) {
351- score += query[i] * key[i];
351+ score += query[i] * __half2float ( key[i]) ;
352352 }
353353 out[h * max_seq_len + t] = score / sqrtf ((float )head_dim);
354354}
@@ -389,7 +389,7 @@ void attn_softmax(
389389
390390__global__
391391void att_mix (
392- const float * vb, // (max_seq_len, n_kv_heads, head_dim)
392+ const half * vb, // (max_seq_len, n_kv_heads, head_dim)
393393 const float * att, // (n_heads, kv_len)
394394 int head_dim,
395395 int n_heads,
@@ -405,7 +405,7 @@ void att_mix(
405405 int kv_stride = n_kv_heads * head_dim;
406406
407407 const float * atth = att + max_seq_len * h;
408- const float * vh = vb + head_dim * g;
408+ const half * vh = vb + head_dim * g;
409409 float * outh = out + head_dim * h;
410410
411411 int warp_id = threadIdx .y ;
@@ -421,7 +421,7 @@ void att_mix(
421421 __syncthreads ();
422422 float sum = 0.0 ;
423423 for (int t = warp_id; t < seq_len; t += t_stride) {
424- sum += vh[kv_stride * t + i] * atth[t];
424+ sum += __half2float ( vh[kv_stride * t + i]) * atth[t];
425425 }
426426 atomicAdd (&shared[threadIdx .x ], sum);
427427 __syncthreads ();
@@ -466,6 +466,42 @@ inline void rope(
466466 }
467467}
468468
469+ __device__
470+ inline void rope (
471+ const float * x, int pair_idx, int head_dim, int pos, float theta, int rotary_dim, half* out
472+ ) {
473+ int j_head = pair_idx % head_dim;
474+ if (j_head < head_dim - 1 ) { // Ensure we have a pair of elements
475+ float freq = j_head >= rotary_dim ? 0 .f : 1 .0f / powf (theta, (float )j_head / (float )rotary_dim);
476+ float val = pos * freq;
477+ float fcr = cosf (val);
478+ float fci = sinf (val);
479+
480+ float v0 = x[pair_idx];
481+ float v1 = x[pair_idx + 1 ];
482+ out[pair_idx] = __float2half (v0 * fcr - v1 * fci);
483+ out[pair_idx + 1 ] = __float2half (v0 * fci + v1 * fcr);
484+ }
485+ }
486+
487+ __device__
488+ inline void rope (
489+ const half* x, int pair_idx, int head_dim, int pos, float theta, int rotary_dim, half* out
490+ ) {
491+ int j_head = pair_idx % head_dim;
492+ if (j_head < head_dim - 1 ) { // Ensure we have a pair of elements
493+ float freq = j_head >= rotary_dim ? 0 .f : 1 .0f / powf (theta, (float )j_head / (float )rotary_dim);
494+ float val = pos * freq;
495+ float fcr = cosf (val);
496+ float fci = sinf (val);
497+
498+ float v0 = __half2float (x[pair_idx]);
499+ float v1 = __half2float (x[pair_idx + 1 ]);
500+ out[pair_idx] = __float2half (v0 * fcr - v1 * fci);
501+ out[pair_idx + 1 ] = __float2half (v0 * fci + v1 * fcr);
502+ }
503+ }
504+
469505template <ActivationType A> __device__ inline float act (float x);
470506template <> __device__ inline float act<ActivationType::SILU>(float x) {
471507 return x / (1 .0f + expf (-x));
@@ -538,8 +574,8 @@ void fused_rope_and_cache_update(
538574 float theta, // RoPE theta parameter
539575 int rotary_dim, // how many dimensions to rotate
540576 float * q_out, // (n_heads * head_dim,)
541- float * kb, // (max_seq_len, n_kv_heads, head_dim)
542- float * vb // (max_seq_len, n_kv_heads, head_dim)
577+ half * kb, // (max_seq_len, n_kv_heads, head_dim)
578+ half * vb // (max_seq_len, n_kv_heads, head_dim)
543579) {
544580 // Each thread handles two consecutive elements (for RoPE complex rotation)
545581 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -555,7 +591,7 @@ void fused_rope_and_cache_update(
555591
556592 // Handle K matrix RoPE and cache update
557593 if (pair_idx < n_kv_heads * head_dim) {
558- float * k_out = &kb[kv_pos * (n_kv_heads * head_dim)];
594+ half * k_out = &kb[kv_pos * (n_kv_heads * head_dim)];
559595 rope (
560596 k, pair_idx, head_dim, pos,
561597 theta, rotary_dim, k_out
@@ -566,15 +602,15 @@ void fused_rope_and_cache_update(
566602 if (pair_idx < n_kv_heads * head_dim) {
567603 int cache_idx = kv_pos * (n_kv_heads * head_dim) + pair_idx;
568604 if (pair_idx < n_kv_heads * head_dim - 1 ) {
569- vb[cache_idx] = v[pair_idx];
570- vb[cache_idx + 1 ] = v[pair_idx + 1 ];
605+ vb[cache_idx] = __float2half ( v[pair_idx]) ;
606+ vb[cache_idx + 1 ] = __float2half ( v[pair_idx + 1 ]) ;
571607 }
572608 }
573609}
574610
575611__global__
576612void rotate_sink_tokens (
577- float * kb,
613+ half * kb,
578614 int kv_sink, // number of attention sinks
579615 int kv_dim, // size of each entry (all concatenated heads) in KV cache
580616 int head_dim,
@@ -588,7 +624,7 @@ void rotate_sink_tokens(
588624
589625 if (pair_idx < kv_dim) {
590626 for (int r = 0 ; r < kv_sink; r++) {
591- float * k = kb + r * kv_dim;
627+ half * k = kb + r * kv_dim;
592628 rope (k, pair_idx, head_dim, 1 , theta, rotary_dim, k);
593629 }
594630 }
@@ -635,8 +671,8 @@ void Block::_block_cuda(
635671 // Update Q, K with RoPE relative positional encoding:
636672 // complex-valued rotate q and k in each head
637673 // Also copy K, V to KV cache
638- float * kb = key_cache ();
639- float * vb = value_cache ();
674+ half * kb = (half*) key_cache ();
675+ half * vb = (half*) value_cache ();
640676 {
641677 // Calculate number of thread blocks needed
642678 // We need enough threads to handle the largest of:
@@ -749,8 +785,8 @@ void Block::_block_cuda(
749785void mha_cuda (
750786 float * xout, // (n_heads, head_dim)
751787 float * att, // (n_heads, max_seq_len)
752- float * kb, // (max_seq_len, n_kv_heads, head_dim)
753- float * vb, // (max_seq_len, n_kv_heads, head_dim)
788+ f16_t * kb, // (max_seq_len, n_kv_heads, head_dim)
789+ f16_t * vb, // (max_seq_len, n_kv_heads, head_dim)
754790 float * q, // (n_heads, head_dim)
755791 int head_dim, int kv_len, int max_seq_len, int n_heads, int n_kv_heads
756792) {
@@ -759,8 +795,8 @@ void mha_cuda(
759795 // all cuda uploads leak forever...
760796 register_cuda_host (xout, n_heads * head_dim * sizeof (float ));
761797 register_cuda_host (att, n_heads * max_seq_len * sizeof (float ));
762- kb = static_cast <float *>(upload_cuda (kb, max_seq_len * n_kv_heads * head_dim * sizeof (float )));
763- vb = static_cast <float *>(upload_cuda (vb, max_seq_len * n_kv_heads * head_dim * sizeof (float )));
798+ kb = static_cast <f16_t *>(upload_cuda (kb, max_seq_len * n_kv_heads * head_dim * sizeof (f16_t )));
799+ vb = static_cast <f16_t *>(upload_cuda (vb, max_seq_len * n_kv_heads * head_dim * sizeof (f16_t )));
764800 q = static_cast <float *>(upload_cuda (q, n_heads * head_dim * sizeof (float )));
765801 // multihead attention: dot products and softmax
766802 {
@@ -771,7 +807,7 @@ void mha_cuda(
771807 blocks.x = (kv_len + tpb.x - 1 ) / tpb.x ;
772808 blocks.y = (n_heads + tpb.y - 1 ) / tpb.y ;
773809 attn<<<blocks, tpb>>> (
774- kb, q, head_dim, kv_len, max_seq_len, n_heads, n_kv_heads, att
810+ (half*) kb, q, head_dim, kv_len, max_seq_len, n_heads, n_kv_heads, att
775811 );
776812 attn_softmax<<<n_heads, warp_size>>> (
777813 att, kv_len, max_seq_len, n_heads, att
@@ -785,7 +821,7 @@ void mha_cuda(
785821 dim3 blocks;
786822 blocks.x = n_heads;
787823 att_mix<<<blocks, tpb>>> (
788- vb, att,
824+ (half*) vb, att,
789825 head_dim, n_heads, n_kv_heads,
790826 kv_len, max_seq_len, xout
791827 );
0 commit comments