From 73d03160096aab2576461dbc15f9b7b570df5228 Mon Sep 17 00:00:00 2001 From: Shawn Graves Date: Mon, 16 Feb 2026 18:39:47 -0500 Subject: [PATCH] vulkan: add Vulkan backend for GGML_OP_GATED_DELTA_NET Adds compute shader and full ggml-vulkan.cpp wiring for the gated delta net recurrence op. Supports multi-token sequences, GQA stride-based broadcast, and inline sigmoid/exp transforms. Tested 7/7 on NVIDIA RTX 3080 Ti via test-backend-ops: - head_size 32/64/128, single/multi token, single/multi seq, GQA, permuted --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 108 +++++++++++++++ .../vulkan-shaders/gated_delta_net.comp | 126 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + 3 files changed, 236 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72097ffd0ff..df293f17c9d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -807,6 +807,7 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_gated_delta_net_f32; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -1431,6 +1432,19 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_gated_delta_net_push_constants { + uint32_t H; + uint32_t S_v; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t s_off; + uint32_t sq1, sq2, sq3; + uint32_t sk1, sk2, sk3; + uint32_t sv1, sv2, sv3; + uint32_t sg1, sg2; + uint32_t rq1, rq3, rk1, rk3; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -4408,6 +4422,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32, "gated_delta_net_f32", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {}, 1); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -9245,6 +9261,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_GATED_DELTA_NET: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_gated_delta_net_f32; + } + return nullptr; case GGML_OP_SSM_SCAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { const uint32_t d_state = src0->ne[0]; @@ -10057,6 +10078,76 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_state = dst->src[5]; + + const uint32_t S_v = (uint32_t)src_v->ne[0]; + const uint32_t H = (uint32_t)src_v->ne[1]; + const uint32_t n_tokens = (uint32_t)src_v->ne[2]; + const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + + const uint32_t s_off = S_v * H * n_tokens * n_seqs; + + for (int i = 0; i < 6; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < 6; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + // Strides in elements (bytes / sizeof(float)) + const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float)); + const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float)); + const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float)); + const uint32_t sk1 = (uint32_t)(src_k->nb[1] / sizeof(float)); + const uint32_t sk2 = (uint32_t)(src_k->nb[2] / sizeof(float)); + const uint32_t sk3 = (uint32_t)(src_k->nb[3] / sizeof(float)); + const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float)); + const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float)); + const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float)); + // g is contiguous: ne[0]=H, ne[1]=n_tokens + const uint32_t sg1 = (uint32_t)src_g->ne[0]; + const uint32_t sg2 = (uint32_t)(src_g->ne[0] * src_g->ne[1]); + + // GQA ratios + const uint32_t rq1 = (uint32_t)(src_v->ne[1] / src_q->ne[1]); + const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); + const uint32_t rk1 = (uint32_t)(src_v->ne[1] / src_k->ne[1]); + const uint32_t rk3 = (uint32_t)(src_v->ne[3] / src_k->ne[3]); + + const vk_op_gated_delta_net_push_constants pc = { + H, S_v, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sk1, sk2, sk3, + sv1, sv2, sv3, + sg1, sg2, + rq1, rq3, rk1, rk3 + }; + + std::array elements = { + H * n_seqs, + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); +} + static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const size_t seq_length = dst->src[0]->ne[2]; const size_t n_embed = dst->ne[0]; @@ -12796,6 +12887,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_GATED_DELTA_NET: + ggml_vk_gated_delta_net(ctx, compute_ctx, node); + + break; + case GGML_OP_SSM_SCAN: ggml_vk_ssm_scan(ctx, compute_ctx, node); @@ -15034,6 +15130,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; // all inputs are contiguous, see ggml.c + case GGML_OP_GATED_DELTA_NET: + { + for (int i = 0; i < 6; i++) { + if (op->src[i] && ggml_is_quantized(op->src[i]->type)) { + return false; + } + } + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + } case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -15855,6 +15960,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[2]); } else if (tensor->op == GGML_OP_ADD_ID) { tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { + tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_SSM_SCAN) { tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5], src_clone[6]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp new file mode 100644 index 00000000000..73a853e96fc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -0,0 +1,126 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#ifndef BLOCK_SIZE +#define BLOCK_SIZE 128 +#endif + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint S_v; + uint n_tokens; + uint n_seqs; + uint s_off; + // q strides (in elements) + uint sq1, sq2, sq3; + // k strides (in elements) + uint sk1, sk2, sk3; + // v strides (in elements) + uint sv1, sv2, sv3; + // g strides: sg1 = ne[0] (H), sg2 = ne[0]*ne[1] (H*n_tokens) + uint sg1, sg2; + // GQA ratios + uint rq1, rq3, rk1, rk3; +}; + +layout(binding = 0) readonly buffer QBuf { A_TYPE q_in[]; }; +layout(binding = 1) readonly buffer KBuf { A_TYPE k_in[]; }; +layout(binding = 2) readonly buffer VBuf { A_TYPE v_in[]; }; +layout(binding = 3) readonly buffer GBuf { A_TYPE g_in[]; }; +layout(binding = 4) readonly buffer BetaBuf { A_TYPE beta_in[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[];}; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE s_k[BLOCK_SIZE]; +shared A_TYPE s_q[BLOCK_SIZE]; +shared A_TYPE s_v[BLOCK_SIZE]; + +void main() { + const uint head_id = gl_WorkGroupID.x % H; + const uint seq_id = gl_WorkGroupID.x / H; + const uint tid = gl_LocalInvocationID.x; + + if (seq_id >= n_seqs || head_id >= H) { + return; + } + + const bool is_active = (tid < S_v); + + const uint state_size = S_v * S_v; + const uint state_base = (seq_id * H + head_id) * state_size; + + // GQA: map v-head to q/k-head + const uint iq1 = head_id / rq1; + const uint iq3 = seq_id / rq3; + const uint ik1 = head_id / rk1; + const uint ik3 = seq_id / rk3; + + // Load initial state (column-major in ggml) + A_TYPE state[BLOCK_SIZE]; + if (is_active) { + for (uint j = 0; j < S_v; j++) { + state[j] = state_in[state_base + j * S_v + tid]; + } + } + + for (uint t = 0; t < n_tokens; t++) { + // All threads participate in shared memory load + barrier + if (is_active) { + const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; + const uint k_off = ik3 * sk3 + t * sk2 + ik1 * sk1; + const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; + + s_q[tid] = q_in[q_off + tid]; + s_k[tid] = k_in[k_off + tid]; + s_v[tid] = v_in[v_off + tid]; + } + barrier(); + + if (is_active) { + // g and beta: contiguous, layout [H, n_tokens, n_seqs] + const uint gb_off = seq_id * sg2 + t * sg1 + head_id; + const A_TYPE g_val = exp(g_in[gb_off]); + const A_TYPE beta_val = 1.0 / (1.0 + exp(-beta_in[gb_off])); + + // Decay state + for (uint j = 0; j < S_v; j++) { + state[j] *= g_val; + } + + // kv = dot(state_col, k) + A_TYPE kv = 0.0; + for (uint j = 0; j < S_v; j++) { + kv += state[j] * s_k[j]; + } + + // delta = (v[tid] - kv) * beta + A_TYPE delta = (s_v[tid] - kv) * beta_val; + + // Rank-1 update: state[j] += k[j] * delta + for (uint j = 0; j < S_v; j++) { + state[j] += s_k[j] * delta; + } + + // Output: out[tid] = dot(state_col, q) + A_TYPE out_val = 0.0; + for (uint j = 0; j < S_v; j++) { + out_val += state[j] * s_q[j]; + } + + // Attention output layout: [n_seqs, n_tokens, H, S_v] + dst[seq_id * n_tokens * H * S_v + t * H * S_v + head_id * S_v + tid] = out_val; + } + + barrier(); + } + + // Write final state + if (is_active) { + for (uint j = 0; j < S_v; j++) { + dst[s_off + state_base + j * S_v + tid] = state[j]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 42ebc21e2a6..2155b5c60db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -971,6 +971,8 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));