Skip to content

Commit

Permalink
feat: enable tdrz
Browse files Browse the repository at this point in the history
Enables tinydiarize models ggerganov/whisper.cpp#1058
  • Loading branch information
JEF1056 committed Sep 29, 2023
1 parent d07a114 commit e2999d5
Show file tree
Hide file tree
Showing 12 changed files with 12,445 additions and 8,888 deletions.
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Expand Up @@ -474,6 +474,8 @@ private int full(int jobId, ReadableMap options, float[] audioData, int audioDat
options.hasKey("speedUp") ? options.getBoolean("speedUp") : false,
// jboolean translate,
options.hasKey("translate") ? options.getBoolean("translate") : false,
// jboolean tdrz_enable,
options.hasKey("tdrzEnable") ? options.getBoolean("tdrzEnable") : false,
// jstring language,
options.hasKey("language") ? options.getString("language") : "auto",
// jstring prompt
Expand Down Expand Up @@ -645,6 +647,7 @@ protected static native int fullTranscribe(
int best_of,
boolean speed_up,
boolean translate,
boolean tdrz_enable,
String language,
String prompt,
ProgressCallback progressCallback
Expand Down
4 changes: 3 additions & 1 deletion android/src/main/jni.cpp
Expand Up @@ -232,6 +232,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jint best_of,
jboolean speed_up,
jboolean translate,
jboolean tdrz_enable,
jstring language,
jstring prompt,
jobject progress_callback_instance
Expand All @@ -256,7 +257,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.print_realtime = false;
params.print_progress = false;
params.print_timestamps = false;
params.print_special = false;
params.print_special = true;
params.translate = translate;
const char *language_chars = env->GetStringUTFChars(language, nullptr);
params.language = language_chars;
Expand All @@ -265,6 +266,7 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.offset_ms = 0;
params.no_context = true;
params.single_segment = false;
params.tdrz_enable = tdrz_enable;

if (max_len > -1) {
params.max_len = max_len;
Expand Down
8 changes: 7 additions & 1 deletion cpp/coreml/whisper-encoder.mm
Expand Up @@ -22,7 +22,13 @@

NSURL * url_model = [NSURL fileURLWithPath: path_model_str];

const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model error:nil]);
// select which device to run the Core ML model on
MLModelConfiguration *config = [[MLModelConfiguration alloc] init];
config.computeUnits = MLComputeUnitsCPUAndGPU;
//config.computeUnits = MLComputeUnitsCPUAndNeuralEngine;
//config.computeUnits = MLComputeUnitsAll;

const void * data = CFBridgingRetain([[whisper_encoder_impl alloc] initWithContentsOfURL:url_model configuration:config error:nil]);

if (data == NULL) {
return NULL;
Expand Down

0 comments on commit e2999d5

Please sign in to comment.