Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions examples/omni-vlm/omni-vlm-wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ struct omnivlm_context {
struct llama_model * model = NULL;
};

void* internal_chars = nullptr;

static struct gpt_params params;
static struct llama_model* model;
static struct omnivlm_context* ctx_omnivlm;
Expand Down Expand Up @@ -128,7 +130,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
return ret.c_str();
}

static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
int n_past = 0;

const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
Expand Down Expand Up @@ -172,11 +174,11 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_sampling, ctx_omnivlm->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "<|im_end|>") == 0) break;
if (strcmp(tmp, "</s>") == 0) break;
// if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);
// printf("%s", tmp);
response += tmp;
// if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
// if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
// if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
Expand All @@ -186,6 +188,13 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima

llama_sampling_free(ctx_sampling);
printf("\n");

// const char* ret_char_ptr = (const char*)(malloc(sizeof(char)*response.size()));
if(internal_chars != nullptr) { free(internal_chars); }
internal_chars = malloc(sizeof(char)*(response.size()+1));
strncpy((char*)(internal_chars), response.c_str(), response.size());
((char*)(internal_chars))[response.size()] = '\0';
return (const char*)(internal_chars);
}

static void omnivlm_free(struct omnivlm_context * ctx_omnivlm) {
Expand Down Expand Up @@ -225,7 +234,7 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path)
ctx_omnivlm = omnivlm_init_context(&params, model);
}

void omnivlm_inference(const char *prompt, const char *imag_path) {
const char* omnivlm_inference(const char *prompt, const char *imag_path) {
std::string image = imag_path;
params.prompt = prompt;
auto * image_embed = load_image(ctx_omnivlm, &params, image);
Expand All @@ -234,13 +243,16 @@ void omnivlm_inference(const char *prompt, const char *imag_path) {
throw std::runtime_error("failed to load image " + image);
}
// process the prompt
process_prompt(ctx_omnivlm, image_embed, &params, params.prompt);
const char* ret_chars = process_prompt(ctx_omnivlm, image_embed, &params, params.prompt);

// llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
omnivlm_image_embed_free(image_embed);

return ret_chars;
}

void omnivlm_free() {
if(internal_chars != nullptr) { free(internal_chars); }
ctx_omnivlm->model = NULL;
omnivlm_free(ctx_omnivlm);
llama_free_model(model);
Expand Down
2 changes: 1 addition & 1 deletion examples/omni-vlm/omni-vlm-wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ extern "C" {

OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path);

OMNIVLM_API void omnivlm_inference(const char* prompt, const char* imag_path);
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);

OMNIVLM_API void omnivlm_free();

Expand Down
2 changes: 1 addition & 1 deletion examples/omni-vlm/omni_vlm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):


_lib.omnivlm_inference.argtypes = [omni_char_p, omni_char_p]
_lib.omnivlm_inference.restype = None
_lib.omnivlm_inference.restype = omni_char_p


def omnivlm_free():
Expand Down
6 changes: 4 additions & 2 deletions examples/omni-vlm/omni_vlm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, llm_model_path: str, mmproj_model_path: str):
def inference(self, prompt: str, image_path: str):
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
omni_vlm_cpp.omnivlm_inference(prompt, image_path)
return omni_vlm_cpp.omnivlm_inference(prompt, image_path)

def __del__(self):
omni_vlm_cpp.omnivlm_free()
Expand Down Expand Up @@ -52,4 +52,6 @@ def __del__(self):
while not os.path.exists(image_path):
print("ERROR: can not find image in your input path, please check and input agian.")
image_path = input()
omni_vlm_obj.inference(prompt, image_path)
response = omni_vlm_obj.inference(prompt, image_path)
print("\tresponse:")
print(response.decode('utf-8'))