diff --git a/docs/internals/llm-streams.md b/docs/internals/llm-streams.md new file mode 100644 index 0000000..329f010 --- /dev/null +++ b/docs/internals/llm-streams.md @@ -0,0 +1,70 @@ +# LLM Stream Core + +This document describes the first part of the Layer 3 stream pipeline. + +The current implementation introduces two building blocks: + +- `sse_reader` in `gateway::streams::reader` +- `HubChunkStream` in `gateway::streams::hub` + +## Scope + +This slice only covers the hub-facing stream foundation. + +- `sse_reader` turns a byte stream into complete SSE lines. +- `HubChunkStream` turns provider stream lines into hub `ChatCompletionChunk` values. + +`BridgedStream` and `NativeStream` are intentionally deferred to later steps. + +## `sse_reader` + +`sse_reader` keeps the contract simple: it emits raw SSE lines as strings. + +Three details matter: + +- it preserves the original line content instead of stripping `data:` prefixes +- it appends a synthetic trailing newline so the last partial line is flushed on EOF +- it drops empty separator lines so downstream transforms only see meaningful records + +That behavior matches the current provider transforms, which already parse SSE framing themselves. + +## `HubChunkStream` + +`HubChunkStream` is the first stream adapter that works on top of provider transforms. + +Its polling behavior is deliberately ordered: + +1. drain the internal buffer first +2. poll the raw line stream only when the buffer is empty +3. call `ProviderCapabilities::transform_stream_chunk()` on each raw line +4. return the first produced hub chunk immediately and queue the rest + +That fixes the earlier class of bug where a provider transform could return multiple chunks for one raw input line and only the first chunk would be observed. + +## Usage Accumulation + +`HubChunkStream` also centralizes usage tracking. + +Whenever a transformed hub chunk carries `usage`, the stream copies `prompt_tokens` and `completion_tokens` into `ChatStreamState`. This keeps token accounting outside individual provider transforms while still making the latest usage totals available to later pipeline stages. + +## Stream State + +`ChatStreamState` now carries both aggregation data and provider stream metadata. + +It currently tracks: + +- buffered tool call assembly state +- latest input and output token counts +- streamed response metadata such as `id`, `model`, and `created` + +Those metadata fields are required because some providers only emit response identity once at stream start, while later events still need to be converted into well-formed hub chunks. + +## Current Limits + +This implementation is intentionally narrow. + +- only the SSE reader is implemented in this slice +- `JsonArrayStream` and `AwsEventStream` readers are still future work +- no format bridging happens here yet; this stream only produces hub chunks + +That keeps the first stream-layer step focused on correctness of buffering, polling order, and usage propagation. diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e3c903c..2bd2b3a 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -2,5 +2,6 @@ pub mod error; pub mod formats; pub mod provider_instance; pub mod providers; +pub mod streams; pub mod traits; pub mod types; diff --git a/src/gateway/streams/hub.rs b/src/gateway/streams/hub.rs new file mode 100644 index 0000000..177162d --- /dev/null +++ b/src/gateway/streams/hub.rs @@ -0,0 +1,208 @@ +use std::{ + collections::VecDeque, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::Stream; +use pin_project::pin_project; + +use crate::gateway::{ + error::Result, + traits::{ChatStreamState, ProviderCapabilities}, + types::openai::ChatCompletionChunk, +}; + +/// Buffered hub stream adapter for provider-produced raw stream lines. +/// +/// `HubChunkStream` preserves output ordering when one raw input item expands +/// into multiple `ChatCompletionChunk` values. The first transformed chunk is +/// returned immediately and the remaining chunks are queued in `buffer` for +/// subsequent polls. +/// +/// The stream mutates `state` as transformed chunks flow through it. In +/// particular, provider-specific stream metadata and the latest observed usage +/// totals are accumulated there so later pipeline stages can inspect them. +/// Provider-specific transformation behavior is delegated to `def`, held as an +/// `Arc`. +#[pin_project] +pub struct HubChunkStream { + #[pin] + inner: Pin> + Send>>, + def: Arc, + pub(crate) state: ChatStreamState, + buffer: VecDeque, +} + +impl HubChunkStream { + /// Creates a `HubChunkStream` from raw provider stream lines. + /// + /// The input stream must preserve line order. The returned stream stays + /// `Send` as long as the input stream is `Send`, and every polled raw line + /// is transformed through the supplied provider definition. + pub fn new( + inner: impl Stream> + Send + 'static, + def: Arc, + ) -> Self { + Self { + inner: Box::pin(inner), + def, + state: ChatStreamState::default(), + buffer: VecDeque::new(), + } + } +} + +impl Stream for HubChunkStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if let Some(chunk) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(chunk))); + } + + loop { + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(raw))) => { + match this.def.transform_stream_chunk(&raw, this.state) { + Ok(chunks) => { + if chunks.is_empty() { + continue; + } + + this.state.chunk_index += chunks.len(); + for chunk in &chunks { + if let Some(usage) = &chunk.usage { + this.state.input_tokens = usage.prompt_tokens; + this.state.output_tokens = usage.completion_tokens; + } + } + + let mut chunks = VecDeque::from(chunks); + let first = chunks.pop_front().unwrap(); + this.buffer.extend(chunks); + return Poll::Ready(Some(Ok(first))); + } + Err(error) => return Poll::Ready(Some(Err(error))), + } + } + Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use futures::StreamExt; + use http::HeaderMap; + + use super::HubChunkStream; + use crate::gateway::{ + error::Result, + provider_instance::ProviderAuth, + traits::{ChatTransform, ProviderCapabilities, ProviderMeta, StreamReaderKind}, + types::openai::{ + ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionChunkDelta, + ChatCompletionUsage, + }, + }; + + struct DummyProvider; + + impl ProviderMeta for DummyProvider { + fn name(&self) -> &'static str { + "dummy" + } + + fn default_base_url(&self) -> &'static str { + "https://example.com" + } + + fn stream_reader_kind(&self) -> StreamReaderKind { + StreamReaderKind::Sse + } + + fn build_auth_headers(&self, _auth: &ProviderAuth) -> Result { + Ok(HeaderMap::new()) + } + } + + impl ChatTransform for DummyProvider { + fn transform_stream_chunk( + &self, + raw: &str, + _state: &mut crate::gateway::traits::ChatStreamState, + ) -> Result> { + match raw { + "data: buffered" => Ok(vec![ + chunk_with_content("first", None), + chunk_with_content("second", None), + ]), + "data: usage" => Ok(vec![chunk_with_content("usage", Some((7, 11)))]), + _ => Ok(vec![]), + } + } + } + + impl ProviderCapabilities for DummyProvider {} + + #[tokio::test] + async fn hub_chunk_stream_consumes_buffered_chunks_in_order() { + let raw_stream = futures::stream::iter(vec![Ok("data: buffered".to_string())]); + let mut stream = HubChunkStream::new(raw_stream, Arc::new(DummyProvider)); + + let first = stream.next().await.unwrap().unwrap(); + let second = stream.next().await.unwrap().unwrap(); + + assert_eq!(first.choices[0].delta.content.as_deref(), Some("first")); + assert_eq!(second.choices[0].delta.content.as_deref(), Some("second")); + assert!(stream.next().await.is_none()); + } + + #[tokio::test] + async fn hub_chunk_stream_accumulates_usage_from_emitted_chunks() { + let raw_stream = futures::stream::iter(vec![Ok("data: usage".to_string())]); + let mut stream = HubChunkStream::new(raw_stream, Arc::new(DummyProvider)); + + let chunk = stream.next().await.unwrap().unwrap(); + + assert_eq!(chunk.usage.as_ref().unwrap().prompt_tokens, 7); + assert_eq!(stream.state.input_tokens, 7); + assert_eq!(stream.state.output_tokens, 11); + assert_eq!(stream.state.chunk_index, 1); + } + + fn chunk_with_content(content: &str, usage: Option<(u32, u32)>) -> ChatCompletionChunk { + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1, + model: "gpt-test".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionChunkDelta { + role: None, + content: Some(content.into()), + tool_calls: None, + }, + finish_reason: None, + }], + usage: usage.map(|(prompt_tokens, completion_tokens)| ChatCompletionUsage { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + prompt_tokens_details: None, + completion_tokens_details: None, + }), + system_fingerprint: None, + } + } +} diff --git a/src/gateway/streams/mod.rs b/src/gateway/streams/mod.rs new file mode 100644 index 0000000..a090249 --- /dev/null +++ b/src/gateway/streams/mod.rs @@ -0,0 +1,5 @@ +pub mod hub; +pub mod reader; + +pub use hub::HubChunkStream; +pub use reader::sse_reader; diff --git a/src/gateway/streams/reader.rs b/src/gateway/streams/reader.rs new file mode 100644 index 0000000..ce730b5 --- /dev/null +++ b/src/gateway/streams/reader.rs @@ -0,0 +1,125 @@ +use std::pin::Pin; + +use bytes::{Bytes, BytesMut}; +use futures::{Stream, StreamExt}; + +use crate::gateway::error::{GatewayError, Result}; + +struct ReaderState { + buffer: BytesMut, + terminated: bool, +} + +/// `sse_reader` decodes a byte stream into newline-delimited SSE lines. +/// +/// The returned stream yields UTF-8 `String` values split on newline +/// boundaries, filters out empty separator lines, and flushes any buffered +/// partial line when the upstream stream ends cleanly. Transport failures are +/// surfaced as `GatewayError::Http` and terminate the reader without emitting +/// buffered partial data. +pub fn sse_reader(stream: S) -> Pin> + Send>> +where + S: Stream> + Send + 'static, +{ + let stream = stream + .chain(futures::stream::once(async { + Ok(Bytes::from_static(b"\n")) + })) + .scan( + ReaderState { + buffer: BytesMut::new(), + terminated: false, + }, + |state, result| { + if state.terminated { + return futures::future::ready(None); + } + + match result { + Ok(chunk) => { + state.buffer.extend_from_slice(&chunk); + + let mut lines = Vec::new(); + if let Some(last_newline) = + state.buffer.iter().rposition(|&byte| byte == b'\n') + { + let complete_data = state.buffer.split_to(last_newline + 1); + let text = String::from_utf8_lossy(&complete_data); + for line in text.lines() { + if !line.is_empty() { + lines.push(Ok(line.to_string())); + } + } + } + + futures::future::ready(Some(futures::stream::iter(lines))) + } + Err(error) => { + state.buffer.clear(); + state.terminated = true; + futures::future::ready(Some(futures::stream::iter(vec![Err( + GatewayError::Http(error), + )]))) + } + } + }, + ) + .flatten(); + + Box::pin(stream) +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::StreamExt; + + use super::sse_reader; + use crate::gateway::error::GatewayError; + + #[tokio::test] + async fn sse_reader_splits_lines_across_chunks() { + let byte_stream = futures::stream::iter(vec![ + Ok(Bytes::from("data: first\n")), + Ok(Bytes::from("data: second")), + Ok(Bytes::from("\n")), + ]); + + let mut reader = sse_reader(byte_stream); + + assert_eq!(reader.next().await.unwrap().unwrap(), "data: first"); + assert_eq!(reader.next().await.unwrap().unwrap(), "data: second"); + assert!(reader.next().await.is_none()); + } + + #[tokio::test] + async fn sse_reader_flushes_trailing_partial_line_on_eof() { + let byte_stream = futures::stream::iter(vec![ + Ok(Bytes::from("data: first\n")), + Ok(Bytes::from("data: second")), + ]); + + let mut reader = sse_reader(byte_stream); + + assert_eq!(reader.next().await.unwrap().unwrap(), "data: first"); + assert_eq!(reader.next().await.unwrap().unwrap(), "data: second"); + assert!(reader.next().await.is_none()); + } + + #[tokio::test] + async fn sse_reader_does_not_flush_partial_line_after_error() { + let error = reqwest::Client::new() + .get("http://[::1") + .build() + .unwrap_err(); + let byte_stream = futures::stream::iter(vec![Ok(Bytes::from("data: partial")), Err(error)]); + + let mut reader = sse_reader(byte_stream); + + assert!(matches!( + reader.next().await.unwrap(), + Err(GatewayError::Http(_)) + )); + assert!(reader.next().await.is_none()); + } +}