Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion examples/omni-vlm/omni-vlm-wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
#include <vector>
#include <string>
#include <iostream>
#include <memory>

#include "omni-vlm-wrapper.h"


struct omnivlm_context {
struct clip_ctx * ctx_clip = NULL;
struct llama_context * ctx_llama = NULL;
Expand All @@ -30,6 +30,53 @@ void* internal_chars = nullptr;
static struct common_params params;
static struct llama_model* model;
static struct omnivlm_context* ctx_omnivlm;
static std::unique_ptr<struct omni_streaming_sample> g_oss = nullptr;

static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past);
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm);

struct omni_streaming_sample {
struct common_sampler * ctx_sampling_;
std::string image_;
std::string ret_str_;
int32_t n_past_;
int32_t dec_cnt_;

omni_streaming_sample() = delete;
omni_streaming_sample(const std::string& image)
:image_(image) {
n_past_ = 0;
dec_cnt_ = 0;
params.sparams.top_k = 1;
params.sparams.top_p = 1.0f;
ctx_sampling_ = common_sampler_init(model, params.sparams);
}

int32_t sample() {
const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1);
common_sampler_accept(ctx_sampling_, id, true);
if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) {
ret_str_ = "</s>";
} else {
ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id);
}
eval_id(ctx_omnivlm->ctx_llama, id, &n_past_);

++dec_cnt_;
return id;
}

~omni_streaming_sample() {
common_sampler_free(ctx_sampling_);
if(ctx_omnivlm != nullptr) {
ctx_omnivlm->model = nullptr;
omnivlm_free(ctx_omnivlm);
free(ctx_omnivlm);
ctx_omnivlm = nullptr;
}
}
};


static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {

Expand Down Expand Up @@ -286,3 +333,81 @@ void omnivlm_free() {
}
llama_free_model(model);
}


struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) {
if (g_oss) {
g_oss.reset();
}
g_oss = std::make_unique<omni_streaming_sample>(std::string(imag_path));

ctx_omnivlm = omnivlm_init_context(&params, model);

params.prompt = prompt;

if (params.omni_vlm_version == "vlm-81-ocr") {
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>";
} else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") {
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>";
} else {
LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str());
throw std::runtime_error("You set wrong vlm_version info strings.");
}

return g_oss.get();
}

int32_t sample(omni_streaming_sample* oss) {
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
int32_t ret_id;
if(oss->n_past_ == 0) {
auto * image_embed = load_image(ctx_omnivlm, &params, oss->image_);
if (!image_embed) {
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str());
throw std::runtime_error("failed to load image " + oss->image_);
}

size_t image_pos = params.prompt.find("<|image_pad|>");
std::string system_prompt, user_prompt;

system_prompt = params.prompt.substr(0, image_pos);
user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length());
if (params.verbose_prompt) {
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
}
}
if (params.verbose_prompt) {
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
}
}

eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true);
omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_));
eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false);

omnivlm_image_embed_free(image_embed);

ret_id = oss->sample();
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
ret_id = -1;
}
} else {
if(oss->dec_cnt_ == max_tgt_len) {
ret_id = -2;
} else {
ret_id = oss->sample();
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
ret_id = -1;
}
}
}
return ret_id;
}

const char* get_str(omni_streaming_sample* oss) {
return oss->ret_str_.c_str();
}
12 changes: 10 additions & 2 deletions examples/omni-vlm/omni-vlm-wrapper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#ifndef OMNIVLMWRAPPER_H
#define OMNIVLMWRAPPER_H
#include <stdint.h>

#ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
Expand All @@ -20,14 +20,22 @@
extern "C" {
#endif

struct omni_streaming_sample;

OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version);

OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);

OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path);

OMNIVLM_API int32_t sample(struct omni_streaming_sample *);

OMNIVLM_API const char* get_str(struct omni_streaming_sample *);

OMNIVLM_API void omnivlm_free();

#ifdef __cplusplus
}
#endif

#endif
#endif