Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_gated_delta_net_f32;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
Expand Down Expand Up @@ -1431,6 +1432,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_gated_delta_net_push_constants {
uint32_t H;
uint32_t S_v;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t s_off;
uint32_t sq1, sq2, sq3;
uint32_t sk1, sk2, sk3;
uint32_t sv1, sv2, sv3;
uint32_t sg1, sg2;
uint32_t rq1, rq3, rk1, rk3;
};

struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
Expand Down Expand Up @@ -4408,6 +4422,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);

ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32, "gated_delta_net_f32", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {}, 1);

if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
Expand Down Expand Up @@ -9245,6 +9261,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_GATED_DELTA_NET:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_gated_delta_net_f32;
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
Expand Down Expand Up @@ -10057,6 +10078,76 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}

static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_k = dst->src[1];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_g = dst->src[3];
const ggml_tensor * src_state = dst->src[5];

const uint32_t S_v = (uint32_t)src_v->ne[0];
const uint32_t H = (uint32_t)src_v->ne[1];
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
const uint32_t n_seqs = (uint32_t)src_v->ne[3];

const uint32_t s_off = S_v * H * n_tokens * n_seqs;

for (int i = 0; i < 6; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}
GGML_ASSERT(dst->buffer != nullptr);

vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);

vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
for (int i = 0; i < 6; i++) {
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
}

// Strides in elements (bytes / sizeof(float))
const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
const uint32_t sk1 = (uint32_t)(src_k->nb[1] / sizeof(float));
const uint32_t sk2 = (uint32_t)(src_k->nb[2] / sizeof(float));
const uint32_t sk3 = (uint32_t)(src_k->nb[3] / sizeof(float));
const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
// g is contiguous: ne[0]=H, ne[1]=n_tokens
const uint32_t sg1 = (uint32_t)src_g->ne[0];
const uint32_t sg2 = (uint32_t)(src_g->ne[0] * src_g->ne[1]);

// GQA ratios
const uint32_t rq1 = (uint32_t)(src_v->ne[1] / src_q->ne[1]);
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
const uint32_t rk1 = (uint32_t)(src_v->ne[1] / src_k->ne[1]);
const uint32_t rk3 = (uint32_t)(src_v->ne[3] / src_k->ne[3]);

const vk_op_gated_delta_net_push_constants pc = {
H, S_v, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sk1, sk2, sk3,
sv1, sv2, sv3,
sg1, sg2,
rq1, rq3, rk1, rk3
};

std::array<uint32_t, 3> elements = {
H * n_seqs,
1,
1
};

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
}

static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const size_t seq_length = dst->src[0]->ne[2];
const size_t n_embed = dst->ne[0];
Expand Down Expand Up @@ -12796,6 +12887,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;

case GGML_OP_GATED_DELTA_NET:
ggml_vk_gated_delta_net(ctx, compute_ctx, node);

break;

case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node);

Expand Down Expand Up @@ -15034,6 +15130,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true; // all inputs are contiguous, see ggml.c
case GGML_OP_GATED_DELTA_NET:
{
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
}
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
Expand Down Expand Up @@ -15855,6 +15960,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
Expand Down
126 changes: 126 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#version 450

#extension GL_EXT_control_flow_attributes : require

#ifndef BLOCK_SIZE
#define BLOCK_SIZE 128
#endif

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
uint H;
uint S_v;
uint n_tokens;
uint n_seqs;
uint s_off;
// q strides (in elements)
uint sq1, sq2, sq3;
// k strides (in elements)
uint sk1, sk2, sk3;
// v strides (in elements)
uint sv1, sv2, sv3;
// g strides: sg1 = ne[0] (H), sg2 = ne[0]*ne[1] (H*n_tokens)
uint sg1, sg2;
// GQA ratios
uint rq1, rq3, rk1, rk3;
};

layout(binding = 0) readonly buffer QBuf { A_TYPE q_in[]; };
layout(binding = 1) readonly buffer KBuf { A_TYPE k_in[]; };
layout(binding = 2) readonly buffer VBuf { A_TYPE v_in[]; };
layout(binding = 3) readonly buffer GBuf { A_TYPE g_in[]; };
layout(binding = 4) readonly buffer BetaBuf { A_TYPE beta_in[]; };
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[];};
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };

shared A_TYPE s_k[BLOCK_SIZE];
shared A_TYPE s_q[BLOCK_SIZE];
shared A_TYPE s_v[BLOCK_SIZE];

void main() {
const uint head_id = gl_WorkGroupID.x % H;
const uint seq_id = gl_WorkGroupID.x / H;
const uint tid = gl_LocalInvocationID.x;

if (seq_id >= n_seqs || head_id >= H) {
return;
}

const bool is_active = (tid < S_v);

const uint state_size = S_v * S_v;
const uint state_base = (seq_id * H + head_id) * state_size;

// GQA: map v-head to q/k-head
const uint iq1 = head_id / rq1;
const uint iq3 = seq_id / rq3;
const uint ik1 = head_id / rk1;
const uint ik3 = seq_id / rk3;

// Load initial state (column-major in ggml)
A_TYPE state[BLOCK_SIZE];
if (is_active) {
for (uint j = 0; j < S_v; j++) {
state[j] = state_in[state_base + j * S_v + tid];
}
}

for (uint t = 0; t < n_tokens; t++) {
// All threads participate in shared memory load + barrier
if (is_active) {
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = ik3 * sk3 + t * sk2 + ik1 * sk1;
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;

s_q[tid] = q_in[q_off + tid];
s_k[tid] = k_in[k_off + tid];
s_v[tid] = v_in[v_off + tid];
}
barrier();

if (is_active) {
// g and beta: contiguous, layout [H, n_tokens, n_seqs]
const uint gb_off = seq_id * sg2 + t * sg1 + head_id;
const A_TYPE g_val = exp(g_in[gb_off]);
const A_TYPE beta_val = 1.0 / (1.0 + exp(-beta_in[gb_off]));

// Decay state
for (uint j = 0; j < S_v; j++) {
state[j] *= g_val;
}

// kv = dot(state_col, k)
A_TYPE kv = 0.0;
for (uint j = 0; j < S_v; j++) {
kv += state[j] * s_k[j];
}

// delta = (v[tid] - kv) * beta
A_TYPE delta = (s_v[tid] - kv) * beta_val;

// Rank-1 update: state[j] += k[j] * delta
for (uint j = 0; j < S_v; j++) {
state[j] += s_k[j] * delta;
}

// Output: out[tid] = dot(state_col, q)
A_TYPE out_val = 0.0;
for (uint j = 0; j < S_v; j++) {
out_val += state[j] * s_q[j];
}

// Attention output layout: [n_seqs, n_tokens, H, S_v]
dst[seq_id * n_tokens * H * S_v + t * H * S_v + head_id * S_v + tid] = out_val;
}

barrier();
}

// Write final state
if (is_active) {
for (uint j = 0; j < S_v; j++) {
dst[s_off + state_base + j * S_v + tid] = state[j];
}
}
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,8 @@ void process_shaders() {

string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

Expand Down