Skip to content

Commit 19a2191

Browse files
committed
f16 kv cache
1 parent 600a77c commit 19a2191

File tree

5 files changed

+130
-67
lines changed

5 files changed

+130
-67
lines changed

src/infer.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@
77
#include "immintrin.h"
88
#include "f16cintrin.h"
99

10+
#if defined(__AVX2__) && defined(__F16C__)
11+
inline float half_to_float(f16_t x) {
12+
return _cvtsh_ss(x);
13+
}
14+
inline f16_t float_to_half(float x) {
15+
return _cvtss_sh(x, 0);
16+
}
17+
#else
18+
inline float half_to_float(f16_t x) {
19+
assert(false && "float16 not supported on this platform");
20+
return 0.0f;
21+
}
22+
inline f16_t float_to_half(float x) {
23+
assert(false && "float16 not supported on this platform");
24+
return 0;
25+
}
26+
#endif
27+
1028
#if DEBUG_MODEL
1129
static std::map<std::string, std::vector<float>> _debug_map;
1230
std::map<std::string, std::vector<float>>& debug_map_cpu() {
@@ -164,8 +182,8 @@ void attn(
164182
float* xout, // (dim,) - output vector
165183
float* atth, // (kv_len,) - scratch space to hold attention scores of the sequence
166184
float* qh, // (head_dim,) - query vector for this head
167-
float* kh, // (kv_len, n_kv_heads, head_dim) - buffer containing key vectors of the sequence for all KV heads
168-
float* vh, // (kv_len, n_kv_heads, head_dim) - buffer containing value vectors of the sequence for all KV heads
185+
f16_t* kh, // (kv_len, n_kv_heads, head_dim) - buffer containing key vectors of the sequence for all KV heads
186+
f16_t* vh, // (kv_len, n_kv_heads, head_dim) - buffer containing value vectors of the sequence for all KV heads
169187
int head_dim, // size of the "key-space"
170188
int n_kv_heads, // number of kv heads, can be < n_heads (1 is MultiQueryAttention, >1 is GroupedQueryAttention)
171189
int kv_len // number of tokens of the sequence we will attend over
@@ -175,7 +193,7 @@ void attn(
175193
for (int t = 0; t < kv_len; ++t) {
176194
float score = 0.0f;
177195
for (int i = 0; i < head_dim; ++i) {
178-
score += qh[i] * kh[t * kv_stride + i];
196+
score += qh[i] * half_to_float(kh[t * kv_stride + i]);
179197
}
180198
score /= sqrtf(head_dim);
181199
atth[t] = score;
@@ -188,7 +206,7 @@ void attn(
188206
for (int i = 0; i < head_dim; ++i) {
189207
float vi = 0.0f;
190208
for (int t = 0; t < kv_len; ++t) {
191-
vi += atth[t] * vh[t * kv_stride + i];
209+
vi += atth[t] * half_to_float(vh[t * kv_stride + i]);
192210
}
193211
xout[i] = vi;
194212
}
@@ -239,12 +257,12 @@ void Block::_block_cpu(
239257
rope(s.k(), kv_dim, c.head_dim, pos, c.rope_theta, c.rotary_dim);
240258

241259
// key and value point to the kv cache
242-
float* kb = key_cache();
243-
float* vb = value_cache();
260+
f16_t* kb = key_cache();
261+
f16_t* vb = value_cache();
244262
// update kv cache
245263
for (int i = 0; i < kv_dim; ++i) {
246-
kb[kv_pos * kv_dim + i] = s.k()[i];
247-
vb[kv_pos * kv_dim + i] = s.v()[i];
264+
kb[kv_pos * kv_dim + i] = float_to_half(s.k()[i]);
265+
vb[kv_pos * kv_dim + i] = float_to_half(s.v()[i]);
248266
}
249267

250268
// Sink tokens remain untouched while the rest of the KV cache is incrementally
@@ -253,13 +271,13 @@ void Block::_block_cpu(
253271
// forward by 1. See https://arxiv.org/abs/2309.17453 for more.
254272
for (int r = 0; r < kv_sink; r++) {
255273
for (int i = 0; i < kv_dim; ++i) {
256-
s.k()[i] = kb[r * kv_dim + i];
274+
s.k()[i] = half_to_float(kb[r * kv_dim + i]);
257275
}
258276

259277
rope(s.k(), kv_dim, c.head_dim, 1, c.rope_theta, c.rotary_dim);
260278

261279
for (int i = 0; i < kv_dim; i++) {
262-
kb[r * kv_dim + i] = s.k()[i];
280+
kb[r * kv_dim + i] = float_to_half(s.k()[i]);
263281
}
264282
}
265283

@@ -269,8 +287,8 @@ void Block::_block_cpu(
269287
#pragma omp parallel for private(h)
270288
for (h = 0; h < c.n_heads; h++) {
271289
int kv_head_offset = (h / q_per_kv_head) * c.head_dim;
272-
float* kh = kb + kv_head_offset;
273-
float* vh = vb + kv_head_offset;
290+
f16_t* kh = kb + kv_head_offset;
291+
f16_t* vh = vb + kv_head_offset;
274292
attn(s.xb2(h), s.att(h), s.q(h), kh, vh, c.head_dim, c.n_kv_heads, kv_len);
275293
}
276294

@@ -319,8 +337,8 @@ void Block::_block_cpu(
319337
void mha_cpu(
320338
float* xout, // (n_heads, head_dim)
321339
float* att, // (n_heads, max_seq_len)
322-
float* kb, // (max_seq_len, n_kv_heads, head_dim)
323-
float* vb, // (max_seq_len, n_kv_heads, head_dim)
340+
f16_t* kb, // (max_seq_len, n_kv_heads, head_dim)
341+
f16_t* vb, // (max_seq_len, n_kv_heads, head_dim)
324342
float* q, // (n_heads, head_dim)
325343
int head_dim, int kv_len, int max_seq_len, int n_heads, int n_kv_heads
326344
) {
@@ -330,8 +348,8 @@ void mha_cpu(
330348
#pragma omp parallel for private(h)
331349
for (h = 0; h < n_heads; h++) {
332350
int kv_head_offset = (h / q_per_kv_head) * head_dim;
333-
float* kh = kb + kv_head_offset;
334-
float* vh = vb + kv_head_offset;
351+
f16_t* kh = kb + kv_head_offset;
352+
f16_t* vh = vb + kv_head_offset;
335353
attn(
336354
xout + head_dim * h, att + max_seq_len * h, q + head_dim * h,
337355
kh, vh, head_dim, n_kv_heads, kv_len
@@ -393,14 +411,10 @@ void Model::_copy_embedding(InferenceState& s, int token) {
393411
break;
394412
}
395413
case DType::F16: {
396-
#if defined(__AVX2__) && defined(__F16C__)
397414
f16_t* emb = static_cast<f16_t*>(token_embedding_table);
398415
for (int i = 0; i < c.dim; i+=1) {
399-
s.x()[i] = _cvtsh_ss(emb[token * c.dim + i]);
416+
s.x()[i] = half_to_float(emb[token * c.dim + i]);
400417
}
401-
#else
402-
assert(false && "float16 not supported on this platform");
403-
#endif
404418
break;
405419
}
406420
default: {

src/infer.cu

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ void fused_qkv_matmul_clip(
330330

331331
__global__
332332
void attn(
333-
const float* kb, // (max_seq_len, n_kv_heads, head_dim)
333+
const half* kb, // (max_seq_len, n_kv_heads, head_dim)
334334
const float* q, // (n_heads, head_dim)
335335
int head_dim,
336336
int kv_len,
@@ -345,10 +345,10 @@ void attn(
345345
if (t >= kv_len || h >= n_heads) return;
346346

347347
const float* query = q + h * head_dim;
348-
const float* key = kb + n_kv_heads * head_dim * t + head_dim * group;
348+
const half* key = kb + n_kv_heads * head_dim * t + head_dim * group;
349349
float score = 0.0;
350350
for (int i = 0; i < head_dim; i++) {
351-
score += query[i] * key[i];
351+
score += query[i] * __half2float(key[i]);
352352
}
353353
out[h * max_seq_len + t] = score / sqrtf((float)head_dim);
354354
}
@@ -389,7 +389,7 @@ void attn_softmax(
389389

390390
__global__
391391
void att_mix(
392-
const float* vb, // (max_seq_len, n_kv_heads, head_dim)
392+
const half* vb, // (max_seq_len, n_kv_heads, head_dim)
393393
const float* att, // (n_heads, kv_len)
394394
int head_dim,
395395
int n_heads,
@@ -405,7 +405,7 @@ void att_mix(
405405
int kv_stride = n_kv_heads * head_dim;
406406

407407
const float* atth = att + max_seq_len * h;
408-
const float* vh = vb + head_dim * g;
408+
const half* vh = vb + head_dim * g;
409409
float* outh = out + head_dim * h;
410410

411411
int warp_id = threadIdx.y;
@@ -421,7 +421,7 @@ void att_mix(
421421
__syncthreads();
422422
float sum = 0.0;
423423
for (int t = warp_id; t < seq_len; t += t_stride) {
424-
sum += vh[kv_stride * t + i] * atth[t];
424+
sum += __half2float(vh[kv_stride * t + i]) * atth[t];
425425
}
426426
atomicAdd(&shared[threadIdx.x], sum);
427427
__syncthreads();
@@ -466,6 +466,42 @@ inline void rope(
466466
}
467467
}
468468

469+
__device__
470+
inline void rope(
471+
const float* x, int pair_idx, int head_dim, int pos, float theta, int rotary_dim, half* out
472+
) {
473+
int j_head = pair_idx % head_dim;
474+
if (j_head < head_dim - 1) { // Ensure we have a pair of elements
475+
float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
476+
float val = pos * freq;
477+
float fcr = cosf(val);
478+
float fci = sinf(val);
479+
480+
float v0 = x[pair_idx];
481+
float v1 = x[pair_idx + 1];
482+
out[pair_idx] = __float2half(v0 * fcr - v1 * fci);
483+
out[pair_idx + 1] = __float2half(v0 * fci + v1 * fcr);
484+
}
485+
}
486+
487+
__device__
488+
inline void rope(
489+
const half* x, int pair_idx, int head_dim, int pos, float theta, int rotary_dim, half* out
490+
) {
491+
int j_head = pair_idx % head_dim;
492+
if (j_head < head_dim - 1) { // Ensure we have a pair of elements
493+
float freq = j_head >= rotary_dim ? 0.f : 1.0f / powf(theta, (float)j_head / (float)rotary_dim);
494+
float val = pos * freq;
495+
float fcr = cosf(val);
496+
float fci = sinf(val);
497+
498+
float v0 = __half2float(x[pair_idx]);
499+
float v1 = __half2float(x[pair_idx + 1]);
500+
out[pair_idx] = __float2half(v0 * fcr - v1 * fci);
501+
out[pair_idx + 1] = __float2half(v0 * fci + v1 * fcr);
502+
}
503+
}
504+
469505
template <ActivationType A> __device__ inline float act(float x);
470506
template<> __device__ inline float act<ActivationType::SILU>(float x) {
471507
return x / (1.0f + expf(-x));
@@ -538,8 +574,8 @@ void fused_rope_and_cache_update(
538574
float theta, // RoPE theta parameter
539575
int rotary_dim, // how many dimensions to rotate
540576
float* q_out, // (n_heads * head_dim,)
541-
float* kb, // (max_seq_len, n_kv_heads, head_dim)
542-
float* vb // (max_seq_len, n_kv_heads, head_dim)
577+
half* kb, // (max_seq_len, n_kv_heads, head_dim)
578+
half* vb // (max_seq_len, n_kv_heads, head_dim)
543579
) {
544580
// Each thread handles two consecutive elements (for RoPE complex rotation)
545581
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -555,7 +591,7 @@ void fused_rope_and_cache_update(
555591

556592
// Handle K matrix RoPE and cache update
557593
if (pair_idx < n_kv_heads * head_dim) {
558-
float* k_out = &kb[kv_pos * (n_kv_heads * head_dim)];
594+
half* k_out = &kb[kv_pos * (n_kv_heads * head_dim)];
559595
rope(
560596
k, pair_idx, head_dim, pos,
561597
theta, rotary_dim, k_out
@@ -566,15 +602,15 @@ void fused_rope_and_cache_update(
566602
if (pair_idx < n_kv_heads * head_dim) {
567603
int cache_idx = kv_pos * (n_kv_heads * head_dim) + pair_idx;
568604
if (pair_idx < n_kv_heads * head_dim - 1) {
569-
vb[cache_idx] = v[pair_idx];
570-
vb[cache_idx + 1] = v[pair_idx + 1];
605+
vb[cache_idx] = __float2half(v[pair_idx]);
606+
vb[cache_idx + 1] = __float2half(v[pair_idx + 1]);
571607
}
572608
}
573609
}
574610

575611
__global__
576612
void rotate_sink_tokens(
577-
float* kb,
613+
half* kb,
578614
int kv_sink, // number of attention sinks
579615
int kv_dim, // size of each entry (all concatenated heads) in KV cache
580616
int head_dim,
@@ -588,7 +624,7 @@ void rotate_sink_tokens(
588624

589625
if (pair_idx < kv_dim) {
590626
for (int r = 0; r < kv_sink; r++) {
591-
float* k = kb + r * kv_dim;
627+
half* k = kb + r * kv_dim;
592628
rope(k, pair_idx, head_dim, 1, theta, rotary_dim, k);
593629
}
594630
}
@@ -635,8 +671,8 @@ void Block::_block_cuda(
635671
// Update Q, K with RoPE relative positional encoding:
636672
// complex-valued rotate q and k in each head
637673
// Also copy K, V to KV cache
638-
float* kb = key_cache();
639-
float* vb = value_cache();
674+
half* kb = (half*)key_cache();
675+
half* vb = (half*)value_cache();
640676
{
641677
// Calculate number of thread blocks needed
642678
// We need enough threads to handle the largest of:
@@ -749,8 +785,8 @@ void Block::_block_cuda(
749785
void mha_cuda(
750786
float* xout, // (n_heads, head_dim)
751787
float* att, // (n_heads, max_seq_len)
752-
float* kb, // (max_seq_len, n_kv_heads, head_dim)
753-
float* vb, // (max_seq_len, n_kv_heads, head_dim)
788+
f16_t* kb, // (max_seq_len, n_kv_heads, head_dim)
789+
f16_t* vb, // (max_seq_len, n_kv_heads, head_dim)
754790
float* q, // (n_heads, head_dim)
755791
int head_dim, int kv_len, int max_seq_len, int n_heads, int n_kv_heads
756792
) {
@@ -759,8 +795,8 @@ void mha_cuda(
759795
// all cuda uploads leak forever...
760796
register_cuda_host(xout, n_heads * head_dim * sizeof(float));
761797
register_cuda_host(att, n_heads * max_seq_len * sizeof(float));
762-
kb = static_cast<float*>(upload_cuda(kb, max_seq_len * n_kv_heads * head_dim * sizeof(float)));
763-
vb = static_cast<float*>(upload_cuda(vb, max_seq_len * n_kv_heads * head_dim * sizeof(float)));
798+
kb = static_cast<f16_t*>(upload_cuda(kb, max_seq_len * n_kv_heads * head_dim * sizeof(f16_t)));
799+
vb = static_cast<f16_t*>(upload_cuda(vb, max_seq_len * n_kv_heads * head_dim * sizeof(f16_t)));
764800
q = static_cast<float*>(upload_cuda(q, n_heads * head_dim * sizeof(float)));
765801
// multihead attention: dot products and softmax
766802
{
@@ -771,7 +807,7 @@ void mha_cuda(
771807
blocks.x = (kv_len + tpb.x - 1) / tpb.x;
772808
blocks.y = (n_heads + tpb.y - 1) / tpb.y;
773809
attn<<<blocks, tpb>>>(
774-
kb, q, head_dim, kv_len, max_seq_len, n_heads, n_kv_heads, att
810+
(half*)kb, q, head_dim, kv_len, max_seq_len, n_heads, n_kv_heads, att
775811
);
776812
attn_softmax<<<n_heads, warp_size>>>(
777813
att, kv_len, max_seq_len, n_heads, att
@@ -785,7 +821,7 @@ void mha_cuda(
785821
dim3 blocks;
786822
blocks.x = n_heads;
787823
att_mix<<<blocks, tpb>>>(
788-
vb, att,
824+
(half*)vb, att,
789825
head_dim, n_heads, n_kv_heads,
790826
kv_len, max_seq_len, xout
791827
);

src/model.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ size_t Config::active_bytes(size_t pos) const {
7373
bytes_per_block += n_heads * head_dim * dim * weight_size; // wo
7474
bytes_per_block += 3 * dim * hidden_dim * weight_size; // w1, w2, w3
7575
size_t kv_len = std::min(static_cast<size_t>(max_seq_len), pos + 1);
76-
size_t kv_entry_size = sizeof(float);
76+
size_t kv_entry_size = sizeof(f16_t);
7777
bytes_per_block += 2 * kv_len * n_kv_heads * head_dim * kv_entry_size; // key_cache, value_cache
7878

7979
size_t bytes = 0;
@@ -174,8 +174,8 @@ Block::Block(
174174
w3, config->weight_dtype, {config->hidden_dim, config->dim, 0, 0}
175175
);
176176

177-
_key_cache = new float[config->max_seq_len * config->n_kv_heads * config->head_dim]();
178-
_value_cache = new float[config->max_seq_len * config->n_kv_heads * config->head_dim]();
177+
_key_cache = new f16_t[config->max_seq_len * config->n_kv_heads * config->head_dim]();
178+
_value_cache = new f16_t[config->max_seq_len * config->n_kv_heads * config->head_dim]();
179179
}
180180

181181
Block::~Block() {
@@ -210,8 +210,8 @@ void Block::cuda() {
210210
_w3 = upload_cuda(_w3, _config->hidden_dim * _config->dim * weight_size);
211211

212212
// kv cache
213-
_key_cache = static_cast<float*>(upload_cuda(_key_cache, _config->max_seq_len * _config->n_kv_heads * _config->head_dim * sizeof(float)));
214-
_value_cache = static_cast<float*>(upload_cuda(_value_cache, _config->max_seq_len * _config->n_kv_heads * _config->head_dim * sizeof(float)));
213+
_key_cache = static_cast<f16_t*>(upload_cuda(_key_cache, _config->max_seq_len * _config->n_kv_heads * _config->head_dim * sizeof(f16_t)));
214+
_value_cache = static_cast<f16_t*>(upload_cuda(_value_cache, _config->max_seq_len * _config->n_kv_heads * _config->head_dim * sizeof(f16_t)));
215215
}
216216

217217
void Block::block(

0 commit comments

Comments
 (0)