diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 8a5efa7de0c..d2ea756679e 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -271,12 +271,12 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token { // Task tokens func Whisper_token_translate() Token { - return Token(C.whisper_token_translate()) + return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx))) } // Task tokens func Whisper_token_transcribe() Token { - return Token(C.whisper_token_transcribe()) + return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx))) } // Performance information diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index c1fb4f8e3b0..ad9faa0be70 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library { int whisper_token_lang(Pointer ctx, int lang_id); // Task tokens - int whisper_token_translate(); - int whisper_token_transcribe(); + int whisper_token_translate (Pointer ctx); + int whisper_token_transcribe(Pointer ctx); // Performance information from the default state. void whisper_print_timings(Pointer ctx); diff --git a/whisper.cpp b/whisper.cpp index 58ffca341c2..79112b896d3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -380,16 +380,17 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - id token_eot = 50256; - id token_sot = 50257; - id token_solm = 50359; // ?? TODO@Akash - rename appropriately - id token_prev = 50360; - id token_not = 50362; // no timestamps - id token_beg = 50363; // begin timestamps - - // available tasks - static const id token_translate = 50358; // TODO@Akash - technically it's 50357 for .en models - static const id token_transcribe = 50359; // TODO@Akash - technically it's 50358 for .en models + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // ?? TODO@Akash - rename appropriately + id token_prev = 50360; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps bool is_multilingual() const { return n_vocab == 51865; @@ -966,8 +967,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (vocab.is_multilingual()) { vocab.token_eot++; vocab.token_sot++; - vocab.token_prev++; + vocab.token_translate++; + vocab.token_transcribe++; vocab.token_solm++; + vocab.token_prev++; vocab.token_not++; vocab.token_beg++; } @@ -3228,12 +3231,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { return whisper_token_sot(ctx) + 1 + lang_id; } -whisper_token whisper_token_translate(void) { - return whisper_vocab::token_translate; +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; } -whisper_token whisper_token_transcribe(void) { - return whisper_vocab::token_transcribe; +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; } void whisper_print_timings(struct whisper_context * ctx) { @@ -4018,9 +4021,9 @@ int whisper_full_with_state( state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); if (params.translate) { - prompt_init.push_back(whisper_token_translate()); + prompt_init.push_back(whisper_token_translate(ctx)); } else { - prompt_init.push_back(whisper_token_transcribe()); + prompt_init.push_back(whisper_token_transcribe(ctx)); } } diff --git a/whisper.h b/whisper.h index e983c7d4fa3..6525b47df03 100644 --- a/whisper.h +++ b/whisper.h @@ -284,8 +284,8 @@ extern "C" { WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); // Task tokens - WHISPER_API whisper_token whisper_token_translate (void); - WHISPER_API whisper_token whisper_token_transcribe(void); + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); // Performance information from the default state. WHISPER_API void whisper_print_timings(struct whisper_context * ctx);