Skip to content

Commit

Permalink
feat: enable tdrz
Browse files Browse the repository at this point in the history
Enables tinydiarize models ggerganov/whisper.cpp#1058
  • Loading branch information
JEF1056 committed Sep 29, 2023
1 parent d07a114 commit 6de5fc8
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 1 deletion.
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Expand Up @@ -474,6 +474,8 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat
options.hasKey("speedUp") ? options.getBoolean("speedUp") : false,
// jboolean translate,
options.hasKey("translate") ? options.getBoolean("translate") : false,
// jboolean tdrz_enable,
options.hasKey("tdrzEnable") ? options.getBoolean("tdrzEnable") : false,
// jstring language,
options.hasKey("language") ? options.getString("language") : "auto",
// jstring prompt
Expand Down Expand Up @@ -645,6 +647,7 @@ protected static native int fullTranscribe(
int best_of,
boolean speed_up,
boolean translate,
boolean tdrz_enable,
String language,
String prompt,
ProgressCallback progressCallback
Expand Down
4 changes: 3 additions & 1 deletion android/src/main/jni.cpp
Expand Up @@ -232,6 +232,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jint best_of,
jboolean speed_up,
jboolean translate,
jboolean tdrz_enable,
jstring language,
jstring prompt,
jobject progress_callback_instance
Expand All @@ -256,7 +257,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.print_realtime = false;
params.print_progress = false;
params.print_timestamps = false;
params.print_special = false;
params.print_special = true;
params.translate = translate;
const char *language_chars = env->GetStringUTFChars(language, nullptr);
params.language = language_chars;
Expand All @@ -265,6 +266,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;
params.tdrz_enable = tdrz_enable;

if (max_len > -1) {
params.max_len = max_len;
Expand Down
3 changes: 3 additions & 0 deletions cpp/whisper.cpp
Expand Up @@ -3727,6 +3727,7 @@ static void whisper_process_logits(
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
logits[vocab.token_solm] = -INFINITY;
log("[TDRZ] solm token suppressed\n");
}

// suppress task tokens
Expand Down Expand Up @@ -4717,8 +4718,10 @@ int whisper_full_with_state(
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}

log("step");
// [TDRZ] record if speaker turn was predicted after current segment
if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
log("trdz status=%s\nSpeaker turn happened", params.tdrz_enable ? "enabled" : "disabled");
speaker_turn_next = true;
}

Expand Down
1 change: 1 addition & 0 deletions docs/API/README.md
Expand Up @@ -82,6 +82,7 @@ ___
| `prompt?` | `string` | Initial Prompt |
| `speedUp?` | `boolean` | Speed up audio by x2 (reduced accuracy) |
| `temperature?` | `number` | Tnitial decoding temperature |
| `tdrzEnable?` | `boolean` | Enable tinydiarize https://github.com/ggerganov/whisper.cpp/pull/1058 |
| `temperatureInc?` | `number` | - |
| `tokenTimestamps?` | `boolean` | Enable token-level timestamps |
| `translate?` | `boolean` | Translate from source language to english (Default: false) |
Expand Down
1 change: 1 addition & 0 deletions ios/RNWhisperContext.mm
Expand Up @@ -381,6 +381,7 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId
params.print_special = false;
params.speed_up = options[@"speedUp"] != nil ? [options[@"speedUp"] boolValue] : false;
params.translate = options[@"translate"] != nil ? [options[@"translate"] boolValue] : false;
params.tdrz_enable = options[@"tdrzEnable"] != nil ? [options[@"tdrzEnable"] boolValue] : false;
params.language = options[@"language"] != nil ? [options[@"language"] UTF8String] : "auto";
params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
params.offset_ms = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/NativeRNWhisper.ts
Expand Up @@ -30,6 +30,8 @@ export type TranscribeOptions = {
bestOf?: number,
/** Speed up audio by x2 (reduced accuracy) */
speedUp?: boolean,
/** Enable tinydiarize (https://github.com/ggerganov/whisper.cpp/pull/1058) */
tdrzEnable?: boolean,
/** Initial Prompt */
prompt?: string,
}
Expand Down

0 comments on commit 6de5fc8

Please sign in to comment.