Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions openless-all/app/src-tauri/src/asr/volcengine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! quirks are preserved verbatim — see comments tagged with `[asr]` for the
//! original learnings (especially the "definite=true is NOT stream end" bug).

use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;

Expand All @@ -12,7 +13,7 @@ use parking_lot::Mutex as ParkingMutex;
use serde_json::{json, Value};
use tokio::net::TcpStream;
use tokio::runtime::Handle;
use tokio::sync::{oneshot, Mutex as AsyncMutex};
use tokio::sync::{oneshot, Mutex as AsyncMutex, Notify};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::HeaderValue;
use tokio_tungstenite::tungstenite::Message;
Expand Down Expand Up @@ -72,6 +73,10 @@ struct SyncState {
final_tx: Option<oneshot::Sender<Result<RawTranscript, VolcengineASRError>>>,
runtime: Option<Handle>,
start: Option<Instant>,
/// 最近一次 partial(非 final)的累积 transcript。服务端在 final 帧到达前
/// 关闭连接 / 网络中断时,作为 fallback 回给上层,避免「用户的话已经识别出来
/// 但没拿到 final」就丢光。
last_partial_text: String,
}

pub struct VolcengineStreamingASR {
Expand All @@ -83,6 +88,11 @@ pub struct VolcengineStreamingASR {
/// of the lifetime of any particular `&self` borrow.
writer: SharedWriter,
final_rx: ParkingMutex<Option<oneshot::Receiver<Result<RawTranscript, VolcengineASRError>>>>,
/// 在飞的 audio 帧 spawn 数。consume_pcm_chunk +1,spawn 内 send 完成 -1。
/// send_last_frame 必须等它降到 0 才能安全发末帧,否则末帧可能被服务端先收到
/// 而把后续 chunk 当成「stream 已结束」之后的多余数据丢弃 → 尾句丢失。
pending_sends: Arc<AtomicUsize>,
send_done: Arc<Notify>,
}

impl VolcengineStreamingASR {
Expand All @@ -93,6 +103,8 @@ impl VolcengineStreamingASR {
state: ParkingMutex::new(SyncState::default()),
writer: Arc::new(AsyncMutex::new(None)),
final_rx: ParkingMutex::new(None),
pending_sends: Arc::new(AtomicUsize::new(0)),
send_done: Arc::new(Notify::new()),
}
}

Expand Down Expand Up @@ -152,7 +164,9 @@ impl VolcengineStreamingASR {
st.final_tx = Some(tx);
st.runtime = Some(Handle::current());
st.start = Some(Instant::now());
st.last_partial_text.clear();
}
self.pending_sends.store(0, Ordering::SeqCst);
*self.final_rx.lock() = Some(rx);
*self.writer.lock().await = Some(write);

Expand Down Expand Up @@ -186,14 +200,17 @@ impl VolcengineStreamingASR {
}
}
Ok(Message::Close(_)) => {
// Server closed without a final frame — treat as no result.
this.signal_error(VolcengineASRError::NoFinalResult);
// 服务端没发 final 就关连接 → 用最近一次 partial 兜底,不丢已识别的文字。
this.fallback_to_partial_or_error(VolcengineASRError::NoFinalResult);
break;
}
Ok(_) => { /* ignore text/ping/pong */ }
Err(e) => {
log::error!("[asr] receive loop error: {}", e);
this.signal_error(VolcengineASRError::ConnectionFailed(e.to_string()));
// 网络中断同样回退到 partial,让用户至少拿到已经识别的部分。
this.fallback_to_partial_or_error(VolcengineASRError::ConnectionFailed(
e.to_string(),
));
break;
}
}
Expand All @@ -207,6 +224,23 @@ impl VolcengineStreamingASR {
}

pub async fn send_last_frame(&self) -> Result<(), VolcengineASRError> {
// 等所有 fire-and-forget 发送完成。否则末帧(NegativeSequence)可能比尾部
// chunk 先到服务端,被识别为「流已结束」之后再到的 chunk 全部丢弃 = 尾句吞掉。
// 给一个 800ms 上限避免极端网络下永远等。
let drain_deadline = Instant::now() + std::time::Duration::from_millis(800);
while self.pending_sends.load(Ordering::SeqCst) > 0 {
let remaining = drain_deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
log::warn!(
"[asr] send_last_frame: pending {} 帧未发送完,超时强制继续",
self.pending_sends.load(Ordering::SeqCst)
);
break;
}
// notified() 返回 future,被 timeout 包住 → 等待发送完成或超时
let _ = tokio::time::timeout(remaining, self.send_done.notified()).await;
}

// Drain leftover audio (if any) into one final positive-sequence frame.
let leftover = {
let mut st = self.state.lock();
Expand Down Expand Up @@ -394,6 +428,12 @@ impl VolcengineStreamingASR {
}
}

// 缓存最新的 partial transcript:服务端在 final 帧前断连时 fallback 用。
// 仅在非空且不是 final 时更新(final 走另一条路径)。
if !has_final && !full_text.is_empty() {
self.state.lock().last_partial_text = full_text.clone();
}

if has_final {
let duration_ms = self
.state
Expand Down Expand Up @@ -425,6 +465,34 @@ impl VolcengineStreamingASR {
let _ = tx.send(Err(err));
}
}

/// 服务端 close / 网络中断时调用:如果有缓存的 partial 文本,作为 transcript
/// 兜底返回;否则才报错。配合 `last_partial_text` 实现「至少不丢用户已识别出的话」。
fn fallback_to_partial_or_error(&self, err: VolcengineASRError) {
let (partial, duration_ms) = {
let st = self.state.lock();
(
st.last_partial_text.clone(),
st.start
.map(|s| s.elapsed().as_millis() as u64)
.unwrap_or(0),
)
};
if !partial.is_empty() {
log::warn!(
"[asr] {}; 使用 partial 兜底({} 字)",
err,
partial.chars().count()
);
self.signal_success(RawTranscript {
text: partial,
duration_ms,
});
} else {
self.signal_error(err);
}
self.state.lock().is_connected = false;
}
}

impl AudioConsumer for VolcengineStreamingASR {
Expand Down Expand Up @@ -464,6 +532,10 @@ impl AudioConsumer for VolcengineStreamingASR {
};

let writer = Arc::clone(&self.writer);
// pending_sends + Notify 让 send_last_frame 知道何时所有 chunk 都已发出。
self.pending_sends.fetch_add(1, Ordering::SeqCst);
let pending = Arc::clone(&self.pending_sends);
let notify = Arc::clone(&self.send_done);
runtime.spawn(async move {
let frame = frame::build(
MessageType::AudioOnlyRequest,
Expand All @@ -476,6 +548,9 @@ impl AudioConsumer for VolcengineStreamingASR {
// 把丢帧错误顶到日志里,定位"为什么服务端只收到 100ms"
log::error!("[asr] audio frame seq={} send 失败: {}", seq, e);
}
if pending.fetch_sub(1, Ordering::SeqCst) == 1 {
notify.notify_waiters();
}
});
}
}
Expand Down
19 changes: 17 additions & 2 deletions openless-all/app/src-tauri/src/asr/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,38 @@ impl WhisperBatchASR {

/// Stop collecting audio, encode the buffer as WAV, and POST to the
/// Whisper transcriptions endpoint.
///
/// 失败时**保留** PCM buffer,让上层有机会重试或在历史中至少留一个失败记录;
/// 之前的实现一进函数就 `mem::take` 把 buffer 清空,凭证错或网络中断都会
/// 让用户的录音直接消失。
pub async fn transcribe(&self) -> Result<RawTranscript> {
let pcm = std::mem::take(&mut *self.buffer.lock());
// clone 而不是 take:~30s 16 kHz 16-bit 音频 ≈ 960 KB,会话末调用一次,可接受。
let pcm = self.buffer.lock().clone();
if pcm.is_empty() {
return Ok(RawTranscript {
text: String::new(),
duration_ms: 0,
});
}

let result = self.transcribe_inner(&pcm).await;
// 仅在成功路径上才清 buffer。失败时 PCM 还在,coordinator 拿到 Err 但
// 用户重新触发 stop 时仍能再发一次,或日后增加重试入口时复用。
if result.is_ok() {
self.buffer.lock().clear();
}
result
}

async fn transcribe_inner(&self, pcm: &[u8]) -> Result<RawTranscript> {
// 16 kHz mono 16-bit: 2 bytes per sample.
let duration_ms = (pcm.len() as u64 / 2) * 1000 / 16_000;

if self.api_key.is_empty() {
anyhow::bail!("Whisper API key missing");
}

let wav = encode_wav_16k_mono(&pcm);
let wav = encode_wav_16k_mono(pcm);
let base_url = self.base_url.trim_end_matches('/');
let url = format!("{}/audio/transcriptions", base_url);

Expand Down