Skip to content

Commit

Permalink
test-backend-ops : add moe test
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Dec 10, 2023
1 parent e640cbe commit cefebb3
Showing 1 changed file with 116 additions and 12 deletions.
128 changes: 116 additions & 12 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
t.join();
}

if (tensor->type == GGML_TYPE_F32) {
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
Expand Down Expand Up @@ -233,14 +233,18 @@ static bool ggml_is_view_op(enum ggml_op op) {
struct test_case {
virtual ~test_case() {}

virtual std::string op_desc(ggml_tensor * t) {
return ggml_op_desc(t);
}

virtual std::string vars() {
return "";
}

virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;

virtual double max_nmse_err() {
return 1e-6;
return 1e-7;
}

virtual void initialize_tensors(ggml_context * ctx) {
Expand Down Expand Up @@ -270,13 +274,13 @@ struct test_case {

ggml_tensor * out = build_graph(ctx);

if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
//printf(" %s: skipping\n", ggml_op_desc(out));
if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}

printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);

// check if backends support op
Expand Down Expand Up @@ -317,29 +321,40 @@ struct test_case {
for (size_t i = 0; i < f1.size(); i++) {
// check for nans
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
printf("NaN at index %zu ", i);
printf("[%s] NaN at index %zu ", ggml_op_desc(t1), i);
ud->ok = false;
return true;
}
// check for infs: both must be inf of the same sign, or both must be finite
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
} else {
printf("inf mismatch: %f %f ", f1[i], f2[i]);
printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
}
}

//if (t1->op == GGML_OP_SOFT_MAX) {
// printf("[%s] ", ggml_op_desc(t1));
// for (int i = 0; i < f1.size(); i++) {
// printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
// }
// printf("\n");
//}
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
printf("NMSE = %f ", err);
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
//for (int i = 0; i < f1.size(); i++) {
// printf("(%f, %f) ", f1[i], f2[i]);
//}
//printf("\n");
ud->ok = false;
}
return true;
Expand Down Expand Up @@ -374,13 +389,13 @@ struct test_case {

ggml_tensor * out = build_graph(ctx);

if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
//printf(" %s: skipping\n", ggml_op_desc(out));
if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}

int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);

// check if backends support op
Expand Down Expand Up @@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case {
}
};

struct test_moe : public test_case {
const int n_experts = 8;
const int n_experts_per_tok = 2;
const int n_tokens = 1;
const int n_embd = 4096;
const int n_ff = 14336;

std::string op_desc(ggml_tensor * t) override {
return "MOE";
GGML_UNUSED(t);
}

std::string vars() override {
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
}

test_moe() {
}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);

std::vector<ggml_tensor *> ffn_up_exp(n_experts);
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
std::vector<ggml_tensor *> ffn_down_exp(n_experts);

for (int i = 0; i < n_experts; ++i) {
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
}

ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);

ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_tokens, num_experts]

// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]

ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n",
weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3],
weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]);


weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]

ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);

weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]

// compute expert outputs
ggml_tensor * moe_out = nullptr;

for (int i = 0; i < n_experts_per_tok; ++i) {
ggml_tensor * cur_expert;

ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);

ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);

cur_gate = ggml_silu(ctx, cur_gate);

cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd]

cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]

cur_expert = ggml_mul(ctx, cur_expert,
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));

if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}

cur = moe_out;

return cur;
}
};

enum test_mode {
MODE_TEST,
MODE_PERF,
Expand All @@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_Q6_K
};

test_cases.emplace_back(new test_moe());

// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
}

test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
for (bool v : {false, true}) {
Expand Down Expand Up @@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_concat());

for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
}

Expand Down

0 comments on commit cefebb3

Please sign in to comment.