Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 321 additions & 24 deletions common/common.cpp

Large diffs are not rendered by default.

23 changes: 17 additions & 6 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ int32_t get_num_physical_cores();
// CLI argument parsing
//

struct cpu_params {
int32_t n_threads = -1;
bool cpumask[GGML_N_CORES_MAX] = {false}; // CPU affinity mask.
bool mask_valid = false; // Default: any CPU
int32_t priority = 0; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
bool strict_cpu = false; // Use strict CPU placement
bool poll = false; // Use polling (busywait) to wait for work
};

struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed

int32_t n_threads = get_math_cpu_count();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
Expand Down Expand Up @@ -87,6 +92,11 @@ struct gpt_params {
ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
struct cpu_params draft_cpuparams;
struct cpu_params draft_cpuparams_batch;

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
Expand Down Expand Up @@ -214,8 +224,9 @@ std::string sampler_type_to_name_string(llama_sampler_type sampler_type);
// TODO: avoid tuplue, use struct
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);

struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params (const gpt_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);

struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params);
Expand Down
2 changes: 1 addition & 1 deletion examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ constexpr float rms_norm_eps = 5e-6f;
#endif

static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);

if (plan.work_size > 0) {
buf.resize(plan.work_size);
Expand Down
5 changes: 3 additions & 2 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ int main(int argc, char ** argv) {
ctx_params.n_ubatch = n_ubatch;
ctx_params.flash_attn = flash_attn;

ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.cpuparams.n_threads;
ctx_params.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;

// ensure enough sequences are available
ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());
Expand Down
5 changes: 3 additions & 2 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_seq_max = n_parallel;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.cpuparams.n_threads;
ctx_params.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

Expand Down
2 changes: 1 addition & 1 deletion examples/benchmark/benchmark-matmult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#endif

static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);

if (plan.work_size > 0) {
buf.resize(plan.work_size);
Expand Down
2 changes: 1 addition & 1 deletion examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int

ggml_gallocr_alloc_graph(alloc, gf);

struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads, nullptr);
static std::vector<uint8_t> data_work;
data_work.resize(cplan.work_size);
cplan.work_data = data_work.data();
Expand Down
2 changes: 1 addition & 1 deletion examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ int main(int argc, char ** argv) {
opt_cb_data.millis_per_iter = 0.0;

// measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads, nullptr).work_size + GGML_OBJECT_SIZE;
printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));

// context for work buffer
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
#endif

ggml_backend_graph_compute(ctx->backend, gf);
ggml_backend_graph_compute(ctx->backend, gf, NULL);

// the last node is the embedding tensor
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1];
Expand Down
4 changes: 2 additions & 2 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
if (!params->image.empty()) {
LOG_TEE("using base64 encoded image instead of command line image path\n");
}
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt);
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
if (!embed) {
LOG_TEE("%s: can't load image from prompt\n", __func__);
return NULL;
}
params->prompt = remove_image_from_prompt(prompt);
} else {
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, fname.c_str());
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
if (!embed) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
return NULL;
Expand Down
29 changes: 29 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,38 @@ int main(int argc, char ** argv) {
ctx_guidance = llama_new_context_with_model(model, lparams);
}

LOG("%s: llama threadpool init = n_threads = %d\n",
__func__,
(int32_t) params.cpuparams.n_threads
);
struct ggml_threadpool_params tpp_batch =
ggml_threadpool_params_from_cpu_params(params.cpuparams_batch);
struct ggml_threadpool_params tpp =
ggml_threadpool_params_from_cpu_params(params.cpuparams);

struct ggml_compute_threadpool * threadpool_batch = ggml_create_threadpool(&tpp_batch);
if (!threadpool_batch) {
LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads);
exit(1);
}
struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}

if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
return 1;
}

llama_attach_batch_threadpool(ctx, threadpool_batch);
llama_attach_threadpool(ctx, threadpool);
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
llama_attach_threadpool(ctx_guidance, threadpool);
}

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
Expand Down Expand Up @@ -955,6 +982,8 @@ int main(int argc, char ** argv) {
llama_sampling_free(ctx_sampling);
llama_backend_free();

ggml_release_threadpool(threadpool);

#ifndef LOG_DISABLE_LOGS
LOG_TEE("Log end\n");
#endif // LOG_DISABLE_LOGS
Expand Down
5 changes: 3 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ int main(int argc, char ** argv) {
ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.cpuparams.n_threads;
ctx_params.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;

GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");

Expand Down
10 changes: 5 additions & 5 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.cpuparams.n_threads);
printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" --threads-http N number of threads in the http server pool to process requests (default: max(hardware concurrency - 1, --parallel N + 2))\n");
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
Expand Down Expand Up @@ -2612,7 +2612,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
params.n_threads = std::stoi(argv[i]);
params.cpuparams.n_threads = std::stoi(argv[i]);
} else if (arg == "--grp-attn-n" || arg == "-gan") {
if (++i >= argc) {
invalid_param = true;
Expand All @@ -2632,7 +2632,7 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
params.cpuparams_batch.n_threads = std::stoi(argv[i]);
} else if (arg == "--threads-http") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -2943,8 +2943,8 @@ int main(int argc, char ** argv) {
});

LOG_INFO("system info", {
{"n_threads", params.n_threads},
{"n_threads_batch", params.n_threads_batch},
{"n_threads", params.cpuparams.n_threads},
{"n_threads_batch", params.cpuparams_batch.n_threads},
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});
Expand Down
5 changes: 3 additions & 2 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ int main(int argc, char ** argv) {

ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
ctx_params.n_threads = params.cpuparams.n_threads;
ctx_params.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

Expand Down
39 changes: 36 additions & 3 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ struct seq_draft {
struct llama_sampling_context * ctx_sampling;
};

static void switch_active_threadpool(
ggml_compute_threadpool_t cur,
ggml_compute_threadpool_t nxt
) {
ggml_pause_threadpool(cur);
ggml_resume_threadpool(nxt);
}

int main(int argc, char ** argv) {
gpt_params params;

Expand Down Expand Up @@ -67,13 +75,19 @@ int main(int argc, char ** argv) {
// load the target model
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);

ggml_threadpool_params tpp_tgt = ggml_threadpool_params_from_cpu_params(params.cpuparams);
ggml_compute_threadpool * threadpool_tgt = ggml_create_threadpool(&tpp_tgt);
if (!threadpool_tgt) {
LOG_TEE("%s: target threadpool create failed : n_threads %d\n", __func__, tpp_tgt.n_threads);
exit(1);
}

// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
if (params.n_threads_draft > 0) {
params.n_threads = params.n_threads_draft;
if (params.draft_cpuparams.n_threads > 0) {
params.cpuparams = params.draft_cpuparams;
}
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);

const bool vocab_type_tgt = llama_vocab_type(model_tgt);
Expand All @@ -98,6 +112,17 @@ int main(int argc, char ** argv) {
return 1;
}

ggml_threadpool_params tpp_dft = ggml_threadpool_params_from_cpu_params(params.draft_cpuparams);
ggml_compute_threadpool * threadpool_dft = ggml_create_threadpool(&tpp_dft);
if (!threadpool_dft) {
LOG_TEE("%s: draft threadpool create failed : n_threads %d\n", __func__, tpp_dft.n_threads);
exit(1);
}

llama_attach_threadpool(ctx_tgt, threadpool_tgt);
llama_attach_threadpool(ctx_dft, threadpool_dft);
ggml_pause_threadpool(threadpool_dft);

{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
Expand Down Expand Up @@ -153,6 +178,7 @@ int main(int argc, char ** argv) {
// eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
switch_active_threadpool(threadpool_tgt, threadpool_dft);
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));

const auto t_enc_end = ggml_time_us();
Expand Down Expand Up @@ -542,13 +568,15 @@ int main(int argc, char ** argv) {

// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);

++n_past_cur;
++n_drafted;

if (batch_tgt.n_tokens > n_draft) {
break;
}
}
switch_active_threadpool(threadpool_dft, threadpool_tgt);

// evaluate the target model on the drafted tokens
{
Expand All @@ -559,6 +587,8 @@ int main(int argc, char ** argv) {

// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
switch_active_threadpool(threadpool_tgt, threadpool_dft);

++n_past_tgt;
}

Expand Down Expand Up @@ -608,6 +638,9 @@ int main(int argc, char ** argv) {

llama_backend_free();

ggml_release_threadpool(threadpool_tgt);
ggml_release_threadpool(threadpool_dft);

fprintf(stderr, "\n\n");

return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ int main(int argc, char ** argv) {
opt_cb_data.millis_per_iter = 0.0;

// measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads, nullptr).work_size + GGML_OBJECT_SIZE;
printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));

// context for work buffer
Expand Down
5 changes: 3 additions & 2 deletions ggml-alloc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ extern "C" {
#endif

typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend * ggml_backend_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
typedef struct ggml_backend * ggml_backend_t;
typedef struct ggml_compute_threadpool * ggml_compute_threadpool_t;

// Tensor allocator
struct ggml_tallocr {
Expand Down
5 changes: 3 additions & 2 deletions ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,14 @@ extern "C" {
void (*GGML_CALL synchronize)(ggml_backend_t backend);

// compute graph with a plan (not used currently)
ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph, ggml_compute_threadpool_t threadpool);
void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);

// compute graph with a plan
enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

// compute graph without a plan (async)
enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph, ggml_compute_threadpool_t threadpool);

// check if the backend supports an operation
bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
Expand Down
Loading