Skip to content

Commit

Permalink
Fix memory bug in grammar parser
Browse files Browse the repository at this point in the history
The llama.cpp grammar parser had a bug where forgetting to add a closing
quotation mark to a string could cause an overrun read. Anyone running a
server on a public endpoint is advised to upgrade. To reproduce this bug

    ./llamafile -m foo.gguf -p bar --grammar 'root::="'
  • Loading branch information
jart committed May 8, 2024
1 parent 68aef2b commit 22aba95
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 5 deletions.
8 changes: 3 additions & 5 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1393,14 +1393,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}

if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg);
}
}

if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument: " + arg);
if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument: " + arg);
}
}

if (params.prompt_cache_all &&
Expand Down
12 changes: 12 additions & 0 deletions llama.cpp/grammar-parser.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi

#include "grammar-parser.h"
#include <cstdint>
#include <cwchar>
Expand Down Expand Up @@ -142,6 +145,9 @@ namespace grammar_parser {
pos++;
last_sym_start = out_elements.size();
while (*pos != '"') {
if (!*pos) { // [jart] don't sync until upstream fixes bug
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
Expand All @@ -156,6 +162,9 @@ namespace grammar_parser {
}
last_sym_start = out_elements.size();
while (*pos != ']') {
if (!*pos) { // [jart] don't sync until upstream fixes bug
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < out_elements.size()
Expand All @@ -164,6 +173,9 @@ namespace grammar_parser {

out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (pos[1]) { // [jart] don't sync until upstream fixes bug
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
Expand Down
5 changes: 5 additions & 0 deletions llama.cpp/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
LOG_TEE("\n");

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
if (!ctx_sampling) { // [jart] fixes crash
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}

std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
Expand Down
4 changes: 4 additions & 0 deletions llama.cpp/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ int main(int argc, char ** argv) {
}

struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
if (!ctx_sampling) { // [jart] fixes crash
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
return 1;
}

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down
5 changes: 5 additions & 0 deletions llama.cpp/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,11 @@ struct llama_server_context
llama_sampling_free(slot->ctx_sampling);
}
slot->ctx_sampling = llama_sampling_init(slot->sparams);
if (!slot->ctx_sampling) { // [jart] fixes crash
LOG_TEE("%s: failed to initialize sampling subsystem\n", __func__);
return false;
}

llama_set_rng_seed(ctx, slot->params.seed);
slot->command = LOAD_PROMPT;

Expand Down

0 comments on commit 22aba95

Please sign in to comment.