diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index cd8889b..17740a4 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -474,6 +474,8 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat options.hasKey("speedUp") ? options.getBoolean("speedUp") : false, // jboolean translate, options.hasKey("translate") ? options.getBoolean("translate") : false, + // jboolean tdrz_enable, + options.hasKey("tdrzEnable") ? options.getBoolean("tdrzEnable") : false, // jstring language, options.hasKey("language") ? options.getString("language") : "auto", // jstring prompt @@ -645,6 +647,7 @@ protected static native int fullTranscribe( int best_of, boolean speed_up, boolean translate, + boolean tdrz_enable, String language, String prompt, ProgressCallback progressCallback diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 20e6f02..90e2ea5 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -232,6 +232,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jint best_of, jboolean speed_up, jboolean translate, + jboolean tdrz_enable, jstring language, jstring prompt, jobject progress_callback_instance @@ -256,7 +257,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.print_realtime = false; params.print_progress = false; params.print_timestamps = false; - params.print_special = false; + params.print_special = true; params.translate = translate; const char *language_chars = env->GetStringUTFChars(language, nullptr); params.language = language_chars; @@ -265,6 +266,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.offset_ms = 0; params.no_context = true; params.single_segment = false; + params.tdrz_enable = tdrz_enable; if (max_len > -1) { params.max_len = max_len; diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index be83206..defe127 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -3727,6 +3727,7 @@ static void whisper_process_logits( // [TDRZ] when tinydiarize is disabled, suppress solm token if (params.tdrz_enable == false) { logits[vocab.token_solm] = -INFINITY; + log("[TDRZ] solm token suppressed\n"); } // suppress task tokens @@ -4717,8 +4718,10 @@ int whisper_full_with_state( text += whisper_token_to_str(ctx, tokens_cur[i].id); } + log("step"); // [TDRZ] record if speaker turn was predicted after current segment if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + log("trdz status=%s\nSpeaker turn happened", params.tdrz_enable ? "enabled" : "disabled"); speaker_turn_next = true; } diff --git a/docs/API/README.md b/docs/API/README.md index f276dbb..d9f12df 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -82,6 +82,7 @@ ___ | `prompt?` | `string` | Initial Prompt | | `speedUp?` | `boolean` | Speed up audio by x2 (reduced accuracy) | | `temperature?` | `number` | Tnitial decoding temperature | +| `tdrzEnable?` | `boolean` | Enable tinydiarize https://github.com/ggerganov/whisper.cpp/pull/1058 | | `temperatureInc?` | `number` | - | | `tokenTimestamps?` | `boolean` | Enable token-level timestamps | | `translate?` | `boolean` | Translate from source language to english (Default: false) | diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 6401fd6..0f9aa12 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -381,6 +381,7 @@ - (struct whisper_full_params)getParams:(NSDictionary *)options jobId:(int)jobId params.print_special = false; params.speed_up = options[@"speedUp"] != nil ? [options[@"speedUp"] boolValue] : false; params.translate = options[@"translate"] != nil ? [options[@"translate"] boolValue] : false; + params.tdrz_enable = options[@"tdrzEnable"] != nil ? [options[@"tdrzEnable"] boolValue] : false; params.language = options[@"language"] != nil ? [options[@"language"] UTF8String] : "auto"; params.n_threads = n_threads > 0 ? n_threads : default_n_threads; params.offset_ms = 0; diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index b7992a1..68175f6 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -30,6 +30,8 @@ export type TranscribeOptions = { bestOf?: number, /** Speed up audio by x2 (reduced accuracy) */ speedUp?: boolean, + /** Enable tinydiarize (https://github.com/ggerganov/whisper.cpp/pull/1058) */ + tdrzEnable?: boolean, /** Initial Prompt */ prompt?: string, }