Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions examples/nexa-omni-audio/omni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
// Constants
//

void* internal_chars = nullptr;

static const char *AUDIO_TOKEN = "<|AUDIO|>";

//
Expand Down Expand Up @@ -570,7 +572,7 @@ static omni_params get_omni_params_from_context_params(omni_context_params &para
all_params.gpt.n_gpu_layers = params.n_gpu_layers;
all_params.gpt.model = params.model;
all_params.gpt.prompt = params.prompt;

// Initialize whisper params
all_params.whisper.model = params.mmproj;
all_params.whisper.fname_inp = {params.file};
Expand Down Expand Up @@ -703,6 +705,10 @@ struct omni_context *omni_init_context(omni_context_params &params)

void omni_free(struct omni_context *ctx_omni)
{
if(internal_chars != nullptr)
{
free(internal_chars);
}
if (ctx_omni->ctx_whisper)
{
whisper_free(ctx_omni->ctx_whisper);
Expand Down Expand Up @@ -792,7 +798,7 @@ ggml_tensor *omni_process_audio(struct omni_context *ctx_omni, omni_params &para
return embed_proj;
}

void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params &params, const std::string &prompt)
const char* omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed, omni_params &params, const std::string &prompt)
{
int n_past = 0;

Expand Down Expand Up @@ -833,12 +839,11 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
for (int i = 0; i < max_tgt_len; i++)
{
const char * tmp = sample(ctx_sampling, ctx_omni->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0)
break;
if (strstr(tmp, "###"))
break; // Yi-VL behavior
printf("%s", tmp);
// printf("%s", tmp);
if (strstr(response.c_str(), "<|im_end|>"))
break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
if (strstr(response.c_str(), "<|im_start|>"))
Expand All @@ -847,16 +852,22 @@ void omni_process_prompt(struct omni_context *ctx_omni, ggml_tensor *audio_embed
break; // mistral llava-1.6

fflush(stdout);
response += tmp;
}

llama_sampling_free(ctx_sampling);
printf("\n");
if(internal_chars != nullptr) { free(internal_chars); }
internal_chars = malloc(sizeof(char)*(response.size()+1));
strncpy((char*)(internal_chars), response.c_str(), response.size());
((char*)(internal_chars))[response.size()] = '\0';
return (const char*)(internal_chars);
}

void omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
const char* omni_process_full(struct omni_context *ctx_omni, omni_context_params &params)
{
omni_params all_params = get_omni_params_from_context_params(params);

ggml_tensor *audio_embed = omni_process_audio(ctx_omni, all_params);
omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
}
return omni_process_prompt(ctx_omni, audio_embed, all_params, all_params.gpt.prompt);
}
4 changes: 2 additions & 2 deletions examples/nexa-omni-audio/omni.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ OMNI_AUDIO_API struct omni_context *omni_init_context(omni_context_params &param

OMNI_AUDIO_API void omni_free(struct omni_context *ctx_omni);

OMNI_AUDIO_API void omni_process_full(
OMNI_AUDIO_API const char* omni_process_full(
struct omni_context *ctx_omni,
omni_context_params &params
);

#ifdef __cplusplus
}
#endif
#endif