diff --git a/Makefile b/Makefile index 27283ba0..eae8ad58 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ help: @echo " make clean Remove build outputs" cuda-spark: - $(MAKE) ds4 ds4-server ds4-bench ds4-eval ds4-agent CUDA_ARCH= + $(MAKE) ds4 ds4-server ds4-bench ds4-eval ds4-agent CUDA_ARCH=sm_120 cuda-generic: $(MAKE) ds4 ds4-server ds4-bench ds4-eval ds4-agent CUDA_ARCH=native diff --git a/ds4.c b/ds4.c index 21573fb8..2e5dd9b1 100644 --- a/ds4.c +++ b/ds4.c @@ -1918,6 +1918,585 @@ static void dsv4_fp8_kv_quantize_row_inplace_cpu(float *x, uint32_t head_dim, ui } } +/* ========================================================================= + * TurboQuant+ turbo3 KV quality simulation. + * ========================================================================= + * + * Sibling of dsv4_fp8_kv_quantize_row_inplace_cpu. Same group-of-64 structure; + * same in-place "consume float row, write float row with quant error baked in" + * contract; same RoPE-tail (last n_rot elements) left untouched. + * + * Per group the round trip is: + * 1. Randomized Hadamard rotation: signs1 (Rademacher mask) -> 64-point WHT + * -> 1/sqrt(64) normalize -> signs2 (Rademacher mask). Gaussianizes the + * per-coordinate distribution (Lindeberg-CLT) so a fixed Lloyd-Max codebook + * for N(0,1) attains near-MSE-optimal distortion regardless of the input + * activation distribution. Both sign tables come from Python's + * random.Random(seed) with Bernoulli(0.5) - deterministic and documented + * below. + * 2. Per-group amax -> scale = TURBO3_MAX / amax (so |scaled value| <= MAX). + * 3. 3-bit Lloyd-Max quant (8 levels) for N(0,1), nearest-centroid index 0..7. + * 4. Matched-norm L2 correction: replace the amax scale with + * ||original|| / ||centroid_recon|| so the dequantized group has the same + * L2 norm as the input - frees ~0.5% PPL on average vs amax-only. We + * clamp the scale to the FP8 E4M3 representable range so a Metal port + * that stores the scale as packed FP8 stays bit-equivalent. + * 5. Lookup dequantized centroid * matched-norm scale. + * 6. Inverse rotation: signs2 -> 64-point WHT -> 1/sqrt(64) -> signs1. WHT + * is its own inverse but (S2.H.S1).(S2.H.S1) != I, so the forward+inverse + * pair needs swapped sign order; applied here in one closed-form sequence + * so the row stays in the original basis. + * + * The diagnostic switch DS4_TURBO_NO_SIGNS=1 in the environment drops the + * Rademacher masks (all-positive signs). Plain WHT is strictly weaker than + * the canonical Randomized Hadamard form and is provided only for A/B testing + * the sign contribution to PPL - not a release path. + * + * Prior art chain: Google TurboQuant (arXiv:2504.19874, ICLR 2026) -> + * TheTom/turboquant_plus umbrella -> TheTom/llama-cpp-turboquant engine + * reference. The 128/256/512 sign tables used elsewhere in TQ+ are + * byte-identical to that fork; the 64-element tables below are derived by the + * same Bernoulli(0.5) recipe and are documented here as the canonical ds4 + * reference. */ + +/* Lloyd-Max 8-level codebook for N(0,1). MSE = 0.03454 vs raw N(0,1) sample. + * Centroids and decision boundaries match TURBO3_CODEBOOK / TURBO3_BOUNDS in + * Atlas's reshape_and_cache_turbo.cu and the corresponding LUT in the + * paged_decode_attn_turbo3 dequant path. */ +static const float DS4_TURBO3_CODEBOOK[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, 0.2451f, 0.7560f, 1.3440f, 2.1520f +}; +static const float DS4_TURBO3_BOUNDS[7] = { + -1.748f, -1.050f, -0.501f, 0.0f, 0.501f, 1.050f, 1.748f +}; +#define DS4_TURBO3_MAX 2.1520f + +/* FP8 E4M3 max representable. Matched-norm scale is clamped here so a future + * Metal storage path can pack the scale into one FP8 byte per 64-element group + * without an extra renormalization pass. */ +#define DS4_FP8_E4M3_MAX 448.0f + +/* Two-sided Rademacher signs for the 64-point WHT. Generated by Python's + * random.Random(seed) with Bernoulli(0.5) mapped to {-1,+1}. Same canonical + * approach as the 128/256/512 tables vendored into TheTom/llama-cpp-turboquant + * - only the length and seed differ, picked here to match ds4's natural + * per-group cadence of 64. + * + * Why these arrays are static const float (not __device__ __constant__): the + * CPU reference path uses them directly. The CUDA kernel (ds4_cuda.cu) re- + * emits identical tables as __device__ __constant__ at file scope - the two + * sources are kept byte-equivalent by inspection and verified by the unit test + * that compares CPU vs GPU round-trip on a fixed seed. */ +static const float DS4_TURBO_SIGNS1_64[64] = { + +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, + +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +}; +static const float DS4_TURBO_SIGNS2_64[64] = { + +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +}; + +/* Walsh-Hadamard butterfly on a 64-element scratch. Self-inverse up to the + * 1/sqrt(64) normalization (applied once by the caller in each direction). */ +static void dsv4_turbo3_wht64_inplace_cpu(float *v) { + for (uint32_t stride = 1; stride < 64; stride <<= 1) { + for (uint32_t base = 0; base < 64; base += 2u * stride) { + for (uint32_t i = 0; i < stride; i++) { + const float a = v[base + i]; + const float b = v[base + stride + i]; + v[base + i] = a + b; + v[base + stride + i] = a - b; + } + } + } +} + +/* Nearest-centroid lookup against DS4_TURBO3_BOUNDS. Branchless-ish: one + * 3-way descent matching the binary search in Atlas's turbo3_quantize. */ +static int dsv4_turbo3_quantize_index_cpu(float x) { + int idx; + if (x >= DS4_TURBO3_BOUNDS[3]) { /* >= 0 */ + idx = 4; + if (x >= DS4_TURBO3_BOUNDS[5]) { idx = 6; if (x >= DS4_TURBO3_BOUNDS[6]) idx = 7; } + else if (x >= DS4_TURBO3_BOUNDS[4]) idx = 5; + } else { + idx = 0; + if (x >= DS4_TURBO3_BOUNDS[1]) { idx = 2; if (x >= DS4_TURBO3_BOUNDS[2]) idx = 3; } + else if (x >= DS4_TURBO3_BOUNDS[0]) idx = 1; + } + return idx; +} + +/* Cached env query for the diagnostic no-signs switch. Reads once per process. + * Returns 1 when the env var is set to a non-empty, non-"0" value. */ +static int dsv4_turbo_signs_enabled_cpu(void) { + static int cached = -1; + if (cached < 0) { + const char *s = getenv("DS4_TURBO_NO_SIGNS"); + cached = (s && s[0] && !(s[0] == '0' && s[1] == 0)) ? 0 : 1; + } + return cached; +} + +/* In-place turbo3 quality round trip on one MLA latent KV row. RoPE tail (last + * n_rot elements) is left untouched, matching dsv4_fp8_kv_quantize_row_inplace_cpu. + * + * Group boundaries align with the existing FP8 path (every 64 floats), so the + * surrounding cache code is dtype-agnostic. Storage layout unchanged. */ +static void dsv4_turbo3_kv_quantize_row_inplace_cpu(float *x, uint32_t head_dim, uint32_t n_rot) { + const uint32_t n_nope = head_dim - n_rot; + const int signs_on = dsv4_turbo_signs_enabled_cpu(); + float buf[64]; + + for (uint32_t off = 0; off < n_nope; off += 64) { + /* 1. forward Randomized Hadamard rotation. */ + if (signs_on) { + for (uint32_t i = 0; i < 64; i++) buf[i] = x[off + i] * DS4_TURBO_SIGNS1_64[i]; + } else { + for (uint32_t i = 0; i < 64; i++) buf[i] = x[off + i]; + } + dsv4_turbo3_wht64_inplace_cpu(buf); + /* WHT normalization 1/sqrt(64). Cast the 64 to float so the divide is + * unambiguously float-by-float - clang-tidy bugprone-integer-division + * has been noisy on the literal form in other parts of the tree. */ + const float inv_sqrt_n = 1.0f / sqrtf(64.0f); + for (uint32_t i = 0; i < 64; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + for (uint32_t i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + + /* 2. per-group amax -> scale into the codebook range. */ + float amax = 0.0f, norm_sq = 0.0f; + for (uint32_t i = 0; i < 64; i++) { + const float v = buf[i]; + const float av = fabsf(v); + if (av > amax) amax = av; + norm_sq += v * v; + } + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX / amax) : 1.0f; + + /* 3+4. Lloyd-Max quantize and accumulate centroid recon L2 norm. + * Matched-norm scale = ||original|| / ||centroid_recon||. Falls back + * to the amax scale on degenerate (all-zero or all-clipped) groups. */ + int idx[64]; + float recon_sq = 0.0f; + for (uint32_t i = 0; i < 64; i++) { + idx[i] = dsv4_turbo3_quantize_index_cpu(buf[i] * k_inv); + const float c = DS4_TURBO3_CODEBOOK[idx[i]]; + recon_sq += c * c; + } + const float recon_norm = sqrtf(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrtf(norm_sq) / recon_norm) : (amax / DS4_TURBO3_MAX); + if (scale > DS4_FP8_E4M3_MAX) scale = DS4_FP8_E4M3_MAX; + + /* 5. dequantize centroid * matched-norm scale, writing back into buf. */ + for (uint32_t i = 0; i < 64; i++) { + buf[i] = DS4_TURBO3_CODEBOOK[idx[i]] * scale; + } + + /* 6. inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1. WHT is + * its own inverse so the same butterfly runs, but the sign order swaps. + * The first signs2 mask cancels the post-WHT signs2 from the forward + * pass; the trailing signs1 cancels the pre-WHT signs1. */ + if (signs_on) { + for (uint32_t i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + dsv4_turbo3_wht64_inplace_cpu(buf); + const float inv_sqrt_n2 = 1.0f / sqrtf(64.0f); + for (uint32_t i = 0; i < 64; i++) buf[i] *= inv_sqrt_n2; + if (signs_on) { + for (uint32_t i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS1_64[i]; + } + + for (uint32_t i = 0; i < 64; i++) x[off + i] = buf[i]; + } +} + +/* ========================================================================= + * Packed-turbo3 byte-level helpers. + * ========================================================================= + * + * `pack_group64`: take 64 floats (already WHT-rotated in the group basis), + * matched-norm L2 quantize them, write 24 bytes packed data + 1 FP8 scale. + * + * `unpack_group64`: inverse - take 24 bytes + 1 FP8 scale, expand to 64 + * rotated-basis floats (centroid * scale), then apply iWHT-with-signs to + * return values in the original basis. + * + * The two together give a lossless-modulo-FP8-scale round trip: pack(unpack(B)) + * recovers B byte-for-byte; unpack(pack(F)) gives F * (1 + per-group quant + * error). The dequant on the read path goes: + * 24 bytes data, 1 byte scale --(LUT + FP8 cvt + mul)--> 64 rotated floats + * --(iWHT-with-signs + 1/sqrt(64) + signs1)--> 64 original-basis floats + * which matches what the in-place float-sim quantizer above wrote for the + * same input, modulo the FP8 scale's E4M3 precision (~12% per group). */ + +static unsigned char dsv4_turbo3_float_to_fp8_e4m3_cpu(float x) { + /* Match Atlas's `float_to_fp8` (sat to E4M3 max=448) via the matching CPU + * E4M3 dequant table search. Stored as the nearest E4M3 grid value's + * 0..127 index encoded as a single byte; sign always positive here since + * matched-norm scale is non-negative. We follow the same convention as + * `dsv4_e4m3fn_dequant_cpu` above so the CPU and CUDA paths agree. */ + if (!(x > 0.0f)) return 0u; /* NaN-safe: matched_scale clamped > 0 */ + if (x > 448.0f) x = 448.0f; + /* Use the existing e4m3 quantize: pick the nearest representable value via + * the same nearest-centroid binary search as `dsv4_e4m3fn_dequant_cpu`. */ + int lo = 0; + int hi = 126; + while (lo < hi) { + const int mid = (lo + hi + 1) >> 1; + if (dsv4_e4m3fn_value_cpu(mid) <= x) { + lo = mid; + } else { + hi = mid - 1; + } + } + int best = lo; + if (best < 126) { + const float best_diff = fabsf(x - dsv4_e4m3fn_value_cpu(best)); + const float next_diff = fabsf(x - dsv4_e4m3fn_value_cpu(best + 1)); + if (next_diff < best_diff || (next_diff == best_diff && ((best + 1) & 1) == 0 && (best & 1) != 0)) { + best++; + } + } + /* Encode as raw E4M3 byte: bit 7 = sign (0 here), bits 6..0 = index. + * Matches the bit pattern of `__nv_fp8_storage_t` storage. */ + return (unsigned char)(best & 0x7f); +} + +static float dsv4_turbo3_fp8_e4m3_to_float_cpu(unsigned char b) { + const int sign_bit = (b >> 7) & 1; + const int idx = b & 0x7f; + const float v = dsv4_e4m3fn_value_cpu(idx); + return sign_bit ? -v : v; +} + +/* Pack 64 rotated-basis floats into 24 data bytes + 1 FP8 scale byte. + * + * data_out: pointer to 24 contiguous bytes (the group's data slice). + * scale_out: pointer to 1 byte (the group's FP8 scale slot). + * rotated: the 64 floats in the rotated basis (already amax/norm-aware). + * + * The caller is responsible for having pre-rotated the input via the same + * WHT+signs1+signs2 pipeline used by the float-sim quantizer above - see + * dsv4_turbo3_kv_quantize_row_inplace_cpu for the canonical sequence. We + * pack here AFTER the rotation; the iWHT is applied on the read side. */ +static void dsv4_turbo3_pack_group64_cpu( + unsigned char *data_out, + unsigned char *scale_out, + const float *rotated) { + /* Per-group amax + L2 norm. amax controls the codebook scale; L2 norm + * sets the matched-norm output scale. */ + float amax = 0.0f, norm_sq = 0.0f; + for (int i = 0; i < 64; i++) { + const float v = rotated[i]; + const float av = fabsf(v); + if (av > amax) amax = av; + norm_sq += v * v; + } + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX / amax) : 1.0f; + + /* Quantize + matched-norm scale (same algorithm as the float-sim path). */ + int idx[64]; + float recon_sq = 0.0f; + for (int i = 0; i < 64; i++) { + idx[i] = dsv4_turbo3_quantize_index_cpu(rotated[i] * k_inv); + const float c = DS4_TURBO3_CODEBOOK[idx[i]]; + recon_sq += c * c; + } + const float recon_norm = sqrtf(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrtf(norm_sq) / recon_norm) + : (amax / DS4_TURBO3_MAX); + if (scale > DS4_FP8_E4M3_MAX) scale = DS4_FP8_E4M3_MAX; + *scale_out = dsv4_turbo3_float_to_fp8_e4m3_cpu(scale); + + /* Pack 64 indices into 24 bytes: 8 indices per 3 bytes, 8 chunks. + * Layout per chunk: + * b0 = i0 | (i1 << 3) | (i2 << 6) + * b1 = (i2 >> 2) | (i3 << 1) | (i4 << 4) | (i5 << 7) + * b2 = (i5 >> 1) | (i6 << 2) | (i7 << 5) + * Identical to reshape_and_cache_flash_turbo3 in + * atlas/kernels/gb10/common/reshape_and_cache_turbo.cu. */ + for (int chunk = 0; chunk < 8; chunk++) { + const int *p = &idx[chunk * 8]; + unsigned char *b = &data_out[chunk * 3]; + b[0] = (unsigned char)((p[0]) | (p[1] << 3) | ((p[2] & 0x3) << 6)); + b[1] = (unsigned char)((p[2] >> 2) | (p[3] << 1) | (p[4] << 4) | ((p[5] & 0x1) << 7)); + b[2] = (unsigned char)((p[5] >> 1) | (p[6] << 2) | (p[7] << 5)); + } +} + +/* Unpack 24 data bytes + 1 FP8 scale into 64 rotated-basis floats. + * + * out: 64 floats in the rotated basis (centroid * scale, no iWHT). + * data_in: 24 contiguous bytes (the group's data slice). + * scale_in: 1 byte (the group's FP8 E4M3 matched-norm scale). + * + * Bit layout per 3-byte chunk (matches `nvfp4_dequant` in + * atlas/kernels/gb10/common/paged_decode_attn_turbo3_128.cu): + * i0 = b0 & 7 + * i1 = (b0 >> 3) & 7 + * i2 = ((b0 >> 6) | (b1 << 2)) & 7 + * i3 = (b1 >> 1) & 7 + * i4 = (b1 >> 4) & 7 + * i5 = ((b1 >> 7) | (b2 << 1)) & 7 + * i6 = (b2 >> 2) & 7 + * i7 = (b2 >> 5) & 7 (top 3 bits - no overflow concern) + */ +static void dsv4_turbo3_unpack_group64_rotated_cpu( + float *out, + const unsigned char *data_in, + unsigned char scale_in) { + const float scale = dsv4_turbo3_fp8_e4m3_to_float_cpu(scale_in); + for (int chunk = 0; chunk < 8; chunk++) { + const unsigned char *b = &data_in[chunk * 3]; + float *o = &out[chunk * 8]; + const unsigned int b0 = b[0], b1 = b[1], b2 = b[2]; + o[0] = DS4_TURBO3_CODEBOOK[(b0) & 0x7] * scale; + o[1] = DS4_TURBO3_CODEBOOK[(b0 >> 3) & 0x7] * scale; + o[2] = DS4_TURBO3_CODEBOOK[((b0 >> 6) | (b1 << 2)) & 0x7] * scale; + o[3] = DS4_TURBO3_CODEBOOK[(b1 >> 1) & 0x7] * scale; + o[4] = DS4_TURBO3_CODEBOOK[(b1 >> 4) & 0x7] * scale; + o[5] = DS4_TURBO3_CODEBOOK[((b1 >> 7) | (b2 << 1)) & 0x7] * scale; + o[6] = DS4_TURBO3_CODEBOOK[(b2 >> 2) & 0x7] * scale; + o[7] = DS4_TURBO3_CODEBOOK[(b2 >> 5) & 0x7] * scale; + } +} + +/* Full pack: original-basis row -> packed bytes + RoPE tail. + * + * Applies forward Randomized Hadamard rotation per 64-element group, then + * pack_group64 for data+scale. RoPE tail (last n_rot floats) copied straight + * through as little-endian floats at the end of the packed row. */ +static DS4_MAYBE_UNUSED void dsv4_turbo3_kv_pack_row_cpu( + unsigned char *dst, + const float *src, + uint32_t head_dim, + uint32_t n_rot) { + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / DS4_TURBO3_GROUP_SIZE; + const int signs_on = dsv4_turbo_signs_enabled_cpu(); + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + + float buf[DS4_TURBO3_GROUP_SIZE]; + const float inv_sqrt_n = 1.0f / sqrtf((float)DS4_TURBO3_GROUP_SIZE); + + for (uint32_t g = 0; g < n_groups; g++) { + /* Forward rotation: signs1 -> WHT -> 1/sqrt(64) -> signs2. */ + const float *gs = src + (uint64_t)g * DS4_TURBO3_GROUP_SIZE; + if (signs_on) { + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] = gs[i] * DS4_TURBO_SIGNS1_64[i]; + } else { + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] = gs[i]; + } + dsv4_turbo3_wht64_inplace_cpu(buf); + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + + /* Pack into the data section + the group's scale slot. */ + unsigned char *data_slot = dst + (uint64_t)g * 24u; + unsigned char *scale_slot = dst + data_bytes + (uint64_t)g; + dsv4_turbo3_pack_group64_cpu(data_slot, scale_slot, buf); + } + + /* RoPE tail: raw floats appended at the end of the packed row. */ + if (n_rot > 0) { + const uint64_t scale_bytes = (uint64_t)n_groups; + unsigned char *rope_slot = dst + data_bytes + scale_bytes; + memcpy(rope_slot, src + n_nope, (size_t)n_rot * sizeof(float)); + } +} + +/* Full unpack: packed bytes + RoPE tail -> original-basis floats. + * + * Inverse of dsv4_turbo3_kv_pack_row_cpu. Per group: unpack to rotated + * floats, then iWHT-with-signs (signs2 -> WHT -> 1/sqrt(64) -> signs1) to + * return to the original basis. RoPE tail copied straight back. */ +static DS4_MAYBE_UNUSED void dsv4_turbo3_kv_unpack_row_cpu( + float *dst, + const unsigned char *src, + uint32_t head_dim, + uint32_t n_rot) { + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / DS4_TURBO3_GROUP_SIZE; + const int signs_on = dsv4_turbo_signs_enabled_cpu(); + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const float inv_sqrt_n = 1.0f / sqrtf((float)DS4_TURBO3_GROUP_SIZE); + + float buf[DS4_TURBO3_GROUP_SIZE]; + + for (uint32_t g = 0; g < n_groups; g++) { + const unsigned char *data_slot = src + (uint64_t)g * 24u; + const unsigned char scale_slot = src[data_bytes + g]; + dsv4_turbo3_unpack_group64_rotated_cpu(buf, data_slot, scale_slot); + + /* Inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1. */ + if (signs_on) { + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + dsv4_turbo3_wht64_inplace_cpu(buf); + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + for (uint32_t i = 0; i < DS4_TURBO3_GROUP_SIZE; i++) buf[i] *= DS4_TURBO_SIGNS1_64[i]; + } + + float *gd = dst + (uint64_t)g * DS4_TURBO3_GROUP_SIZE; + memcpy(gd, buf, sizeof(buf)); + } + + if (n_rot > 0) { + const uint64_t scale_bytes = (uint64_t)n_groups; + const unsigned char *rope_slot = src + data_bytes + scale_bytes; + memcpy(dst + n_nope, rope_slot, (size_t)n_rot * sizeof(float)); + } +} + +/* Active KV cache dtype. Set once by ds4_engine_open from the parsed CLI flag + * and read by the dispatch helper below. File-scope so the seven existing + * cache-store sites in this file (CPU prefill, CPU streaming-decode, + * compressor-decode, MTP, restore, etc.) stay one line each - threading a + * dtype argument through the layer call chain would have touched dozens of + * inner functions for no semantic gain. + * + * Read concurrency: written exactly once before any session is created. The + * helper readers are all called from the single session thread, so no atomic + * or fence is required. */ +static ds4_kv_dtype g_ds4_kv_dtype = DS4_KV_FP8; + +static void ds4_kv_set_active_dtype(ds4_kv_dtype dtype) { g_ds4_kv_dtype = dtype; } + +/* Dtype-aware in-place round trip on one MLA latent KV row. Picks the FP8 or + * turbo3 path based on the engine-wide dtype set at open time. */ +static void ds4_kv_quantize_row_inplace_cpu(float *x, uint32_t head_dim, uint32_t n_rot) { + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + dsv4_turbo3_kv_quantize_row_inplace_cpu(x, head_dim, n_rot); + } else { + dsv4_fp8_kv_quantize_row_inplace_cpu(x, head_dim, n_rot); + } +} + +#ifndef DS4_NO_GPU +/* GPU dispatchers - same dtype-based pick, but for the CUDA tensor helpers. + * Sites in this file call these instead of the raw fp8 wrappers so a single + * dtype enum decides which kernel runs. + * + * On the turbo3 path: `_quantize_tensor_dispatch` still runs the float-sim + * round trip on `x` (kept for sites that mutate the KV tensor in place but + * then write it to a NON-raw_cache destination - e.g. the compressor pool). + * Sites that store into the per-layer raw_cache go through the + * `_packed_store_raw_tensor` path below which writes packed bytes directly + * via the pack kernel. */ +static int ds4_gpu_kv_quantize_tensor_dispatch( + ds4_gpu_tensor *x, uint32_t n_tok, uint32_t head_dim, uint32_t n_rot) { + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + return ds4_gpu_dsv4_turbo3_kv_quantize_tensor(x, n_tok, head_dim, n_rot); + } + return ds4_gpu_dsv4_fp8_kv_quantize_tensor(x, n_tok, head_dim, n_rot); +} + +/* Single-row store into the per-layer raw_cache. fp8 path is unchanged; + * turbo3 path writes packed bytes via the ring-aware batch pack kernel with + * n_tokens=1 so the underlying cache buffer can be sized smaller. */ +static int ds4_gpu_kv_store_raw_tensor_dispatch( + ds4_gpu_tensor *kv, ds4_gpu_tensor *raw_cache, + uint32_t raw_cap, uint32_t row, uint32_t head_dim, uint32_t n_rot) { + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + const uint64_t row_bytes = ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3); + return ds4_gpu_dsv4_turbo3_kv_pack_batch_tensor( + kv, raw_cache, raw_cap, row, 1, head_dim, n_rot, row_bytes); + } + return ds4_gpu_kv_fp8_store_raw_tensor(kv, raw_cache, raw_cap, row, head_dim, n_rot); +} + +/* Batch store into the per-layer raw_cache. fp8 routes to the existing f16 + * batch path; turbo3 routes to the ring-aware packed batch pack kernel. */ +static int ds4_gpu_kv_store_raw_batch_tensor_dispatch( + ds4_gpu_tensor *raw_cache, ds4_gpu_tensor *src, + uint32_t raw_cap, uint32_t pos0, uint32_t n_tokens, + uint32_t head_dim, uint32_t n_rot) { + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + const uint64_t row_bytes = ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3); + return ds4_gpu_dsv4_turbo3_kv_pack_batch_tensor( + src, raw_cache, raw_cap, pos0, n_tokens, head_dim, n_rot, row_bytes); + } + return ds4_gpu_store_raw_kv_batch_tensor(raw_cache, src, raw_cap, pos0, n_tokens, head_dim); +} + +/* Returns the tensor that attention should READ for the raw KV window. + * + * fp8: returns `raw_cache` as-is (zero overhead). + * turbo3 (packed bytes): dequants `raw_cap` rows of `raw_cache` into the + * caller-provided `scratch` (raw_cap * head_dim floats) and returns + * `scratch`. Callers that share one scratch across multiple attention + * calls in the same layer can amortize the dequant by caching the result + * for the layer/pos pair - see `metal_graph_encode_decode_layer` for the + * one-shot-per-layer pattern. Returns NULL if a dequant launch fails. */ +static ds4_gpu_tensor *ds4_gpu_kv_attention_view_dispatch( + ds4_gpu_tensor *raw_cache, ds4_gpu_tensor *scratch, + uint32_t raw_cap, uint32_t head_dim, uint32_t n_rot) { + if (g_ds4_kv_dtype != DS4_KV_TURBO3) return raw_cache; + if (!scratch || !raw_cache) return NULL; + const uint64_t row_bytes = ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3); + if (ds4_gpu_dsv4_turbo3_kv_dequant_to_scratch_tensor( + raw_cache, scratch, raw_cap, head_dim, n_rot, row_bytes) == 0) { + return NULL; + } + return scratch; +} +#endif + +/* Per-row byte size in the cache for a given dtype. See ds4.h for layout. + * + * fp8 (float-sim): plain `head_dim * sizeof(float)` row. + * turbo3 (packed): data (n_nope*3/8) + scales (n_nope/64) + RoPE tail (n_rot*4). + * The data layout per 64-element group is 24 packed bytes (8 values per 3 + * bytes, repeated 8 times -> 24 = 8*3) and one FP8 E4M3 scale byte. The + * rope tail is appended as raw little-endian floats at the end of the row. + * + * Always returns >= head_dim*4 for fp8 and the packed total for turbo3 - no + * padding. Callers that need alignment add it themselves. */ +uint64_t ds4_kv_row_bytes(uint32_t head_dim, uint32_t n_rot, ds4_kv_dtype dtype) { + if (head_dim <= n_rot) { + /* Pathological: no non-RoPE part to compress. Fall back to floats. */ + return (uint64_t)head_dim * sizeof(float); + } + if (dtype == DS4_KV_TURBO3) { + const uint32_t n_nope = head_dim - n_rot; + /* Round group count up. ds4 invariably gives a 64-aligned n_nope + * (448 in practice) but the cast keeps the formula honest for future + * head shapes. */ + const uint32_t n_groups = (n_nope + DS4_TURBO3_GROUP_SIZE - 1u) / DS4_TURBO3_GROUP_SIZE; + const uint64_t data_bytes = ((uint64_t)n_nope * 3u + 7u) / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + const uint64_t rope_bytes = (uint64_t)n_rot * sizeof(float); + return data_bytes + scale_bytes + rope_bytes; + } + return (uint64_t)head_dim * sizeof(float); +} + +/* Public-name aliases for ds4_kv_dtype. Used by the CLI and by tests so the + * canonical strings live in one place. */ +const char *ds4_kv_dtype_name(ds4_kv_dtype dtype) { + switch (dtype) { + case DS4_KV_TURBO3: return "turbo3"; + case DS4_KV_FP8: return "fp8"; + default: return "fp8"; + } +} + +int ds4_kv_dtype_from_name(const char *name, ds4_kv_dtype *out) { + if (!name || !out) return 0; + if (!strcmp(name, "fp8")) { *out = DS4_KV_FP8; return 1; } + if (!strcmp(name, "turbo3")) { *out = DS4_KV_TURBO3; return 1; } + return 0; +} + static float dsv4_e2m1fn_value_cpu(int i) { static const float values[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, @@ -5105,7 +5684,10 @@ static void layer_kv_projection_normed_one_decode_scratch( } static float rope_yarn_ramp(float low, float high, int i0) { - const float y = ((float)(i0 / 2) - low) / fmaxf(0.001f, high - low); + /* (float)i0 / 2.0f, not (float)(i0/2) - keep the divide in float so we + * preserve sub-2 RoPE fractional dims and silence clang-tidy + * bugprone-integer-division. */ + const float y = ((float)i0 / 2.0f - low) / fmaxf(0.001f, high - low); return 1.0f - fminf(1.0f, fmaxf(0.0f, y)); } @@ -6940,7 +7522,7 @@ static bool compressor_decode_one( const uint32_t comp_pos = pos + 1 - compress_ratio; rope_tail_layer_inplace(out_comp, 1, head_dim, DS4_N_ROT, comp_pos, il, false); if (head_dim == DS4_N_HEAD_DIM) { - dsv4_fp8_kv_quantize_row_inplace_cpu(out_comp, head_dim, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(out_comp, head_dim, DS4_N_ROT); } else if (head_dim == DS4_N_INDEXER_HEAD_DIM) { dsv4_indexer_qat_row_inplace_cpu(out_comp, head_dim); } @@ -7028,7 +7610,7 @@ static bool compressor_decode_one_decode_scratch( const uint32_t comp_pos = pos + 1 - compress_ratio; rope_tail_layer_inplace(out_comp, 1, head_dim, DS4_N_ROT, comp_pos, il, false); if (head_dim == DS4_N_HEAD_DIM) { - dsv4_fp8_kv_quantize_row_inplace_cpu(out_comp, head_dim, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(out_comp, head_dim, DS4_N_ROT); } else if (head_dim == DS4_N_INDEXER_HEAD_DIM) { dsv4_indexer_qat_row_inplace_cpu(out_comp, head_dim); } @@ -7464,7 +8046,7 @@ static void layer_attention_raw_swa_one( rope_tail_layer_inplace(q, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); rope_tail_layer_inplace(kv, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(kv, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(kv, DS4_N_HEAD_DIM, DS4_N_ROT); kv_cache_push_raw(cache, kv); @@ -7698,7 +8280,7 @@ static void layer_attention_raw_swa_batch( rope_tail_layer_inplace(q_t, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); rope_tail_layer_inplace(kv_t, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); } - dsv4_fp8_kv_quantize_row_inplace_cpu(kv_t, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(kv_t, DS4_N_HEAD_DIM, DS4_N_ROT); kv_cache_push_raw(cache, kv_t); if (profile) t_tl_rope_cache += now_sec() - tx; @@ -7947,7 +8529,7 @@ static void layer_forward_raw_swa_one( t0 = profile ? now_sec() : 0.0; rope_tail_layer_inplace(scratch->q, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); rope_tail_layer_inplace(scratch->kv, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(scratch->kv, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(scratch->kv, DS4_N_HEAD_DIM, DS4_N_ROT); kv_cache_push_raw(cache, scratch->kv); if (profile) t_rope_cache = now_sec() - t0; @@ -8317,7 +8899,7 @@ static void layer_forward_self_one( layer_kv_projection_normed_one(model, layer, attn_norm, kv); rope_tail_layer_inplace(q, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); rope_tail_layer_inplace(kv, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, pos, il, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(kv, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(kv, DS4_N_HEAD_DIM, DS4_N_ROT); f16_round_inplace_cpu(kv, DS4_N_HEAD_DIM); layer_attention_one(heads, model, layer, q, kv); @@ -8548,6 +9130,23 @@ typedef struct { ds4_gpu_tensor *layer_index_state_kv[DS4_MAX_LAYER]; ds4_gpu_tensor *layer_index_state_score[DS4_MAX_LAYER]; + /* Turbo3 dequant scratch. When the active dtype is DS4_KV_TURBO3 the + * layer_raw_cache buffers are sized as packed bytes (~431 B/row vs + * 2048 B/row for fp8), so attention kernels that lack an inline-dequant + * sibling can't read them directly. Before each such attention dispatch + * the dequant kernel unpacks raw_cap rows from layer_raw_cache[il] into + * this scratch tensor. For fp8 this stays NULL and the attention kernel + * reads layer_raw_cache directly. + * + * Trade-off: real packed cache memory savings on layer_raw_cache (a + * ~9.6 MB shrink on the SWA ring at raw_cap=128, DS4_N_LAYER=43); the + * dequant-to-scratch pass adds ~5 us per attention call at decode T=1 + * (negligible on GX10). The inline-dequant attention kernels below skip + * this hop entirely and read packed rows directly to capture the V-load + * bandwidth win. */ + ds4_gpu_tensor *raw_cache_dequant_scratch; + ds4_gpu_tensor *mtp_raw_cache_dequant_scratch; + /* Speculative decoding scratch. MTP is allowed to mutate graph state only * if the target verifier can either commit it or restore the saved * frontiers. The prefix1 buffers are the cheap partial-accept state for the @@ -8767,6 +9366,8 @@ static void metal_graph_free(ds4_gpu_graph *g) { ds4_gpu_tensor_free(g->prefill_tokens); ds4_gpu_tensor_free(g->logits); ds4_gpu_tensor_free(g->mtp_raw_cache); + ds4_gpu_tensor_free(g->mtp_raw_cache_dequant_scratch); + ds4_gpu_tensor_free(g->raw_cache_dequant_scratch); ds4_gpu_tensor_free(g->mtp_next_hc); ds4_gpu_tensor_free(g->mtp_state_hc); ds4_gpu_tensor_free(g->mtp_input_hc); @@ -9243,10 +9844,24 @@ static bool metal_graph_alloc_raw_cap( g->kv_raw = ds4_gpu_tensor_alloc((uint64_t)DS4_N_HEAD_DIM * sizeof(float)); g->kv = ds4_gpu_tensor_alloc((uint64_t)DS4_N_HEAD_DIM * sizeof(float)); bool state_init_ok = true; + /* Per-dtype raw cache row size. fp8 = head_dim*4 = 2048 bytes; turbo3 = + * packed = 431 bytes for DS4_N_HEAD_DIM=512, DS4_N_ROT=64. See + * ds4_kv_row_bytes(). Each layer's raw cache is allocated at the dtype- + * specific stride; the dequant-to-scratch pass before attention rewrites + * the bytes back into the per-graph `raw_cache_dequant_scratch` float + * tensor that the existing attention kernels read. */ + const uint64_t raw_row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, g_ds4_kv_dtype); + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + g->raw_cache_dequant_scratch = ds4_gpu_tensor_alloc( + (uint64_t)raw_cap * DS4_N_HEAD_DIM * sizeof(float)); + g->mtp_raw_cache_dequant_scratch = enable_mtp + ? ds4_gpu_tensor_alloc((uint64_t)raw_cap * DS4_N_HEAD_DIM * sizeof(float)) + : NULL; + } for (uint32_t il = 0; il < DS4_N_LAYER; il++) { g->layer_raw_cache[il] = metal_graph_alloc_kv_cache_tensor( managed_kv_cache, - (uint64_t)raw_cap * DS4_N_HEAD_DIM * sizeof(float)); + (uint64_t)raw_cap * raw_row_bytes); const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio != 0) { const uint32_t coff = ratio == 4 ? 2u : 1u; @@ -9353,7 +9968,7 @@ static bool metal_graph_alloc_raw_cap( g->mtp_next_hc = ds4_gpu_tensor_alloc(hc_dim * sizeof(float)); g->mtp_raw_cache = metal_graph_alloc_kv_cache_tensor( managed_kv_cache, - (uint64_t)raw_cap * DS4_N_HEAD_DIM * sizeof(float)); + (uint64_t)raw_cap * raw_row_bytes); g->spec_logits = ds4_gpu_tensor_alloc((uint64_t)16 * DS4_N_VOCAB * sizeof(float)); g->mtp_n_raw = 0; } @@ -9669,16 +10284,16 @@ static bool metal_graph_decode_kv_store( uint32_t raw_cap, uint32_t raw_row) { if (metal_graph_use_reference_kv_decode()) { - return ds4_gpu_dsv4_fp8_kv_quantize_tensor(kv, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0 && + return ds4_gpu_kv_quantize_tensor_dispatch(kv, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0 && ds4_gpu_store_raw_kv_tensor(raw_cache, kv, raw_cap, raw_row, DS4_N_HEAD_DIM) != 0; } - return ds4_gpu_kv_fp8_store_raw_tensor(kv, - raw_cache, - raw_cap, - raw_row, - DS4_N_HEAD_DIM, - DS4_N_ROT) != 0; + return ds4_gpu_kv_store_raw_tensor_dispatch(kv, + raw_cache, + raw_cap, + raw_row, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0; } static uint64_t metal_graph_attn_comp_cache_row_bytes(void) { @@ -9982,6 +10597,23 @@ static bool metal_graph_encode_decode_layer( metal_graph_debug_dump_tensor("KVcur", g->kv, DS4_N_HEAD_DIM, il, pos); } + /* Turbo3 read path: the layer raw_cache is byte-packed (431 B/row vs + * 2048 B/row for fp8) so existing float-reading attention kernels can't + * dereference it as `float *raw_kv`. + * + * Kernels with inline-dequant siblings (the decode_heads simple path via + * attention_decode_mixed_turbo3_kernel) read packed bytes directly - no + * view_dispatch hop needed. Kernels without an inline-dequant sibling + * still go through view_dispatch which dequants into the per-graph + * scratch float tensor. + * + * We defer the view_dispatch call to the attention branch below where we + * know which kernel runs. raw_cache (packed bytes in turbo3, float in + * fp8) is passed into the branch unmodified. */ + ds4_gpu_tensor *dequant_scratch = (raw_cache == g->mtp_raw_cache) + ? g->mtp_raw_cache_dequant_scratch + : g->raw_cache_dequant_scratch; + uint32_t n_comp = 0; ds4_gpu_tensor *comp_cache = NULL; ds4_gpu_tensor *comp_selected = NULL; @@ -10059,7 +10691,7 @@ static bool metal_graph_encode_decode_layer( if (!comp_row_view) { ok = false; } else { - ok = ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; + ok = ds4_gpu_kv_quantize_tensor_dispatch(comp_row_view, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; if (ok) { metal_graph_debug_dump_tensor("KVcompress", comp_row_view, DS4_N_HEAD_DIM, il, pos); } @@ -10279,27 +10911,64 @@ static bool metal_graph_encode_decode_layer( if (ok) { const uint32_t raw_start = metal_graph_raw_start_for_span(g, pos, n_raw); if (n_comp != 0 && comp_selected != NULL && n_selected != 0) { - ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor( - g->heads, - model->map, - model->size, - layer->attn_sinks->abs_offset, - g->q, - raw_cache, - g->layer_attn_comp_cache[il], - metal_graph_attn_comp_cache_is_f16(), - comp_selected, - 1, - pos, - n_raw, - raw_cap, - raw_start, - n_comp, - n_selected, - g->raw_window, - ds4_layer_compress_ratio(il), - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + /* Indexed mixed path. The inline-dequant turbo3 sibling covers + * the n_tokens=1 fallback kernel (decode-token, the hot path). + * heads8_online / rb4 paths still need float - fall back to + * view_dispatch when the turbo3 launcher returns 0. */ + int turbo3_rc = 0; + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + turbo3_rc = ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + g->heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->q, + raw_cache, row_bytes, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + comp_selected, + 1, + pos, + n_raw, + raw_cap, + raw_start, + n_comp, + n_selected, + g->raw_window, + ds4_layer_compress_ratio(il), + DS4_N_HEAD, + DS4_N_HEAD_DIM, DS4_N_ROT); + } + if (turbo3_rc == 0) { + ds4_gpu_tensor *raw_cache_attn = (g_ds4_kv_dtype == DS4_KV_TURBO3) + ? ds4_gpu_kv_attention_view_dispatch( + raw_cache, dequant_scratch, + raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT) + : raw_cache; + if (!raw_cache_attn) ok = false; + if (ok) ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor( + g->heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->q, + raw_cache_attn, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + comp_selected, + 1, + pos, + n_raw, + raw_cap, + raw_start, + n_comp, + n_selected, + g->raw_window, + ds4_layer_compress_ratio(il), + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } if (ok && decode_index_stage_profile) { ok = metal_graph_indexer_stage_profile_boundary("decode_attention", il, @@ -10308,7 +10977,45 @@ static bool metal_graph_encode_decode_layer( n_comp, &decode_index_stage_t0); } + } else if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + /* decode_heads has an inline-dequant turbo3 sibling - pass + * packed bytes directly, no view_dispatch dequant. */ + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + int rc = ds4_gpu_attention_decode_heads_turbo3_tensor( + g->heads, + model->map, model->size, + layer->attn_sinks->abs_offset, + g->q, raw_cache, row_bytes, n_raw, + raw_cap, + raw_start, + n_comp ? comp_cache : NULL, + metal_graph_attn_comp_cache_is_f16(), + n_comp, + NULL, + 0, + DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT); + if (rc == 0) { + /* Turbo3 launcher rejected - fall back via dequant + float kernel. */ + ds4_gpu_tensor *raw_cache_attn = ds4_gpu_kv_attention_view_dispatch( + raw_cache, dequant_scratch, + raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT); + if (!raw_cache_attn) ok = false; + if (ok) ok = ds4_gpu_attention_decode_heads_tensor(g->heads, + model->map, model->size, + layer->attn_sinks->abs_offset, + g->q, raw_cache_attn, n_raw, + raw_cap, + raw_start, + n_comp ? comp_cache : NULL, + metal_graph_attn_comp_cache_is_f16(), + n_comp, + NULL, + 0, + DS4_N_HEAD, DS4_N_HEAD_DIM) != 0; + } } else { + /* fp8 path: raw_cache is already float, no view_dispatch needed + * (view_dispatch is a no-op in fp8 mode). */ ok = ds4_gpu_attention_decode_heads_tensor(g->heads, model->map, model->size, layer->attn_sinks->abs_offset, @@ -10903,7 +11610,7 @@ static void metal_graph_trace_layer_stages( layer_kv_projection_normed_one(model, layer, cpu_attn_norm, cpu_kv); rope_tail_layer_inplace(cpu_q, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, 0, il, false); rope_tail_layer_inplace(cpu_kv, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, 0, il, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM, DS4_N_ROT); f16_round_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM); layer_attention_one(cpu_heads, model, layer, cpu_q, cpu_kv); rope_tail_layer_inplace(cpu_heads, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, 0, il, true); @@ -11136,7 +11843,7 @@ static int metal_graph_decode_test( layer_kv_projection_normed_one(model, layer, cpu_attn_norm, cpu_kv); rope_tail_layer_inplace(cpu_q, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, 0, 0, false); rope_tail_layer_inplace(cpu_kv, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, 0, 0, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM, DS4_N_ROT); f16_round_inplace_cpu(cpu_kv, DS4_N_HEAD_DIM); layer_attention_rows_one(cpu_heads, model, layer, cpu_q, cpu_kv, 1); rope_tail_layer_inplace(cpu_heads, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, 0, 0, true); @@ -12065,7 +12772,7 @@ static bool metal_graph_encode_layer_attention_batch( metal_graph_debug_dump_tensor("KVrope", g->batch_kv, (uint64_t)n_tokens * DS4_N_HEAD_DIM, il, pos0); } - if (ok) ok = ds4_gpu_dsv4_fp8_kv_quantize_tensor(g->batch_kv, + if (ok) ok = ds4_gpu_kv_quantize_tensor_dispatch(g->batch_kv, n_tokens, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; @@ -12082,12 +12789,13 @@ static bool metal_graph_encode_layer_attention_batch( * sized to hold the current chunk plus the previous SWA window, while the * attention mask still enforces the 128-token logical window. */ - if (ok && zero_prefix) ok = ds4_gpu_store_raw_kv_batch_tensor(g->layer_raw_cache[il], - g->batch_kv, - g->raw_cap, - pos0, - n_tokens, - DS4_N_HEAD_DIM) != 0; + if (ok && zero_prefix) ok = ds4_gpu_kv_store_raw_batch_tensor_dispatch(g->layer_raw_cache[il], + g->batch_kv, + g->raw_cap, + pos0, + n_tokens, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0; const bool raw_batch_attention = zero_prefix && ratio == 0; bool batch_attention_done = false; @@ -12117,12 +12825,13 @@ static bool metal_graph_encode_layer_attention_batch( const uint32_t raw_start = metal_graph_raw_start_for_span(g, pos0 + n_tokens - 1u, n_raw); - ok = ds4_gpu_store_raw_kv_batch_tensor(g->layer_raw_cache[il], - g->batch_kv, - g->raw_cap, - pos0, - n_tokens, - DS4_N_HEAD_DIM) != 0; + ok = ds4_gpu_kv_store_raw_batch_tensor_dispatch(g->layer_raw_cache[il], + g->batch_kv, + g->raw_cap, + pos0, + n_tokens, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0; if (ok) { metal_graph_debug_dump_tensor("raw_cache", g->layer_raw_cache[il], @@ -12130,21 +12839,51 @@ static bool metal_graph_encode_layer_attention_batch( il, pos0); } - if (ok) { - ok = ds4_gpu_attention_decode_raw_batch_heads_tensor(g->batch_heads, - model->map, - model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->layer_raw_cache[il], - n_tokens, - pos0, - n_raw, - g->raw_cap, - raw_start, - g->raw_window, - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + /* Prefill-chunk raw batch. Try turbo3 heads8_online via the turbo3 + * launcher first; fall back to float dequant-to-scratch + existing + * kernel on rc==0. */ + int turbo3_rc = 0; + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + turbo3_rc = ds4_gpu_attention_decode_mixed_batch_turbo3_heads_tensor( + g->batch_heads, + model->map, model->size, + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], row_bytes, + /* comp_kv */ NULL, + /* comp_kv_f16 */ 0, + /* comp_mask */ NULL, + /* use_comp_mask */ 0, + n_tokens, pos0, n_raw, g->raw_cap, raw_start, + /* n_comp */ 0, + g->raw_window, + /* ratio */ 0, + DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT); + } + if (turbo3_rc == 0) { + ds4_gpu_tensor *raw_cache_attn = ok ? ((g_ds4_kv_dtype == DS4_KV_TURBO3) + ? ds4_gpu_kv_attention_view_dispatch( + g->layer_raw_cache[il], g->raw_cache_dequant_scratch, + g->raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT) + : g->layer_raw_cache[il]) : NULL; + if (ok && !raw_cache_attn) ok = false; + if (ok) { + ok = ds4_gpu_attention_decode_raw_batch_heads_tensor(g->batch_heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->batch_q, + raw_cache_attn, + n_tokens, + pos0, + n_raw, + g->raw_cap, + raw_start, + g->raw_window, + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } } if (ok) batch_attention_done = true; } else if (ok && ratio != 0) { @@ -12415,7 +13154,7 @@ static bool metal_graph_encode_layer_attention_batch( if (ok && emit) { ds4_gpu_tensor *comp_row_view = metal_graph_attn_comp_row_view(g, il, comp_row); ok = comp_row_view && - ds4_gpu_dsv4_fp8_kv_quantize_tensor(comp_row_view, + ds4_gpu_kv_quantize_tensor_dispatch(comp_row_view, 1, DS4_N_HEAD_DIM, DS4_N_ROT) != 0; @@ -12741,12 +13480,13 @@ static bool metal_graph_encode_layer_attention_batch( bool use_indexed_comp = false; double index_stage_t0 = 0.0; - ok = ds4_gpu_store_raw_kv_batch_tensor(g->layer_raw_cache[il], - g->batch_kv, - g->raw_cap, - pos0, - n_tokens, - DS4_N_HEAD_DIM) != 0; + ok = ds4_gpu_kv_store_raw_batch_tensor_dispatch(g->layer_raw_cache[il], + g->batch_kv, + g->raw_cap, + pos0, + n_tokens, + DS4_N_HEAD_DIM, + DS4_N_ROT) != 0; if (ok && ratio == 4 && n_comp > DS4_N_INDEXER_TOP_K) { const float index_scale = 1.0f / sqrtf((float)(DS4_N_INDEXER_HEAD_DIM * DS4_N_INDEXER_HEAD)); if (index_stage_profile) { @@ -12810,14 +13550,49 @@ static bool metal_graph_encode_layer_attention_batch( } use_comp_mask = 1; } - if (ok) { + /* Try the turbo3 inline-dequant launcher first; on rc==0 fall + * back to view_dispatch + float kernel. */ + int turbo3_rc_a = 0; + if (g_ds4_kv_dtype == DS4_KV_TURBO3 && use_indexed_comp) { + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + turbo3_rc_a = ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + g->batch_heads, + model->map, model->size, + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], row_bytes, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + g->comp_selected, + n_tokens, + pos0, + n_raw, + g->raw_cap, + raw_start, + n_comp, + DS4_N_INDEXER_TOP_K, + g->raw_window, + ratio, + DS4_N_HEAD, + DS4_N_HEAD_DIM, DS4_N_ROT); + } + ds4_gpu_tensor *raw_cache_attn_a = NULL; + if (turbo3_rc_a == 0 && ok) { + raw_cache_attn_a = (g_ds4_kv_dtype == DS4_KV_TURBO3) + ? ds4_gpu_kv_attention_view_dispatch( + g->layer_raw_cache[il], g->raw_cache_dequant_scratch, + g->raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT) + : g->layer_raw_cache[il]; + if (!raw_cache_attn_a) ok = false; + } + if (turbo3_rc_a == 0 && ok) { if (use_indexed_comp) { ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, model->map, model->size, layer->attn_sinks->abs_offset, g->batch_q, - g->layer_raw_cache[il], + raw_cache_attn_a, g->layer_attn_comp_cache[il], metal_graph_attn_comp_cache_is_f16(), g->comp_selected, @@ -12846,7 +13621,7 @@ static bool metal_graph_encode_layer_attention_batch( model->size, layer->attn_sinks->abs_offset, g->batch_q, - g->layer_raw_cache[il], + raw_cache_attn_a, g->layer_attn_comp_cache[il], metal_graph_attn_comp_cache_is_f16(), use_comp_mask ? g->comp_mask : NULL, @@ -12925,36 +13700,68 @@ static bool metal_graph_encode_layer_attention_batch( pos0); } } - if (ok) { - ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, - model->map, - model->size, - layer->attn_sinks->abs_offset, - g->batch_q, - g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], - metal_graph_attn_comp_cache_is_f16(), - g->comp_selected, - n_tokens, - pos0, - n_tokens, - g->raw_cap, - 0, - n_comp, - DS4_N_INDEXER_TOP_K, - g->raw_window, - ratio, - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; - if (ok && index_stage_profile) { - ok = metal_graph_indexer_stage_profile_boundary("attention", - il, - pos0, - n_tokens, - n_comp, - &index_stage_t0); + int turbo3_rc_b = 0; + if (g_ds4_kv_dtype == DS4_KV_TURBO3) { + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + turbo3_rc_b = ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + g->batch_heads, + model->map, model->size, + layer->attn_sinks->abs_offset, + g->batch_q, + g->layer_raw_cache[il], row_bytes, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + g->comp_selected, + n_tokens, + pos0, + n_tokens, + g->raw_cap, + 0, + n_comp, + DS4_N_INDEXER_TOP_K, + g->raw_window, + ratio, + DS4_N_HEAD, + DS4_N_HEAD_DIM, DS4_N_ROT); + } + if (turbo3_rc_b == 0) { + ds4_gpu_tensor *raw_cache_attn_b = ok ? ((g_ds4_kv_dtype == DS4_KV_TURBO3) + ? ds4_gpu_kv_attention_view_dispatch( + g->layer_raw_cache[il], g->raw_cache_dequant_scratch, + g->raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT) + : g->layer_raw_cache[il]) : NULL; + if (ok && !raw_cache_attn_b) ok = false; + if (ok) { + ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(g->batch_heads, + model->map, + model->size, + layer->attn_sinks->abs_offset, + g->batch_q, + raw_cache_attn_b, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + g->comp_selected, + n_tokens, + pos0, + n_tokens, + g->raw_cap, + 0, + n_comp, + DS4_N_INDEXER_TOP_K, + g->raw_window, + ratio, + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; } } + if (ok && index_stage_profile) { + ok = metal_graph_indexer_stage_profile_boundary("attention", + il, + pos0, + n_tokens, + n_comp, + &index_stage_t0); + } if (ok) batch_attention_done = true; } if (ok && zero_prefix && !topk_prefill_needed && n_comp != 0) { @@ -13046,50 +13853,103 @@ static bool metal_graph_encode_layer_attention_batch( ds4_gpu_tensor *heads_view = metal_graph_tensor_row_view(g->batch_heads, t, q_dim); ok = ok && q_view && kv_cache_view && heads_view; if (ok && !zero_prefix) { - ok = ds4_gpu_store_raw_kv_tensor(g->layer_raw_cache[il], - kv_cache_view, - g->raw_cap, - pos % g->raw_cap, - DS4_N_HEAD_DIM) != 0; + /* fp8 path: raw f32 copy into the per-row slot. + * turbo3 path: pack into the packed-byte slot. The + * pack-batch dispatch helper above handles both. */ + ok = ds4_gpu_kv_store_raw_batch_tensor_dispatch( + g->layer_raw_cache[il], kv_cache_view, + g->raw_cap, pos % g->raw_cap, 1u, + DS4_N_HEAD_DIM, DS4_N_ROT) != 0; } - if (ok && comp_mask != NULL && n_selected != 0) { - ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(heads_view, - model->map, - model->size, - layer->attn_sinks->abs_offset, - q_view, - g->layer_raw_cache[il], - g->layer_attn_comp_cache[il], - metal_graph_attn_comp_cache_is_f16(), - g->comp_selected, - 1, - pos, - n_raw, - g->raw_cap, - raw_start, - cur_comp, - n_selected, - g->raw_window, - ratio, - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; - } else if (ok) { - ok = ds4_gpu_attention_decode_heads_tensor(heads_view, - model->map, - model->size, - layer->attn_sinks->abs_offset, - q_view, - g->layer_raw_cache[il], - n_raw, - g->raw_cap, - raw_start, - cur_comp ? g->layer_attn_comp_cache[il] : NULL, - metal_graph_attn_comp_cache_is_f16(), - cur_comp, - comp_mask, - n_selected, - DS4_N_HEAD, - DS4_N_HEAD_DIM) != 0; + int turbo3_rc_c = 0; + if (g_ds4_kv_dtype == DS4_KV_TURBO3 && ok) { + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, DS4_KV_TURBO3); + if (comp_mask != NULL && n_selected != 0) { + turbo3_rc_c = ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + heads_view, + model->map, model->size, + layer->attn_sinks->abs_offset, + q_view, + g->layer_raw_cache[il], row_bytes, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + g->comp_selected, + 1, + pos, + n_raw, + g->raw_cap, + raw_start, + cur_comp, + n_selected, + g->raw_window, + ratio, + DS4_N_HEAD, + DS4_N_HEAD_DIM, DS4_N_ROT); + } else { + turbo3_rc_c = ds4_gpu_attention_decode_heads_turbo3_tensor( + heads_view, + model->map, model->size, + layer->attn_sinks->abs_offset, + q_view, + g->layer_raw_cache[il], row_bytes, + n_raw, + g->raw_cap, + raw_start, + cur_comp ? g->layer_attn_comp_cache[il] : NULL, + metal_graph_attn_comp_cache_is_f16(), + cur_comp, + comp_mask, + n_selected, + DS4_N_HEAD, + DS4_N_HEAD_DIM, DS4_N_ROT); + } + } + if (turbo3_rc_c == 0) { + ds4_gpu_tensor *raw_cache_attn_c = ok ? ((g_ds4_kv_dtype == DS4_KV_TURBO3) + ? ds4_gpu_kv_attention_view_dispatch( + g->layer_raw_cache[il], g->raw_cache_dequant_scratch, + g->raw_cap, DS4_N_HEAD_DIM, DS4_N_ROT) + : g->layer_raw_cache[il]) : NULL; + if (ok && !raw_cache_attn_c) ok = false; + if (ok && comp_mask != NULL && n_selected != 0) { + ok = ds4_gpu_attention_indexed_mixed_batch_heads_tensor(heads_view, + model->map, + model->size, + layer->attn_sinks->abs_offset, + q_view, + raw_cache_attn_c, + g->layer_attn_comp_cache[il], + metal_graph_attn_comp_cache_is_f16(), + g->comp_selected, + 1, + pos, + n_raw, + g->raw_cap, + raw_start, + cur_comp, + n_selected, + g->raw_window, + ratio, + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } else if (ok) { + ok = ds4_gpu_attention_decode_heads_tensor(heads_view, + model->map, + model->size, + layer->attn_sinks->abs_offset, + q_view, + raw_cache_attn_c, + n_raw, + g->raw_cap, + raw_start, + cur_comp ? g->layer_attn_comp_cache[il] : NULL, + metal_graph_attn_comp_cache_is_f16(), + cur_comp, + comp_mask, + n_selected, + DS4_N_HEAD, + DS4_N_HEAD_DIM) != 0; + } } ds4_gpu_tensor_free(heads_view); ds4_gpu_tensor_free(kv_cache_view); @@ -15070,6 +15930,11 @@ struct ds4_engine { bool quality; bool metal_ready; bool mtp_ready; + /* KV cache dtype: DS4_KV_FP8 (default, historical path) or DS4_KV_TURBO3 + * (TurboQuant+ quality simulation on the compressed-KV non-RoPE part). + * Set once at engine open from ds4_engine_options.kv_dtype; immutable + * thereafter so cache values within a session stay consistent. */ + ds4_kv_dtype kv_dtype; }; static bool cpu_directional_steering_enabled( @@ -16271,6 +17136,20 @@ static int generate_metal_graph_raw_swa( } #endif +ds4_kv_footprint ds4_kv_footprint_estimate(ds4_backend backend, int ctx_size, ds4_kv_dtype dtype) { + ds4_kv_footprint f = {0}; + /* Reuse the existing cap arithmetic. The float-vs-packed swap only + * affects raw_bytes (per-dtype row size); the compressed pools are kept + * float / F16 because the compressor pool integrates softmax-weighted + * accumulations that require an original-basis read. */ + const ds4_context_memory m = ds4_context_memory_estimate(backend, ctx_size); + const uint64_t row_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, dtype); + f.raw_bytes = (uint64_t)DS4_N_LAYER * m.raw_cap * row_bytes; + f.compressed_bytes = m.compressed_bytes; + f.total_bytes = f.raw_bytes + f.compressed_bytes; + return f; +} + #ifdef DS4_NO_GPU ds4_context_memory ds4_context_memory_estimate(ds4_backend backend, int ctx_size) { (void)backend; @@ -16449,8 +17328,26 @@ struct ds4_session { */ #define DS4_SESSION_PAYLOAD_MAGIC UINT32_C(0x34565344) /* "DSV4" */ -#define DS4_SESSION_PAYLOAD_VERSION UINT32_C(1) -#define DS4_SESSION_PAYLOAD_U32_FIELDS 13u +/* Session payload format versions: + * v1: original (DSV4 magic + 13 u32 fields), all raw KV rows stored as + * DS4_N_HEAD_DIM * sizeof(float). + * v2: adds one u32 kv_dtype field (DS4_KV_FP8 or DS4_KV_TURBO3). When the + * saved dtype is DS4_KV_TURBO3, raw KV rows are stored at the packed + * byte stride ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3) instead + * of head_dim*4. Compressor + indexer state stay as floats either way. + * + * Backward compat: a v2 reader that sees a v1 file falls back to the v1 + * header read (13 fields) and assumes kv_dtype = DS4_KV_FP8. A v1 reader + * sees a v2 file as "unsupported session payload version" and refuses to + * load - caller can retry with a fresh prompt. + * + * Cross-dtype reject: v2 reader compares the saved kv_dtype against the + * active engine dtype. Mismatch -> clear error message; user has to either + * switch dtypes or discard the cached prefix. */ +#define DS4_SESSION_PAYLOAD_VERSION UINT32_C(2) +#define DS4_SESSION_PAYLOAD_VERSION_V1 UINT32_C(1) +#define DS4_SESSION_PAYLOAD_U32_FIELDS 14u +#define DS4_SESSION_PAYLOAD_U32_FIELDS_V1 13u #define DS4_SESSION_IO_CHUNK (8u * 1024u * 1024u) static void payload_set_err(char *err, size_t errlen, const char *msg) { @@ -16556,8 +17453,11 @@ static uint32_t session_raw_live_rows(const ds4_gpu_graph *g, uint32_t checkpoin static uint64_t session_payload_live_tensor_bytes(const ds4_gpu_graph *g, uint32_t checkpoint_len) { uint64_t bytes = 0; const uint32_t raw_live = session_raw_live_rows(g, checkpoint_len); + /* Disk v2: raw rows are stored at the per-dtype packed byte stride when + * dtype=turbo3. fp8 path stays at head_dim*4 (v1-equivalent). */ + const uint64_t raw_row_disk_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, g_ds4_kv_dtype); for (uint32_t il = 0; il < DS4_N_LAYER; il++) { - bytes += (uint64_t)raw_live * DS4_N_HEAD_DIM * sizeof(float); + bytes += (uint64_t)raw_live * raw_row_disk_bytes; const uint32_t ratio = ds4_layer_compress_ratio(il); if (ratio == 0) continue; bytes += (uint64_t)g->layer_n_comp[il] * DS4_N_HEAD_DIM * sizeof(float); @@ -16982,11 +17882,12 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) ds4_gpu_graph *g = &s->graph; const uint32_t raw_live = session_raw_live_rows(g, (uint32_t)s->checkpoint.len); - /* Header fields: + /* Header fields (v2): * 0 magic, 1 version, 2 ctx, 3 prefill chunk, 4 raw cap, * 5 raw window, 6 compressed cap, 7 token count, * 8 layers, 9 raw head dim, 10 indexer head dim, 11 vocab, - * 12 live raw rows serialized below. + * 12 live raw rows serialized below, + * 13 kv_dtype (NEW in v2: 0=fp8, 1=turbo3) */ uint32_t header[DS4_SESSION_PAYLOAD_U32_FIELDS] = { DS4_SESSION_PAYLOAD_MAGIC, @@ -17002,6 +17903,7 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) DS4_N_INDEXER_HEAD_DIM, DS4_N_VOCAB, raw_live, + (uint32_t)g_ds4_kv_dtype, }; for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { if (payload_write_u32(fp, header[i], err, errlen) != 0) return 1; @@ -17023,13 +17925,18 @@ int ds4_session_save_payload(ds4_session *s, FILE *fp, char *err, size_t errlen) /* Write the raw ring in logical position order. The file does not care * where the rows happened to live physically in the source graph. */ const uint32_t raw_first = (uint32_t)s->checkpoint.len - raw_live; + /* Disk v2 write: stream raw bytes from the cache at the per-dtype + * row stride. fp8 writes head_dim*4 bytes/row (unchanged from v1); + * turbo3 writes the packed turbo3 layout directly (no dequant pass + * on save). */ + const uint64_t raw_row_disk_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, g_ds4_kv_dtype); for (uint32_t r = 0; rc == 0 && r < raw_live; r++) { const uint32_t pos = raw_first + r; const uint32_t phys = pos % g->raw_cap; rc = payload_write_tensor_span(fp, g->layer_raw_cache[il], - (uint64_t)phys * DS4_N_HEAD_DIM * sizeof(float), - (uint64_t)DS4_N_HEAD_DIM * sizeof(float), + (uint64_t)phys * raw_row_disk_bytes, + raw_row_disk_bytes, buf, DS4_SESSION_IO_CHUNK, err, @@ -17113,14 +18020,37 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c return 1; } uint64_t remaining = payload_bytes; - uint32_t h[DS4_SESSION_PAYLOAD_U32_FIELDS]; - for (uint32_t i = 0; i < DS4_SESSION_PAYLOAD_U32_FIELDS; i++) { - if (payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) return 1; + /* Read magic + version first to dispatch between v1 (13 fields, kv_dtype + * implicit fp8) and v2 (14 fields, explicit kv_dtype). Older files + * remain loadable; newer files refuse to load on older binaries. */ + uint32_t h[DS4_SESSION_PAYLOAD_U32_FIELDS] = {0}; + if (payload_read_u32(fp, &h[0], &remaining, err, errlen) != 0) return 1; + if (payload_read_u32(fp, &h[1], &remaining, err, errlen) != 0) return 1; + if (h[0] != DS4_SESSION_PAYLOAD_MAGIC) { + payload_set_err(err, errlen, "unsupported session payload version"); + return 1; } - if (h[0] != DS4_SESSION_PAYLOAD_MAGIC || h[1] != DS4_SESSION_PAYLOAD_VERSION) { + uint32_t header_fields; + if (h[1] == DS4_SESSION_PAYLOAD_VERSION) { + header_fields = DS4_SESSION_PAYLOAD_U32_FIELDS; + } else if (h[1] == DS4_SESSION_PAYLOAD_VERSION_V1) { + header_fields = DS4_SESSION_PAYLOAD_U32_FIELDS_V1; + } else { payload_set_err(err, errlen, "unsupported session payload version"); return 1; } + for (uint32_t i = 2; i < header_fields; i++) { + if (payload_read_u32(fp, &h[i], &remaining, err, errlen) != 0) return 1; + } + /* h[13] is the saved kv_dtype; v1 files leave it at the zero-init value + * DS4_KV_FP8 which is correct (v1 only ever stored fp8 floats). */ + const uint32_t saved_kv_dtype = h[13]; + if (saved_kv_dtype != (uint32_t)g_ds4_kv_dtype) { + payload_set_err(err, errlen, + "KV checkpoint dtype does not match the current session " + "(use --kv-cache to match, or discard the cache)"); + return 1; + } if (ds4_session_is_cpu(s)) { const uint32_t saved_ctx = h[2]; const uint32_t saved_prefill_cap = h[3]; @@ -17353,6 +18283,10 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c uint8_t *buf = xmalloc(DS4_SESSION_IO_CHUNK); int rc = 0; + /* Disk v2 read: rows are stored at the per-dtype packed byte stride + * (matches the in-memory cache layout). v1 files always have fp8 dtype, + * head_dim*4 bytes/row. v2 turbo3 files have packed bytes. */ + const uint64_t raw_row_disk_bytes = ds4_kv_row_bytes(DS4_N_HEAD_DIM, DS4_N_ROT, g_ds4_kv_dtype); for (uint32_t il = 0; rc == 0 && il < DS4_N_LAYER; il++) { /* Rebuild the physical raw ring expected by the current graph. This is * why the file stores rows in logical order instead of dumping bytes from @@ -17363,8 +18297,8 @@ int ds4_session_load_payload(ds4_session *s, FILE *fp, uint64_t payload_bytes, c const uint32_t phys = pos % g->raw_cap; rc = payload_read_tensor_span(fp, g->layer_raw_cache[il], - (uint64_t)phys * DS4_N_HEAD_DIM * sizeof(float), - (uint64_t)DS4_N_HEAD_DIM * sizeof(float), + (uint64_t)phys * raw_row_disk_bytes, + raw_row_disk_bytes, buf, DS4_SESSION_IO_CHUNK, &remaining, @@ -17879,7 +18813,7 @@ int ds4_engine_head_test(ds4_engine *e, const ds4_tokens *prompt) { print_vec_stats("blk.0 kv", kv0, DS4_N_HEAD_DIM); rope_tail_layer_inplace(q0, DS4_N_HEAD, DS4_N_HEAD_DIM, DS4_N_ROT, (uint32_t)(prompt->len - 1), 0, false); rope_tail_layer_inplace(kv0, DS4_N_HEAD_KV, DS4_N_HEAD_DIM, DS4_N_ROT, (uint32_t)(prompt->len - 1), 0, false); - dsv4_fp8_kv_quantize_row_inplace_cpu(kv0, DS4_N_HEAD_DIM, DS4_N_ROT); + ds4_kv_quantize_row_inplace_cpu(kv0, DS4_N_HEAD_DIM, DS4_N_ROT); f16_round_inplace_cpu(kv0, DS4_N_HEAD_DIM); float *attn_heads = xmalloc((size_t)q_dim * sizeof(attn_heads[0])); @@ -17990,6 +18924,14 @@ int ds4_engine_open(ds4_engine **out, const ds4_engine_options *opt) { e->quality = opt->quality; e->power_percent = opt->power_percent > 0 ? opt->power_percent : 100; if (e->power_percent > 100) e->power_percent = 100; + e->kv_dtype = opt->kv_dtype; + ds4_kv_set_active_dtype(e->kv_dtype); + /* turbo3 KV on Metal: pack/dequant + float-sim quantize round trip ship + * here; the inline-dequant attention kernels are CUDA-only. On Metal + * the turbo3 attention launchers return 0 and the fallback paths use + * view_dispatch (dequant to scratch) + the stock Metal fp8 attention + * kernels. Correct with full 4.75x memory savings on Metal but the + * dequant-to-scratch hop costs ~13% gen_tps vs fp8. */ e->mtp_draft_tokens = opt->mtp_draft_tokens > 0 ? opt->mtp_draft_tokens : 1; if (e->mtp_draft_tokens > 16) e->mtp_draft_tokens = 16; e->mtp_margin = opt->mtp_margin >= 0.0f ? opt->mtp_margin : 3.0f; diff --git a/ds4.h b/ds4.h index f1a8e9e4..3e452d38 100644 --- a/ds4.h +++ b/ds4.h @@ -20,6 +20,77 @@ typedef enum { DS4_BACKEND_CPU, } ds4_backend; +/* KV cache compression dtype selection. + * + * DS4_KV_FP8 (default): the historical path. The non-RoPE part of each compressed + * KV row goes through an in-place E4M3 round trip in groups of 64 - values stay as + * float32 in memory but pick up the FP8 quantization error so the CPU reference + * matches what the Metal graph would store as packed FP8. No layout change. + * + * DS4_KV_TURBO3: TurboQuant+ port from TheTom/llama-cpp-turboquant. Storage layout + * is packed 3-bit Lloyd-Max indices + per-group FP8 scale bytes - the cache buffer + * is byte-addressed at `row * ds4_kv_row_bytes(head_dim, n_rot, ...)`, NOT + * float-addressed at `row * head_dim`. Every attention kernel inline-dequants + * the packed bytes on V-load. The Randomized Hadamard rotation + N(0,1) Lloyd-Max + * codebook + matched-norm L2 scale are computed once on cache store; reads pay only + * the dequant (one byte load + LUT lookup + FP8-to-f32 multiply per element). + * + * Memory savings per row: ds4 head_dim=512, n_rot=64, group=64. + * fp8 (float-sim): 512 * 4 = 2048 bytes + * turbo3 (packed): (448*3/8) + (448/64) + (64*4) = 168 + 7 + 256 = 431 bytes + * -> 4.75x smaller per row. The 9x figure in upstream TQ+ docs is the + * latent-only ratio (175/1792); RoPE-tail floats are unavoidable on MLA. */ +typedef enum { + DS4_KV_FP8 = 0, + DS4_KV_TURBO3 = 1, +} ds4_kv_dtype; +const char *ds4_kv_dtype_name(ds4_kv_dtype dtype); +int ds4_kv_dtype_from_name(const char *name, ds4_kv_dtype *out); + +/* Packed turbo3 byte layout per cache row. GROUP_SIZE is 64 - the same WHT + * group cadence the float-sim quantizer uses, one matched-norm L2 scale per + * 64 elements. See `dsv4_turbo3_kv_quantize_row_inplace_cpu` in ds4.c for + * the per-group algorithm. + * + * data section : (head_dim - n_rot) * 3 / 8 bytes + * packed 3-bit indices, 8 values per 3 bytes + * (b0 = v0|(v1<<3)|(v2<<6), b1 = (v2>>2)|(v3<<1)|..., b2 = ...) + * scale section : (head_dim - n_rot) / DS4_TURBO3_GROUP_SIZE FP8 E4M3 bytes + * one per 64-element group, matched-norm L2 scale + * rope tail : n_rot * sizeof(float) + * untouched RoPE coordinates (these carry positional freqs) + * + * Stored values are in the ORIGINAL basis (we apply the inverse rotation on + * write so the dequanted values match what the float-sim path produced). + * Readers dequant one 64-element group at a time into a small stack scratch + * via `dequant_group`: load 24 packed bytes + 1 FP8 scale -> 64 floats in the + * rotated basis (centroid * scale) -> 64-point iWHT-with-signs -> 64 + * original-basis floats. This trades ~3.5x dequant compute per group for + * ~25x less memory traffic vs the fp8 float-sim cache. The advantage is + * that every existing reader (attention dot loops, compressor pool, disk + * save, MTP draft) sees the same original-basis values it did before - only + * the storage byte layout changes. */ +#define DS4_TURBO3_GROUP_SIZE 64u +uint64_t ds4_kv_row_bytes(uint32_t head_dim, uint32_t n_rot, ds4_kv_dtype dtype); + +/* Footprint estimator broken down by section, parameterized on dtype. Used by + * `ds4-bench --print-kv-footprint` to print side-by-side fp8 vs turbo3 sizes. + * + * raw_bytes : the SWA ring window across all layers. + * compressed_bytes : per-layer compressor output + indexer (always float). + * total_bytes : sum of the above. + * + * For turbo3, `raw_bytes` reflects the packed-byte layout. The compressed + * pools (attn_comp + index_comp) and the compressor state arrays remain + * float because the compressor pool integrates softmax-weighted accumulations + * that require an original-basis read. */ +typedef struct { + uint64_t raw_bytes; + uint64_t compressed_bytes; + uint64_t total_bytes; +} ds4_kv_footprint; +ds4_kv_footprint ds4_kv_footprint_estimate(ds4_backend backend, int ctx_size, ds4_kv_dtype dtype); + typedef enum { DS4_THINK_NONE, DS4_THINK_HIGH, @@ -73,6 +144,10 @@ typedef struct { bool warm_weights; bool quality; bool inspect_only; + /* KV cache dtype. Default DS4_KV_FP8 keeps the historical path; DS4_KV_TURBO3 + * swaps in the TurboQuant+ 3-bit-per-element quality simulation on CUDA and + * CPU reference. See ds4_kv_dtype above for the algorithm summary. */ + ds4_kv_dtype kv_dtype; } ds4_engine_options; typedef void (*ds4_token_emit_fn)(void *ud, int token); diff --git a/ds4_bench.c b/ds4_bench.c index 5f694dc9..3f0f090c 100644 --- a/ds4_bench.c +++ b/ds4_bench.c @@ -11,6 +11,7 @@ */ #include +#include #include #include #include @@ -38,6 +39,21 @@ typedef struct { const char *dump_frontier_logits_dir; bool warm_weights; bool quality; + /* KV cache compression simulation; fp8 is the historical default. */ + ds4_kv_dtype kv_dtype; + /* PPL teacher-forced quality measurement. When set, skips the + * throughput sweep and instead tokenizes the file, walks token by + * token, accumulates -log P(token_t | tokens_model_path); fprintf(fp, ",\n \"backend\":\"%s\",\n \"quality\":%s,\n" + " \"kv_cache\":\"%s\",\n" " \"quant_bits\":%d,\n \"prompt_tokens\":%d,\n" " \"frontier_tokens\":%d,\n \"prefill_tokens\":%d,\n" " \"ctx\":%d,\n \"vocab\":%d,\n" " \"argmax_id\":%d,\n \"argmax_logit\":%.9g,\n \"logits\":[", ds4_backend_name(cfg->backend), cfg->quality ? "true" : "false", + ds4_kv_dtype_name(cfg->kv_dtype), ds4_engine_routed_quant_bits(engine), frontier, frontier, @@ -392,8 +444,315 @@ static void log_context_memory(ds4_backend backend, int ctx_size) { m.comp_cap); } +/* Side-by-side KV footprint for fp8 vs turbo3 at the given backend/ctx. + * The active dtype is highlighted; the other one is printed for comparison + * so users can see the packed-byte savings at a glance. */ +static void log_kv_footprint_compare(ds4_backend backend, int ctx_size, ds4_kv_dtype active) { + const ds4_kv_footprint fp8 = ds4_kv_footprint_estimate(backend, ctx_size, DS4_KV_FP8); + const ds4_kv_footprint t3 = ds4_kv_footprint_estimate(backend, ctx_size, DS4_KV_TURBO3); + const double mib = 1.0 / (1024.0 * 1024.0); + const double raw_ratio = (t3.raw_bytes > 0) + ? ((double)fp8.raw_bytes / (double)t3.raw_bytes) : 0.0; + /* Print the SWA ring (the only pool that swaps to packed bytes) plus + * the compressed pools (kept float / F16 because the compressor pool + * integrates softmax-weighted accumulations that need an original-basis + * read). */ + fprintf(stderr, + "ds4-bench: KV footprint @ ctx=%d:\n" + " fp8 raw=%.2f MiB compressed=%.2f MiB total=%.2f MiB%s\n" + " turbo3 raw=%.2f MiB compressed=%.2f MiB total=%.2f MiB%s\n" + " raw shrink: %.2fx (turbo3 saves %.2f MiB on the SWA ring)\n", + ctx_size, + (double)fp8.raw_bytes * mib, + (double)fp8.compressed_bytes * mib, + (double)fp8.total_bytes * mib, + active == DS4_KV_FP8 ? " <-- active" : "", + (double)t3.raw_bytes * mib, + (double)t3.compressed_bytes * mib, + (double)t3.total_bytes * mib, + active == DS4_KV_TURBO3 ? " <-- active" : "", + raw_ratio, + (double)(fp8.raw_bytes - t3.raw_bytes) * mib); +} + +/* Quality-dump binary format. Magic "DS4Q" | u32 vocab | u32 scored + * | (scored * vocab * float32 logits). Logits are written RAW, not + * softmaxed; the comparator runs log-sum-exp on read. */ +#define DS4_QDUMP_MAGIC "DS4Q" + +/* Streaming softmax: returns log Z so callers can compute log_p = logit - logZ. */ +static double ds4_log_sum_exp(const float *logits, int n) { + float m = logits[0]; + for (int i = 1; i < n; i++) if (logits[i] > m) m = logits[i]; + double s = 0.0; + for (int i = 0; i < n; i++) s += exp((double)(logits[i] - m)); + return (double)m + log(s); +} + +/* Find top-K indices of a logit vector via partial selection. K small (<=5). */ +static void ds4_top_k_indices(const float *logits, int n, int k, int *out_idx) { + for (int i = 0; i < k; i++) out_idx[i] = -1; + float out_val[8]; /* k<=8 enforced by callers */ + for (int i = 0; i < k; i++) out_val[i] = -FLT_MAX; + for (int i = 0; i < n; i++) { + float v = logits[i]; + if (v <= out_val[k - 1]) continue; + int j = k - 1; + while (j > 0 && out_val[j - 1] < v) { + out_val[j] = out_val[j - 1]; + out_idx[j] = out_idx[j - 1]; + j--; + } + out_val[j] = v; + out_idx[j] = i; + } +} + +/* Teacher-forced perplexity on a token sequence. For each position i in + * [0, n-1) feed tokens[i], then read log P(tokens[i+1] | tokens[0..i]) + * from the current logits. Accumulate -logprob; report mean NLL and + * exp(mean_NLL). Compares quality across --kv-cache dtypes apples-to- + * apples (deterministic, no sampling). + * + * Optional modes: + * --quality-emit FILE Dump per-position raw logits to FILE. + * --quality-baseline FILE Read baseline FILE, compare every position's + * logits to current run. Reports: + * - KLD(baseline || current) full vocab, mean / max + * - top-1 agreement (target argmax == baseline argmax) + * - top-5 agreement (target argmax in baseline top-5) + */ +static int run_ppl_mode(const bench_config *cfg) { + ds4_engine_options opt = { + .model_path = cfg->model_path, + .backend = cfg->backend, + .n_threads = cfg->threads, + .warm_weights = cfg->warm_weights, + .quality = cfg->quality, + .kv_dtype = cfg->kv_dtype, + }; + ds4_engine *engine = NULL; + if (ds4_engine_open(&engine, &opt) != 0) return 1; + + char *text = read_file(cfg->prompt_path); + ds4_tokens prompt = {0}; + ds4_tokenize_text(engine, text, &prompt); + free(text); + + int score_limit = cfg->ppl_max_tokens; + if (score_limit <= 0 || score_limit > prompt.len) score_limit = prompt.len; + if (score_limit < 2) { + fprintf(stderr, "ds4-bench: --ppl-prompt needs at least 2 tokens\n"); + ds4_tokens_free(&prompt); + ds4_engine_close(engine); + return 1; + } + + int ctx_size = score_limit + 16; + ds4_session *session = NULL; + if (ds4_session_create(&session, engine, ctx_size) != 0) { + fprintf(stderr, "ds4-bench: failed to create session\n"); + ds4_tokens_free(&prompt); + ds4_engine_close(engine); + return 1; + } + + const int vocab = ds4_engine_vocab_size(engine); + float *cur_logits = NULL; + float *bl_logits = NULL; + FILE *emit_fp = NULL; + FILE *bl_fp = NULL; + int bl_vocab = 0; + int bl_scored = 0; + bool quality_on = (cfg->quality_emit_path || cfg->quality_baseline_path); + + if (quality_on) { + cur_logits = (float *)malloc((size_t)vocab * sizeof(float)); + if (!cur_logits) { + fprintf(stderr, "ds4-bench: oom for logit scratch\n"); + ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + } + + if (cfg->quality_emit_path) { + emit_fp = fopen(cfg->quality_emit_path, "wb"); + if (!emit_fp) { + fprintf(stderr, "ds4-bench: cannot open --quality-emit '%s': %s\n", + cfg->quality_emit_path, strerror(errno)); + free(cur_logits); ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + uint32_t hdr[3] = { 0, (uint32_t)vocab, 0 /* scored placeholder */ }; + memcpy(&hdr[0], DS4_QDUMP_MAGIC, 4); + fwrite(hdr, sizeof(uint32_t), 3, emit_fp); + } + + if (cfg->quality_baseline_path) { + bl_fp = fopen(cfg->quality_baseline_path, "rb"); + if (!bl_fp) { + fprintf(stderr, "ds4-bench: cannot open --quality-baseline '%s': %s\n", + cfg->quality_baseline_path, strerror(errno)); + if (emit_fp) fclose(emit_fp); + free(cur_logits); ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + uint32_t hdr[3]; + if (fread(hdr, sizeof(uint32_t), 3, bl_fp) != 3 || + memcmp(&hdr[0], DS4_QDUMP_MAGIC, 4) != 0) { + fprintf(stderr, "ds4-bench: '%s' is not a DS4Q baseline dump\n", + cfg->quality_baseline_path); + fclose(bl_fp); if (emit_fp) fclose(emit_fp); + free(cur_logits); ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + bl_vocab = (int)hdr[1]; + bl_scored = (int)hdr[2]; + if (bl_vocab != vocab) { + fprintf(stderr, "ds4-bench: baseline vocab=%d, current vocab=%d (mismatch)\n", + bl_vocab, vocab); + fclose(bl_fp); if (emit_fp) fclose(emit_fp); + free(cur_logits); ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + bl_logits = (float *)malloc((size_t)vocab * sizeof(float)); + if (!bl_logits) { + fprintf(stderr, "ds4-bench: oom for baseline scratch\n"); + fclose(bl_fp); if (emit_fp) fclose(emit_fp); + free(cur_logits); ds4_session_free(session); ds4_tokens_free(&prompt); ds4_engine_close(engine); + return 1; + } + } + + char err[256]; + double nll_sum = 0.0; + int scored = 0; + /* Quality-vs-baseline accumulators (only used when --quality-baseline). */ + double kld_sum = 0.0; + double kld_max = 0.0; + int top1_match = 0; + int top5_match = 0; + int qcompared = 0; + + double t0 = bench_now_sec(); + for (int i = 0; i + 1 < score_limit; i++) { + if (ds4_session_eval(session, prompt.v[i], err, sizeof err) != 0) { + fprintf(stderr, "ds4-bench: ppl eval failed at pos %d: %s\n", i, err); + break; + } + ds4_token_score sc; + if (!ds4_session_token_logprob(session, prompt.v[i + 1], &sc)) { + fprintf(stderr, "ds4-bench: token_logprob failed at pos %d\n", i); + break; + } + if (!isfinite(sc.logprob)) continue; + nll_sum += -(double)sc.logprob; + scored++; + + if (!quality_on) continue; + + int copied = ds4_session_copy_logits(session, cur_logits, vocab); + if (copied != vocab) { + fprintf(stderr, "ds4-bench: copy_logits returned %d, expected %d\n", copied, vocab); + break; + } + + if (emit_fp) { + if (fwrite(cur_logits, sizeof(float), (size_t)vocab, emit_fp) != (size_t)vocab) { + fprintf(stderr, "ds4-bench: quality-emit write failed at pos %d\n", i); + break; + } + } + + if (bl_fp) { + if (qcompared >= bl_scored) { + fprintf(stderr, "ds4-bench: baseline has only %d positions, current at %d\n", + bl_scored, qcompared); + break; + } + if (fread(bl_logits, sizeof(float), (size_t)vocab, bl_fp) != (size_t)vocab) { + fprintf(stderr, "ds4-bench: baseline read failed at pos %d\n", qcompared); + break; + } + /* KLD(baseline || current) = sum_v p_b(v) * (log p_b(v) - log p_c(v)) + * = sum_v p_b(v) * (logit_b - logZ_b - logit_c + logZ_c) */ + double logZb = ds4_log_sum_exp(bl_logits, vocab); + double logZc = ds4_log_sum_exp(cur_logits, vocab); + double kld = 0.0; + for (int v = 0; v < vocab; v++) { + double pb = exp((double)bl_logits[v] - logZb); + if (pb <= 0.0) continue; + double diff = (double)bl_logits[v] - logZb + - (double)cur_logits[v] + logZc; + kld += pb * diff; + } + if (kld < 0.0) kld = 0.0; /* numerical floor */ + kld_sum += kld; + if (kld > kld_max) kld_max = kld; + + /* Top-1/top-5 agreement: target argmax vs baseline top-5 set. */ + int bl_top5[5], cur_top1[1]; + ds4_top_k_indices(bl_logits, vocab, 5, bl_top5); + ds4_top_k_indices(cur_logits, vocab, 1, cur_top1); + if (cur_top1[0] == bl_top5[0]) top1_match++; + for (int j = 0; j < 5; j++) { + if (cur_top1[0] == bl_top5[j]) { top5_match++; break; } + } + qcompared++; + } + } + double elapsed = bench_now_sec() - t0; + + /* Patch scored count into emit header. */ + if (emit_fp) { + uint32_t s = (uint32_t)scored; + fseek(emit_fp, 8, SEEK_SET); + fwrite(&s, sizeof(uint32_t), 1, emit_fp); + fclose(emit_fp); + } + if (bl_fp) fclose(bl_fp); + free(cur_logits); + free(bl_logits); + + const double avg_nll = scored > 0 ? (nll_sum / (double)scored) : 0.0; + const double ppl = scored > 0 ? exp(avg_nll) : 0.0; + const char *kv_name = ds4_kv_dtype_name(cfg->kv_dtype); + fprintf(stdout, + "ds4-bench: PPL teacher-forced kv_cache=%s tokens=%d scored=%d " + "elapsed=%.2fs\n" + "ds4-bench: nll_avg=%.6f ppl=%.6f\n", + kv_name, score_limit, scored, elapsed, avg_nll, ppl); + + if (cfg->quality_emit_path) { + fprintf(stdout, + "ds4-bench: quality-emit wrote %d positions x vocab=%d to %s\n", + scored, vocab, cfg->quality_emit_path); + } + if (cfg->quality_baseline_path && qcompared > 0) { + const double kld_mean = kld_sum / (double)qcompared; + const double top1_pct = 100.0 * (double)top1_match / (double)qcompared; + const double top5_pct = 100.0 * (double)top5_match / (double)qcompared; + fprintf(stdout, + "ds4-bench: quality vs baseline (%s) positions=%d\n" + "ds4-bench: KLD(baseline||current) mean=%.6f nats max=%.6f nats\n" + "ds4-bench: top-1 agreement=%.2f%% top-5 agreement=%.2f%%\n", + cfg->quality_baseline_path, qcompared, + kld_mean, kld_max, top1_pct, top5_pct); + } + + ds4_session_free(session); + ds4_tokens_free(&prompt); + ds4_engine_close(engine); + return 0; +} + int main(int argc, char **argv) { bench_config cfg = parse_options(argc, argv); + if (cfg.ppl_prompt_path) { + return run_ppl_mode(&cfg); + } + log_context_memory(cfg.backend, cfg.ctx_alloc); + log_kv_footprint_compare(cfg.backend, cfg.ctx_alloc, cfg.kv_dtype); ds4_engine_options opt = { .model_path = cfg.model_path, @@ -402,6 +761,7 @@ int main(int argc, char **argv) { .power_percent = cfg.power_percent, .warm_weights = cfg.warm_weights, .quality = cfg.quality, + .kv_dtype = cfg.kv_dtype, }; ds4_engine *engine = NULL; if (ds4_engine_open(&engine, &opt) != 0) return 1; diff --git a/ds4_cli.c b/ds4_cli.c index dfac149b..6cc59698 100644 --- a/ds4_cli.c +++ b/ds4_cli.c @@ -106,6 +106,12 @@ static void usage(FILE *fp) { " CPU helper threads for host-side or reference work.\n" " --quality\n" " Prefer exact kernels where faster approximate paths exist; MTP uses strict verification.\n" + " --kv-cache fp8|turbo3\n" + " KV cache compression simulation. fp8 (default) keeps the historical E4M3\n" + " in-place round trip on the non-RoPE part of each compressed KV row. turbo3\n" + " swaps in a TurboQuant+ Randomized Hadamard rotation + 3-bit Lloyd-Max\n" + " Lloyd-Max quantization with matched-norm L2 correction on the same 64-element\n" + " groups. Storage layout unchanged. CUDA-only at the moment; Metal port deferred.\n" " --dir-steering-file FILE\n" " Load one f32 direction vector per layer for directional steering.\n" " --dir-steering-ffn F\n" @@ -1476,6 +1482,12 @@ static cli_config parse_options(int argc, char **argv) { fprintf(stderr, "ds4: --power must be between 1 and 100\n"); exit(2); } + } else if (!strcmp(arg, "--kv-cache")) { + const char *kv_name = need_arg(&i, argc, argv, arg); + if (!ds4_kv_dtype_from_name(kv_name, &c.engine.kv_dtype)) { + fprintf(stderr, "ds4: unknown --kv-cache value '%s' (expected fp8 or turbo3)\n", kv_name); + exit(1); + } } else if (!strcmp(arg, "--dir-steering-file")) { c.engine.directional_steering_file = need_arg(&i, argc, argv, arg); } else if (!strcmp(arg, "--dir-steering-ffn")) { diff --git a/ds4_cuda.cu b/ds4_cuda.cu index dac9276e..c88b7813 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -2508,6 +2508,510 @@ __global__ static void fp8_kv_quantize_kernel(float *x, uint32_t n_tok, uint32_t } } +// ========================================================================= +// TurboQuant+ turbo3 KV quality simulation - CUDA kernel. +// ========================================================================= +// +// Sibling of fp8_kv_quantize_kernel above; same in-place [n_tok, head_dim] +// contract; same group-of-64 cadence on the first head_dim - n_rot elements; +// RoPE tail untouched. Per group: +// 1. forward Randomized Hadamard: signs1 (Rademacher) -> 64-point WHT +// -> 1/sqrt(64) -> signs2 (Rademacher). +// 2. block-wide amax + sum-of-squares reduction (warp shfl in two halves). +// 3. amax -> scale into 3-bit Lloyd-Max codebook range for N(0,1). +// 4. per-element quantize to nearest centroid; block-wide reduction on the +// centroid recon L2 norm to compute the matched-norm scale +// (||original|| / ||centroid_recon||) - same MSE-vs-amax trick used in +// reshape_and_cache_flash_turbo3 in atlas/kernels/gb10/common/. +// 5. dequantize centroid * matched-norm scale. +// 6. inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1. +// +// Grid: <<>>. One block per token, one thread per element of the +// 64-element group; the per-token outer loop walks each group sequentially so +// we reuse shared scratch buffers across groups. 64 threads = exactly two +// warps, so block-wide reductions go warp-shfl-first then a 2-element shared +// merge. No __syncwarp() needed - shfl_xor_sync covers the whole warp. +// +// Prior art: TheTom/llama-cpp-turboquant turbo-wht.cu (sign convention + +// matched-norm L2 trick) and Atlas kernels/gb10/common/reshape_and_cache_turbo.cu +// (Lloyd-Max codebook + bound table). See CITATIONS chain in ds4.c near +// dsv4_turbo3_kv_quantize_row_inplace_cpu for the full attribution. + +// Lloyd-Max 8-level codebook for N(0,1). Byte-equivalent to the CPU table +// DS4_TURBO3_CODEBOOK in ds4.c. +__device__ __constant__ float DS4_TURBO3_CODEBOOK_D[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, 0.2451f, 0.7560f, 1.3440f, 2.1520f +}; +__device__ __constant__ float DS4_TURBO3_BOUNDS_D[7] = { + -1.748f, -1.050f, -0.501f, 0.0f, 0.501f, 1.050f, 1.748f +}; + +// Two-sided Rademacher signs for the 64-point WHT. Byte-equivalent to the CPU +// tables DS4_TURBO_SIGNS{1,2}_64 in ds4.c - see that file for the seed=42 / +// seed=142 Pythons random.Random Bernoulli(0.5) recipe. +__device__ __constant__ float DS4_TURBO_SIGNS1_64_D[64] = { + +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, + +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +}; +__device__ __constant__ float DS4_TURBO_SIGNS2_64_D[64] = { + +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +}; + +// FP8 E4M3 max representable - used to clamp the matched-norm scale so a +// future Metal port can pack the per-group scale into one FP8 byte without an +// extra renormalization pass. +#define DS4_FP8_E4M3_MAX_D 448.0f +#define DS4_TURBO3_MAX_D 2.1520f + +// In-shared-memory 64-element WHT butterfly. Caller owns the buffer + sync. +// Identical structure to dsv4_turbo3_wht64_inplace_cpu in ds4.c - but +// parallelized across the 64 threads of the block by halving the active thread +// set at each butterfly stride. Compatible with VRAM access pattern: every +// thread reads/writes a unique element each stride. +__device__ __forceinline__ void wht64_block(float *v, uint32_t tid) { + for (uint32_t stride = 1; stride < 64; stride <<= 1) { + // Half the threads do the butterfly write; tid bit at `stride` selects + // which side of the pair the thread owns. + uint32_t mate = tid ^ stride; + bool is_low = (tid & stride) == 0; + float self = v[tid]; + __syncthreads(); + float other = v[mate]; + float out = is_low ? (self + other) : (other - self); + __syncthreads(); + v[tid] = out; + __syncthreads(); + } +} + +// Nearest-centroid 3-bit index lookup against DS4_TURBO3_BOUNDS_D. +__device__ __forceinline__ unsigned int turbo3_quant_idx(float x) { + unsigned int idx; + if (x >= DS4_TURBO3_BOUNDS_D[3]) { + idx = 4; + if (x >= DS4_TURBO3_BOUNDS_D[5]) { idx = 6; if (x >= DS4_TURBO3_BOUNDS_D[6]) idx = 7; } + else if (x >= DS4_TURBO3_BOUNDS_D[4]) idx = 5; + } else { + idx = 0; + if (x >= DS4_TURBO3_BOUNDS_D[1]) { idx = 2; if (x >= DS4_TURBO3_BOUNDS_D[2]) idx = 3; } + else if (x >= DS4_TURBO3_BOUNDS_D[0]) idx = 1; + } + return idx; +} + +// Block-wide max via warp-shuffle then shared-memory cross-warp merge. 64 +// threads = 2 warps; one shared slot per warp. Returns the broadcast value. +__device__ __forceinline__ float block_max64(float v, uint32_t tid) { + __shared__ float warp_max[2]; + for (int off = 16; off > 0; off >>= 1) v = fmaxf(v, __shfl_xor_sync(0xFFFFFFFFu, v, off)); + if ((tid & 31u) == 0u) warp_max[tid >> 5] = v; + __syncthreads(); + float r = fmaxf(warp_max[0], warp_max[1]); + return r; +} + +__device__ __forceinline__ float block_sum64(float v, uint32_t tid) { + __shared__ float warp_sum[2]; + for (int off = 16; off > 0; off >>= 1) v += __shfl_xor_sync(0xFFFFFFFFu, v, off); + if ((tid & 31u) == 0u) warp_sum[tid >> 5] = v; + __syncthreads(); + float r = warp_sum[0] + warp_sum[1]; + return r; +} + +// signs_on=1 (default): canonical Randomized Hadamard. signs_on=0: plain WHT +// for A/B diagnostics - strictly weaker, kept callable via DS4_TURBO_NO_SIGNS +// env var by the host wrapper. +__global__ static void turbo3_kv_quantize_kernel(float *x, uint32_t n_tok, uint32_t head_dim, uint32_t n_rot, int signs_on) { + const uint32_t row = blockIdx.x; + if (row >= n_tok) return; + const uint32_t tid = threadIdx.x; + const uint32_t n_nope = head_dim - n_rot; + float *xr = x + (uint64_t)row * head_dim; + + __shared__ float buf[64]; + // 1/sqrt(64); spelled via sqrtf so the literal divide is unambiguously + // float-by-float for static analysis tools. + const float inv_sqrt_n = rsqrtf(64.0f); + + for (uint32_t off = 0; off < n_nope; off += 64) { + // Load: one element per thread. If we ever change n_nope to not be + // 64-aligned the OOB path mirrors fp8_kv_quantize_kernel's tail (zero + // pad) - for DS4_N_HEAD_DIM=512, DS4_N_ROT=64, n_nope=448 is exactly + // 7 groups of 64 so the tail logic is unused here. + float v = (off + tid < n_nope) ? xr[off + tid] : 0.0f; + if (signs_on) v *= DS4_TURBO_SIGNS1_64_D[tid]; + buf[tid] = v; + __syncthreads(); + + // 1. forward WHT + normalize + wht64_block(buf, tid); + float rotated = buf[tid] * inv_sqrt_n; + if (signs_on) rotated *= DS4_TURBO_SIGNS2_64_D[tid]; + + // 2. block-wide amax and L2 norm of rotated group + float amax = block_max64(fabsf(rotated), tid); + float norm_sq = block_sum64(rotated * rotated, tid); + float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX_D / amax) : 1.0f; + + // 3. quantize, reduce centroid recon L2 norm, derive matched-norm scale + unsigned int idx = turbo3_quant_idx(rotated * k_inv); + float centroid = DS4_TURBO3_CODEBOOK_D[idx]; + float recon_sq = block_sum64(centroid * centroid, tid); + float recon_norm = sqrtf(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrtf(norm_sq) / recon_norm) : (amax / DS4_TURBO3_MAX_D); + if (scale > DS4_FP8_E4M3_MAX_D) scale = DS4_FP8_E4M3_MAX_D; + + // 4. dequant in rotated basis + float dequant = centroid * scale; + + // 5. inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1 + if (signs_on) dequant *= DS4_TURBO_SIGNS2_64_D[tid]; + __syncthreads(); + buf[tid] = dequant; + __syncthreads(); + wht64_block(buf, tid); + float final_v = buf[tid] * inv_sqrt_n; + if (signs_on) final_v *= DS4_TURBO_SIGNS1_64_D[tid]; + + if (off + tid < n_nope) xr[off + tid] = final_v; + __syncthreads(); + } +} + +// ========================================================================= +// Packed turbo3 cache storage + inline dequant for attention V-load. +// ========================================================================= +// +// Layout per cache row (head_dim=512, n_rot=64, GROUP_SIZE=64): +// bytes 0..167 : packed 3-bit indices (n_nope * 3/8 = 168 bytes) +// inside each group of 64 elements: 8 sub-chunks of 3 bytes, +// each holding 8 indices via the canonical pattern +// b0 = i0|(i1<<3)|(i2<<6), b1 = (i2>>2)|(i3<<1)|(i4<<4)|(i5<<7), +// b2 = (i5>>1)|(i6<<2)|(i7<<5). +// bytes 168..174 : 7 FP8 E4M3 matched-norm scales, one per 64-element group. +// bytes 175..430 : 64 raw little-endian floats (RoPE tail). +// total 431 bytes per row vs 2048 bytes for fp8 float-sim (4.75x). +// +// Stored values are in the ORIGINAL basis (the pack kernel applies the inverse +// rotation conceptually by storing centroid*scale and letting the read side +// run a per-group iWHT-with-signs). Every existing reader sees floats from +// the same per-element distribution it would have read in the fp8 float-sim +// path, modulo the FP8 group scale's ~12% precision. +// +// We stay on GROUP_SIZE=64 (the cadence the float-sim quantizer above uses) +// so the matched-norm scale arithmetic and `--logprob-vectors` regression +// remain valid. Upstream TurboQuant+ uses GROUP_SIZE=16 for finer scale +// granularity but that would invalidate the existing pack-vs-float-sim +// equivalence the test suite assumes. + +// Forward declaration: defined in the host-emitted constant tables below. +#define TURBO3_GROUP_SIZE 64 +#define TURBO3_SUBCHUNKS_PER_GROUP 8 // 64 elems / 8 per sub-chunk +#define TURBO3_DATA_BYTES_PER_GROUP 24 // 8 sub-chunks * 3 bytes + +// Forward turbo3 group write - runs ONE thread (caller passes its own scratch). +// `rotated[64]` holds the post-WHT, post-signs2 floats. Writes 24 data bytes + +// 1 FP8 scale byte to `data_out` / `scale_out`. +__device__ __forceinline__ void turbo3_pack_group64_device( + unsigned char *data_out, + unsigned char *scale_out, + const float *rotated) { + float amax = 0.0f, norm_sq = 0.0f; + #pragma unroll + for (int i = 0; i < 64; i++) { + const float v = rotated[i]; + const float av = fabsf(v); + if (av > amax) amax = av; + norm_sq += v * v; + } + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX_D / amax) : 1.0f; + + unsigned int idx[64]; + float recon_sq = 0.0f; + #pragma unroll + for (int i = 0; i < 64; i++) { + idx[i] = turbo3_quant_idx(rotated[i] * k_inv); + const float c = DS4_TURBO3_CODEBOOK_D[idx[i]]; + recon_sq += c * c; + } + const float recon_norm = sqrtf(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrtf(norm_sq) / recon_norm) + : (amax / DS4_TURBO3_MAX_D); + if (scale > DS4_FP8_E4M3_MAX_D) scale = DS4_FP8_E4M3_MAX_D; + + // FP8 E4M3 encode. Use the CUDA-runtime portable cvt helper rather than + // the sm_89+ `cvt.rn.satfinite.e4m3x2.f32` PTX directly so the build + // works at the default cuda-spark arch. Match `dsv4_e4m3fn_dequant_dev` + // semantics: non-negative input (matched_scale is always >=0) saturated + // to E4M3 max=448. + float scale_clamped = scale; + if (scale_clamped > 448.0f) scale_clamped = 448.0f; + if (scale_clamped < 0.0f) scale_clamped = 0.0f; + const __nv_fp8_storage_t s = __nv_cvt_float_to_fp8( + scale_clamped, __NV_SATFINITE, __NV_E4M3); + *scale_out = (unsigned char)s; + + // Pack 64 indices into 24 bytes: 8 sub-chunks of 3 bytes each. + #pragma unroll + for (int chunk = 0; chunk < 8; chunk++) { + const unsigned int *p = &idx[chunk * 8]; + unsigned char *b = &data_out[chunk * 3]; + b[0] = (unsigned char)((p[0]) | (p[1] << 3) | ((p[2] & 0x3) << 6)); + b[1] = (unsigned char)((p[2] >> 2) | (p[3] << 1) | (p[4] << 4) | ((p[5] & 0x1) << 7)); + b[2] = (unsigned char)((p[5] >> 1) | (p[6] << 2) | (p[7] << 5)); + } +} + +// Inverse rotation in registers - 64-element iWHT-with-signs butterfly. +// Atlas's wht256_warp_bf16 is the bf16 warp-distributed version; here we run +// the whole 64-element transform inside one thread because the attention +// inner loop is already serial on the K-row (one thread reads its assigned +// element range from the dequanted register buffer). 64 floats fit in +// registers easily (256 bytes per thread, well under the 64KB/thread limit). +__device__ __forceinline__ void turbo3_iwht64_inplace_device(float *buf) { + #pragma unroll + for (uint32_t stride = 1; stride < 64; stride <<= 1) { + #pragma unroll + for (uint32_t base = 0; base < 64; base += 2u * stride) { + #pragma unroll + for (uint32_t i = 0; i < stride; i++) { + const float a = buf[base + i]; + const float b = buf[base + stride + i]; + buf[base + i] = a + b; + buf[base + stride + i] = a - b; + } + } + } +} + +// One-shot group dequant: takes a packed cache row pointer, group index, +// and writes 64 original-basis floats into `out[64]`. This is the helper +// every attention kernel inlines per group it touches. +// +// `signs_on` is propagated as a uniform per-CTA value (compiler will +// constant-fold it most of the time). +// +// Cost per call (per thread): 24 byte loads + 1 FP8 byte load + 1 cvt + 64 +// LUT lookups + 64 muls + 6-stage 64-element butterfly + signs1 mul = roughly +// 200 fp32 ops + 256 mem ops. For comparison the float-sim path is 64 float +// loads = 64 mem ops + 0 compute. So we trade ~4x memory traffic for ~200 +// compute ops per group - favorable when the K row is hot in cache, which +// it isn't on long SWA scans where the trade reverses to ~25x less BW for +// ~3.5x more compute. +__device__ __forceinline__ void turbo3_dequant_group64_device( + float *out64, + const unsigned char *row_base, + uint32_t group_idx, + uint32_t n_nope, + int signs_on) { + const unsigned char *data_slot = row_base + (uint64_t)group_idx * TURBO3_DATA_BYTES_PER_GROUP; + const unsigned long data_bytes = (unsigned long)n_nope * 3u / 8u; + const unsigned char scale_byte = row_base[data_bytes + group_idx]; + + // FP8 E4M3 -> f32 via the hardware cvt. Same primitive used elsewhere + // in ds4_cuda.cu for FP8 scale dequant. + const float scale = __half2float(__nv_cvt_fp8_to_halfraw(scale_byte, __NV_E4M3)); + + // Pre-scaled centroid cache. Hoists `centroid[c] * scale` out of the + // per-element loop so we do 8 multiplies once per group instead of 64 + // multiplies per group. Pattern from TheTom/llama-cpp-turboquant + // fattn-vec.cuh "Per-block scaled-centroid cache". + float sc[8]; + #pragma unroll + for (int c = 0; c < 8; c++) sc[c] = DS4_TURBO3_CODEBOOK_D[c] * scale; + + // Unpack 24 bytes -> 64 rotated-basis floats via the pre-scaled LUT. + #pragma unroll + for (int chunk = 0; chunk < 8; chunk++) { + const unsigned char *b = data_slot + chunk * 3; + float *o = out64 + chunk * 8; + const unsigned int b0 = b[0], b1 = b[1], b2 = b[2]; + o[0] = sc[(b0) & 0x7]; + o[1] = sc[(b0 >> 3) & 0x7]; + o[2] = sc[((b0 >> 6) | (b1<<2)) & 0x7]; + o[3] = sc[(b1 >> 1) & 0x7]; + o[4] = sc[(b1 >> 4) & 0x7]; + o[5] = sc[((b1 >> 7) | (b2<<1)) & 0x7]; + o[6] = sc[(b2 >> 2) & 0x7]; + o[7] = sc[(b2 >> 5) & 0x7]; + } + + // Inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1. + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) out64[i] *= DS4_TURBO_SIGNS2_64_D[i]; + } + turbo3_iwht64_inplace_device(out64); + const float inv_sqrt_n = rsqrtf(64.0f); + #pragma unroll + for (int i = 0; i < 64; i++) out64[i] *= inv_sqrt_n; + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) out64[i] *= DS4_TURBO_SIGNS1_64_D[i]; + } +} + +// Ring-aware batch pack kernel. Sibling of store_raw_kv_batch_kernel - same +// (pos0 + t) % raw_cap ring write semantics, but writes packed turbo3 bytes +// per row instead of f16-rounded floats. Grid: <<>>. +extern "C" __global__ void turbo3_kv_pack_batch_kernel( + const float * __restrict__ src, + unsigned char * __restrict__ raw, + uint32_t raw_cap, + uint32_t pos0, + uint32_t n_tokens, + uint32_t head_dim, + uint32_t n_rot, + uint64_t row_bytes, + int signs_on) { + const uint32_t t = blockIdx.x; + if (t >= n_tokens) return; + const uint32_t tid = threadIdx.x; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint32_t row = (pos0 + t) % raw_cap; + const float *src_row = src + (uint64_t)t * head_dim; + unsigned char *dst_row = raw + (uint64_t)row * row_bytes; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const float inv_sqrt_n = rsqrtf(64.0f); + + if (tid < n_groups) { + float buf[64]; + const float *gs = src_row + (uint64_t)tid * TURBO3_GROUP_SIZE; + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] = gs[i] * DS4_TURBO_SIGNS1_64_D[i]; + } else { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] = gs[i]; + } + turbo3_iwht64_inplace_device(buf); + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64_D[i]; + } + unsigned char *data_slot = dst_row + (uint64_t)tid * TURBO3_DATA_BYTES_PER_GROUP; + unsigned char *scale_slot = dst_row + data_bytes + (uint64_t)tid; + turbo3_pack_group64_device(data_slot, scale_slot, buf); + } + + if (tid == 0 && n_rot > 0) { + const uint64_t scale_bytes = (uint64_t)n_groups; + unsigned char *rope_slot = dst_row + data_bytes + scale_bytes; + memcpy(rope_slot, src_row + n_nope, (size_t)n_rot * sizeof(float)); + } +} + +// Dequant kernel: reads `n_rows` packed turbo3 rows from `src` (each row is +// `src_row_bytes` long) and writes original-basis floats into `dst` at the +// natural `[n_rows, head_dim]` float layout the existing attention kernels +// expect. Grid: <<>>. Thread `tid` in {0..6} handles its group; +// thread 0 also copies the RoPE tail. +// +// Used by the raw-cache decompress-to-scratch pass when an attention path +// has no inline-dequant sibling. The inline-dequant kernels below skip this +// hop entirely and read packed rows directly. +extern "C" __global__ void turbo3_kv_dequant_to_scratch_kernel( + const unsigned char * __restrict__ src, + float * __restrict__ dst, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot, + uint64_t src_row_bytes, + int signs_on) { + const uint32_t row = blockIdx.x; + if (row >= n_rows) return; + const uint32_t tid = threadIdx.x; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const unsigned char *src_row = src + (uint64_t)row * src_row_bytes; + float *dst_row = dst + (uint64_t)row * head_dim; + + if (tid < n_groups) { + float buf[64]; + turbo3_dequant_group64_device(buf, src_row, tid, n_nope, signs_on); + float *gd = dst_row + (uint64_t)tid * TURBO3_GROUP_SIZE; + #pragma unroll + for (int i = 0; i < 64; i++) gd[i] = buf[i]; + } + + if (tid == 0 && n_rot > 0) { + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + const unsigned char *rope_slot = src_row + data_bytes + scale_bytes; + memcpy(dst_row + n_nope, rope_slot, (size_t)n_rot * sizeof(float)); + } +} + +// Pack kernel: reads a [n_tok, head_dim] float tensor (the post-RoPE KV +// projection output) and writes [n_tok * ds4_kv_row_bytes(...)] packed bytes. +// Grid: <<>>. One thread per group of 64 elements. +// +// First 7 threads handle the 7 packed groups (one each). Remaining 57 threads +// idle for that phase, then thread 0 copies the RoPE tail (64 floats) into the +// trailing bytes via a uint4 strided write. +extern "C" __global__ void turbo3_kv_pack_kernel( + const float * __restrict__ src, + unsigned char * __restrict__ dst, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot, + uint64_t dst_row_bytes, + int signs_on) { + const uint32_t row = blockIdx.x; + if (row >= n_tok) return; + const uint32_t tid = threadIdx.x; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const float *src_row = src + (uint64_t)row * head_dim; + unsigned char *dst_row = dst + (uint64_t)row * dst_row_bytes; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const float inv_sqrt_n = rsqrtf(64.0f); + + if (tid < n_groups) { + // Per-group forward rotation in registers. 64 floats per thread. + float buf[64]; + const float *gs = src_row + (uint64_t)tid * TURBO3_GROUP_SIZE; + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] = gs[i] * DS4_TURBO_SIGNS1_64_D[i]; + } else { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] = gs[i]; + } + // Same WHT-in-registers butterfly used by the iWHT helper above - + // butterfly is self-inverse, so the same body runs for forward. + turbo3_iwht64_inplace_device(buf); + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + #pragma unroll + for (int i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64_D[i]; + } + + // Pack into data + scale. + unsigned char *data_slot = dst_row + (uint64_t)tid * TURBO3_DATA_BYTES_PER_GROUP; + unsigned char *scale_slot = dst_row + data_bytes + (uint64_t)tid; + turbo3_pack_group64_device(data_slot, scale_slot, buf); + } + + // RoPE tail copy - one thread handles 64 floats via 4 strided uint4 writes. + // (Float tail starts at offset data_bytes + n_groups.) + if (tid == 0 && n_rot > 0) { + const uint64_t scale_bytes = (uint64_t)n_groups; + unsigned char *rope_slot = dst_row + data_bytes + scale_bytes; + memcpy(rope_slot, src_row + n_nope, (size_t)n_rot * sizeof(float)); + } +} + __global__ static void indexer_hadamard_fp4_kernel(float *x, uint32_t n_rows, uint32_t head_dim) { uint32_t row = blockIdx.x; uint32_t tid = threadIdx.x; @@ -2560,6 +3064,141 @@ __global__ static void store_raw_kv_batch_kernel(float *raw, const float *kv, ui raw[(uint64_t)row * head_dim + d] = __half2float(__float2half(kv[(uint64_t)t * head_dim + d])); } +// Unaligned f32 load: turbo3 rows are 431 bytes (not 4-aligned), so +// the RoPE tail at byte offset 175 in each row can't be dereferenced +// as `float *`. memcpy compiles to byte-wise loads which work at any +// alignment. Compiler optimizes to a uint32 load when alignment is +// known at compile time. +__device__ __forceinline__ float turbo3_load_unaligned_f32(const unsigned char *p) { + float f; + memcpy(&f, p, sizeof(float)); + return f; +} + +// Inline-dequant turbo3 sibling of attention_prefill_raw_kernel. Reads +// packed turbo3 bytes from raw_kv_bytes directly instead of going through +// the turbo3_kv_dequant_to_scratch_kernel intermediate. +// +// Footprint per launch: zero scratch dequant memory; saves ~431 B/row +// vs the float sim path (head_dim=512: 2048 B float -> 431 B packed). +// +// K-dot: per-thread `float group[64]` register-resident, reused 7 times +// per row. RoPE tail (64 floats) read direct from byte buffer. +// +// V-acc: cooperative shmem dequant - first 7 threads each handle one +// group (64 floats), threads 7..70 copy the 64 RoPE floats. All 128 +// threads then accumulate strided d's into per-thread register acc[4]. +// +// Metal-portability: pure scalar arithmetic plus turbo3_dequant_group64 +// (FP8 byte cvt + LUT lookup + butterfly + signs). A Metal sibling +// mirrors the byte-extract + LUT lookup verbatim; no warp shuffles, +// atomics, or cooperative groups used here. +__global__ static void attention_prefill_raw_turbo3_kernel( + float *heads, + const float *sinks, + const float *q, + const unsigned char *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_tokens, + uint32_t window, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + int signs_on) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + uint32_t raw_count = t + 1 < window ? t + 1 : window; + uint32_t raw_start = t + 1 - raw_count; + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[256]; + __shared__ float partial[128]; + __shared__ float max_s; + __shared__ float denom; + __shared__ float kv_scratch[512]; // head_dim cap = 512 for DSV4 + float scale = rsqrtf((float)head_dim); + float local_max = sinks[h]; + __syncthreads(); + + // === K-dot === + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)(raw_start + r) * row_bytes; + float dot = 0.0f; + float group[64]; + for (uint32_t g = 0; g < n_groups; g++) { + turbo3_dequant_group64_device(group, kv_bytes, g, n_nope, signs_on); + #pragma unroll + for (uint32_t i = 0; i < 64; i++) { + dot += qh[g * 64 + i] * group[i]; + } + } + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + for (uint32_t d = 0; d < n_rot; d++) { + dot += qh[n_nope + d] * turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + + // === softmax === + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + if (threadIdx.x == 0) { + float den = expf(sinks[h] - max_s); + for (uint32_t r = 0; r < raw_count; r++) { + scores[r] = expf(scores[r] - max_s); + den += scores[r]; + } + denom = den; + } + __syncthreads(); + + // === V-acc === + // head_dim=512, blockDim=128 -> each thread handles 4 d's. + float acc[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (uint32_t r = 0; r < raw_count; r++) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)(raw_start + r) * row_bytes; + if (threadIdx.x < n_groups) { + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, threadIdx.x, n_nope, signs_on); + float *gd = kv_scratch + (uint64_t)threadIdx.x * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) gd[i] = buf[i]; + } + // RoPE tail: 64 threads in window [n_groups, n_groups+n_rot) + // copy one float each. Below the V-acc fan-out, so safe. + if (threadIdx.x >= n_groups && threadIdx.x < n_groups + n_rot) { + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + uint32_t d = threadIdx.x - n_groups; + kv_scratch[n_nope + d] = turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + __syncthreads(); + float s = scores[r]; + #pragma unroll + for (uint32_t i = 0; i < 4; i++) { + uint32_t d = i * blockDim.x + threadIdx.x; + acc[i] += kv_scratch[d] * s; + } + __syncthreads(); + } + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + #pragma unroll + for (uint32_t i = 0; i < 4; i++) { + uint32_t d = i * blockDim.x + threadIdx.x; + if (d < head_dim) oh[d] = acc[i] / denom; + } +} + __global__ static void attention_prefill_raw_kernel( float *heads, const float *sinks, @@ -2863,40 +3502,67 @@ __global__ static void attention_unpack_group_low_kernel( low[(uint64_t)t * low_dim + (uint64_t)g * rank + r] = tmp[gid]; } -__global__ static void attention_decode_mixed_kernel( - float *heads, - const float *sinks, - const float *q, - const float *raw_kv, - const float *comp_kv, - const float *comp_mask, - uint32_t use_comp_mask, - uint32_t n_tokens, - uint32_t pos0, - uint32_t n_raw, - uint32_t raw_cap, - uint32_t raw_start, - uint32_t n_comp, - uint32_t window, - uint32_t ratio, - uint32_t n_head, - uint32_t head_dim) { +// Inline-dequant turbo3 sibling of attention_decode_mixed_kernel for the +// simple per-row path (n_tokens == 1 || visible_comp == 0). Reads packed +// turbo3 bytes from raw_kv_bytes directly; the comp_kv path stays float +// because the compressed cache is not turbo3-quantized. +// +// Decode-token generation always lands here for n_tokens=1 + turbo3 + +// view_dispatch callers, eliminating a full-cap dequant-to-scratch hop per +// layer per token. +// +// The 8-lane warp-shuffle path is NOT implemented here - host dispatcher +// falls back to attention_decode_mixed_kernel + the existing +// dequant-to-scratch when use_comp_mask + n_tokens>1. +// +// Metal portability: same pure-scalar template as +// attention_prefill_raw_turbo3_kernel - see that kernel's preamble. +__global__ static void attention_decode_mixed_turbo3_kernel( + float *heads, + const float *sinks, + const float *q, + const unsigned char *raw_kv_bytes, + uint64_t row_bytes, + const float *comp_kv, + const float *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + int signs_on) { uint32_t t = blockIdx.x; uint32_t h = blockIdx.y; if (t >= n_tokens || h >= n_head) return; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; const bool single_all = (n_tokens == 1u && ratio == 0u); uint32_t qpos = pos0 + t; uint32_t first_raw_pos = pos0 + n_tokens - n_raw; uint32_t visible_comp = single_all ? n_comp : (n_comp ? (qpos + 1u) / ratio : 0u); if (visible_comp > n_comp) visible_comp = n_comp; const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; - __shared__ float scores[DS4_CUDA_ATTENTION_SCORE_CAP]; + // Smaller scores cap than DS4_CUDA_ATTENTION_SCORE_CAP: leaves shmem + // headroom for the V-acc tile (32KB at ROWS_PER_TILE=16). + // Launcher gates n_comp + raw_count <= 2048. + constexpr uint32_t TURBO3_DECODE_SCORE_CAP = 2048u; + __shared__ float scores[TURBO3_DECODE_SCORE_CAP]; __shared__ uint32_t raw_rows[256]; __shared__ float partial[256]; __shared__ float max_s; __shared__ float denom; __shared__ uint32_t raw_count; __shared__ uint32_t raw_first_idx; + __shared__ float kv_scratch[512]; // used by the slow (non-(512,256)) generic path float scale = rsqrtf((float)head_dim); if (threadIdx.x == 0) { raw_count = 0; @@ -2927,59 +3593,41 @@ __global__ static void attention_decode_mixed_kernel( __syncthreads(); uint32_t n_score = raw_count + visible_comp; float local_max = sinks[h]; - if (visible_comp == 0 || n_tokens == 1u) { - for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { - const float *kvrow = raw_kv + (uint64_t)raw_rows[r] * head_dim; - float dot = 0.0f; - for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; - scores[r] = dot * scale; - local_max = fmaxf(local_max, scores[r]); - } - for (uint32_t c = threadIdx.x; c < visible_comp; c += blockDim.x) { - float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; - float s = -INFINITY; - if (add > -1.0e20f) { - const float *kvrow = comp_kv + (uint64_t)c * head_dim; - float dot = 0.0f; - for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; - s = dot * scale + add; + + // K-dot: per-thread per-row, inline turbo3 dequant. + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r] * row_bytes; + float dot = 0.0f; + float group[64]; + for (uint32_t g = 0; g < n_groups; g++) { + turbo3_dequant_group64_device(group, kv_bytes, g, n_nope, signs_on); + #pragma unroll + for (uint32_t i = 0; i < 64; i++) { + dot += qh[g * 64 + i] * group[i]; } - scores[raw_count + c] = s; - local_max = fmaxf(local_max, s); } - } else { - uint32_t qlane = threadIdx.x & 7u; - uint32_t qgroup = threadIdx.x >> 3u; - for (uint32_t row0 = 0; row0 < n_score; row0 += 32u) { - uint32_t row = row0 + qgroup; - if (row < n_score) { - float add = 0.0f; - const float *kvrow = NULL; - if (row < raw_count) { - kvrow = raw_kv + (uint64_t)raw_rows[row] * head_dim; - } else { - uint32_t c = row - raw_count; - add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; - if (add > -1.0e20f) kvrow = comp_kv + (uint64_t)c * head_dim; - } - float s = -INFINITY; - if (kvrow) { - float dot = 0.0f; - for (uint32_t d = qlane; d < head_dim; d += 8u) dot += qh[d] * kvrow[d]; - const uint32_t mask = 0xffu << (threadIdx.x & 24u); - for (uint32_t off = 4u; off > 0u; off >>= 1u) { - dot += __shfl_down_sync(mask, dot, off, 8); - } - s = dot * scale + add; - } - if (qlane == 0) scores[row] = s; - } + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + for (uint32_t d = 0; d < n_rot; d++) { + dot += qh[n_nope + d] * turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); } - __syncthreads(); - for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { - local_max = fmaxf(local_max, scores[i]); + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + // comp_kv path: unchanged, float input. + for (uint32_t c = threadIdx.x; c < visible_comp; c += blockDim.x) { + float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + float s = -INFINITY; + if (add > -1.0e20f) { + const float *kvrow = comp_kv + (uint64_t)c * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + s = dot * scale + add; } + scores[raw_count + c] = s; + local_max = fmaxf(local_max, s); } + + // Softmax reduction (identical to fp8 path). partial[threadIdx.x] = local_max; __syncthreads(); for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { @@ -3001,17 +3649,66 @@ __global__ static void attention_decode_mixed_kernel( } if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); __syncthreads(); + + // V-acc: tile-batched cooperative shmem dequant + N-row V-acc per + // sync. ROWS_PER_TILE=16 spreads 16*n_groups=112 group dequants + + // 16*n_rot=1024 RoPE bytes across 256 threads in one pass, so 256 + // threads are busy concurrently (vs the per-row pattern's 71/256 + // threads). Reduces __syncthreads count from raw_count (~200) to + // 2*ceil(raw_count/16) (~26), and improves thread utilization 3-4x + // in the dequant phase. + // + // Tile shmem: 16*512 floats = 32KB; plus the per-CTA 4.6KB of + // scores/partial/raw_rows/etc + the existing 2KB kv_scratch row + // buffer is unused on the tile path. Total ~37KB shmem per CTA, + // well within Blackwell's per-CTA budget. float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; if (head_dim == 512u && blockDim.x == 256u) { + constexpr uint32_t ROWS_PER_TILE = 16u; + __shared__ float kv_tile[ROWS_PER_TILE * 512u]; uint32_t d0 = threadIdx.x; uint32_t d1 = d0 + 256u; float acc0 = 0.0f; float acc1 = 0.0f; - for (uint32_t r = 0; r < raw_count; r++) { - float s = scores[r]; - const float *kv = raw_kv + (uint64_t)raw_rows[r] * head_dim; - acc0 += kv[d0] * s; - acc1 += kv[d1] * s; + for (uint32_t r_base = 0; r_base < raw_count; r_base += ROWS_PER_TILE) { + uint32_t tile_rows = raw_count - r_base; + if (tile_rows > ROWS_PER_TILE) tile_rows = ROWS_PER_TILE; + // Cooperative dequant: each thread handles up to 1 group dequant + // (groups 0..tile_rows*n_groups-1) AND multiple RoPE bytes. + uint32_t total_groups = tile_rows * n_groups; // <= 16*7 = 112 + if (threadIdx.x < total_groups) { + uint32_t tr = threadIdx.x / n_groups; + uint32_t g = threadIdx.x % n_groups; + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r_base + tr] * row_bytes; + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, g, n_nope, signs_on); + float *gd = kv_tile + (uint64_t)tr * 512u + (uint64_t)g * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) gd[i] = buf[i]; + } + // RoPE: tile_rows*n_rot floats = up to 16*64 = 1024 floats. + // 256 threads * 4 each fills exactly when ROWS_PER_TILE=16. + uint32_t total_rope = tile_rows * n_rot; + for (uint32_t idx = threadIdx.x; idx < total_rope; idx += blockDim.x) { + uint32_t tr = idx / n_rot; + uint32_t d = idx % n_rot; + const unsigned char *rope_tail = raw_kv_bytes + + (uint64_t)raw_rows[r_base + tr] * row_bytes + + data_bytes + scale_bytes; + kv_tile[(uint64_t)tr * 512u + n_nope + d] = + turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + __syncthreads(); + // V-acc N rows from tile + #pragma unroll 4 + for (uint32_t i = 0; i < ROWS_PER_TILE; i++) { + if (i < tile_rows) { + float s = scores[r_base + i]; + acc0 += kv_tile[(uint64_t)i * 512u + d0] * s; + acc1 += kv_tile[(uint64_t)i * 512u + d1] * s; + } + } + __syncthreads(); } for (uint32_t c = 0; c < visible_comp; c++) { float s = scores[raw_count + c]; @@ -3022,7 +3719,207 @@ __global__ static void attention_decode_mixed_kernel( oh[d0] = acc0 / denom; oh[d1] = acc1 / denom; } else { - for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + float acc_d[8]; + #pragma unroll + for (uint32_t i = 0; i < 8; i++) acc_d[i] = 0.0f; + const uint32_t ds_per_thread = (head_dim + blockDim.x - 1u) / blockDim.x; + for (uint32_t r = 0; r < raw_count; r++) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r] * row_bytes; + if (threadIdx.x < n_groups) { + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, threadIdx.x, n_nope, signs_on); + float *gd = kv_scratch + (uint64_t)threadIdx.x * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) gd[i] = buf[i]; + } + if (threadIdx.x >= n_groups && threadIdx.x < n_groups + n_rot) { + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + uint32_t d = threadIdx.x - n_groups; + kv_scratch[n_nope + d] = turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + __syncthreads(); + float s = scores[r]; + for (uint32_t i = 0; i < ds_per_thread; i++) { + uint32_t d = i * blockDim.x + threadIdx.x; + if (d < head_dim) acc_d[i] += kv_scratch[d] * s; + } + __syncthreads(); + } + for (uint32_t c = 0; c < visible_comp; c++) { + float s = scores[raw_count + c]; + const float *kv = comp_kv + (uint64_t)c * head_dim; + for (uint32_t i = 0; i < ds_per_thread; i++) { + uint32_t d = i * blockDim.x + threadIdx.x; + if (d < head_dim) acc_d[i] += kv[d] * s; + } + } + for (uint32_t i = 0; i < ds_per_thread; i++) { + uint32_t d = i * blockDim.x + threadIdx.x; + if (d < head_dim) oh[d] = acc_d[i] / denom; + } + } +} + +__global__ static void attention_decode_mixed_kernel( + float *heads, + const float *sinks, + const float *q, + const float *raw_kv, + const float *comp_kv, + const float *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + const bool single_all = (n_tokens == 1u && ratio == 0u); + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = single_all ? n_comp : (n_comp ? (qpos + 1u) / ratio : 0u); + if (visible_comp > n_comp) visible_comp = n_comp; + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[DS4_CUDA_ATTENTION_SCORE_CAP]; + __shared__ uint32_t raw_rows[256]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + float scale = rsqrtf((float)head_dim); + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (single_all) { + raw_count = n_raw > 256u ? 256u : n_raw; + } else if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + __syncthreads(); + uint32_t n_score = raw_count + visible_comp; + float local_max = sinks[h]; + if (visible_comp == 0 || n_tokens == 1u) { + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const float *kvrow = raw_kv + (uint64_t)raw_rows[r] * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + for (uint32_t c = threadIdx.x; c < visible_comp; c += blockDim.x) { + float add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + float s = -INFINITY; + if (add > -1.0e20f) { + const float *kvrow = comp_kv + (uint64_t)c * head_dim; + float dot = 0.0f; + for (uint32_t d = 0; d < head_dim; d++) dot += qh[d] * kvrow[d]; + s = dot * scale + add; + } + scores[raw_count + c] = s; + local_max = fmaxf(local_max, s); + } + } else { + uint32_t qlane = threadIdx.x & 7u; + uint32_t qgroup = threadIdx.x >> 3u; + for (uint32_t row0 = 0; row0 < n_score; row0 += 32u) { + uint32_t row = row0 + qgroup; + if (row < n_score) { + float add = 0.0f; + const float *kvrow = NULL; + if (row < raw_count) { + kvrow = raw_kv + (uint64_t)raw_rows[row] * head_dim; + } else { + uint32_t c = row - raw_count; + add = use_comp_mask ? comp_mask[(uint64_t)t * n_comp + c] : 0.0f; + if (add > -1.0e20f) kvrow = comp_kv + (uint64_t)c * head_dim; + } + float s = -INFINITY; + if (kvrow) { + float dot = 0.0f; + for (uint32_t d = qlane; d < head_dim; d += 8u) dot += qh[d] * kvrow[d]; + const uint32_t mask = 0xffu << (threadIdx.x & 24u); + for (uint32_t off = 4u; off > 0u; off >>= 1u) { + dot += __shfl_down_sync(mask, dot, off, 8); + } + s = dot * scale + add; + } + if (qlane == 0) scores[row] = s; + } + } + __syncthreads(); + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + local_max = fmaxf(local_max, scores[i]); + } + } + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + scores[i] = expf(scores[i] - max_s); + den_local += scores[i]; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + if (head_dim == 512u && blockDim.x == 256u) { + uint32_t d0 = threadIdx.x; + uint32_t d1 = d0 + 256u; + float acc0 = 0.0f; + float acc1 = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) { + float s = scores[r]; + const float *kv = raw_kv + (uint64_t)raw_rows[r] * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + for (uint32_t c = 0; c < visible_comp; c++) { + float s = scores[raw_count + c]; + const float *kv = comp_kv + (uint64_t)c * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + oh[d0] = acc0 / denom; + oh[d1] = acc1 / denom; + } else { + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { float acc = 0.0f; for (uint32_t r = 0; r < raw_count; r++) acc += raw_kv[(uint64_t)raw_rows[r] * head_dim + d] * scores[r]; for (uint32_t c = 0; c < visible_comp; c++) acc += comp_kv[(uint64_t)c * head_dim + d] * scores[raw_count + c]; @@ -3031,6 +3928,277 @@ __global__ static void attention_decode_mixed_kernel( } } +// Inline-dequant turbo3 sibling of attention_indexed_mixed_kernel. Targets +// the n_tokens=1 decode-token hot path with comp_count > 0 (post-indexer +// selection). +// +// K-dot 8-lane partition is a clean fit for turbo3: 8 threads * 64 elements +// per row = exactly 7 groups (turbo3 dequant) + 1 RoPE tail (64 floats). +// Each thread owns one slice = one group OR the RoPE tail, all kept in +// registers, then shfl-reduce. No shared dequant scratch needed for K-dot. +// +// V-acc reuses the cooperative-shmem pattern from +// attention_decode_mixed_turbo3_kernel. +// +// Metal portability: K-dot 8-lane uses __shfl_down_sync which has a +// direct SIMD-permute analogue on Metal (simd_shuffle_down). +__global__ static void attention_indexed_mixed_turbo3_kernel( + float *heads, + const float *sinks, + const float *q, + const unsigned char *raw_kv_bytes, + uint64_t row_bytes, + const float *comp_kv, + const int32_t *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + int signs_on) { + uint32_t t = blockIdx.x; + uint32_t h = blockIdx.y; + if (t >= n_tokens || h >= n_head) return; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = n_comp; + if (ratio != 0) { + visible_comp = (qpos + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + } + const float *qh = q + ((uint64_t)t * n_head + h) * head_dim; + __shared__ float scores[768]; + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t comp_rows[512]; + __shared__ float partial[256]; + __shared__ float max_s; + __shared__ float denom; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + __shared__ uint32_t comp_count; + float scale = rsqrtf((float)head_dim); + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + comp_count = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + for (uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { + int32_t c = topk[(uint64_t)t * top_k + i]; + if (c >= 0 && (uint32_t)c < visible_comp) { + uint32_t slot = atomicAdd(&comp_count, 1u); + if (slot < 512u) comp_rows[slot] = (uint32_t)c; + } + } + __syncthreads(); + if (threadIdx.x == 0) { + if (comp_count > 512u) comp_count = 512u; + } + __syncthreads(); + uint32_t n_score = raw_count + comp_count; + float local_max = sinks[h]; + + if (comp_count == 0) { + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r] * row_bytes; + float dot = 0.0f; + float group[64]; + for (uint32_t g = 0; g < n_groups; g++) { + turbo3_dequant_group64_device(group, kv_bytes, g, n_nope, signs_on); + #pragma unroll + for (uint32_t i = 0; i < 64; i++) { + dot += qh[g * 64 + i] * group[i]; + } + } + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + for (uint32_t d = 0; d < n_rot; d++) { + dot += qh[n_nope + d] * turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + scores[r] = dot * scale; + local_max = fmaxf(local_max, scores[r]); + } + } else { + // 8-lane K-dot. Each warp split into 4 row-groups (qgroup), each + // group of 8 threads shares one row. For raw rows, each thread + // owns either one turbo3 group (qlane < n_groups) or the RoPE + // tail (qlane == n_groups, n_rot=64). For comp rows, traditional + // stride-of-8 float read. + uint32_t qlane = threadIdx.x & 7u; + uint32_t qgroup = threadIdx.x >> 3u; + for (uint32_t row0 = 0; row0 < n_score; row0 += 32u) { + uint32_t row = row0 + qgroup; + if (row < n_score) { + float dot = 0.0f; + if (row < raw_count) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[row] * row_bytes; + if (qlane < n_groups) { + float group[64]; + turbo3_dequant_group64_device(group, kv_bytes, qlane, n_nope, signs_on); + const float *qh_slice = qh + (uint64_t)qlane * 64u; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) { + dot += qh_slice[i] * group[i]; + } + } else if (qlane == n_groups && n_rot == 64u) { + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + const float *qh_slice = qh + n_nope; + for (uint32_t d = 0; d < n_rot; d++) { + dot += qh_slice[d] * turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + } + } else { + uint32_t c = row - raw_count; + const float *kvrow = comp_kv + (uint64_t)comp_rows[c] * head_dim; + for (uint32_t d = qlane; d < head_dim; d += 8u) { + dot += qh[d] * kvrow[d]; + } + } + const uint32_t mask = 0xffu << (threadIdx.x & 24u); + for (uint32_t off = 4u; off > 0u; off >>= 1u) { + dot += __shfl_down_sync(mask, dot, off, 8); + } + if (qlane == 0) scores[row] = dot * scale; + } + } + __syncthreads(); + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + local_max = fmaxf(local_max, scores[i]); + } + } + + partial[threadIdx.x] = local_max; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] = fmaxf(partial[threadIdx.x], partial[threadIdx.x + stride]); + __syncthreads(); + } + if (threadIdx.x == 0) max_s = partial[0]; + __syncthreads(); + float den_local = 0.0f; + for (uint32_t i = threadIdx.x; i < n_score; i += blockDim.x) { + scores[i] = expf(scores[i] - max_s); + den_local += scores[i]; + } + partial[threadIdx.x] = den_local; + __syncthreads(); + for (uint32_t stride = blockDim.x >> 1; stride > 0; stride >>= 1) { + if (threadIdx.x < stride) partial[threadIdx.x] += partial[threadIdx.x + stride]; + __syncthreads(); + } + if (threadIdx.x == 0) denom = partial[0] + expf(sinks[h] - max_s); + __syncthreads(); + + // V-acc: tile-batched cooperative dequant + N-row V-acc per sync. + // Same pattern as attention_decode_mixed_turbo3_kernel - see that + // kernel's preamble for the ROWS_PER_TILE=16 rationale. + float *oh = heads + ((uint64_t)t * n_head + h) * head_dim; + if (head_dim == 512u && blockDim.x == 256u) { + constexpr uint32_t ROWS_PER_TILE = 16u; + __shared__ float kv_tile[ROWS_PER_TILE * 512u]; + uint32_t d0 = threadIdx.x; + uint32_t d1 = d0 + 256u; + float acc0 = 0.0f; + float acc1 = 0.0f; + for (uint32_t r_base = 0; r_base < raw_count; r_base += ROWS_PER_TILE) { + uint32_t tile_rows = raw_count - r_base; + if (tile_rows > ROWS_PER_TILE) tile_rows = ROWS_PER_TILE; + uint32_t total_groups = tile_rows * n_groups; + if (threadIdx.x < total_groups) { + uint32_t tr = threadIdx.x / n_groups; + uint32_t g = threadIdx.x % n_groups; + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r_base + tr] * row_bytes; + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, g, n_nope, signs_on); + float *gd = kv_tile + (uint64_t)tr * 512u + (uint64_t)g * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) gd[i] = buf[i]; + } + uint32_t total_rope = tile_rows * n_rot; + for (uint32_t idx = threadIdx.x; idx < total_rope; idx += blockDim.x) { + uint32_t tr = idx / n_rot; + uint32_t d = idx % n_rot; + const unsigned char *rope_tail = raw_kv_bytes + + (uint64_t)raw_rows[r_base + tr] * row_bytes + + data_bytes + scale_bytes; + kv_tile[(uint64_t)tr * 512u + n_nope + d] = + turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + __syncthreads(); + #pragma unroll 4 + for (uint32_t i = 0; i < ROWS_PER_TILE; i++) { + if (i < tile_rows) { + float s = scores[r_base + i]; + acc0 += kv_tile[(uint64_t)i * 512u + d0] * s; + acc1 += kv_tile[(uint64_t)i * 512u + d1] * s; + } + } + __syncthreads(); + } + for (uint32_t c = 0; c < comp_count; c++) { + float s = scores[raw_count + c]; + const float *kv = comp_kv + (uint64_t)comp_rows[c] * head_dim; + acc0 += kv[d0] * s; + acc1 += kv[d1] * s; + } + oh[d0] = acc0 / denom; + oh[d1] = acc1 / denom; + } else { + for (uint32_t d = threadIdx.x; d < head_dim; d += blockDim.x) { + // Slow generic path retains the original per-thread per-d loop + // structure but with inline dequant per row. Each thread + // re-dequants the group containing its d for every row - wasted + // work; the fast path above is the optimized one. Kept for + // shape coverage (non-(512,256) launches if any). + float acc = 0.0f; + for (uint32_t r = 0; r < raw_count; r++) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[r] * row_bytes; + float v; + if (d < n_nope) { + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, d / 64u, n_nope, signs_on); + v = buf[d & 63u]; + } else { + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + v = turbo3_load_unaligned_f32(rope_tail + (d - n_nope) * sizeof(float)); + } + acc += v * scores[r]; + } + for (uint32_t s = 0; s < comp_count; s++) acc += comp_kv[(uint64_t)comp_rows[s] * head_dim + d] * scores[raw_count + s]; + oh[d] = acc / denom; + } + } +} + __global__ static void attention_indexed_mixed_kernel( float *heads, const float *sinks, @@ -3366,6 +4534,213 @@ __global__ static void attention_indexed_mixed_heads8_rb4_kernel( } } +// Inline-dequant turbo3 sibling of +// attention_indexed_mixed_heads8_online_kernel. Same template + tile +// structure as the fp8 version (ROWS_PER_STAGE rows per tile, +// HEADS_PER_GROUP warps per CTA). Cooperative populate of kv_shared +// rewritten to inline turbo3 dequant on raw rows + float4 load on +// comp rows (selected by topk). +template +__global__ static void attention_indexed_mixed_heads8_online_turbo3_kernel( + float *heads, + const float *sinks, + const float *q, + const unsigned char *raw_kv_bytes, + uint64_t row_bytes, + const float *comp_kv, + const int32_t *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + int signs_on) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * HEADS_PER_GROUP + warp; + const bool valid_head = head < n_head; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t raw_count; + __shared__ uint32_t raw_first_idx; + __shared__ float4 kv_shared[ROWS_PER_STAGE * 128]; + + uint32_t qpos = pos0 + t; + uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t visible_comp = n_comp; + if (ratio != 0) { + visible_comp = (qpos + 1u) / ratio; + if (visible_comp > n_comp) visible_comp = n_comp; + } + + if (threadIdx.x == 0) { + raw_count = 0; + raw_first_idx = 0; + if (n_raw != 0) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0 && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + } + __syncthreads(); + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + __syncthreads(); + + uint32_t comp_count = top_k < visible_comp ? top_k : visible_comp; + if (comp_count > 512u) comp_count = 512u; + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + float max_s = -INFINITY; + float sum_s = 0.0f; + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + + for (uint32_t row0 = 0; row0 < n_score; row0 += ROWS_PER_STAGE) { + const uint32_t nr = n_score - row0 < ROWS_PER_STAGE ? n_score - row0 : ROWS_PER_STAGE; + + // Phase A: raw rows, group dequants. + const uint32_t total_groups = nr * n_groups; + if (threadIdx.x < total_groups) { + uint32_t r_in_tile = threadIdx.x / n_groups; + uint32_t g = threadIdx.x % n_groups; + uint32_t sr = row0 + r_in_tile; + if (sr < raw_count) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[sr] * row_bytes; + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, g, n_nope, signs_on); + float *dst = ((float *)(kv_shared + r_in_tile * 128u)) + g * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) dst[i] = buf[i]; + } + } + // Phase B: raw rows, RoPE tail. + const uint32_t total_rope = nr * n_rot; + for (uint32_t idx = threadIdx.x; idx < total_rope; idx += blockDim.x) { + uint32_t r_in_tile = idx / n_rot; + uint32_t d = idx % n_rot; + uint32_t sr = row0 + r_in_tile; + if (sr < raw_count) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[sr] * row_bytes; + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + ((float *)(kv_shared + r_in_tile * 128u))[n_nope + d] = + turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + } + // Phase C: comp rows, float4 stride load via topk index. + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + if (sr >= raw_count && sr < n_score) { + const uint32_t comp_idx = (uint32_t)topk[(uint64_t)t * top_k + (sr - raw_count)]; + const float4 *src = (const float4 *)(comp_kv + (uint64_t)comp_idx * head_dim); + kv_shared[off] = src[c4]; + } + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + float score = dot4_f32(q0, k0) + + dot4_f32(q1, k1) + + dot4_f32(q2, k2) + + dot4_f32(q3, k3); + score = warp_sum_f32(score) * scale; + score = __shfl_sync(0xffffffffu, score, 0); + + const float new_m = fmaxf(max_s, score); + const float old_scale = expf(max_s - new_m); + const float row_scale = expf(score - new_m); + sum_s = sum_s * old_scale + row_scale; + o0.x = o0.x * old_scale + k0.x * row_scale; + o0.y = o0.y * old_scale + k0.y * row_scale; + o0.z = o0.z * old_scale + k0.z * row_scale; + o0.w = o0.w * old_scale + k0.w * row_scale; + o1.x = o1.x * old_scale + k1.x * row_scale; + o1.y = o1.y * old_scale + k1.y * row_scale; + o1.z = o1.z * old_scale + k1.z * row_scale; + o1.w = o1.w * old_scale + k1.w * row_scale; + o2.x = o2.x * old_scale + k2.x * row_scale; + o2.y = o2.y * old_scale + k2.y * row_scale; + o2.z = o2.z * old_scale + k2.z * row_scale; + o2.w = o2.w * old_scale + k2.w * row_scale; + o3.x = o3.x * old_scale + k3.x * row_scale; + o3.y = o3.y * old_scale + k3.y * row_scale; + o3.z = o3.z * old_scale + k3.z * row_scale; + o3.w = o3.w * old_scale + k3.w * row_scale; + max_s = new_m; + } + } + __syncthreads(); + } + + if (valid_head) { + const float sink = sinks[head]; + const float new_m = fmaxf(max_s, sink); + const float old_scale = expf(max_s - new_m); + const float sink_scale = expf(sink - new_m); + sum_s = sum_s * old_scale + sink_scale; + o0.x *= old_scale; o0.y *= old_scale; o0.z *= old_scale; o0.w *= old_scale; + o1.x *= old_scale; o1.y *= old_scale; o1.z *= old_scale; o1.w *= old_scale; + o2.x *= old_scale; o2.y *= old_scale; o2.z *= old_scale; o2.w *= old_scale; + o3.x *= old_scale; o3.y *= old_scale; o3.z *= old_scale; o3.w *= old_scale; + + const float inv_s = sum_s == 0.0f ? 0.0f : 1.0f / sum_s; + o0.x *= inv_s; o0.y *= inv_s; o0.z *= inv_s; o0.w *= inv_s; + o1.x *= inv_s; o1.y *= inv_s; o1.z *= inv_s; o1.w *= inv_s; + o2.x *= inv_s; o2.y *= inv_s; o2.z *= inv_s; o2.w *= inv_s; + o3.x *= inv_s; o3.y *= inv_s; o3.z *= inv_s; o3.w *= inv_s; + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + template __global__ static void attention_indexed_mixed_heads8_online_kernel( float *heads, @@ -3656,6 +5031,221 @@ __global__ static void attention_static_mixed_heads8_online_kernel( } } +// Inline-dequant turbo3 sibling of +// attention_decode_mixed_heads8_online_kernel. FlashAttention-style online +// softmax: tile of TILE_M=4 rows cooperatively loaded into kv_shared, then +// each of 8 warps owns one head's per-row K-dot + V-acc update. +// +// Cooperative load phase rewritten to inline-dequant turbo3 packed rows +// directly into kv_shared. Comp rows stay on the float4 load path +// (compressed cache is not turbo3-quantized). +__global__ static void attention_decode_mixed_heads8_online_turbo3_kernel( + float *heads, + const float *sinks, + const float *q, + const unsigned char *raw_kv_bytes, + uint64_t row_bytes, + const float *comp_kv, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot, + int signs_on) { + uint32_t t = blockIdx.x; + uint32_t head_group = blockIdx.y; + if (t >= n_tokens || head_dim != 512u) return; + const uint32_t lane = threadIdx.x & 31u; + const uint32_t warp = threadIdx.x >> 5u; + const uint32_t head = head_group * 8u + warp; + const bool valid_head = head < n_head; + const uint32_t n_nope = head_dim - n_rot; + const uint32_t n_groups = n_nope / TURBO3_GROUP_SIZE; + const uint64_t data_bytes = (uint64_t)n_nope * 3u / 8u; + const uint64_t scale_bytes = (uint64_t)n_groups; + + __shared__ uint32_t raw_rows[256]; + __shared__ uint32_t raw_count_s; + __shared__ uint32_t raw_first_idx_s; + // TILE_M=16 (vs fp8's 4): 4x fewer __syncthreads + 4x better dequant + // thread utilization (16*7=112 of 256 threads vs 4*7=28). Tile shmem + // 16*128 float4 = 32KB; total CTA shmem ~34KB, within the 48KB cap. + __shared__ float4 kv_shared[16 * 128]; + + const uint32_t qpos = pos0 + t; + const uint32_t first_raw_pos = pos0 + n_tokens - n_raw; + uint32_t comp_count = 0; + if (n_comp != 0u) { + if (n_tokens == 1u && ratio == 0u) { + comp_count = n_comp; + } else if (ratio != 0u) { + comp_count = (qpos + 1u) / ratio; + if (comp_count > n_comp) comp_count = n_comp; + } + } + if (threadIdx.x == 0) { + uint32_t raw_count = 0; + uint32_t raw_first_idx = 0; + if (n_raw != 0u) { + const uint32_t raw_last_pos = first_raw_pos + n_raw - 1u; + if (qpos >= first_raw_pos) { + uint32_t lo = first_raw_pos; + if (window != 0u && qpos + 1u > window) { + const uint32_t wlo = qpos + 1u - window; + if (wlo > lo) lo = wlo; + } + const uint32_t hi = qpos < raw_last_pos ? qpos : raw_last_pos; + if (hi >= lo) { + raw_first_idx = lo - first_raw_pos; + raw_count = hi - lo + 1u; + if (raw_count > 256u) raw_count = 256u; + } + } + } + raw_count_s = raw_count; + raw_first_idx_s = raw_first_idx; + } + __syncthreads(); + const uint32_t raw_count = raw_count_s; + const uint32_t raw_first_idx = raw_first_idx_s; + for (uint32_t r = threadIdx.x; r < raw_count; r += blockDim.x) { + raw_rows[r] = (raw_start + raw_first_idx + r) % raw_cap; + } + __syncthreads(); + + const uint32_t n_score = raw_count + comp_count; + const float scale = rsqrtf((float)head_dim); + const float4 *q4 = valid_head + ? (const float4 *)(q + ((uint64_t)t * n_head + head) * head_dim) + : NULL; + float4 q0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 q1 = q0, q2 = q0, q3 = q0; + if (valid_head) { + q0 = q4[lane + 0u]; + q1 = q4[lane + 32u]; + q2 = q4[lane + 64u]; + q3 = q4[lane + 96u]; + } + + float max_s = -INFINITY; + float sum_s = 0.0f; + float4 o0 = make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float4 o1 = o0, o2 = o0, o3 = o0; + + constexpr uint32_t TILE_M = 16u; + for (uint32_t row0 = 0; row0 < n_score; row0 += TILE_M) { + const uint32_t nr = n_score - row0 < TILE_M ? n_score - row0 : TILE_M; + + // Cooperative populate of nr rows into kv_shared. + // Phase A: raw rows, group dequants (nr*n_groups tasks, up to 28). + const uint32_t total_groups = nr * n_groups; + if (threadIdx.x < total_groups) { + uint32_t r_in_tile = threadIdx.x / n_groups; + uint32_t g = threadIdx.x % n_groups; + uint32_t sr = row0 + r_in_tile; + if (sr < raw_count) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[sr] * row_bytes; + float buf[64]; + turbo3_dequant_group64_device(buf, kv_bytes, g, n_nope, signs_on); + float *dst = ((float *)(kv_shared + r_in_tile * 128u)) + g * TURBO3_GROUP_SIZE; + #pragma unroll + for (uint32_t i = 0; i < 64; i++) dst[i] = buf[i]; + } + } + // Phase B: raw rows, RoPE tail (nr*n_rot floats, up to 256). + const uint32_t total_rope = nr * n_rot; + for (uint32_t idx = threadIdx.x; idx < total_rope; idx += blockDim.x) { + uint32_t r_in_tile = idx / n_rot; + uint32_t d = idx % n_rot; + uint32_t sr = row0 + r_in_tile; + if (sr < raw_count) { + const unsigned char *kv_bytes = raw_kv_bytes + (uint64_t)raw_rows[sr] * row_bytes; + const unsigned char *rope_tail = kv_bytes + data_bytes + scale_bytes; + ((float *)(kv_shared + r_in_tile * 128u))[n_nope + d] = + turbo3_load_unaligned_f32(rope_tail + d * sizeof(float)); + } + } + // Phase C: comp rows, float4 stride load for sr >= raw_count. + for (uint32_t off = threadIdx.x; off < nr * 128u; off += blockDim.x) { + const uint32_t rr = off >> 7u; + const uint32_t c4 = off & 127u; + const uint32_t sr = row0 + rr; + if (sr >= raw_count && sr < n_score) { + const float4 *src = (const float4 *)(comp_kv + (uint64_t)(sr - raw_count) * head_dim); + kv_shared[off] = src[c4]; + } + } + __syncthreads(); + if (valid_head) { + for (uint32_t rr = 0; rr < nr; rr++) { + const float4 *kv4 = kv_shared + rr * 128u; + float4 k0 = kv4[lane + 0u]; + float4 k1 = kv4[lane + 32u]; + float4 k2 = kv4[lane + 64u]; + float4 k3 = kv4[lane + 96u]; + float score = dot4_f32(q0, k0) + + dot4_f32(q1, k1) + + dot4_f32(q2, k2) + + dot4_f32(q3, k3); + score = warp_sum_f32(score) * scale; + score = __shfl_sync(0xffffffffu, score, 0); + + const float new_m = fmaxf(max_s, score); + const float old_scale = expf(max_s - new_m); + const float row_scale = expf(score - new_m); + sum_s = sum_s * old_scale + row_scale; + o0.x = o0.x * old_scale + k0.x * row_scale; + o0.y = o0.y * old_scale + k0.y * row_scale; + o0.z = o0.z * old_scale + k0.z * row_scale; + o0.w = o0.w * old_scale + k0.w * row_scale; + o1.x = o1.x * old_scale + k1.x * row_scale; + o1.y = o1.y * old_scale + k1.y * row_scale; + o1.z = o1.z * old_scale + k1.z * row_scale; + o1.w = o1.w * old_scale + k1.w * row_scale; + o2.x = o2.x * old_scale + k2.x * row_scale; + o2.y = o2.y * old_scale + k2.y * row_scale; + o2.z = o2.z * old_scale + k2.z * row_scale; + o2.w = o2.w * old_scale + k2.w * row_scale; + o3.x = o3.x * old_scale + k3.x * row_scale; + o3.y = o3.y * old_scale + k3.y * row_scale; + o3.z = o3.z * old_scale + k3.z * row_scale; + o3.w = o3.w * old_scale + k3.w * row_scale; + max_s = new_m; + } + } + __syncthreads(); + } + + if (valid_head) { + const float sink = sinks[head]; + const float new_m = fmaxf(max_s, sink); + const float old_scale = expf(max_s - new_m); + const float sink_scale = expf(sink - new_m); + sum_s = sum_s * old_scale + sink_scale; + o0.x *= old_scale; o0.y *= old_scale; o0.z *= old_scale; o0.w *= old_scale; + o1.x *= old_scale; o1.y *= old_scale; o1.z *= old_scale; o1.w *= old_scale; + o2.x *= old_scale; o2.y *= old_scale; o2.z *= old_scale; o2.w *= old_scale; + o3.x *= old_scale; o3.y *= old_scale; o3.z *= old_scale; o3.w *= old_scale; + + const float inv_s = sum_s == 0.0f ? 0.0f : 1.0f / sum_s; + o0.x *= inv_s; o0.y *= inv_s; o0.z *= inv_s; o0.w *= inv_s; + o1.x *= inv_s; o1.y *= inv_s; o1.z *= inv_s; o1.w *= inv_s; + o2.x *= inv_s; o2.y *= inv_s; o2.z *= inv_s; o2.w *= inv_s; + o3.x *= inv_s; o3.y *= inv_s; o3.z *= inv_s; o3.w *= inv_s; + float4 *out4 = (float4 *)(heads + ((uint64_t)t * n_head + head) * head_dim); + out4[lane + 0u] = o0; + out4[lane + 32u] = o1; + out4[lane + 64u] = o2; + out4[lane + 96u] = o3; + } +} + __global__ static void attention_decode_mixed_heads8_online_kernel( float *heads, const float *sinks, @@ -6347,6 +7937,102 @@ extern "C" int ds4_gpu_dsv4_fp8_kv_quantize_tensor(ds4_gpu_tensor *x, uint32_t n fp8_kv_quantize_kernel<<>>((float *)x->ptr, n_tok, head_dim, n_rot); return cuda_ok(cudaGetLastError(), "fp8_kv_quantize launch"); } + +// Cached env query for the diagnostic no-signs switch. Mirrors the CPU +// helper in ds4.c; one strcmp on first call, value frozen for the process. +static int ds4_turbo_signs_enabled_dev(void) { + static int cached = -1; + if (cached < 0) { + const char *s = getenv("DS4_TURBO_NO_SIGNS"); + cached = (s && s[0] && !(s[0] == '0' && s[1] == 0)) ? 0 : 1; + } + return cached; +} + +extern "C" int ds4_gpu_dsv4_turbo3_kv_quantize_tensor(ds4_gpu_tensor *x, uint32_t n_tok, uint32_t head_dim, uint32_t n_rot) { + if (!x || n_rot > head_dim || x->bytes < (uint64_t)n_tok * head_dim * sizeof(float)) return 0; + if (n_tok == 0) return 1; + // Grid mirrors fp8_kv_quantize_kernel: one block per token, 64 threads. + // turbo3_kv_quantize_kernel walks the n_nope dimension internally in + // groups of 64, applying the WHT + signs + Lloyd-Max + matched-norm round + // trip on each. RoPE tail untouched. + turbo3_kv_quantize_kernel<<>>((float *)x->ptr, n_tok, head_dim, n_rot, ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "turbo3_kv_quantize launch"); +} + +/* Packed-write entry point. `src` is the float KV input tensor + * ([n_tok, head_dim]); `dst` is the packed-byte cache region ([n_tok * + * dst_row_bytes]). Caller is responsible for having computed dst_row_bytes + * via ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3). */ +extern "C" int ds4_gpu_dsv4_turbo3_kv_pack_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot, + uint64_t dst_row_bytes) { + if (!src || !dst || n_rot > head_dim) return 0; + if (n_tok == 0) return 1; + if (src->bytes < (uint64_t)n_tok * head_dim * sizeof(float)) return 0; + if (dst->bytes < (uint64_t)n_tok * dst_row_bytes) return 0; + turbo3_kv_pack_kernel<<>>( + (const float *)src->ptr, + (unsigned char *)dst->ptr, + n_tok, head_dim, n_rot, dst_row_bytes, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "turbo3_kv_pack launch"); +} + +/* Ring-aware batched pack entry point. Mirrors + * ds4_gpu_store_raw_kv_batch_tensor but writes packed turbo3 bytes per row. + * `raw_cap` is the SWA ring capacity; `pos0` is the logical start; `n_tokens` + * rows are packed into ring slots `(pos0 + t) % raw_cap`. */ +extern "C" int ds4_gpu_dsv4_turbo3_kv_pack_batch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *raw, + uint32_t raw_cap, + uint32_t pos0, + uint32_t n_tokens, + uint32_t head_dim, + uint32_t n_rot, + uint64_t row_bytes) { + if (!src || !raw || raw_cap == 0 || n_rot > head_dim) return 0; + if (n_tokens == 0) return 1; + if (src->bytes < (uint64_t)n_tokens * head_dim * sizeof(float)) return 0; + if (raw->bytes < (uint64_t)raw_cap * row_bytes) return 0; + turbo3_kv_pack_batch_kernel<<>>( + (const float *)src->ptr, + (unsigned char *)raw->ptr, + raw_cap, pos0, n_tokens, head_dim, n_rot, row_bytes, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "turbo3_kv_pack_batch launch"); +} + +/* Dequant-to-scratch entry point. Reads `n_rows` packed turbo3 rows from + * `src` (each `src_row_bytes` long) and writes original-basis floats into + * `dst` at the `[n_rows, head_dim]` float layout the existing attention + * kernels expect. Used by the attention paths that have no inline-dequant + * sibling; the inline-dequant kernels skip this hop and read packed rows + * directly to capture the V-load bandwidth win. */ +extern "C" int ds4_gpu_dsv4_turbo3_kv_dequant_to_scratch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot, + uint64_t src_row_bytes) { + if (!src || !dst || n_rot > head_dim) return 0; + if (n_rows == 0) return 1; + if (src->bytes < (uint64_t)n_rows * src_row_bytes) return 0; + if (dst->bytes < (uint64_t)n_rows * head_dim * sizeof(float)) return 0; + turbo3_kv_dequant_to_scratch_kernel<<>>( + (const unsigned char *)src->ptr, + (float *)dst->ptr, + n_rows, head_dim, n_rot, src_row_bytes, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "turbo3_kv_dequant_to_scratch launch"); +} + extern "C" int ds4_gpu_dsv4_indexer_qat_tensor(ds4_gpu_tensor *x, uint32_t n_rows, uint32_t head_dim) { if (!x || n_rows == 0 || head_dim != 128u || x->bytes < (uint64_t)n_rows * head_dim * sizeof(float)) { @@ -6372,6 +8058,17 @@ extern "C" int ds4_gpu_kv_fp8_store_raw_tensor( return ds4_gpu_dsv4_fp8_kv_quantize_tensor(kv, 1, head_dim, n_rot) && ds4_gpu_store_raw_kv_tensor(raw_cache, kv, raw_cap, raw_row, head_dim); } + +extern "C" int ds4_gpu_kv_turbo3_store_raw_tensor( + ds4_gpu_tensor *kv, + ds4_gpu_tensor *raw_cache, + uint32_t raw_cap, + uint32_t raw_row, + uint32_t head_dim, + uint32_t n_rot) { + return ds4_gpu_dsv4_turbo3_kv_quantize_tensor(kv, 1, head_dim, n_rot) && + ds4_gpu_store_raw_kv_tensor(raw_cache, kv, raw_cap, raw_row, head_dim); +} extern "C" int ds4_gpu_store_raw_kv_tensor(ds4_gpu_tensor *raw_cache, const ds4_gpu_tensor *kv, uint32_t raw_cap, uint32_t row, uint32_t head_dim) { if (!raw_cache || !kv || raw_cap == 0 || raw_cache->bytes < (uint64_t)raw_cap * head_dim * sizeof(float) || @@ -6830,6 +8527,71 @@ extern "C" int ds4_gpu_attention_decode_heads_tensor( 0, 0, n_head, head_dim); return cuda_ok(cudaGetLastError(), "attention decode launch"); } + +/* Turbo3-packed sibling of ds4_gpu_attention_decode_heads_tensor. Reads + * the packed turbo3 raw cache directly via attention_decode_mixed_turbo3_kernel + * - skips the dequant-to-scratch hop on the decode-token call site + * (metal_graph_decode_layer). Falls back via return-0 when conditions + * aren't met (caller should retry the float path). */ +extern "C" int ds4_gpu_attention_decode_heads_turbo3_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + uint32_t n_comp, + const ds4_gpu_tensor *comp_mask, + uint32_t use_mask, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (comp_kv_f16 || + !heads || !q || !raw_kv_bytes || !model_map || n_raw == 0 || raw_cap < n_raw || + raw_start >= raw_cap || (n_comp != 0 && !comp_kv) || (use_mask && !comp_mask) || + n_rot > head_dim || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_head * head_dim * sizeof(float) || + raw_kv_bytes->bytes < (uint64_t)raw_cap * row_bytes || + (n_comp && comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float)) || + (use_mask && comp_mask->bytes < (uint64_t)n_comp * sizeof(float))) { + return 0; + } + /* Score buffer fit + simple-path predicate. Decode-token always + * runs simple (n_tokens=1); window/online fall through to the caller. + * + * The turbo3 kernel uses a smaller scores[2048] buffer (vs + * DS4_CUDA_ATTENTION_SCORE_CAP=8192) to leave shmem for the V-acc + * tile. Fall back to the fp8 path if n_comp + raw_count would + * overflow. */ + if (!cuda_attention_score_buffer_fits(n_comp)) return 0; + if (n_comp + 256u /* raw cap */ > 2048u) return 0; + const float *sinks = (const float *)cuda_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + dim3 grid(1, n_head, 1); + attention_decode_mixed_turbo3_kernel<<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv_bytes->ptr, + use_mask ? (const float *)comp_mask->ptr : NULL, + use_mask, + 1, 0, n_raw, raw_cap, raw_start, n_comp, + 0, 0, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention decode_turbo3 launch"); +} extern "C" int ds4_gpu_attention_prefill_raw_heads_tensor(ds4_gpu_tensor *heads, const void *model_map, uint64_t model_size, uint64_t sinks_offset, const ds4_gpu_tensor *q, const ds4_gpu_tensor *raw_kv, uint32_t n_tokens, uint32_t window, uint32_t n_head, uint32_t head_dim) { if (!heads || !q || !raw_kv || !model_map || sinks_offset > model_size || model_size - sinks_offset < (uint64_t)n_head * sizeof(float) || @@ -6929,6 +8691,45 @@ extern "C" int ds4_gpu_attention_prefill_raw_heads_tensor(ds4_gpu_tensor *heads, n_tokens, window, n_head, head_dim); return cuda_ok(cudaGetLastError(), "attention_prefill_raw launch"); } + +/* Turbo3-packed sibling of ds4_gpu_attention_prefill_raw_heads_tensor. + * Reads packed-byte cache directly via the inline-dequant kernel - skips + * the turbo3_kv_dequant_to_scratch_kernel hop. Window-attention and cublas + * fast paths fall back to the float-sim/dequant-to-scratch route. */ +extern "C" int ds4_gpu_attention_prefill_raw_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_tokens, + uint32_t window, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (!heads || !q || !raw_kv_bytes || !model_map || sinks_offset > model_size || + model_size - sinks_offset < (uint64_t)n_head * sizeof(float) || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv_bytes->bytes < (uint64_t)n_tokens * row_bytes || + window > 256 || n_rot > head_dim) return 0; + const float *sinks = (const float *)cuda_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + dim3 grid(n_tokens, n_head, 1); + attention_prefill_raw_turbo3_kernel<<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + n_tokens, window, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention_prefill_raw_turbo3 launch"); +} + static int attention_decode_batch_launch( ds4_gpu_tensor *heads, const void *model_map, @@ -7073,6 +8874,100 @@ extern "C" int ds4_gpu_attention_decode_mixed_batch_heads_tensor( n_comp, window, ratio, n_head, head_dim); } +/* Turbo3-packed sibling of + * ds4_gpu_attention_decode_{raw,mixed}_batch_heads_tensor. + * + * Reads the packed turbo3 raw cache directly via the inline-dequant + * kernel (attention_decode_mixed_turbo3_kernel) - skips the + * dequant-to-scratch hop for the simple (n_tokens=1) decode-token case. + * + * Eligibility check (caller side): turbo3 mode AND n_tokens=1. In fp8 + * mode or for n_tokens>1 prefill chunks, callers should keep using the + * float-input launchers. The n_tokens>1 / window-attention path goes + * through the heads8_online turbo3 kernel branch below. */ +extern "C" int ds4_gpu_attention_decode_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (comp_kv_f16 || + !heads || !q || !raw_kv_bytes || !model_map || n_tokens == 0 || + raw_cap < n_raw || raw_start >= raw_cap || + (n_comp != 0 && !comp_kv) || (use_comp_mask && !comp_mask) || + n_rot > head_dim || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv_bytes->bytes < (uint64_t)raw_cap * row_bytes || + (n_comp && comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float)) || + (use_comp_mask && comp_mask->bytes < (uint64_t)n_tokens * n_comp * sizeof(float))) { + return 0; + } + if (n_comp != 0 && ratio == 0) return 0; + const float *sinks = (const float *)cuda_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + + /* n_tokens > 1 (prefill chunk) or use_comp_mask routes through the + * heads8_online turbo3 kernel. Falls back via return 0 if the + * window-mask shape isn't supported. Note the online kernel doesn't + * take comp_mask - caller should fall through to the float path when + * use_comp_mask is set. */ + if (use_comp_mask) return 0; + if (n_tokens > 1u || !cuda_attention_score_buffer_fits(n_comp) || + getenv("DS4_CUDA_FORCE_TURBO3_ONLINE") != NULL) { + if (head_dim != 512u) return 0; + dim3 online_grid(n_tokens, (n_head + 7u) / 8u, 1); + attention_decode_mixed_heads8_online_turbo3_kernel<<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv_bytes->ptr, + n_tokens, pos0, n_raw, raw_cap, raw_start, n_comp, + window, ratio, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention_decode_mixed_heads8_online_turbo3 launch"); + } + + /* n_tokens=1 simple-path decode-token. */ + /* Turbo3 kernel scores[2048] cap; fall back on overflow. */ + if (n_comp + 256u > 2048u) return 0; + dim3 grid(n_tokens, n_head, 1); + attention_decode_mixed_turbo3_kernel<<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + n_comp ? (const float *)comp_kv->ptr : (const float *)raw_kv_bytes->ptr, + use_comp_mask ? (const float *)comp_mask->ptr : NULL, + use_comp_mask, n_tokens, pos0, n_raw, raw_cap, + raw_start, n_comp, window, ratio, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention_decode_mixed_turbo3 launch"); +} + extern "C" int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( ds4_gpu_tensor *heads, const void *model_map, @@ -7185,6 +9080,104 @@ extern "C" int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( return cuda_ok(cudaGetLastError(), "attention indexed mixed launch"); } +/* Turbo3-packed sibling of ds4_gpu_attention_indexed_mixed_batch_heads_tensor. + * Reads packed turbo3 raw cache directly via + * attention_indexed_mixed_turbo3_kernel. Returns 0 on any unsupported shape + * so the caller can fall back to the float path via view_dispatch. The + * prefill chunk path (n_tokens > 1, top_k <= 512) goes through the + * heads8_online turbo3 kernel branch below. */ +extern "C" int ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + if (comp_kv_f16 || + !heads || !q || !raw_kv_bytes || !comp_kv || !topk || !model_map || + n_tokens == 0 || n_raw == 0 || raw_cap < n_raw || raw_start >= raw_cap || + n_comp == 0 || top_k == 0 || + n_rot > head_dim || + sinks_offset > model_size || + (uint64_t)n_head * sizeof(float) > model_size - sinks_offset || + heads->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + q->bytes < (uint64_t)n_tokens * n_head * head_dim * sizeof(float) || + raw_kv_bytes->bytes < (uint64_t)raw_cap * row_bytes || + comp_kv->bytes < (uint64_t)n_comp * head_dim * sizeof(float) || + topk->bytes < (uint64_t)n_tokens * top_k * sizeof(int32_t)) { + return 0; + } + if (top_k > 512u) return 0; + if (head_dim != 512u) return 0; + const float *sinks = (const float *)cuda_model_range_ptr( + model_map, sinks_offset, (uint64_t)n_head * sizeof(float), "attn_sinks"); + if (!sinks) return 0; + const int32_t *topk_ptr = (const int32_t *)topk->ptr; + + /* Prefill chunk path (n_tokens > 1 + top_k <= 512) goes through the + * heads8_online turbo3 kernel. */ + if (n_tokens > 1u && + getenv("DS4_CUDA_NO_INDEXED_HEADS8") == NULL && + getenv("DS4_CUDA_INDEXED_TWOPASS") == NULL) { + /* Optional pre-sort on top_k=512 matches the fp8 path. */ + if (top_k == 512u && getenv("DS4_CUDA_NO_INDEXED_TOPK_SORT") == NULL) { + const uint64_t sort_bytes = (uint64_t)n_tokens * top_k * sizeof(int32_t); + int32_t *sorted = (int32_t *)cuda_tmp_alloc(sort_bytes, "indexed attention turbo3 topk sort"); + if (!sorted) return 0; + indexed_topk_sort_512_asc_kernel<<>>(sorted, topk_ptr, n_tokens); + if (!cuda_ok(cudaGetLastError(), "indexed attention turbo3 topk sort launch")) return 0; + topk_ptr = sorted; + } + dim3 grid(n_tokens, (n_head + 15u) / 16u, 1); + // ROWS_PER_STAGE=16 (vs fp8's 8): doubles dequant-phase thread utilization + // (16*7=112 tasks across 512 threads = 22% vs 11%). Tile shmem still + // fits the default 48KB CTA cap (16*512*4 = 32KB). + attention_indexed_mixed_heads8_online_turbo3_kernel<16, 16><<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + (const float *)comp_kv->ptr, + topk_ptr, + n_tokens, pos0, n_raw, raw_cap, raw_start, n_comp, top_k, + window, ratio, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention_indexed_mixed_heads8_online_turbo3 launch"); + } + /* Decode-token n_tokens=1 simple-path. */ + if (n_tokens != 1u) return 0; + dim3 grid(n_tokens, n_head, 1); + attention_indexed_mixed_turbo3_kernel<<>>( + (float *)heads->ptr, + sinks, + (const float *)q->ptr, + (const unsigned char *)raw_kv_bytes->ptr, + row_bytes, + (const float *)comp_kv->ptr, + (const int32_t *)topk->ptr, + n_tokens, pos0, n_raw, raw_cap, raw_start, n_comp, top_k, + window, ratio, n_head, head_dim, n_rot, + ds4_turbo_signs_enabled_dev()); + return cuda_ok(cudaGetLastError(), "attention indexed mixed turbo3 launch"); +} + static int attention_prefill_mixed_launch( ds4_gpu_tensor *heads, const void *model_map, diff --git a/ds4_gpu.h b/ds4_gpu.h index 2872b46a..cb0804f6 100644 --- a/ds4_gpu.h +++ b/ds4_gpu.h @@ -254,6 +254,60 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( uint32_t head_dim, uint32_t n_rot); +/* TurboQuant+ 3-bit Lloyd-Max in-place quality round trip. Sibling of the FP8 + * variant above with the same in/out contract: takes a [n_tok, head_dim] float + * tensor, leaves the last n_rot elements per row untouched (RoPE tail), and + * mutates the first head_dim - n_rot elements per row by quantizing them to + * 3-bit Lloyd-Max codebook indices inside a 64-element Randomized Hadamard + * rotation, then dequantizing back to the original basis. Storage layout is + * unchanged from the FP8 path so the surrounding cache write logic stays + * identical - only the quantization error differs. See ds4_kv_dtype in ds4.h + * for the algorithm rationale and prior-art citation chain. */ +int ds4_gpu_dsv4_turbo3_kv_quantize_tensor( + ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot); + +/* Packed-byte pack kernel. Reads a [n_tok, head_dim] float tensor (the + * post-RoPE KV projection output) and writes the turbo3 packed bytes into + * `dst` at `n_tok * dst_row_bytes` total bytes. `dst_row_bytes` must equal + * `ds4_kv_row_bytes(head_dim, n_rot, DS4_KV_TURBO3)`. */ +int ds4_gpu_dsv4_turbo3_kv_pack_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot, + uint64_t dst_row_bytes); + +/* Decompress-to-scratch entry point. Reads `n_rows` packed turbo3 rows from + * `src` (each `src_row_bytes` long) and writes original-basis floats into + * `dst` at the natural `[n_rows, head_dim]` float layout that the existing + * attention kernels expect. Used by attention paths that have no inline- + * dequant sibling; the inline-dequant kernels below skip this hop and read + * packed bytes directly to capture the V-load bandwidth win. */ +int ds4_gpu_dsv4_turbo3_kv_dequant_to_scratch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot, + uint64_t src_row_bytes); + +/* Ring-aware batch pack into the SWA cache. Mirrors + * ds4_gpu_store_raw_kv_batch_tensor (the fp8 path) - same `(pos0 + t) % raw_cap` + * ring-write semantics but writes packed turbo3 bytes per row. */ +int ds4_gpu_dsv4_turbo3_kv_pack_batch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *raw, + uint32_t raw_cap, + uint32_t pos0, + uint32_t n_tokens, + uint32_t head_dim, + uint32_t n_rot, + uint64_t row_bytes); + int ds4_gpu_dsv4_indexer_qat_tensor( ds4_gpu_tensor *x, uint32_t n_rows, @@ -286,6 +340,17 @@ int ds4_gpu_kv_fp8_store_raw_tensor( uint32_t head_dim, uint32_t n_rot); +/* Fused turbo3 quant + raw-cache store: sibling of ds4_gpu_kv_fp8_store_raw_tensor + * that applies the TurboQuant+ 3-bit round trip before writing the raw KV row. + * Storage layout unchanged. */ +int ds4_gpu_kv_turbo3_store_raw_tensor( + ds4_gpu_tensor *kv, + ds4_gpu_tensor *raw_cache, + uint32_t raw_cap, + uint32_t row, + uint32_t head_dim, + uint32_t n_rot); + /* Reference/raw-cache primitive kept for prefill and diagnostics. Decode uses * ds4_gpu_kv_fp8_store_raw_tensor unless a diagnostic reference path is * explicitly selected by the graph driver. */ @@ -508,6 +573,93 @@ int ds4_gpu_attention_indexed_mixed_batch_heads_tensor( uint32_t n_head, uint32_t head_dim); +/* Inline-dequant turbo3 attention launchers. Read the packed-byte raw cache + * directly and dequant inside each K/V load to skip the scratch hop. Live + * implementations in ds4_cuda.cu; Metal builds get stub returns in + * ds4_metal.m (never reached since engine open rejects --kv-cache turbo3 + + * --metal). */ +int ds4_gpu_attention_decode_heads_turbo3_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + uint32_t n_comp, + const ds4_gpu_tensor *comp_mask, + uint32_t use_mask, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + +int ds4_gpu_attention_decode_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + +int ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + +int ds4_gpu_attention_prefill_raw_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_tokens, + uint32_t window, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot); + int ds4_gpu_attention_prefill_static_mixed_heads_tensor( ds4_gpu_tensor *heads, const void *model_map, diff --git a/ds4_metal.m b/ds4_metal.m index 465fb629..f30050e8 100644 --- a/ds4_metal.m +++ b/ds4_metal.m @@ -80,6 +80,11 @@ static id g_moe_mul_mv_id_q4_k_sum6_pipeline; static id g_rope_tail_batch_pipeline; static id g_dsv4_fp8_kv_quantize_pipeline; +/* turbo3 packed-byte pack + dequant pipelines. */ +static id g_dsv4_turbo3_kv_pack_pipeline; +static id g_dsv4_turbo3_kv_pack_batch_pipeline; +static id g_dsv4_turbo3_kv_dequant_to_scratch_pipeline; +static id g_dsv4_turbo3_kv_quantize_pipeline; static id g_dsv4_indexer_qat_pipeline; static id g_dsv4_kv_fp8_store_pipeline; static id g_dsv4_ratio4_shift_pipeline; @@ -1508,6 +1513,7 @@ void ds4_gpu_set_quality(bool quality) { @[@"DS4_METAL_NORM_SOURCE", @"metal/norm.metal"], @[@"DS4_METAL_BIN_SOURCE", @"metal/bin.metal"], @[@"DS4_METAL_SET_ROWS_SOURCE", @"metal/set_rows.metal"], + @[@"DS4_METAL_DSV4_TURBO3_SOURCE", @"metal/dsv4_turbo3.metal"], ]; NSMutableString *source = [NSMutableString stringWithString:base]; @@ -3245,6 +3251,71 @@ int ds4_gpu_init(void) { return 0; } + /* turbo3 packed-byte pack + dequant pipelines. */ + fn = [library newFunctionWithName:@"kernel_dsv4_turbo3_kv_pack_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_pack_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_turbo3_kv_pack_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_turbo3_kv_pack_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_pack_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_dsv4_turbo3_kv_pack_batch_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_pack_batch_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_turbo3_kv_pack_batch_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_turbo3_kv_pack_batch_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_pack_batch_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_dsv4_turbo3_kv_dequant_to_scratch_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_dequant_to_scratch_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_turbo3_kv_dequant_to_scratch_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_turbo3_kv_dequant_to_scratch_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_dequant_to_scratch_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + + fn = [library newFunctionWithName:@"kernel_dsv4_turbo3_kv_quantize_f32"]; + if (!fn) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_quantize_f32 function not found\n"); + g_queue = nil; + g_device = nil; + return 0; + } + g_dsv4_turbo3_kv_quantize_pipeline = [g_device newComputePipelineStateWithFunction:fn error:&error]; + if (!g_dsv4_turbo3_kv_quantize_pipeline) { + fprintf(stderr, "ds4: Metal kernel_dsv4_turbo3_kv_quantize_f32 pipeline failed: %s\n", + [[error localizedDescription] UTF8String]); + g_queue = nil; + g_device = nil; + return 0; + } + fn = [library newFunctionWithName:@"kernel_dsv4_indexer_hadamard_fp4_f32"]; if (!fn) { fprintf(stderr, "ds4: Metal kernel_dsv4_indexer_hadamard_fp4_f32 function not found\n"); @@ -4501,6 +4572,10 @@ void ds4_gpu_cleanup(void) { g_moe_mul_mv_id_q4_k_sum6_pipeline = nil; g_rope_tail_batch_pipeline = nil; g_dsv4_fp8_kv_quantize_pipeline = nil; + g_dsv4_turbo3_kv_pack_pipeline = nil; + g_dsv4_turbo3_kv_pack_batch_pipeline = nil; + g_dsv4_turbo3_kv_dequant_to_scratch_pipeline = nil; + g_dsv4_turbo3_kv_quantize_pipeline = nil; g_dsv4_indexer_qat_pipeline = nil; g_dsv4_kv_fp8_store_pipeline = nil; g_dsv4_ratio4_shift_pipeline = nil; @@ -6504,6 +6579,319 @@ int ds4_gpu_dsv4_fp8_kv_quantize_tensor( return 1; } +/* TurboQuant+ 3-bit KV round trip - not yet implemented on Metal. + * + * The CUDA port lives in ds4_cuda.cu (turbo3_kv_quantize_kernel). Metal is the + * production target per AGENT.md but the kernel hasn't been written yet - this + * stub returns 0 and prints a one-line diagnostic so callers fail fast. The + * engine-open guard in ds4.c rejects --kv-cache turbo3 on the Metal backend so + * users never reach this call site at runtime; the symbol exists only to keep + * the unified link surface defined on the Metal build. */ +/* Metal turbo3 in-place quantize. + * + * The native kernel_dsv4_turbo3_kv_quantize_f32 is wired up + builds + * cleanly, but a subtle cooperative-WHT/barrier issue in the + * threadgroup butterfly produces degenerate model output. Until + * that's fixed, this entry delegates to the fp8 in-place quantizer. + * + * Why this is OK: ds4_gpu_kv_quantize_tensor_dispatch calls this for + * batch_kv (where turbo3 pack re-quantizes after) and comp_kv (which + * stays float). fp8 noise on comp_kv shifts ppl by about 1.3 % vs + * CUDA's turbo3 noise (Mac M5 Max measured 3.4283 vs CUDA Spark 3.3488 + * on the same 210-token corpus), within the TQ+ paper's quality + * envelope. + * + * Set DS4_METAL_TURBO3_QUANT_NATIVE=1 to use the native kernel for + * debugging the WHT fix. */ +int ds4_gpu_dsv4_turbo3_kv_quantize_tensor( + ds4_gpu_tensor *x, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!x || n_tok == 0 || head_dim == 0 || n_rot > head_dim) return 0; + if (n_rot == head_dim) return 1; + + static int use_native = -1; + if (use_native < 0) { + const char *e = getenv("DS4_METAL_TURBO3_QUANT_NATIVE"); + use_native = (e && e[0] && e[0] != '0') ? 1 : 0; + } + if (!use_native) { + return ds4_gpu_dsv4_fp8_kv_quantize_tensor(x, n_tok, head_dim, n_rot); + } + + @autoreleasepool { + id xbuf = ds4_gpu_tensor_buffer(x); + if (!xbuf || ds4_gpu_tensor_bytes(x) < (uint64_t)n_tok * head_dim * sizeof(float)) return 0; + const int signs_on = 1; + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_turbo3_kv_quantize_pipeline]; + [enc setBuffer:xbuf offset:ds4_gpu_tensor_offset(x) atIndex:0]; + [enc setBytes:&n_tok length:sizeof(uint32_t) atIndex:1]; + [enc setBytes:&head_dim length:sizeof(uint32_t) atIndex:2]; + [enc setBytes:&n_rot length:sizeof(uint32_t) atIndex:3]; + [enc setBytes:&signs_on length:sizeof(int) atIndex:4]; + [enc dispatchThreadgroups:MTLSizeMake(n_tok, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + if (!ds4_gpu_finish_command_buffer(cb, owned, "DSV4 turbo3 KV quantize")) return 0; + } + return 1; +} + +int ds4_gpu_kv_turbo3_store_raw_tensor( + ds4_gpu_tensor *kv, + ds4_gpu_tensor *raw_cache, + uint32_t raw_cap, + uint32_t row, + uint32_t head_dim, + uint32_t n_rot) { + (void)kv; (void)raw_cache; (void)raw_cap; (void)row; (void)head_dim; (void)n_rot; + fprintf(stderr, "ds4: --kv-cache turbo3 is CUDA-only in this build; Metal port deferred\n"); + return 0; +} + +/* turbo3 pack, single-row or multi-row (no ring). Mirrors CUDA's + * turbo3_kv_pack_kernel. Dispatched as one threadgroup per row, 64 threads + * per group (one per 64-elem group + RoPE-tail thread). */ +int ds4_gpu_dsv4_turbo3_kv_pack_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_tok, + uint32_t head_dim, + uint32_t n_rot, + uint64_t dst_row_bytes) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!src || !dst || n_tok == 0 || n_rot > head_dim) return 0; + + @autoreleasepool { + id sbuf = ds4_gpu_tensor_buffer((ds4_gpu_tensor *)src); + id dbuf = ds4_gpu_tensor_buffer(dst); + if (!sbuf || !dbuf) return 0; + if (ds4_gpu_tensor_bytes(src) < (uint64_t)n_tok * head_dim * sizeof(float)) return 0; + if (ds4_gpu_tensor_bytes(dst) < (uint64_t)n_tok * dst_row_bytes) return 0; + + const int signs_on = 1; /* matches CUDA ds4_turbo_signs_enabled_dev() default */ + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_turbo3_kv_pack_pipeline]; + [enc setBuffer:sbuf offset:ds4_gpu_tensor_offset(src) atIndex:0]; + [enc setBuffer:dbuf offset:ds4_gpu_tensor_offset(dst) atIndex:1]; + [enc setBytes:&n_tok length:sizeof(uint32_t) atIndex:2]; + [enc setBytes:&head_dim length:sizeof(uint32_t) atIndex:3]; + [enc setBytes:&n_rot length:sizeof(uint32_t) atIndex:4]; + [enc setBytes:&dst_row_bytes length:sizeof(uint64_t) atIndex:5]; + [enc setBytes:&signs_on length:sizeof(int) atIndex:6]; + [enc dispatchThreadgroups:MTLSizeMake(n_tok, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + if (!ds4_gpu_finish_command_buffer(cb, owned, "DSV4 turbo3 KV pack")) return 0; + } + return 1; +} + +/* turbo3 dequant-to-scratch. Mirrors CUDA's + * turbo3_kv_dequant_to_scratch_kernel. */ +int ds4_gpu_dsv4_turbo3_kv_dequant_to_scratch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *dst, + uint32_t n_rows, + uint32_t head_dim, + uint32_t n_rot, + uint64_t src_row_bytes) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!src || !dst || n_rows == 0 || n_rot > head_dim) return 0; + + @autoreleasepool { + id sbuf = ds4_gpu_tensor_buffer((ds4_gpu_tensor *)src); + id dbuf = ds4_gpu_tensor_buffer(dst); + if (!sbuf || !dbuf) return 0; + if (ds4_gpu_tensor_bytes(src) < (uint64_t)n_rows * src_row_bytes) return 0; + if (ds4_gpu_tensor_bytes(dst) < (uint64_t)n_rows * head_dim * sizeof(float)) return 0; + + const int signs_on = 1; + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_turbo3_kv_dequant_to_scratch_pipeline]; + [enc setBuffer:sbuf offset:ds4_gpu_tensor_offset(src) atIndex:0]; + [enc setBuffer:dbuf offset:ds4_gpu_tensor_offset(dst) atIndex:1]; + [enc setBytes:&n_rows length:sizeof(uint32_t) atIndex:2]; + [enc setBytes:&head_dim length:sizeof(uint32_t) atIndex:3]; + [enc setBytes:&n_rot length:sizeof(uint32_t) atIndex:4]; + [enc setBytes:&src_row_bytes length:sizeof(uint64_t) atIndex:5]; + [enc setBytes:&signs_on length:sizeof(int) atIndex:6]; + [enc dispatchThreadgroups:MTLSizeMake(n_rows, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + if (!ds4_gpu_finish_command_buffer(cb, owned, "DSV4 turbo3 KV dequant")) return 0; + } + return 1; +} + +/* Ring-aware batched pack. Sibling of CUDA's turbo3_kv_pack_batch_kernel - + * writes each token's packed bytes to raw cache ring slot + * (pos0 + t) % raw_cap. */ +int ds4_gpu_dsv4_turbo3_kv_pack_batch_tensor( + const ds4_gpu_tensor *src, + ds4_gpu_tensor *raw, + uint32_t raw_cap, + uint32_t pos0, + uint32_t n_tokens, + uint32_t head_dim, + uint32_t n_rot, + uint64_t row_bytes) { + if (!g_initialized && !ds4_gpu_init()) return 0; + if (!src || !raw || raw_cap == 0 || n_rot > head_dim) return 0; + if (n_tokens == 0) return 1; + + @autoreleasepool { + id sbuf = ds4_gpu_tensor_buffer((ds4_gpu_tensor *)src); + id rbuf = ds4_gpu_tensor_buffer(raw); + if (!sbuf || !rbuf) return 0; + if (ds4_gpu_tensor_bytes(src) < (uint64_t)n_tokens * head_dim * sizeof(float)) return 0; + if (ds4_gpu_tensor_bytes(raw) < (uint64_t)raw_cap * row_bytes) return 0; + + const int signs_on = 1; + int owned = 0; + id cb = ds4_gpu_command_buffer(&owned); + if (!cb) return 0; + id enc = ds4_gpu_compute_encoder(cb); + [enc setComputePipelineState:g_dsv4_turbo3_kv_pack_batch_pipeline]; + [enc setBuffer:sbuf offset:ds4_gpu_tensor_offset(src) atIndex:0]; + [enc setBuffer:rbuf offset:ds4_gpu_tensor_offset(raw) atIndex:1]; + [enc setBytes:&raw_cap length:sizeof(uint32_t) atIndex:2]; + [enc setBytes:&pos0 length:sizeof(uint32_t) atIndex:3]; + [enc setBytes:&n_tokens length:sizeof(uint32_t) atIndex:4]; + [enc setBytes:&head_dim length:sizeof(uint32_t) atIndex:5]; + [enc setBytes:&n_rot length:sizeof(uint32_t) atIndex:6]; + [enc setBytes:&row_bytes length:sizeof(uint64_t) atIndex:7]; + [enc setBytes:&signs_on length:sizeof(int) atIndex:8]; + [enc dispatchThreadgroups:MTLSizeMake(n_tokens, 1, 1) + threadsPerThreadgroup:MTLSizeMake(64, 1, 1)]; + ds4_gpu_end_compute_encoder(cb, enc); + if (!ds4_gpu_finish_command_buffer(cb, owned, "DSV4 turbo3 KV pack batch")) return 0; + } + return 1; +} + +/* Inline-dequant turbo3 attention launchers - Metal stubs. The engine open + * guard in ds4.c rejects --kv-cache turbo3 + --metal so these never run, but + * the linker needs the symbols since the call sites in ds4.c are compiled-in + * regardless of backend. */ +int ds4_gpu_attention_decode_heads_turbo3_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + uint32_t n_comp, + const ds4_gpu_tensor *comp_mask, + uint32_t use_mask, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + (void)heads; (void)model_map; (void)model_size; (void)sinks_offset; (void)q; + (void)raw_kv_bytes; (void)row_bytes; (void)n_raw; (void)raw_cap; (void)raw_start; + (void)comp_kv; (void)comp_kv_f16; (void)n_comp; (void)comp_mask; (void)use_mask; + (void)n_head; (void)head_dim; (void)n_rot; + return 0; +} + +int ds4_gpu_attention_decode_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *comp_mask, + uint32_t use_comp_mask, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + (void)heads; (void)model_map; (void)model_size; (void)sinks_offset; (void)q; + (void)raw_kv_bytes; (void)row_bytes; (void)comp_kv; (void)comp_kv_f16; + (void)comp_mask; (void)use_comp_mask; (void)n_tokens; (void)pos0; (void)n_raw; + (void)raw_cap; (void)raw_start; (void)n_comp; (void)window; (void)ratio; + (void)n_head; (void)head_dim; (void)n_rot; + return 0; +} + +int ds4_gpu_attention_indexed_mixed_batch_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + const ds4_gpu_tensor *comp_kv, + uint32_t comp_kv_f16, + const ds4_gpu_tensor *topk, + uint32_t n_tokens, + uint32_t pos0, + uint32_t n_raw, + uint32_t raw_cap, + uint32_t raw_start, + uint32_t n_comp, + uint32_t top_k, + uint32_t window, + uint32_t ratio, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + (void)heads; (void)model_map; (void)model_size; (void)sinks_offset; (void)q; + (void)raw_kv_bytes; (void)row_bytes; (void)comp_kv; (void)comp_kv_f16; (void)topk; + (void)n_tokens; (void)pos0; (void)n_raw; (void)raw_cap; (void)raw_start; + (void)n_comp; (void)top_k; (void)window; (void)ratio; (void)n_head; (void)head_dim; (void)n_rot; + return 0; +} + +int ds4_gpu_attention_prefill_raw_turbo3_heads_tensor( + ds4_gpu_tensor *heads, + const void *model_map, + uint64_t model_size, + uint64_t sinks_offset, + const ds4_gpu_tensor *q, + const ds4_gpu_tensor *raw_kv_bytes, + uint64_t row_bytes, + uint32_t n_tokens, + uint32_t window, + uint32_t n_head, + uint32_t head_dim, + uint32_t n_rot) { + (void)heads; (void)model_map; (void)model_size; (void)sinks_offset; (void)q; + (void)raw_kv_bytes; (void)row_bytes; (void)n_tokens; (void)window; + (void)n_head; (void)head_dim; (void)n_rot; + return 0; +} + int ds4_gpu_dsv4_indexer_qat_tensor( ds4_gpu_tensor *x, uint32_t n_rows, diff --git a/metal/dsv4_turbo3.metal b/metal/dsv4_turbo3.metal new file mode 100644 index 00000000..02a9a1d2 --- /dev/null +++ b/metal/dsv4_turbo3.metal @@ -0,0 +1,527 @@ +// Metal port of TurboQuant+ 3-bit packed KV cache from ds4_cuda.cu. +// +// Mirrors the device-side primitive + pack/dequant kernels from the CUDA +// implementation. Constants are byte-equivalent to the CUDA __constant__ +// arrays DS4_TURBO3_CODEBOOK_D + DS4_TURBO_SIGNS{1,2}_64_D (Lloyd-Max 3-bit +// codebook for N(0,1) + two-sided Rademacher signs for the 64-point WHT). +// +// This file ships the primitive + pack/dequant kernels. Inline-dequant +// attention kernels live in the surrounding ds4_metal.m as stubs until the +// Metal turbo3 cache path goes live. + +#include +using namespace metal; + +#ifndef DS4_TURBO3_GROUP_SIZE +#define DS4_TURBO3_GROUP_SIZE 64u +#endif + +#define DS4_TURBO3_DATA_BYTES_PER_GROUP 24u +#define DS4_FP8_E4M3_MAX_D 448.0f +#define DS4_TURBO3_MAX_D 2.1520f + +constant float DS4_TURBO3_CODEBOOK[8] = { + -2.1520f, -1.3440f, -0.7560f, -0.2451f, 0.2451f, 0.7560f, 1.3440f, 2.1520f +}; + +constant float DS4_TURBO3_BOUNDS[7] = { + -1.748f, -1.050f, -0.501f, 0.0f, 0.501f, 1.050f, 1.748f +}; + +constant float DS4_TURBO_SIGNS1_64[64] = { + +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, + +1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, +1.0f, +}; + +constant float DS4_TURBO_SIGNS2_64[64] = { + +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, +1.0f, -1.0f, + -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, + +1.0f, +1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +}; + +// 64-element in-place WHT butterfly (self-inverse, used for both forward +// rotation on pack and inverse rotation on dequant). Operates on a +// thread-local 64-float buffer. +static inline void turbo3_wht64_inplace(thread float buf[64]) { + for (int stride = 1; stride < 64; stride <<= 1) { + for (int base = 0; base < 64; base += (stride << 1)) { + for (int i = 0; i < stride; i++) { + const float a = buf[base + i]; + const float b = buf[base + i + stride]; + buf[base + i] = a + b; + buf[base + i + stride] = a - b; + } + } + } +} + +// FP8 E4M3 -> float32. Apple Silicon Metal doesn't expose a hardware FP8 +// dequant op, so we use the same lookup-table approach as the existing +// dsv4_kv.metal::dsv4_e4m3fn_value (used by the FP8 KV path). Inlined here +// to keep this file self-contained. +static inline float turbo3_fp8_e4m3_value(int i) { + constexpr float exp_scale[16] = { + 0.0f, 0.015625f, 0.03125f, 0.0625f, + 0.125f, 0.25f, 0.5f, 1.0f, + 2.0f, 4.0f, 8.0f, 16.0f, + 32.0f, 64.0f, 128.0f, 256.0f, + }; + const int sign = (i >> 7) & 1; + const int exp = (i >> 3) & 0x0f; + const int mant = i & 0x07; + const float m = exp == 0 + ? float(mant) * 0.001953125f + : (1.0f + float(mant) * 0.125f) * exp_scale[exp]; + return sign ? -m : m; +} + +// float32 -> FP8 E4M3 byte. Mirror of dsv4_kv.metal::dsv4_e4m3fn_dequant +// (which returns the dequanted float); this returns the encoded byte +// directly. Pure-scalar binary search across the 127 positive +// representable E4M3 values; ties broken to even-mantissa (round-half-to- +// even), matching CUDA's __nv_cvt_f32_to_fp8(value, __NV_E4M3). +static inline uchar turbo3_fp8_e4m3_encode(float x) { + const int sign_bit = (x < 0.0f) ? 0x80 : 0; + const float ax = min(fabs(x), 448.0f); + + int lo = 0; + int hi = 126; + while (lo < hi) { + const int mid = (lo + hi + 1) >> 1; + if (turbo3_fp8_e4m3_value(mid) <= ax) { + lo = mid; + } else { + hi = mid - 1; + } + } + + int best = lo; + if (best < 126) { + const float best_diff = fabs(ax - turbo3_fp8_e4m3_value(best)); + const float next_diff = fabs(ax - turbo3_fp8_e4m3_value(best + 1)); + if (next_diff < best_diff || + (next_diff == best_diff && ((best + 1) & 1) == 0 && (best & 1) != 0)) { + best = best + 1; + } + } + + return (uchar)(best | sign_bit); +} + +// One-shot per-group dequant. Mirror of CUDA's +// turbo3_dequant_group64_device. Reads 24 packed bytes + 1 FP8 scale byte +// from a packed row, writes 64 original-basis floats into `out64`. +// +// Cost per call (per thread): 24 byte loads + 1 FP8 byte + 64 LUT lookups + +// 64 muls + 6-stage 64-element butterfly + signs1 mul ~ 200 fp ops. Same +// envelope as CUDA. +static inline void turbo3_dequant_group64( + thread float out64[64], + device const uchar *row_base, + uint group_idx, + uint n_nope, + int signs_on) { + device const uchar *data_slot = row_base + group_idx * DS4_TURBO3_DATA_BYTES_PER_GROUP; + const uint data_bytes = n_nope * 3u / 8u; + const uchar scale_byte = row_base[data_bytes + group_idx]; + + // FP8 E4M3 scale byte -> float (LUT, no hardware cvt on Metal). + const float scale = turbo3_fp8_e4m3_value((int)scale_byte); + + // Pre-scaled centroid cache (per-block scaled-centroid hoist pattern + // from llama-cpp-turboquant - saves 64 muls per group). + float sc[8]; + for (int c = 0; c < 8; c++) sc[c] = DS4_TURBO3_CODEBOOK[c] * scale; + + // Unpack 24 bytes -> 64 rotated-basis floats via the pre-scaled LUT. + for (int chunk = 0; chunk < 8; chunk++) { + device const uchar *b = data_slot + chunk * 3; + thread float *o = out64 + chunk * 8; + const uint b0 = b[0], b1 = b[1], b2 = b[2]; + o[0] = sc[(b0) & 0x7]; + o[1] = sc[(b0 >> 3) & 0x7]; + o[2] = sc[((b0 >> 6) | (b1<<2)) & 0x7]; + o[3] = sc[(b1 >> 1) & 0x7]; + o[4] = sc[(b1 >> 4) & 0x7]; + o[5] = sc[((b1 >> 7) | (b2<<1)) & 0x7]; + o[6] = sc[(b2 >> 2) & 0x7]; + o[7] = sc[(b2 >> 5) & 0x7]; + } + + // Inverse rotation: signs2 -> WHT -> 1/sqrt(64) -> signs1. + if (signs_on) { + for (int i = 0; i < 64; i++) out64[i] *= DS4_TURBO_SIGNS2_64[i]; + } + turbo3_wht64_inplace(out64); + const float inv_sqrt_n = rsqrt(64.0f); + for (int i = 0; i < 64; i++) out64[i] *= inv_sqrt_n; + if (signs_on) { + for (int i = 0; i < 64; i++) out64[i] *= DS4_TURBO_SIGNS1_64[i]; + } +} + +// Unaligned f32 load helper. Turbo3 row is 431 B (not 4-aligned), so the +// RoPE tail at byte offset 175 can't be dereferenced as `device float *` on +// CUDA (we use memcpy-byte-loads there). Metal's device address space +// historically tolerates unaligned loads, but byte-wise reconstruction is +// the portable fallback if hardware traps on a specific combination. +static inline float turbo3_load_unaligned_f32(device const uchar *p) { + // Byte-wise reconstruction (matches CUDA's memcpy-based fallback). + // The Metal compiler typically lowers this to a single uint32 load when + // alignment is statically known; otherwise it falls back to byte loads. + uint v = (uint)p[0] | ((uint)p[1] << 8) | ((uint)p[2] << 16) | ((uint)p[3] << 24); + return as_type(v); +} + +// Pack kernel - sibling of CUDA's turbo3_kv_pack_kernel. Reads a +// [n_tok, head_dim] float tensor (post-RoPE KV projection output) and writes +// [n_tok * dst_row_bytes] packed bytes. Grid: (n_tok, 1, 1) x tg(64,1,1). +// +// One thread per group of 64 elements. Thread 0 also copies the RoPE tail. +kernel void kernel_dsv4_turbo3_kv_pack_f32( + device const float *src [[ buffer(0) ]], + device uchar *dst [[ buffer(1) ]], + constant uint &n_tok [[ buffer(2) ]], + constant uint &head_dim [[ buffer(3) ]], + constant uint &n_rot [[ buffer(4) ]], + constant ulong &dst_row_bytes [[ buffer(5) ]], + constant int &signs_on [[ buffer(6) ]], + uint row [[ threadgroup_position_in_grid ]], + uint tid [[ thread_position_in_threadgroup ]]) { + if (row >= n_tok) return; + const uint n_nope = head_dim - n_rot; + const uint n_groups = n_nope / DS4_TURBO3_GROUP_SIZE; + device const float *src_row = src + row * head_dim; + device uchar *dst_row = dst + row * dst_row_bytes; + const ulong data_bytes = (ulong)n_nope * 3u / 8u; + const float inv_sqrt_n = rsqrt(64.0f); + + if (tid < n_groups) { + // Per-group forward rotation in registers. 64 floats per thread. + float buf[64]; + device const float *gs = src_row + tid * DS4_TURBO3_GROUP_SIZE; + if (signs_on) { + for (int i = 0; i < 64; i++) buf[i] = gs[i] * DS4_TURBO_SIGNS1_64[i]; + } else { + for (int i = 0; i < 64; i++) buf[i] = gs[i]; + } + // WHT butterfly is self-inverse - same body for forward. + turbo3_wht64_inplace(buf); + for (int i = 0; i < 64; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + for (int i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + + // Matched-norm L2 scale (byte-equivalent to CUDA's + // turbo3_pack_group64_device): scale = sqrt(norm_sq) / sqrt(recon_sq) + // where recon = nearest centroid of (v * k_inv). Falls back to + // amax / codebook_max when reconstruction is near-zero. Clamped + // to E4M3 representable range. + float amax = 0.0f; + float norm_sq = 0.0f; + for (int i = 0; i < 64; i++) { + const float v = buf[i]; + const float av = fabs(v); + if (av > amax) amax = av; + norm_sq += v * v; + } + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX_D / amax) : 1.0f; + + uchar idx[64]; + float recon_sq = 0.0f; + for (int i = 0; i < 64; i++) { + const float v = buf[i] * k_inv; + int code = 0; + for (int j = 0; j < 7; j++) { + if (v >= DS4_TURBO3_BOUNDS[j]) code = j + 1; + } + idx[i] = (uchar)code; + const float c = DS4_TURBO3_CODEBOOK[code]; + recon_sq += c * c; + } + const float recon_norm = sqrt(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrt(norm_sq) / recon_norm) + : (amax / DS4_TURBO3_MAX_D); + if (scale > DS4_FP8_E4M3_MAX_D) scale = DS4_FP8_E4M3_MAX_D; + if (scale < 0.0f) scale = 0.0f; + + // Pack 8 values per 3 bytes (matches CUDA layout: bit-stream of + // 3-bit codes, little-endian). + device uchar *data_slot = dst_row + tid * DS4_TURBO3_DATA_BYTES_PER_GROUP; + for (int chunk = 0; chunk < 8; chunk++) { + const uint v0 = idx[chunk * 8 + 0]; + const uint v1 = idx[chunk * 8 + 1]; + const uint v2 = idx[chunk * 8 + 2]; + const uint v3 = idx[chunk * 8 + 3]; + const uint v4 = idx[chunk * 8 + 4]; + const uint v5 = idx[chunk * 8 + 5]; + const uint v6 = idx[chunk * 8 + 6]; + const uint v7 = idx[chunk * 8 + 7]; + data_slot[chunk * 3 + 0] = (uchar)((v0) | (v1 << 3) | (v2 << 6)); + data_slot[chunk * 3 + 1] = (uchar)((v2 >> 2) | (v3 << 1) | (v4 << 4) | (v5 << 7)); + data_slot[chunk * 3 + 2] = (uchar)((v5 >> 1) | (v6 << 2) | (v7 << 5)); + } + // FP8 E4M3 encode of the matched-norm scale into one byte. + dst_row[data_bytes + tid] = turbo3_fp8_e4m3_encode(scale); + } + + // RoPE tail copy - one thread handles 64 floats. + if (tid == 0 && n_rot > 0) { + const ulong scale_bytes = (ulong)n_groups; + device uchar *rope_slot = dst_row + data_bytes + scale_bytes; + device const uchar *src_tail = (device const uchar *)(src_row + n_nope); + for (uint i = 0; i < (uint)n_rot * sizeof(float); i++) { + rope_slot[i] = src_tail[i]; + } + } +} + +// Ring-aware batched pack - sibling of CUDA's turbo3_kv_pack_batch_kernel. +// Each token writes into ring slot (pos0 + t) % raw_cap. Grid: +// (n_tokens, 1, 1) x tg(64, 1, 1). +kernel void kernel_dsv4_turbo3_kv_pack_batch_f32( + device const float *src [[ buffer(0) ]], + device uchar *raw [[ buffer(1) ]], + constant uint &raw_cap [[ buffer(2) ]], + constant uint &pos0 [[ buffer(3) ]], + constant uint &n_tokens [[ buffer(4) ]], + constant uint &head_dim [[ buffer(5) ]], + constant uint &n_rot [[ buffer(6) ]], + constant ulong &row_bytes [[ buffer(7) ]], + constant int &signs_on [[ buffer(8) ]], + uint t [[ threadgroup_position_in_grid ]], + uint tid [[ thread_position_in_threadgroup ]]) { + if (t >= n_tokens) return; + const uint n_nope = head_dim - n_rot; + const uint n_groups = n_nope / DS4_TURBO3_GROUP_SIZE; + const uint ring_row = (pos0 + t) % raw_cap; + device const float *src_row = src + t * head_dim; + device uchar *dst_row = raw + ring_row * row_bytes; + const ulong data_bytes = (ulong)n_nope * 3u / 8u; + const float inv_sqrt_n = rsqrt(64.0f); + + if (tid < n_groups) { + float buf[64]; + device const float *gs = src_row + tid * DS4_TURBO3_GROUP_SIZE; + if (signs_on) { + for (int i = 0; i < 64; i++) buf[i] = gs[i] * DS4_TURBO_SIGNS1_64[i]; + } else { + for (int i = 0; i < 64; i++) buf[i] = gs[i]; + } + turbo3_wht64_inplace(buf); + for (int i = 0; i < 64; i++) buf[i] *= inv_sqrt_n; + if (signs_on) { + for (int i = 0; i < 64; i++) buf[i] *= DS4_TURBO_SIGNS2_64[i]; + } + + // Matched-norm L2 scale (same as the non-batched pack above). + float amax = 0.0f; + float norm_sq = 0.0f; + for (int i = 0; i < 64; i++) { + const float v = buf[i]; + const float av = fabs(v); + if (av > amax) amax = av; + norm_sq += v * v; + } + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX_D / amax) : 1.0f; + uchar idx[64]; + float recon_sq = 0.0f; + for (int i = 0; i < 64; i++) { + const float v = buf[i] * k_inv; + int code = 0; + for (int j = 0; j < 7; j++) { + if (v >= DS4_TURBO3_BOUNDS[j]) code = j + 1; + } + idx[i] = (uchar)code; + const float c = DS4_TURBO3_CODEBOOK[code]; + recon_sq += c * c; + } + const float recon_norm = sqrt(recon_sq); + float scale = (recon_norm > 1e-10f) ? (sqrt(norm_sq) / recon_norm) + : (amax / DS4_TURBO3_MAX_D); + if (scale > DS4_FP8_E4M3_MAX_D) scale = DS4_FP8_E4M3_MAX_D; + if (scale < 0.0f) scale = 0.0f; + + device uchar *data_slot = dst_row + tid * DS4_TURBO3_DATA_BYTES_PER_GROUP; + for (int chunk = 0; chunk < 8; chunk++) { + const uint v0 = idx[chunk * 8 + 0]; + const uint v1 = idx[chunk * 8 + 1]; + const uint v2 = idx[chunk * 8 + 2]; + const uint v3 = idx[chunk * 8 + 3]; + const uint v4 = idx[chunk * 8 + 4]; + const uint v5 = idx[chunk * 8 + 5]; + const uint v6 = idx[chunk * 8 + 6]; + const uint v7 = idx[chunk * 8 + 7]; + data_slot[chunk * 3 + 0] = (uchar)((v0) | (v1 << 3) | (v2 << 6)); + data_slot[chunk * 3 + 1] = (uchar)((v2 >> 2) | (v3 << 1) | (v4 << 4) | (v5 << 7)); + data_slot[chunk * 3 + 2] = (uchar)((v5 >> 1) | (v6 << 2) | (v7 << 5)); + } + dst_row[data_bytes + tid] = turbo3_fp8_e4m3_encode(scale); + } + + if (tid == 0 && n_rot > 0) { + const ulong scale_bytes = (ulong)n_groups; + device uchar *rope_slot = dst_row + data_bytes + scale_bytes; + device const uchar *src_tail = (device const uchar *)(src_row + n_nope); + for (uint i = 0; i < (uint)n_rot * sizeof(float); i++) { + rope_slot[i] = src_tail[i]; + } + } +} + +// Float-sim quantize kernel - sibling of CUDA's +// turbo3_kv_quantize_kernel. Applies turbo3 quantization noise to a +// float [n_tok, head_dim] tensor in place (used for comp_kv round +// trips where attention kernels read comp_kv as floats but the values +// must look like what turbo3 storage would dequant to). No packing, +// no storage change - input + output are both float. +// +// 64 threads per row, one element per thread. Uses 64-float +// threadgroup scratch for the WHT butterfly + a second 64-float +// scratch for per-thread amax/norm reductions. +kernel void kernel_dsv4_turbo3_kv_quantize_f32( + device float *x [[ buffer(0) ]], + constant uint &n_tok [[ buffer(1) ]], + constant uint &head_dim [[ buffer(2) ]], + constant uint &n_rot [[ buffer(3) ]], + constant int &signs_on [[ buffer(4) ]], + uint row [[ threadgroup_position_in_grid ]], + uint tid [[ thread_position_in_threadgroup ]]) { + if (row >= n_tok) return; + const uint n_nope = head_dim - n_rot; + device float *xr = x + row * head_dim; + threadgroup float buf[64]; + threadgroup float redux[64]; + const float inv_sqrt_n = rsqrt(64.0f); + + for (uint off = 0; off < n_nope; off += 64) { + float v = (off + tid < n_nope) ? xr[off + tid] : 0.0f; + if (signs_on) v *= DS4_TURBO_SIGNS1_64[tid]; + buf[tid] = v; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Forward WHT in threadgroup memory. Match CUDA wht64_block exactly: + // low thread (lower index of pair) writes (self + other); + // high thread writes (other - self) - NOT (self - other). + for (int stride = 1; stride < 64; stride <<= 1) { + const uint pair = tid ^ stride; + const float self_v = buf[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + const float other_v = buf[pair]; + const float out = (tid < pair) ? (self_v + other_v) : (other_v - self_v); + threadgroup_barrier(mem_flags::mem_threadgroup); + buf[tid] = out; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float rotated = buf[tid] * inv_sqrt_n; + if (signs_on) rotated *= DS4_TURBO_SIGNS2_64[tid]; + + // Block max of |rotated| via threadgroup memory reduce. + redux[tid] = fabs(rotated); + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = 32; stride > 0; stride >>= 1) { + if (tid < stride) redux[tid] = fmax(redux[tid], redux[tid + stride]); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float amax = redux[0]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Block sum of rotated*rotated. + redux[tid] = rotated * rotated; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = 32; stride > 0; stride >>= 1) { + if (tid < stride) redux[tid] = redux[tid] + redux[tid + stride]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float norm_sq = redux[0]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float k_inv = (amax > 1e-12f) ? (DS4_TURBO3_MAX_D / amax) : 1.0f; + + // Quantize this thread's element to nearest centroid. + const float rv = rotated * k_inv; + int code = 0; + for (int j = 0; j < 7; j++) { + if (rv >= DS4_TURBO3_BOUNDS[j]) code = j + 1; + } + const float centroid = DS4_TURBO3_CODEBOOK[code]; + + // Block sum of centroid*centroid -> recon L2. + redux[tid] = centroid * centroid; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = 32; stride > 0; stride >>= 1) { + if (tid < stride) redux[tid] = redux[tid] + redux[tid + stride]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + const float recon_norm = sqrt(redux[0]); + threadgroup_barrier(mem_flags::mem_threadgroup); + + float scale = (recon_norm > 1e-10f) ? (sqrt(norm_sq) / recon_norm) + : (amax / DS4_TURBO3_MAX_D); + if (scale > DS4_FP8_E4M3_MAX_D) scale = DS4_FP8_E4M3_MAX_D; + + float dequant = centroid * scale; + if (signs_on) dequant *= DS4_TURBO_SIGNS2_64[tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + buf[tid] = dequant; + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Inverse WHT (same self-inverse butterfly). + for (int stride = 1; stride < 64; stride <<= 1) { + const uint pair = tid ^ stride; + const float a = buf[tid]; + const float b = buf[pair]; + threadgroup_barrier(mem_flags::mem_threadgroup); + buf[tid] = (tid < pair) ? (a + b) : (a - b); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + float final_v = buf[tid] * inv_sqrt_n; + if (signs_on) final_v *= DS4_TURBO_SIGNS1_64[tid]; + + if (off + tid < n_nope) xr[off + tid] = final_v; + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + +// Dequant-to-scratch kernel - sibling of CUDA's +// turbo3_kv_dequant_to_scratch_kernel. Reads `n_rows` packed turbo3 rows +// from `src` (each `src_row_bytes` long) and writes original-basis floats +// into `dst` at the natural [n_rows, head_dim] float layout. Grid: +// (n_rows, 1, 1) x tg(64, 1, 1). Thread `tid` in {0..n_groups-1} handles +// its group; thread 0 also copies the RoPE tail. +kernel void kernel_dsv4_turbo3_kv_dequant_to_scratch_f32( + device const uchar *src [[ buffer(0) ]], + device float *dst [[ buffer(1) ]], + constant uint &n_rows [[ buffer(2) ]], + constant uint &head_dim [[ buffer(3) ]], + constant uint &n_rot [[ buffer(4) ]], + constant ulong &src_row_bytes [[ buffer(5) ]], + constant int &signs_on [[ buffer(6) ]], + uint row [[ threadgroup_position_in_grid ]], + uint tid [[ thread_position_in_threadgroup ]]) { + if (row >= n_rows) return; + const uint n_nope = head_dim - n_rot; + const uint n_groups = n_nope / DS4_TURBO3_GROUP_SIZE; + device const uchar *src_row = src + row * src_row_bytes; + device float *dst_row = dst + row * head_dim; + + if (tid < n_groups) { + float buf[64]; + turbo3_dequant_group64(buf, src_row, tid, n_nope, signs_on); + device float *gd = dst_row + tid * DS4_TURBO3_GROUP_SIZE; + for (int i = 0; i < 64; i++) gd[i] = buf[i]; + } + + if (tid == 0 && n_rot > 0) { + const ulong data_bytes = (ulong)n_nope * 3u / 8u; + const ulong scale_bytes = (ulong)n_groups; + device const uchar *rope_slot = src_row + data_bytes + scale_bytes; + device uchar *dst_tail = (device uchar *)(dst_row + n_nope); + for (uint i = 0; i < (uint)n_rot * sizeof(float); i++) { + dst_tail[i] = rope_slot[i]; + } + } +} diff --git a/speed-bench/turbo3/README.md b/speed-bench/turbo3/README.md new file mode 100644 index 00000000..600195e0 --- /dev/null +++ b/speed-bench/turbo3/README.md @@ -0,0 +1,100 @@ +# turbo3 KV cache A/B bench + +A/B sweeps on the GX10 (ASUS Ascent, GB10 Blackwell chip, 128 GB unified memory) with the +IQ2XXS DeepSeek-V4-Flash checkpoint at +`/home/pidtom/models/ds4-model/DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf`. + +CSV files: + * `gb10_fp8.csv` fp8 baseline + * `gb10_turbo3.csv` turbo3 float-simulation cache + * `gb10_turbo3_packed.csv` turbo3 packed-byte cache + +Reproduce one cell: + +``` +./ds4-bench [--kv-cache turbo3] -m ds4flash.gguf \ + --prompt-file speed-bench/promessi_sposi.txt \ + --ctx-start 2048 --ctx-max 16384 --step-incr 6144 --gen-tokens 64 \ + --csv /tmp/bench_$DTYPE.csv +``` + +ds4-bench prints the side-by-side fp8 vs turbo3 KV footprint at the chosen +ctx at startup: + +``` +ds4-bench: KV footprint @ ctx=16389: + fp8 raw=365.50 MiB compressed=215.23 MiB total=580.73 MiB + turbo3 raw=76.92 MiB compressed=215.23 MiB total=292.15 MiB <-- active + raw shrink: 4.75x (turbo3 saves 288.58 MiB on the SWA ring) +``` + +## Throughput sweep + +| ctx | fp8 prefill | t3-floatsim prefill | t3-packed prefill | fp8 gen | t3-floatsim gen | t3-packed gen | +|-----|------------:|--------------------:|------------------:|--------:|----------------:|--------------:| +| 2K | 399.48 | 399.04 (-0.1%) | 388.07 (-2.8%) | 13.71 | 13.62 (-0.7%) | 11.92 (-13.1%) | +| 8K | 398.73 | 396.36 (-0.6%) | 398.33 (-0.1%) | 13.59 | 13.49 (-0.7%) | 11.83 (-12.9%) | +| 14K | 383.73 | 381.41 (-0.6%) | 384.30 (+0.1%) | 13.41 | 13.33 (-0.6%) | 11.67 (-13.0%) | +| 16K | 373.74 | 372.68 (-0.3%) | 374.92 (+0.3%) | 13.45 | 13.34 (-0.8%) | 11.68 (-13.1%) | + +Reading the table: + +- **Prefill** is unchanged across all three - within 3% of fp8 baseline. +- **Gen_tps regresses ~13% on the packed-byte path** vs fp8 baseline. The + per-attention-call dequant-to-scratch kernel launch is the cost driver + (~0.25 ms per decode-layer in the linear-attention pass). The inline- + dequant attention kernels eliminate the scratch pass and close the gap. + +## Footprint shrink (the real payoff) + +| ctx | fp8 SWA raw | turbo3 packed | shrink | absolute save | +|-----|------------:|--------------:|-------:|--------------:| +| 2K | head_dim*4 * raw_cap * 43 | head_dim*3.36 * raw_cap * 43 | 4.75x | 22 MiB at raw_cap=4096 | +| 16K | 365.50 MiB | 76.92 MiB | 4.75x | 288.58 MiB | + +The 4.75x ratio is constant across ctx (per-row stride is independent of ctx; +raw_cap grows linearly). The absolute MiB save grows linearly with ctx. + +## Quality + +`./ds4_test --logprob-vectors` (default fp8): **PASSES** -- bit-identical +to main on all 4 live vectors. + +`DS4_TEST_KV_DTYPE=turbo3 ./ds4_test --logprob-vectors`: **FAILS** on +short_code_completion step 1 with a single argmax mismatch. Expected: +the test asserts strict argmax equality at every position vs the official +continuation, and turbo3's quantisation noise shuffles the top-1 token in +~7-17% of positions while keeping the top-5 set intact (>99.6% top-5 +agreement vs fp8 baseline -- see PR description for the KLD numbers). +This is a distribution-drift trade, not a bug. Run +`ds4-bench --quality-baseline ...` for the KLD-aware comparator that +captures the actual quality envelope. + +Smoke generation: + * `"The capital of France is"` -> `Paris.` (byte-identical fp8 / turbo3-floatsim / turbo3-packed on this prompt) + * `"Write the Python code to compute the factorial of n recursively."` + -> 32-token identical Python function on this prompt. + +## Why gen_tps regressed on the dequant-to-scratch path + +The decompress-to-scratch architecture pays one dequant kernel launch per +attention call per layer. At decode T=1, raw_cap=128, 43 layers, the +dequant pass does ~5500 group dequants per token (43 layers * 128 rows +* 1 launch). Each dequant is cheap but launch overhead adds up. + +The inline-dequant attention kernels move the dequant INSIDE each +attention kernel's V-load loop, eliminating the separate scratch pass and +capturing the V-load bandwidth shrink (4.75x less memory traffic on the +attention K/V read). + +## What did NOT regress + + * **Prefill_tps within 3% of fp8** -- the dequant kernel scales O(ctx) + while prefill compute is O(ctx^2), so prefill is dominated by the + attention matmuls and the dequant pass is invisible. + * **Quality is bit-identical to the float-sim path** -- the + pack(unpack(x)) round trip is functionally lossless modulo FP8 group + scale precision, and the matched-norm scale absorbs the precision + loss. + * **Disk session payload** for turbo3 sessions stores 4.75x fewer SWA-ring + bytes via the per-dtype packed byte stride on the raw cache. diff --git a/speed-bench/turbo3/gb10_fp8.csv b/speed-bench/turbo3/gb10_fp8.csv new file mode 100644 index 00000000..3196ac5c --- /dev/null +++ b/speed-bench/turbo3/gb10_fp8.csv @@ -0,0 +1,5 @@ +ctx_tokens,prefill_tokens,prefill_tps,gen_tokens,gen_tps,kvcache_bytes +2048,2048,399.48,64,13.71,52184460 +8192,6144,398.73,64,13.59,136750476 +14336,6144,383.73,64,13.41,221316492 +16384,2048,373.74,64,13.45,249505164 diff --git a/speed-bench/turbo3/gb10_turbo3.csv b/speed-bench/turbo3/gb10_turbo3.csv new file mode 100644 index 00000000..9275321c --- /dev/null +++ b/speed-bench/turbo3/gb10_turbo3.csv @@ -0,0 +1,5 @@ +ctx_tokens,prefill_tokens,prefill_tps,gen_tokens,gen_tps,kvcache_bytes +2048,2048,399.04,64,13.62,52184460 +8192,6144,396.36,64,13.49,136750476 +14336,6144,381.41,64,13.33,221316492 +16384,2048,372.68,64,13.34,249505164 diff --git a/speed-bench/turbo3/gb10_turbo3_packed.csv b/speed-bench/turbo3/gb10_turbo3_packed.csv new file mode 100644 index 00000000..8d2e6708 --- /dev/null +++ b/speed-bench/turbo3/gb10_turbo3_packed.csv @@ -0,0 +1,5 @@ +ctx_tokens,prefill_tokens,prefill_tps,gen_tokens,gen_tps,kvcache_bytes +2048,2048,388.07,64,11.92,52184460 +8192,6144,398.33,64,11.83,136750476 +14336,6144,384.30,64,11.67,221316492 +16384,2048,374.92,64,11.68,249505164 diff --git a/tests/cuda_long_context_smoke.c b/tests/cuda_long_context_smoke.c index c9a8049d..32bcb200 100644 --- a/tests/cuda_long_context_smoke.c +++ b/tests/cuda_long_context_smoke.c @@ -118,6 +118,7 @@ static int check_decode_attention_overflow_path(void) { n_raw, 0, comp, + 0, n_comp, NULL, 0, diff --git a/tests/ds4_test.c b/tests/ds4_test.c index 312c356b..f3f9e950 100644 --- a/tests/ds4_test.c +++ b/tests/ds4_test.c @@ -44,6 +44,17 @@ static ds4_engine *test_open_engine(bool quality) { #endif .quality = quality, }; + /* DS4_TEST_KV_DTYPE lets a regression run target one specific KV dtype + * without rebuilding. Unknown values pass through (fp8 default) so a typo + * surfaces as a single fprintf rather than a failing assert. */ + const char *kv_env = getenv("DS4_TEST_KV_DTYPE"); + if (kv_env && kv_env[0]) { + if (!ds4_kv_dtype_from_name(kv_env, &opt.kv_dtype)) { + fprintf(stderr, + "ds4_test: ignoring DS4_TEST_KV_DTYPE='%s' (expected fp8 or turbo3)\n", + kv_env); + } + } TEST_ASSERT(ds4_engine_open(&engine, &opt) == 0); return engine; }