Skip to content

Commit

Permalink
fix incorrect translate/transcribe token_ids that are not static const
Browse files Browse the repository at this point in the history
  • Loading branch information
akashmjn committed Jun 26, 2023
1 parent 62c851b commit c8e1ed6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
4 changes: 2 additions & 2 deletions bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {

// Task tokens
func Whisper_token_translate() Token {
return Token(C.whisper_token_translate())
return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
}

// Task tokens
func Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe())
return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
}

// Performance information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library {
int whisper_token_lang(Pointer ctx, int lang_id);

// Task tokens
int whisper_token_translate();
int whisper_token_transcribe();
int whisper_token_translate (Pointer ctx);
int whisper_token_transcribe(Pointer ctx);

// Performance information from the default state.
void whisper_print_timings(Pointer ctx);
Expand Down
37 changes: 20 additions & 17 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,17 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;

id token_eot = 50256;
id token_sot = 50257;
id token_solm = 50359; // ?? TODO@Akash - rename appropriately
id token_prev = 50360;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps

// available tasks
static const id token_translate = 50358; // TODO@Akash - technically it's 50357 for .en models
static const id token_transcribe = 50359; // TODO@Akash - technically it's 50358 for .en models
// reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_eot = 50256;
id token_sot = 50257;
// task tokens (used only for multilingual models)
id token_translate = 50357;
id token_transcribe = 50358;
// other special tokens
id token_solm = 50359; // ?? TODO@Akash - rename appropriately
id token_prev = 50360;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps

bool is_multilingual() const {
return n_vocab == 51865;
Expand Down Expand Up @@ -966,8 +967,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
if (vocab.is_multilingual()) {
vocab.token_eot++;
vocab.token_sot++;
vocab.token_prev++;
vocab.token_translate++;
vocab.token_transcribe++;
vocab.token_solm++;
vocab.token_prev++;
vocab.token_not++;
vocab.token_beg++;
}
Expand Down Expand Up @@ -3228,12 +3231,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id;
}

whisper_token whisper_token_translate(void) {
return whisper_vocab::token_translate;
whisper_token whisper_token_translate(struct whisper_context * ctx) {
return ctx->vocab.token_translate;
}

whisper_token whisper_token_transcribe(void) {
return whisper_vocab::token_transcribe;
whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}

void whisper_print_timings(struct whisper_context * ctx) {
Expand Down Expand Up @@ -4018,9 +4021,9 @@ int whisper_full_with_state(
state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
prompt_init.push_back(whisper_token_translate(ctx));
} else {
prompt_init.push_back(whisper_token_transcribe());
prompt_init.push_back(whisper_token_transcribe(ctx));
}
}

Expand Down
4 changes: 2 additions & 2 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ extern "C" {
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);

// Task tokens
WHISPER_API whisper_token whisper_token_translate (void);
WHISPER_API whisper_token whisper_token_transcribe(void);
WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);

// Performance information from the default state.
WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
Expand Down

0 comments on commit c8e1ed6

Please sign in to comment.