diff --git a/Cargo.lock b/Cargo.lock index e174c43f..b8edca1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2215,6 +2215,7 @@ dependencies = [ "lru", "num_cpus", "parking_lot", + "parquet", "petgraph 0.7.1", "proctitle", "prost", @@ -2228,6 +2229,7 @@ dependencies = [ "serde_yaml", "sqlparser", "strum", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index 87d4ea03..531601d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "sync", "tim serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9" serde_json = "1.0" -uuid = { version = "1.0", features = ["v4"] } +uuid = { version = "1.0", features = ["v4", "v7"] } log = "0.4" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -51,6 +51,7 @@ arrow = { version = "55", default-features = false } arrow-array = "55" arrow-ipc = "55" arrow-schema = { version = "55", features = ["serde"] } +parquet = "55" futures = "0.3" serde_json_path = "0.7" xxhash-rust = { version = "0.8", features = ["xxh3"] } @@ -78,3 +79,6 @@ governor = "0.8.0" default = ["incremental-cache", "python"] incremental-cache = ["wasmtime/incremental-cache"] python = [] + +[dev-dependencies] +tempfile = "3.27.0" diff --git a/Makefile b/Makefile index c8e1da4d..87a2339a 100644 --- a/Makefile +++ b/Makefile @@ -13,12 +13,49 @@ APP_NAME := function-stream VERSION := $(shell grep '^version' Cargo.toml | head -1 | awk -F '"' '{print $$2}') -ARCH := $(shell uname -m) -OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +# 1. Auto-detect system environment & normalize architecture +RAW_ARCH := $(shell uname -m) +# Fix macOS M-series returning arm64 while Rust expects aarch64 +ifeq ($(RAW_ARCH), arm64) + ARCH := aarch64 +else ifeq ($(RAW_ARCH), amd64) + ARCH := x86_64 +else + ARCH := $(RAW_ARCH) +endif + +OS := $(shell uname -s | tr '[:upper:]' '[:lower:]') +OS_NAME := $(shell uname -s) + +# 2. Configure RUSTFLAGS and target triple per platform DIST_ROOT := dist -TARGET_DIR := target/release +ifeq ($(OS_NAME), Linux) + TRIPLE := $(ARCH)-unknown-linux-gnu + STATIC_FLAGS := +else ifeq ($(OS_NAME), Darwin) + # macOS: strip symbols but keep dynamic linking (Apple system restriction) + TRIPLE := $(ARCH)-apple-darwin + STATIC_FLAGS := +else ifneq (,$(findstring MINGW,$(OS_NAME))$(findstring MSYS,$(OS_NAME))) + # Windows (Git Bash / MSYS2): static-link MSVC runtime + TRIPLE := $(ARCH)-pc-windows-msvc + STATIC_FLAGS := -C target-feature=+crt-static +else + # Fallback + TRIPLE := $(ARCH)-unknown-linux-gnu + STATIC_FLAGS := +endif + +# 3. Aggressive optimization flags +# opt-level=z : size-oriented, minimize binary footprint +# strip=symbols: remove debug symbol table at link time +# Note: panic=abort is intentionally omitted to preserve stack unwinding +# for better fault tolerance in the streaming runtime +OPTIMIZE_FLAGS := -C opt-level=z -C strip=symbols $(STATIC_FLAGS) + +TARGET_DIR := target/$(TRIPLE)/release PYTHON_ROOT := python WASM_SOURCE := $(PYTHON_ROOT)/functionstream-runtime/target/functionstream-python-runtime.wasm @@ -67,18 +104,42 @@ help: @echo "" @echo " Version: $(VERSION) | Arch: $(ARCH) | OS: $(OS)" -build: .check-env .build-wasm - $(call log,BUILD,Rust Full Features) - @cargo build --release --features python --quiet +# 4. Auto-install missing Rust target toolchain +.ensure-target: + @rustup target list --installed | grep -q "$(TRIPLE)" || \ + (printf "$(C_Y)[!] Auto-installing target toolchain for $(OS_NAME): $(TRIPLE)$(C_0)\n" && \ + rustup target add $(TRIPLE)) + +# 5. Build targets (depend on .ensure-target for automatic toolchain setup) +build: .check-env .ensure-target .build-wasm + $(call log,BUILD,Rust Full [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --features python \ + --quiet $(call log,BUILD,CLI) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) -build-lite: .check-env - $(call log,BUILD,Rust Lite No Python) - @cargo build --release --no-default-features --features incremental-cache --quiet +build-lite: .check-env .ensure-target + $(call log,BUILD,Rust Lite [$(OS_NAME) / $(TRIPLE)]) + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + --no-default-features \ + --features incremental-cache \ + --quiet $(call log,BUILD,CLI for dist) - @cargo build --release -p function-stream-cli --quiet + @RUSTFLAGS="$(OPTIMIZE_FLAGS)" \ + cargo build --release \ + --target $(TRIPLE) \ + -p function-stream-cli \ + --quiet $(call success,Target: $(TARGET_DIR)/$(APP_NAME) $(TARGET_DIR)/cli) .build-wasm: diff --git a/conf/config.yaml b/conf/config.yaml index 9d0f625e..c83809c7 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -49,6 +49,20 @@ wasm: # When cache exceeds this size, least recently used items will be evicted max_cache_size: 104857600 +# Streaming Runtime Configuration +streaming: + # Global memory pool for streaming pipeline execution (buffers, batch collect, backpressure). + # Default / example: 10 MiB (10485760 bytes). + streaming_runtime_memory_bytes: 10485760 + + # Per stateful operator (join / agg / window): in-memory state store cap before spill. + # Default / example: 5 MiB (5242880 bytes). + operator_state_store_memory_bytes: 5242880 + checkpoint_interval_ms: 60000 + pipeline_parallelism: 1 + # KeyBy (key extraction) operator pipeline parallelism in planned streaming jobs. + key_by_parallelism: 1 + # State Storage Configuration # Used to store runtime state data for tasks state_storage: diff --git a/protocol/proto/storage.proto b/protocol/proto/storage.proto index d7caf7bc..fd021727 100644 --- a/protocol/proto/storage.proto +++ b/protocol/proto/storage.proto @@ -43,6 +43,28 @@ message CatalogSourceTable { // Streaming table storage (CREATE STREAMING TABLE persistence) // ============================================================================= +// Partition offset for one Kafka partition at a completed checkpoint. +message KafkaPartitionOffset { + int32 partition = 1; + int64 offset = 2; +} + +// Kafka source subtask checkpoint: one file / one TaskContext (pipeline + subtask). +message KafkaSourceSubtaskCheckpoint { + uint32 pipeline_id = 1; + uint32 subtask_index = 2; + // Epoch of the barrier when this snapshot was taken (aligns with latest_checkpoint_epoch on commit). + uint64 checkpoint_epoch = 3; + repeated KafkaPartitionOffset partitions = 4; +} + +// Generic source checkpoint payload envelope (enum-like via oneof). +message SourceCheckpointPayload { + oneof checkpoint { + KafkaSourceSubtaskCheckpoint kafka = 1; + } +} + // Persisted record for one streaming table (CREATE STREAMING TABLE). // On restart, the engine re-submits each record to JobManager to resume the pipeline. message StreamingTableDefinition { @@ -52,6 +74,17 @@ message StreamingTableDefinition { // Stored as opaque bytes to avoid coupling storage schema with runtime API protos. bytes fs_program_bytes = 3; string comment = 4; + + uint64 checkpoint_interval_ms = 5; + + // Last globally-committed checkpoint epoch. + // Updated by JobManager after all operators ACK. Used for crash recovery. + uint64 latest_checkpoint_epoch = 6; + + // Kafka source per-subtask offsets at the same committed epoch as `latest_checkpoint_epoch`. + // Populated by the runtime coordinator from source checkpoint ACKs. Optional `.bin` files under + // the job state dir may exist only for local recovery materialization from this field. + repeated KafkaSourceSubtaskCheckpoint kafka_source_checkpoints = 7; } // ============================================================================= diff --git a/src/config/global_config.rs b/src/config/global_config.rs index c76bf4b0..dcfbcf5c 100644 --- a/src/config/global_config.rs +++ b/src/config/global_config.rs @@ -17,13 +17,33 @@ use uuid::Uuid; use crate::config::log_config::LogConfig; use crate::config::python_config::PythonConfig; use crate::config::service_config::ServiceConfig; +use crate::config::streaming_job::{ResolvedStreamingJobConfig, StreamingJobConfig}; use crate::config::wasm_config::WasmConfig; +/// Default for [`StreamingConfig::streaming_runtime_memory_bytes`] when unset. **10 MiB** (pipeline buffers, backpressure). +pub const DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES: u64 = 10 * 1024 * 1024; + +/// Default for [`StreamingConfig::operator_state_store_memory_bytes`] when unset. **5 MiB** per stateful operator cap. +pub const DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES: u64 = 5 * 1024 * 1024; + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct StreamingConfig { - /// Maximum heap memory (in bytes) available to the streaming runtime's memory pool. - /// Defaults to 256 MiB when absent. - pub max_memory_bytes: Option, + #[serde(flatten)] + pub job: StreamingJobConfig, + /// Bytes reserved in the global memory pool for streaming pipeline execution (buffers, + /// batch collect, backpressure). Default 10 MiB. + #[serde(default)] + pub streaming_runtime_memory_bytes: Option, + /// Per stateful operator: in-memory state store cap before spill. Default 5 MiB. + #[serde(default)] + pub operator_state_store_memory_bytes: Option, +} + +impl StreamingConfig { + #[inline] + pub fn resolved_job(&self) -> ResolvedStreamingJobConfig { + self.job.resolve() + } } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/src/config/mod.rs b/src/config/mod.rs index f08051af..e60dcfde 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,9 +17,13 @@ pub mod paths; pub mod python_config; pub mod service_config; pub mod storage; +pub mod streaming_job; +pub mod system; pub mod wasm_config; -pub use global_config::GlobalConfig; +pub use global_config::{ + DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, +}; pub use loader::load_global_config; pub use log_config::LogConfig; #[allow(unused_imports)] @@ -31,3 +35,4 @@ pub use paths::{ }; #[cfg(feature = "python")] pub use python_config::PythonConfig; +pub use streaming_job::{DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_PIPELINE_PARALLELISM}; diff --git a/src/config/streaming_job.rs b/src/config/streaming_job.rs new file mode 100644 index 00000000..0b0d1cde --- /dev/null +++ b/src/config/streaming_job.rs @@ -0,0 +1,72 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; + +pub const DEFAULT_CHECKPOINT_INTERVAL_MS: u64 = 60 * 1000; +pub const DEFAULT_PIPELINE_PARALLELISM: u32 = 1; +pub const DEFAULT_KEY_BY_PARALLELISM: u32 = 1; +pub const DEFAULT_JOB_MANAGER_CONTROL_PLANE_THREADS: u32 = 1; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct StreamingJobConfig { + #[serde(default)] + pub checkpoint_interval_ms: Option, + #[serde(default)] + pub pipeline_parallelism: Option, + /// Physical parallelism for KeyBy / key-extraction operators in planned streaming graphs. + #[serde(default)] + pub key_by_parallelism: Option, + #[serde(default)] + pub job_manager_control_plane_threads: Option, + #[serde(default)] + pub job_manager_data_plane_threads: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct ResolvedStreamingJobConfig { + pub checkpoint_interval_ms: u64, + pub pipeline_parallelism: u32, + pub key_by_parallelism: u32, + pub job_manager_control_plane_threads: u32, + pub job_manager_data_plane_threads: u32, +} + +impl StreamingJobConfig { + pub fn resolve(&self) -> ResolvedStreamingJobConfig { + let cpu_threads = std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(1); + ResolvedStreamingJobConfig { + checkpoint_interval_ms: self + .checkpoint_interval_ms + .filter(|&ms| ms > 0) + .unwrap_or(DEFAULT_CHECKPOINT_INTERVAL_MS), + pipeline_parallelism: self + .pipeline_parallelism + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_PIPELINE_PARALLELISM), + key_by_parallelism: self + .key_by_parallelism + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_KEY_BY_PARALLELISM), + job_manager_control_plane_threads: self + .job_manager_control_plane_threads + .filter(|&p| p > 0) + .unwrap_or(DEFAULT_JOB_MANAGER_CONTROL_PLANE_THREADS), + job_manager_data_plane_threads: self + .job_manager_data_plane_threads + .filter(|&p| p > 0) + .unwrap_or(cpu_threads), + } + } +} diff --git a/src/config/system.rs b/src/config/system.rs new file mode 100644 index 00000000..1a6d2967 --- /dev/null +++ b/src/config/system.rs @@ -0,0 +1,230 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::io; + +pub struct SystemMemoryInfo { + pub total_physical: u64, + pub available_physical: u64, + pub total_virtual: u64, + pub available_virtual: u64, +} + +pub fn system_memory_info() -> io::Result { + sys::system_memory_info() +} + +#[cfg(target_os = "linux")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + let content = std::fs::read_to_string("/proc/meminfo")?; + + let mut total_physical: Option = None; + let mut available_physical: Option = None; + let mut swap_total: u64 = 0; + let mut swap_free: u64 = 0; + + for line in content.lines() { + if let Some(v) = parse_meminfo_kb(line, "MemTotal:") { + total_physical = Some(v); + } else if let Some(v) = parse_meminfo_kb(line, "MemAvailable:") { + available_physical = Some(v); + } else if let Some(v) = parse_meminfo_kb(line, "SwapTotal:") { + swap_total = v; + } else if let Some(v) = parse_meminfo_kb(line, "SwapFree:") { + swap_free = v; + } + } + + let total_phys = total_physical.ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + "MemTotal not found in /proc/meminfo", + ) + })?; + let avail_phys = available_physical.unwrap_or(0); + + Ok(SystemMemoryInfo { + total_physical: total_phys, + available_physical: avail_phys, + total_virtual: total_phys + swap_total, + available_virtual: avail_phys + swap_free, + }) + } + + fn parse_meminfo_kb(line: &str, prefix: &str) -> Option { + let rest = line.strip_prefix(prefix)?; + let kb: u64 = rest.trim().trim_end_matches("kB").trim().parse().ok()?; + Some(kb * 1024) + } +} + +#[cfg(target_os = "macos")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + let total_physical = sysctl_u64("hw.memsize")?; + + let page_size = sysctl_u64("hw.pagesize").unwrap_or(4096); + let vm_stats = read_vm_stat()?; + + let free_pages = vm_stats.free + vm_stats.inactive + vm_stats.purgeable; + let available_physical = free_pages * page_size; + + let swap = read_swap_usage(); + let swap_total = swap.0; + let swap_free = swap_total.saturating_sub(swap.1); + + Ok(SystemMemoryInfo { + total_physical, + available_physical, + total_virtual: total_physical + swap_total, + available_virtual: available_physical + swap_free, + }) + } + + fn sysctl_u64(name: &str) -> io::Result { + let output = std::process::Command::new("sysctl") + .arg("-n") + .arg(name) + .output()?; + if !output.status.success() { + return Err(io::Error::other(format!("sysctl {name} failed"))); + } + String::from_utf8_lossy(&output.stdout) + .trim() + .parse() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + } + + struct VmPages { + free: u64, + inactive: u64, + purgeable: u64, + } + + fn read_vm_stat() -> io::Result { + let output = std::process::Command::new("vm_stat").output()?; + let text = String::from_utf8_lossy(&output.stdout); + + let mut free = 0u64; + let mut inactive = 0u64; + let mut purgeable = 0u64; + + for line in text.lines() { + if let Some(v) = parse_vm_stat_line(line, "Pages free") { + free = v; + } else if let Some(v) = parse_vm_stat_line(line, "Pages inactive") { + inactive = v; + } else if let Some(v) = parse_vm_stat_line(line, "Pages purgeable") { + purgeable = v; + } + } + + Ok(VmPages { + free, + inactive, + purgeable, + }) + } + + fn parse_vm_stat_line(line: &str, key: &str) -> Option { + if !line.contains(key) { + return None; + } + let val_str = line.rsplit(':').next()?.trim().trim_end_matches('.'); + val_str.parse().ok() + } + + fn read_swap_usage() -> (u64, u64) { + let output = match std::process::Command::new("sysctl") + .arg("-n") + .arg("vm.swapusage") + .output() + { + Ok(o) => o, + Err(_) => return (0, 0), + }; + let text = String::from_utf8_lossy(&output.stdout); + let mut total = 0u64; + let mut used = 0u64; + for part in text.split_whitespace() { + if let Some(mb_str) = part.strip_suffix("M") + && let Ok(mb) = mb_str.parse::() + { + if total == 0 { + total = (mb * 1024.0 * 1024.0) as u64; + } else if used == 0 { + used = (mb * 1024.0 * 1024.0) as u64; + } + } + } + (total, used) + } +} + +#[cfg(target_os = "windows")] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + #[repr(C)] + struct MemoryStatusEx { + dw_length: u32, + dw_memory_load: u32, + ull_total_phys: u64, + ull_avail_phys: u64, + ull_total_page_file: u64, + ull_avail_page_file: u64, + ull_total_virtual: u64, + ull_avail_virtual: u64, + ull_avail_extended_virtual: u64, + } + + extern "system" { + fn GlobalMemoryStatusEx(lpBuffer: *mut MemoryStatusEx) -> i32; + } + + pub fn system_memory_info() -> io::Result { + unsafe { + let mut status = std::mem::zeroed::(); + status.dw_length = std::mem::size_of::() as u32; + if GlobalMemoryStatusEx(&mut status) == 0 { + return Err(io::Error::last_os_error()); + } + Ok(SystemMemoryInfo { + total_physical: status.ull_total_phys, + available_physical: status.ull_avail_phys, + total_virtual: status.ull_total_virtual, + available_virtual: status.ull_avail_virtual, + }) + } + } +} + +#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))] +mod sys { + use super::SystemMemoryInfo; + use std::io; + + pub fn system_memory_info() -> io::Result { + Err(io::Error::new( + io::ErrorKind::Unsupported, + "memory detection not supported on this platform", + )) + } +} diff --git a/src/coordinator/execution/executor.rs b/src/coordinator/execution/executor.rs index 0000d0cf..6fb03134 100644 --- a/src/coordinator/execution/executor.rs +++ b/src/coordinator/execution/executor.rs @@ -28,9 +28,12 @@ use crate::coordinator::plan::{ StartFunctionPlan, StopFunctionPlan, StreamingTable, StreamingTableConnectorPlan, }; use crate::coordinator::statement::{ConfigSource, FunctionSource}; +use crate::coordinator::streaming_table_options::{ + parse_checkpoint_interval_ms, parse_pipeline_parallelism, +}; use crate::runtime::streaming::job::JobManager; use crate::runtime::streaming::protocol::control::StopMode; -use crate::runtime::taskexecutor::TaskManager; +use crate::runtime::wasm::taskexecutor::TaskManager; use crate::sql::schema::show_create_catalog_table; use crate::sql::schema::table::Table as CatalogTable; use crate::storage::stream_catalog::CatalogManager; @@ -318,29 +321,45 @@ impl PlanVisitor for Executor { _context: &PlanVisitorContext, ) -> PlanVisitorResult { let execute = || -> Result { - let fs_program: FsProgram = plan.program.clone().into(); + let mut fs_program: FsProgram = plan.program.clone().into(); let job_manager: Arc = Arc::clone(&self.job_manager); + // Only override per-node parallelism when CREATE STREAMING TABLE specifies + // `WITH (parallelism = N)`. Otherwise keep planner-assigned values (e.g. keyed + // aggregates defaulting to a higher parallelism than the job-wide default). + if let Some(pipeline_parallelism) = + parse_pipeline_parallelism(plan.with_options.as_ref()) + { + let p = pipeline_parallelism.max(1); + for node in &mut fs_program.nodes { + node.parallelism = p; + } + } let job_id = plan.name.clone(); - let job_id = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current() - .block_on(job_manager.submit_job(job_id, fs_program.clone())) - }) - .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + + let custom_interval = parse_checkpoint_interval_ms(plan.with_options.as_ref()); self.catalog_manager .persist_streaming_job( &plan.name, &fs_program, plan.comment.as_deref().unwrap_or(""), + custom_interval.unwrap_or(0), ) .map_err(|e| { - ExecuteError::Internal(format!( - "Streaming job '{}' submitted but persistence failed: {e}", - plan.name - )) + ExecuteError::Internal(format!("Streaming job persistence failed: {e}",)) })?; + let job_id = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(job_manager.submit_job( + job_id, + fs_program, + custom_interval, + None, + )) + }) + .map_err(|e| ExecuteError::Internal(format!("Failed to submit streaming job: {e}")))?; + info!( job_id = %job_id, table = %plan.name, diff --git a/src/coordinator/mod.rs b/src/coordinator/mod.rs index 38d4637f..86598bc5 100644 --- a/src/coordinator/mod.rs +++ b/src/coordinator/mod.rs @@ -19,6 +19,7 @@ mod execution_context; mod plan; mod runtime_context; mod statement; +mod streaming_table_options; mod tool; pub use coordinator::Coordinator; diff --git a/src/coordinator/plan/logical_plan_visitor.rs b/src/coordinator/plan/logical_plan_visitor.rs index 6adc6420..d49d0314 100644 --- a/src/coordinator/plan/logical_plan_visitor.rs +++ b/src/coordinator/plan/logical_plan_visitor.rs @@ -168,10 +168,28 @@ impl LogicalPlanVisitor { let validated_program = self.validate_graph_topology(&final_logical_plan)?; + let streaming_with_options: Option> = + if with_options.is_empty() { + None + } else { + let map: std::collections::HashMap = with_options + .iter() + .filter_map(|opt| match opt { + SqlOption::KeyValue { key, value } => Some(( + key.value.clone(), + value.to_string().trim_matches('\'').to_string(), + )), + _ => None, + }) + .collect(); + if map.is_empty() { None } else { Some(map) } + }; + Ok(StreamingTable { name: sink_table_name, comment: comment.clone(), program: validated_program, + with_options: streaming_with_options, }) } diff --git a/src/coordinator/plan/streaming_table_plan.rs b/src/coordinator/plan/streaming_table_plan.rs index 512ec266..e155ba91 100644 --- a/src/coordinator/plan/streaming_table_plan.rs +++ b/src/coordinator/plan/streaming_table_plan.rs @@ -10,6 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use super::{PlanNode, PlanVisitor, PlanVisitorContext, PlanVisitorResult}; use crate::sql::logical_node::logical::LogicalProgram; @@ -19,6 +21,7 @@ pub struct StreamingTable { pub name: String, pub comment: Option, pub program: LogicalProgram, + pub with_options: Option>, } impl PlanNode for StreamingTable { diff --git a/src/coordinator/runtime_context.rs b/src/coordinator/runtime_context.rs index 5d671b98..21b9d876 100644 --- a/src/coordinator/runtime_context.rs +++ b/src/coordinator/runtime_context.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use anyhow::Result; use crate::runtime::streaming::job::JobManager; -use crate::runtime::taskexecutor::TaskManager; +use crate::runtime::wasm::taskexecutor::TaskManager; use crate::sql::schema::StreamSchemaProvider; use crate::storage::stream_catalog::CatalogManager; diff --git a/src/coordinator/streaming_table_options.rs b/src/coordinator/streaming_table_options.rs new file mode 100644 index 00000000..51e020b0 --- /dev/null +++ b/src/coordinator/streaming_table_options.rs @@ -0,0 +1,47 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +fn parse_positive_u64(raw: &str) -> Option { + let t = raw.trim().trim_matches('\''); + t.parse::().ok().filter(|&v| v > 0) +} + +fn parse_positive_u32(raw: &str) -> Option { + let t = raw.trim().trim_matches('\''); + t.parse::().ok().filter(|&v| v > 0) +} + +pub fn parse_checkpoint_interval_ms(opts: Option<&HashMap>) -> Option { + opts.and_then(|m| m.get("checkpoint.interval")) + .and_then(|s| parse_positive_u64(s)) +} + +pub fn parse_pipeline_parallelism(opts: Option<&HashMap>) -> Option { + opts.and_then(|m| m.get("parallelism")) + .and_then(|s| parse_positive_u32(s)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_checkpoint_and_parallelism() { + let mut m = HashMap::new(); + m.insert("checkpoint.interval".to_string(), "30000".to_string()); + m.insert("parallelism".to_string(), "2".to_string()); + assert_eq!(parse_checkpoint_interval_ms(Some(&m)), Some(30_000)); + assert_eq!(parse_pipeline_parallelism(Some(&m)), Some(2)); + } +} diff --git a/src/runtime/memory/block.rs b/src/runtime/memory/block.rs new file mode 100644 index 00000000..2940b3e3 --- /dev/null +++ b/src/runtime/memory/block.rs @@ -0,0 +1,80 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use super::pool::MemoryPool; +use super::ticket::MemoryTicket; + +#[derive(Debug)] +pub struct MemoryBlock { + capacity: u64, + available_bytes: AtomicU64, + pool: Arc, +} + +impl MemoryBlock { + pub(crate) fn new(capacity: u64, pool: Arc) -> Arc { + Arc::new(Self { + capacity, + available_bytes: AtomicU64::new(capacity), + pool, + }) + } + + pub fn try_allocate(self: &Arc, bytes: u64) -> Option { + if bytes == 0 { + return Some(MemoryTicket::new(0, self.clone())); + } + + let mut current_available = self.available_bytes.load(Ordering::Acquire); + loop { + if current_available < bytes { + return None; + } + + match self.available_bytes.compare_exchange_weak( + current_available, + current_available - bytes, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Some(MemoryTicket::new(bytes, self.clone())), + Err(actual) => current_available = actual, + } + } + } + + #[inline] + pub fn available_bytes(&self) -> u64 { + self.available_bytes.load(Ordering::Relaxed) + } + + #[inline] + pub fn capacity(&self) -> u64 { + self.capacity + } + + pub(crate) fn release_ticket(&self, bytes: u64) { + if bytes > 0 { + self.available_bytes.fetch_add(bytes, Ordering::Release); + } + } +} + +impl Drop for MemoryBlock { + fn drop(&mut self) { + self.pool.release_block(self.capacity); + } +} diff --git a/src/runtime/memory/error.rs b/src/runtime/memory/error.rs new file mode 100644 index 00000000..008d5c71 --- /dev/null +++ b/src/runtime/memory/error.rs @@ -0,0 +1,64 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoryError { + AlreadyInitialized, + Uninitialized, + OsAllocationFailed { bytes: u64 }, +} + +impl fmt::Display for MemoryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryError::AlreadyInitialized => { + write!(f, "Global memory pool is already initialized") + } + MemoryError::Uninitialized => { + write!(f, "Global memory pool is not initialized") + } + MemoryError::OsAllocationFailed { bytes } => { + write!( + f, + "insufficient memory: failed to reserve {} bytes (virtual capacity for pool cap) from the OS allocator", + bytes + ) + } + } + } +} + +impl std::error::Error for MemoryError {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MemoryAllocationError { + InsufficientCapacity, + RequestLargerThanPool, +} + +impl fmt::Display for MemoryAllocationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryAllocationError::InsufficientCapacity => { + write!(f, "Insufficient capacity in memory pool") + } + MemoryAllocationError::RequestLargerThanPool => { + write!(f, "Requested block exceeds memory pool maximum") + } + } + } +} + +impl std::error::Error for MemoryAllocationError {} diff --git a/src/runtime/memory/global.rs b/src/runtime/memory/global.rs new file mode 100644 index 00000000..42920147 --- /dev/null +++ b/src/runtime/memory/global.rs @@ -0,0 +1,39 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::{Arc, OnceLock}; + +use super::error::MemoryError; +use super::pool::MemoryPool; + +static GLOBAL_POOL: OnceLock> = OnceLock::new(); + +pub fn init_global_memory_pool(max_bytes: u64) -> Result<(), MemoryError> { + let pool = MemoryPool::try_new(max_bytes)?; + GLOBAL_POOL + .set(pool) + .map_err(|_| MemoryError::AlreadyInitialized) +} + +pub fn try_global_memory_pool() -> Result, MemoryError> { + GLOBAL_POOL.get().cloned().ok_or(MemoryError::Uninitialized) +} + +#[inline] +pub fn global_memory_pool() -> Arc { + try_global_memory_pool().expect("Global memory pool must be initialized before use") +} + +pub fn get_memory_metrics() -> Option<(u64, u64)> { + GLOBAL_POOL.get().map(|p| p.usage_metrics()) +} diff --git a/src/runtime/memory/mod.rs b/src/runtime/memory/mod.rs new file mode 100644 index 00000000..01a917a7 --- /dev/null +++ b/src/runtime/memory/mod.rs @@ -0,0 +1,36 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_array::RecordBatch; + +mod block; +mod error; +pub mod global; +pub mod pool; +pub mod ticket; + +#[allow(unused_imports)] +pub use block::MemoryBlock; +#[allow(unused_imports)] +pub use error::{MemoryAllocationError, MemoryError}; +#[allow(unused_imports)] +pub use global::{ + get_memory_metrics, global_memory_pool, init_global_memory_pool, try_global_memory_pool, +}; +pub use pool::MemoryPool; +pub use ticket::MemoryTicket; + +#[inline] +pub fn get_array_memory_size(batch: &RecordBatch) -> u64 { + RecordBatch::get_array_memory_size(batch) as u64 +} diff --git a/src/runtime/memory/pool.rs b/src/runtime/memory/pool.rs new file mode 100644 index 00000000..98592d35 --- /dev/null +++ b/src/runtime/memory/pool.rs @@ -0,0 +1,139 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use parking_lot::Mutex; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; +use tokio::sync::Notify; +use tracing::{debug, warn}; + +use super::block::MemoryBlock; +use super::error::{MemoryAllocationError, MemoryError}; + +#[derive(Debug)] +pub struct MemoryPool { + max_bytes: u64, + used_bytes: AtomicU64, + available_bytes: Mutex, + notify: Notify, +} + +impl MemoryPool { + pub fn try_new(max_bytes: u64) -> Result, MemoryError> { + if max_bytes > 0 { + let n = usize::try_from(max_bytes) + .map_err(|_| MemoryError::OsAllocationFailed { bytes: max_bytes })?; + let mut v = Vec::::new(); + v.try_reserve_exact(n) + .map_err(|_| MemoryError::OsAllocationFailed { bytes: max_bytes })?; + } + Ok(Arc::new(Self { + max_bytes, + used_bytes: AtomicU64::new(0), + available_bytes: Mutex::new(max_bytes), + notify: Notify::new(), + })) + } + + pub fn new(max_bytes: u64) -> Arc { + Self::try_new(max_bytes).expect("MemoryPool::try_new failed") + } + + pub fn usage_metrics(&self) -> (u64, u64) { + (self.used_bytes.load(Ordering::Relaxed), self.max_bytes) + } + + pub fn try_request_block( + self: &Arc, + bytes: u64, + ) -> Result, MemoryAllocationError> { + if bytes == 0 { + return Ok(MemoryBlock::new(0, self.clone())); + } + if bytes > self.max_bytes { + return Err(MemoryAllocationError::RequestLargerThanPool); + } + let mut available = self.available_bytes.lock(); + if *available >= bytes { + *available -= bytes; + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + Ok(MemoryBlock::new(bytes, self.clone())) + } else { + Err(MemoryAllocationError::InsufficientCapacity) + } + } + + pub async fn request_block(self: &Arc, bytes: u64) -> Arc { + if bytes == 0 { + return MemoryBlock::new(0, self.clone()); + } + + if bytes > self.max_bytes { + warn!( + request_bytes = bytes, + max_bytes = self.max_bytes, + "Requested memory block exceeds total pool size! \ + Permitting to avoid pipeline deadlock, but critical OOM risk exists." + ); + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + return MemoryBlock::new(bytes, self.clone()); + } + + loop { + { + let mut available = self.available_bytes.lock(); + if *available >= bytes { + *available -= bytes; + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + return MemoryBlock::new(bytes, self.clone()); + } + } + + debug!( + bytes = bytes, + "Global backpressure engaged: waiting for memory..." + ); + self.notify.notified().await; + } + } + + pub fn force_reserve(&self, bytes: u64) { + if bytes == 0 { + return; + } + let mut available = self.available_bytes.lock(); + *available = available.saturating_sub(bytes); + self.used_bytes.fetch_add(bytes, Ordering::Relaxed); + } + + pub fn force_release(&self, bytes: u64) { + if bytes == 0 { + return; + } + self.release_block(bytes); + } + + pub(crate) fn release_block(&self, bytes: u64) { + if bytes == 0 { + return; + } + + { + let mut available = self.available_bytes.lock(); + *available += bytes; + } + + self.used_bytes.fetch_sub(bytes, Ordering::Relaxed); + self.notify.notify_waiters(); + } +} diff --git a/src/runtime/streaming/memory/ticket.rs b/src/runtime/memory/ticket.rs similarity index 71% rename from src/runtime/streaming/memory/ticket.rs rename to src/runtime/memory/ticket.rs index cb105be0..24362e2f 100644 --- a/src/runtime/streaming/memory/ticket.rs +++ b/src/runtime/memory/ticket.rs @@ -1,5 +1,6 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 @@ -12,22 +13,27 @@ use std::sync::Arc; -use super::pool::MemoryPool; +use super::block::MemoryBlock; #[derive(Debug)] pub struct MemoryTicket { - bytes: usize, - pool: Arc, + bytes: u64, + block: Arc, } impl MemoryTicket { - pub(crate) fn new(bytes: usize, pool: Arc) -> Self { - Self { bytes, pool } + pub(crate) fn new(bytes: u64, block: Arc) -> Self { + Self { bytes, block } + } + + #[inline] + pub fn bytes(&self) -> u64 { + self.bytes } } impl Drop for MemoryTicket { fn drop(&mut self) { - self.pool.release(self.bytes); + self.block.release_ticket(self.bytes); } } diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 1ba5e2a3..8c72b507 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -12,11 +12,9 @@ // Runtime module -pub mod buffer_and_event; pub mod common; +pub mod memory; pub mod streaming; -pub mod task; -pub mod taskexecutor; pub mod util; pub mod wasm; diff --git a/src/runtime/streaming/api/context.rs b/src/runtime/streaming/api/context.rs index f9dc805e..8b778502 100644 --- a/src/runtime/streaming/api/context.rs +++ b/src/runtime/streaming/api/context.rs @@ -10,15 +10,20 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use arrow_array::RecordBatch; +use protocol::storage::SourceCheckpointPayload; +use tokio::sync::mpsc; -use crate::runtime::streaming::memory::MemoryPool; +use crate::runtime::memory::{MemoryBlock, MemoryPool, get_array_memory_size}; use crate::runtime::streaming::network::endpoint::PhysicalSender; +use crate::runtime::streaming::protocol::control::JobMasterEvent; use crate::runtime::streaming::protocol::event::{StreamEvent, TrackedEvent}; +use crate::runtime::streaming::state::IoManager; #[derive(Debug, Clone)] pub struct TaskContextConfig { @@ -53,7 +58,7 @@ pub struct TaskContext { /// Downstream physical senders (outbound edges). downstream_senders: Vec, - /// Global memory pool for back-pressure and accounting. + /// Job-wide shared pool; memory is accounted only when [`Self::collect`] / [`Self::collect_keyed`] run. memory_pool: Arc, /// Latest aligned event-time watermark for this subtask. @@ -61,9 +66,24 @@ pub struct TaskContext { /// Subtask-level tunables. config: TaskContextConfig, + + pub state_dir: PathBuf, + pub io_manager: IoManager, + + /// Pipeline-wide slab from the global pool; each stateful operator sub-allocates a ticket. + pub pipeline_state_memory_block: Option>, + /// Bytes reserved per stateful operator from [`Self::pipeline_state_memory_block`]. + pub operator_state_memory_bytes: u64, + + /// Last globally-committed safe epoch for crash recovery. + safe_epoch: u64, + + /// When set, pipelines report checkpoint completion (and optional Kafka offsets) to the job coordinator. + checkpoint_ack_tx: Option>, } impl TaskContext { + #[allow(clippy::too_many_arguments)] pub fn new( job_id: String, pipeline_id: u32, @@ -71,6 +91,12 @@ impl TaskContext { parallelism: u32, downstream_senders: Vec, memory_pool: Arc, + io_manager: IoManager, + state_dir: PathBuf, + pipeline_state_memory_block: Option>, + operator_state_memory_bytes: u64, + safe_epoch: u64, + checkpoint_ack_tx: Option>, ) -> Self { let task_name = format!( "Task-[{}]-Pipe[{}]-Sub[{}/{}]", @@ -87,6 +113,34 @@ impl TaskContext { memory_pool, current_watermark: None, config: TaskContextConfig::default(), + state_dir, + io_manager, + pipeline_state_memory_block, + operator_state_memory_bytes, + safe_epoch, + checkpoint_ack_tx, + } + } + + #[inline] + pub fn latest_safe_epoch(&self) -> u64 { + self.safe_epoch + } + + /// Notify the job checkpoint coordinator that this pipeline has finished the barrier for `epoch`. + pub async fn send_checkpoint_ack( + &self, + epoch: u64, + source_payloads: Vec, + ) { + if let Some(tx) = &self.checkpoint_ack_tx { + let _ = tx + .send(JobMasterEvent::CheckpointAck { + pipeline_id: self.pipeline_id, + epoch, + source_payloads, + }) + .await; } } @@ -122,13 +176,19 @@ impl TaskContext { // ------------------------------------------------------------------------- /// Fan-out a data batch to all downstreams (forward / broadcast). + /// + /// Back-pressure and memory accounting happen here via [`MemoryPool::request_block`], not + /// when building the pipeline. pub async fn collect(&self, batch: RecordBatch) -> Result<()> { if self.downstream_senders.is_empty() { return Ok(()); } - let bytes_required = batch.get_array_memory_size(); - let ticket = self.memory_pool.request_memory(bytes_required).await; + let bytes_required = get_array_memory_size(&batch); + let block = self.memory_pool.request_block(bytes_required).await; + let ticket = block + .try_allocate(bytes_required) + .ok_or_else(|| anyhow!("memory block allocation failed"))?; let tracked_event = TrackedEvent::new(StreamEvent::Data(batch), Some(ticket)); self.broadcast_event(tracked_event).await @@ -141,8 +201,11 @@ impl TaskContext { return Ok(()); } - let bytes_required = batch.get_array_memory_size(); - let ticket = self.memory_pool.request_memory(bytes_required).await; + let bytes_required = get_array_memory_size(&batch); + let block = self.memory_pool.request_block(bytes_required).await; + let ticket = block + .try_allocate(bytes_required) + .ok_or_else(|| anyhow!("memory block allocation failed"))?; let event = TrackedEvent::new(StreamEvent::Data(batch), Some(ticket)); let target_idx = (key_hash as usize) % num_downstreams; diff --git a/src/runtime/streaming/api/operator.rs b/src/runtime/streaming/api/operator.rs index df8f0dcb..8eb9e8c4 100644 --- a/src/runtime/streaming/api/operator.rs +++ b/src/runtime/streaming/api/operator.rs @@ -16,7 +16,6 @@ use crate::runtime::streaming::protocol::event::StreamOutput; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; -use std::time::Duration; // --------------------------------------------------------------------------- // ConstructedOperator @@ -27,6 +26,11 @@ pub enum ConstructedOperator { Operator(Box), } +#[async_trait] +pub trait Collector: Send { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()>; +} + #[async_trait] pub trait Operator: Send + 'static { fn name(&self) -> &str; @@ -40,13 +44,15 @@ pub trait Operator: Send + 'static { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> anyhow::Result>; + collector: &mut dyn Collector, + ) -> anyhow::Result<()>; async fn process_watermark( &mut self, watermark: Watermark, ctx: &mut TaskContext, - ) -> anyhow::Result>; + collector: &mut dyn Collector, + ) -> anyhow::Result<()>; async fn snapshot_state( &mut self, @@ -54,24 +60,25 @@ pub trait Operator: Send + 'static { ctx: &mut TaskContext, ) -> anyhow::Result<()>; + /// Global checkpoint **phase 2** (after metadata is durable): finalize external side effects. + /// + /// Default is no-op. Examples of overrides: transactional Kafka sink calls + /// `commit_transaction` on the producer stashed during [`Self::snapshot_state`]. async fn commit_checkpoint( &mut self, - _epoch: u32, + epoch: u32, _ctx: &mut TaskContext, ) -> anyhow::Result<()> { + let _ = epoch; Ok(()) } - fn tick_interval(&self) -> Option { - None - } - - async fn process_tick( - &mut self, - _tick_index: u64, - _ctx: &mut TaskContext, - ) -> anyhow::Result> { - Ok(vec![]) + /// Global checkpoint **rollback** when phase 2 must not commit (e.g. catalog persist failed). + /// + /// Default is no-op. Transactional Kafka sink overrides with `abort_transaction` on the stashed producer. + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> anyhow::Result<()> { + let _ = epoch; + Ok(()) } async fn on_close(&mut self, _ctx: &mut TaskContext) -> anyhow::Result> { diff --git a/src/runtime/streaming/api/source.rs b/src/runtime/streaming/api/source.rs index 81435b47..26851eb2 100644 --- a/src/runtime/streaming/api/source.rs +++ b/src/runtime/streaming/api/source.rs @@ -14,6 +14,9 @@ use crate::runtime::streaming::api::context::TaskContext; use crate::sql::common::{CheckpointBarrier, Watermark}; use arrow_array::RecordBatch; use async_trait::async_trait; +use protocol::storage::{ + KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum SourceOffset { @@ -31,6 +34,22 @@ pub enum SourceEvent { EndOfStream, } +/// Optional metadata returned when a source completes a checkpoint barrier snapshot. +#[derive(Debug, Default, Clone)] +pub struct SourceCheckpointReport { + pub payloads: Vec, +} + +impl SourceCheckpointReport { + pub fn from_kafka_checkpoint(kafka: KafkaSourceSubtaskCheckpoint) -> Self { + Self { + payloads: vec![SourceCheckpointPayload { + checkpoint: Some(source_checkpoint_payload::Checkpoint::Kafka(kafka)), + }], + } + } +} + #[async_trait] pub trait SourceOperator: Send + 'static { fn name(&self) -> &str; @@ -49,13 +68,22 @@ pub trait SourceOperator: Send + 'static { &mut self, barrier: CheckpointBarrier, ctx: &mut TaskContext, - ) -> anyhow::Result<()>; + ) -> anyhow::Result; + /// Same checkpoint **phase 2** hook as [`super::operator::Operator::commit_checkpoint`]. + /// Kafka source keeps the default: offsets are reported at the barrier in [`Self::snapshot_state`]. async fn commit_checkpoint( &mut self, - _epoch: u32, + epoch: u32, _ctx: &mut TaskContext, ) -> anyhow::Result<()> { + let _ = epoch; + Ok(()) + } + + /// Same rollback hook as [`super::operator::Operator::abort_checkpoint`]. + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> anyhow::Result<()> { + let _ = epoch; Ok(()) } diff --git a/src/runtime/streaming/execution/operator_chain.rs b/src/runtime/streaming/execution/operator_chain.rs index a2e6c5c6..88e8f441 100644 --- a/src/runtime/streaming/execution/operator_chain.rs +++ b/src/runtime/streaming/execution/operator_chain.rs @@ -10,10 +10,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use anyhow::anyhow; use async_trait::async_trait; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::error::RunError; use crate::runtime::streaming::protocol::{ control::{ControlCommand, StopMode}, @@ -21,29 +22,39 @@ use crate::runtime::streaming::protocol::{ }; use crate::sql::common::CheckpointBarrier; +// ============================================================================ +// Core Traits +// ============================================================================ + #[async_trait] pub trait OperatorDrive: Send { async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<(), RunError>; + async fn process_event( &mut self, input_idx: usize, event: TrackedEvent, ctx: &mut TaskContext, ) -> Result; + async fn handle_control( &mut self, cmd: ControlCommand, ctx: &mut TaskContext, ) -> Result; + async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError>; } +// ============================================================================ +// Chain Builder +// ============================================================================ + pub struct ChainBuilder; impl ChainBuilder { pub fn build(mut operators: Vec>) -> Option> { let tail_operator = operators.pop()?; - let mut current_driver: Box = Box::new(TailDriver::new(tail_operator)); while let Some(op) = operators.pop() { @@ -54,6 +65,68 @@ impl ChainBuilder { } } +// ============================================================================ +// Collectors (Zero-Allocation Emission Abstractions) +// ============================================================================ + +struct ChainedCollector<'a> { + next: &'a mut dyn OperatorDrive, + op_name: String, +} + +impl<'a> ChainedCollector<'a> { + fn new(next: &'a mut dyn OperatorDrive, op_name: &str) -> Self { + Self { + next, + op_name: op_name.to_string(), + } + } +} + +#[async_trait] +impl<'a> Collector for ChainedCollector<'a> { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => { + self.next + .process_event(0, TrackedEvent::control(StreamEvent::Data(b)), ctx) + .await?; + } + StreamOutput::Watermark(wm) => { + self.next + .process_event(0, TrackedEvent::control(StreamEvent::Watermark(wm)), ctx) + .await?; + } + StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { + return Err(anyhow!( + "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", + self.op_name + )); + } + } + Ok(()) + } +} + +struct TaskCollector; + +#[async_trait] +impl Collector for TaskCollector { + async fn collect(&mut self, out: StreamOutput, ctx: &mut TaskContext) -> anyhow::Result<()> { + match out { + StreamOutput::Forward(b) => ctx.collect(b).await?, + StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, + StreamOutput::Broadcast(b) => ctx.collect(b).await?, + StreamOutput::Watermark(wm) => ctx.broadcast(StreamEvent::Watermark(wm)).await?, + } + Ok(()) + } +} + +// ============================================================================ +// Intermediate Driver (Middle of the Chain) +// ============================================================================ + pub struct IntermediateDriver { operator: Box, next: Box, @@ -64,34 +137,6 @@ impl IntermediateDriver { Self { operator, next } } - async fn dispatch_outputs( - &mut self, - outputs: Vec, - ctx: &mut TaskContext, - ) -> Result<(), RunError> { - for out in outputs { - match out { - StreamOutput::Forward(b) => { - self.next - .process_event(0, TrackedEvent::control(StreamEvent::Data(b)), ctx) - .await?; - } - StreamOutput::Watermark(wm) => { - self.next - .process_event(0, TrackedEvent::control(StreamEvent::Watermark(wm)), ctx) - .await?; - } - StreamOutput::Keyed(_, _) | StreamOutput::Broadcast(_) => { - return Err(RunError::internal(format!( - "Topology Violation: Keyed or Broadcast output emitted in the middle of chain by '{}'", - self.operator.name() - ))); - } - } - } - Ok(()) - } - async fn forward_signal( &mut self, event: StreamEvent, @@ -120,13 +165,17 @@ impl OperatorDrive for IntermediateDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - let outputs = self.operator.process_data(input_idx, batch, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); + self.operator + .process_data(input_idx, batch, ctx, &mut collector) + .await?; Ok(false) } StreamEvent::Watermark(wm) => { - let outputs = self.operator.process_watermark(wm, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); + self.operator + .process_watermark(wm, ctx, &mut collector) + .await?; self.forward_signal(StreamEvent::Watermark(wm), ctx).await?; Ok(false) } @@ -152,12 +201,16 @@ impl OperatorDrive for IntermediateDriver { match &cmd { ControlCommand::TriggerCheckpoint { barrier } => { - let b: CheckpointBarrier = barrier.clone().into(); - self.operator.snapshot_state(b, ctx).await?; + self.operator + .snapshot_state(barrier.clone().into(), ctx) + .await?; } ControlCommand::Commit { epoch } => { self.operator.commit_checkpoint(*epoch, ctx).await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator.abort_checkpoint(*epoch, ctx).await?; + } ControlCommand::Stop { mode } if *mode == StopMode::Immediate => { stop = true; } @@ -173,12 +226,22 @@ impl OperatorDrive for IntermediateDriver { async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError> { let close_outs = self.operator.on_close(ctx).await?; - self.dispatch_outputs(close_outs, ctx).await?; + let mut collector = ChainedCollector::new(self.next.as_mut(), self.operator.name()); + + // 复用 Collector 处理 on_close 产生的数据 + for out in close_outs { + collector.collect(out, ctx).await?; + } + self.next.on_close(ctx).await?; Ok(()) } } +// ============================================================================ +// Tail Driver (End of the Chain) +// ============================================================================ + pub struct TailDriver { operator: Box, } @@ -188,22 +251,6 @@ impl TailDriver { Self { operator } } - async fn dispatch_outputs( - &mut self, - outputs: Vec, - ctx: &mut TaskContext, - ) -> Result<(), RunError> { - for out in outputs { - match out { - StreamOutput::Forward(b) => ctx.collect(b).await?, - StreamOutput::Keyed(hash, b) => ctx.collect_keyed(hash, b).await?, - StreamOutput::Broadcast(b) => ctx.collect(b).await?, - StreamOutput::Watermark(wm) => ctx.broadcast(StreamEvent::Watermark(wm)).await?, - } - } - Ok(()) - } - async fn forward_signal( &mut self, event: StreamEvent, @@ -234,13 +281,17 @@ impl OperatorDrive for TailDriver { ) -> Result { match tracked.event { StreamEvent::Data(batch) => { - let outputs = self.operator.process_data(input_idx, batch, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + let mut collector = TaskCollector; + self.operator + .process_data(input_idx, batch, ctx, &mut collector) + .await?; Ok(false) } StreamEvent::Watermark(wm) => { - let outputs = self.operator.process_watermark(wm, ctx).await?; - self.dispatch_outputs(outputs, ctx).await?; + let mut collector = TaskCollector; + self.operator + .process_watermark(wm, ctx, &mut collector) + .await?; self.forward_signal(StreamEvent::Watermark(wm), ctx).await?; Ok(false) } @@ -273,6 +324,9 @@ impl OperatorDrive for TailDriver { ControlCommand::Commit { epoch } => { self.operator.commit_checkpoint(*epoch, ctx).await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator.abort_checkpoint(*epoch, ctx).await?; + } ControlCommand::Stop { mode } if *mode == StopMode::Immediate => { stop = true; } @@ -284,7 +338,11 @@ impl OperatorDrive for TailDriver { async fn on_close(&mut self, ctx: &mut TaskContext) -> Result<(), RunError> { let close_outs = self.operator.on_close(ctx).await?; - self.dispatch_outputs(close_outs, ctx).await?; + let mut collector = TaskCollector; + + for out in close_outs { + collector.collect(out, ctx).await?; + } Ok(()) } } diff --git a/src/runtime/streaming/execution/pipeline.rs b/src/runtime/streaming/execution/pipeline.rs index d6ef06a3..91309a48 100644 --- a/src/runtime/streaming/execution/pipeline.rs +++ b/src/runtime/streaming/execution/pipeline.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::UnboundedReceiver; use tokio_stream::{StreamExt, StreamMap}; use tracing::{Instrument, info, info_span}; @@ -33,7 +33,7 @@ pub struct Pipeline { chain_head: Box, ctx: TaskContext, inboxes: Vec, - control_rx: Receiver, + control_rx: UnboundedReceiver, wm_tracker: WatermarkTracker, barrier_aligner: BarrierAligner, @@ -45,7 +45,7 @@ impl Pipeline { operators: Vec>, ctx: TaskContext, inboxes: Vec, - control_rx: Receiver, + control_rx: UnboundedReceiver, ) -> Result { let input_count = inboxes.len(); let chain_head = ChainBuilder::build(operators) @@ -110,6 +110,7 @@ impl Pipeline { } } AlignmentStatus::Complete => { + let epoch = barrier.epoch as u64; self.chain_head .process_event( idx, @@ -123,6 +124,7 @@ impl Pipeline { active_streams.insert(i, stream); } } + self.ctx.send_checkpoint_ack(epoch, vec![]).await; } } } diff --git a/src/runtime/streaming/execution/source_driver.rs b/src/runtime/streaming/execution/source_driver.rs index 6813a82a..b4e7d327 100644 --- a/src/runtime/streaming/execution/source_driver.rs +++ b/src/runtime/streaming/execution/source_driver.rs @@ -10,12 +10,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc::UnboundedReceiver; use tokio::time::{Instant, sleep}; use tracing::{Instrument, info, info_span, warn}; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::source::{SourceEvent, SourceOperator}; +use crate::runtime::streaming::api::source::{SourceCheckpointReport, SourceEvent, SourceOperator}; use crate::runtime::streaming::error::RunError; use crate::runtime::streaming::execution::OperatorDrive; use crate::runtime::streaming::protocol::{ @@ -28,7 +28,7 @@ pub struct SourceDriver { operator: Box, chain_head: Option>, ctx: TaskContext, - control_rx: Receiver, + control_rx: UnboundedReceiver, } impl SourceDriver { @@ -36,7 +36,7 @@ impl SourceDriver { operator: Box, chain_head: Option>, ctx: TaskContext, - control_rx: Receiver, + control_rx: UnboundedReceiver, ) -> Self { Self { operator, @@ -154,18 +154,25 @@ impl SourceDriver { async fn handle_control(&mut self, cmd: ControlCommand) -> Result { let mut stop = false; + let mut pending_source_checkpoint: Option<(u64, SourceCheckpointReport)> = None; match &cmd { ControlCommand::TriggerCheckpoint { barrier } => { let b: CheckpointBarrier = barrier.clone().into(); - self.operator.snapshot_state(b, &mut self.ctx).await?; + let report = self.operator.snapshot_state(b, &mut self.ctx).await?; self.dispatch_event(StreamEvent::Barrier(b)).await?; + pending_source_checkpoint = Some((b.epoch as u64, report)); } ControlCommand::Commit { epoch } => { self.operator .commit_checkpoint(*epoch, &mut self.ctx) .await?; } + ControlCommand::AbortCheckpoint { epoch } => { + self.operator + .abort_checkpoint(*epoch, &mut self.ctx) + .await?; + } ControlCommand::Stop { .. } => { stop = true; } @@ -178,6 +185,10 @@ impl SourceDriver { stop = true; } + if let Some((epoch, report)) = pending_source_checkpoint { + self.ctx.send_checkpoint_ack(epoch, report.payloads).await; + } + Ok(stop) } diff --git a/src/runtime/streaming/factory/connector/kafka.rs b/src/runtime/streaming/factory/connector/kafka.rs index 75135197..9d2f114d 100644 --- a/src/runtime/streaming/factory/connector/kafka.rs +++ b/src/runtime/streaming/factory/connector/kafka.rs @@ -200,7 +200,13 @@ impl KafkaConnectorDispatcher { let client_configs = merge_client_configs(&cfg.auth, &cfg.client_configs); let consistency = match cfg.commit_mode() { - KafkaSinkCommitMode::KafkaSinkExactlyOnce => ConsistencyMode::ExactlyOnce, + KafkaSinkCommitMode::KafkaSinkExactlyOnce => { + info!( + topic = %cfg.topic, + "Kafka sink exactly-once: transactional producer + checkpoint 2PC. Downstream Kafka consumers of this topic should set isolation.level=read_committed." + ); + ConsistencyMode::ExactlyOnce + } KafkaSinkCommitMode::KafkaSinkAtLeastOnce => ConsistencyMode::AtLeastOnce, }; diff --git a/src/runtime/streaming/job/job_manager.rs b/src/runtime/streaming/job/job_manager.rs index 3082dc56..a9bc546f 100644 --- a/src/runtime/streaming/job/job_manager.rs +++ b/src/runtime/streaming/job/job_manager.rs @@ -10,16 +10,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::sync::{Arc, OnceLock, RwLock}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::sync::{Arc, Mutex, OnceLock, RwLock}; +use std::time::Duration; use anyhow::{Context, Result, anyhow, bail, ensure}; -use tokio::sync::mpsc; +use tokio::sync::mpsc::{self, UnboundedSender}; +use tokio::task::JoinHandle as TokioJoinHandle; +use tokio::time::Instant; use tokio_stream::wrappers::ReceiverStream; use tracing::{error, info, warn}; use protocol::function_stream_graph::{ChainedOperator, FsProgram}; +use protocol::storage::{ + KafkaSourceSubtaskCheckpoint, SourceCheckpointPayload, source_checkpoint_payload, +}; +use crate::config::{ + DEFAULT_CHECKPOINT_INTERVAL_MS, DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, + DEFAULT_PIPELINE_PARALLELISM, +}; +use crate::runtime::memory::global_memory_pool; use crate::runtime::streaming::api::context::TaskContext; use crate::runtime::streaming::api::operator::{ConstructedOperator, Operator}; use crate::runtime::streaming::api::source::SourceOperator; @@ -29,9 +42,12 @@ use crate::runtime::streaming::job::edge_manager::EdgeManager; use crate::runtime::streaming::job::models::{ PhysicalExecutionGraph, PhysicalPipeline, PipelineStatus, StreamingJobRollupStatus, }; -use crate::runtime::streaming::memory::MemoryPool; use crate::runtime::streaming::network::endpoint::{BoxedEventStream, PhysicalSender}; -use crate::runtime::streaming::protocol::control::{ControlCommand, StopMode}; +use crate::runtime::streaming::protocol::control::{ControlCommand, JobMasterEvent, StopMode}; +use crate::runtime::streaming::protocol::event::CheckpointBarrier; +use crate::runtime::streaming::state::{IoManager, IoPool, NoopMetricsCollector}; +use crate::sql::logical_node::logical::OperatorName; +use crate::storage::stream_catalog::CatalogManager; #[derive(Debug, Clone)] pub struct StreamingJobSummary { @@ -57,12 +73,70 @@ pub struct StreamingJobDetail { pub program: FsProgram, } +#[derive(Debug, Clone)] +pub struct StateConfig { + pub max_background_spills: usize, + pub max_background_compactions: usize, + pub soft_limit_ratio: f64, + pub checkpoint_interval_ms: u64, + pub pipeline_parallelism: u32, + pub job_manager_control_plane_threads: u32, + pub job_manager_data_plane_threads: u32, + /// Total bytes shared by all [`crate::runtime::streaming::state::OperatorStateStore`] (global pool). + pub per_operator_memory_bytes: u64, +} + +impl Default for StateConfig { + fn default() -> Self { + Self { + max_background_spills: 4, + max_background_compactions: 2, + soft_limit_ratio: 0.7, + checkpoint_interval_ms: DEFAULT_CHECKPOINT_INTERVAL_MS, + pipeline_parallelism: DEFAULT_PIPELINE_PARALLELISM, + job_manager_control_plane_threads: 2, + job_manager_data_plane_threads: std::thread::available_parallelism() + .map(|n| n.get() as u32) + .unwrap_or(1), + per_operator_memory_bytes: DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, + } + } +} + static GLOBAL_JOB_MANAGER: OnceLock> = OnceLock::new(); +/// Operators that create an [`crate::runtime::streaming::state::OperatorStateStore`] at runtime. +fn pipeline_state_store_operator_count(operators: &[ChainedOperator]) -> usize { + operators + .iter() + .filter(|op| { + OperatorName::from_str(op.operator_name.as_str()) + .ok() + .is_some_and(|n| { + matches!( + n, + OperatorName::Join + | OperatorName::InstantJoin + | OperatorName::WindowFunction + | OperatorName::TumblingWindowAggregate + | OperatorName::SlidingWindowAggregate + | OperatorName::SessionWindowAggregate + | OperatorName::UpdatingAggregate + ) + }) + }) + .count() +} + pub struct JobManager { active_jobs: Arc>>, operator_factory: Arc, - memory_pool: Arc, + io_manager_client: IoManager, + io_pool: Mutex>, + state_base_dir: PathBuf, + state_config: StateConfig, + control_rt: Arc, + data_rt: Arc, } struct PreparedChain { @@ -75,6 +149,18 @@ enum PipelineRunner { Standard(Pipeline), } +struct CheckpointCoordinatorConfig { + job_id: String, + source_control_txs: Vec>, + all_pipeline_control_txs: Vec>, + job_master_rx: mpsc::Receiver, + expected_pipeline_ids: HashSet, + interval_ms: u64, + start_epoch: u64, + job_state_dir: PathBuf, + timeout: Duration, +} + impl PipelineRunner { async fn run(self) -> Result<(), crate::runtime::streaming::error::RunError> { match self { @@ -84,18 +170,70 @@ impl PipelineRunner { } } +fn decode_kafka_checkpoints_from_source_payloads( + payloads: Vec, + epoch: u64, +) -> Vec { + let mut out = Vec::new(); + for p in payloads { + match p.checkpoint { + Some(source_checkpoint_payload::Checkpoint::Kafka(mut cp)) => { + if cp.checkpoint_epoch != epoch { + cp.checkpoint_epoch = epoch; + } + out.push(cp); + } + None => warn!("Skip empty source checkpoint payload"), + } + } + out +} + impl JobManager { - pub fn new(operator_factory: Arc, max_memory_bytes: usize) -> Self { - Self { + pub fn new( + operator_factory: Arc, + state_base_dir: impl AsRef, + state_config: StateConfig, + ) -> Result { + let control_rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(state_config.job_manager_control_plane_threads.max(1) as usize) + .thread_name("fs-control-plane") + .enable_all() + .build() + .context("Failed to initialize control runtime")?; + let data_rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(state_config.job_manager_data_plane_threads.max(1) as usize) + .thread_name("fs-data-plane") + .enable_all() + .build() + .context("Failed to initialize data runtime")?; + let metrics = Arc::new(NoopMetricsCollector); + let (io_pool, io_manager_client) = IoPool::try_new( + state_config.max_background_spills, + state_config.max_background_compactions, + metrics, + ) + .context("Failed to initialize state engine I/O pool")?; + + Ok(Self { active_jobs: Arc::new(RwLock::new(HashMap::new())), operator_factory, - memory_pool: MemoryPool::new(max_memory_bytes), - } + io_manager_client, + io_pool: Mutex::new(Some(io_pool)), + state_base_dir: state_base_dir.as_ref().to_path_buf(), + state_config, + control_rt: Arc::new(control_rt), + data_rt: Arc::new(data_rt), + }) } - pub fn init(factory: Arc, memory_bytes: usize) -> Result<()> { + pub fn init( + factory: Arc, + state_base_dir: PathBuf, + state_config: StateConfig, + ) -> Result<()> { GLOBAL_JOB_MANAGER - .set(Arc::new(Self::new(factory, memory_bytes))) + .set(Arc::new(Self::new(factory, state_base_dir, state_config)?)) .map_err(|_| anyhow!("JobManager singleton already initialized")) } @@ -106,19 +244,57 @@ impl JobManager { .ok_or_else(|| anyhow!("JobManager not initialized. Call init() first.")) } - pub async fn submit_job(&self, job_id: String, program: FsProgram) -> Result { + pub fn shutdown(&self) { + if let Some(pool) = self.io_pool.lock().unwrap().take() { + pool.shutdown(); + } + } + + #[inline] + pub fn default_pipeline_parallelism(&self) -> u32 { + self.state_config.pipeline_parallelism + } + + /// Per-job state directory (Kafka offset snapshots, operator state roots, etc.). + #[inline] + pub fn job_state_directory(&self, job_id: &str) -> PathBuf { + self.state_base_dir.join(job_id) + } + + pub async fn submit_job( + &self, + job_id: String, + program: FsProgram, + custom_checkpoint_interval_ms: Option, + recovery_epoch: Option, + ) -> Result { let mut edge_manager = EdgeManager::build(&program.nodes, &program.edges); let mut pipelines = HashMap::with_capacity(program.nodes.len()); + let mut source_control_txs = Vec::new(); + let mut all_pipeline_control_txs = Vec::new(); + let mut expected_pipeline_ids = HashSet::new(); + + let job_state_dir = self.state_base_dir.join(&job_id); + std::fs::create_dir_all(&job_state_dir).context("Failed to create job state dir")?; + + let (job_master_tx, job_master_rx) = mpsc::channel(256); + + let safe_epoch = recovery_epoch.unwrap_or(0); + for node in &program.nodes { let pipeline_id = node.node_index as u32; - let pipeline = self + let (pipeline, is_source) = self .build_and_spawn_pipeline( job_id.clone(), pipeline_id, &node.operators, + node.parallelism, &mut edge_manager, + &job_state_dir, + job_master_tx.clone(), + safe_epoch, ) .with_context(|| { format!( @@ -127,9 +303,29 @@ impl JobManager { ) })?; + if is_source { + source_control_txs.push(pipeline.control_tx.clone()); + } + all_pipeline_control_txs.push(pipeline.control_tx.clone()); + expected_pipeline_ids.insert(pipeline_id); pipelines.insert(pipeline_id, pipeline); } + let interval_ms = + custom_checkpoint_interval_ms.unwrap_or(self.state_config.checkpoint_interval_ms); + + self.spawn_checkpoint_coordinator(CheckpointCoordinatorConfig { + job_id: job_id.clone(), + source_control_txs, + all_pipeline_control_txs, + job_master_rx, + expected_pipeline_ids, + interval_ms, + start_epoch: safe_epoch + 1, + job_state_dir: job_state_dir.clone(), + timeout: Duration::from_millis(interval_ms.max(1) * 3), + }); + let graph = PhysicalExecutionGraph { job_id: job_id.clone(), program, @@ -143,7 +339,7 @@ impl JobManager { .map_err(|e| anyhow!("Active jobs lock poisoned: {}", e))?; jobs_guard.insert(job_id.clone(), graph); - info!(job_id = %job_id, "Job submitted successfully."); + info!(job_id = %job_id, interval_ms, recovery_epoch = safe_epoch, "Job submitted successfully."); Ok(job_id) } @@ -151,7 +347,7 @@ impl JobManager { let control_senders = self.extract_control_senders(job_id)?; for tx in control_senders { - let _ = tx.send(ControlCommand::Stop { mode: mode.clone() }).await; + let _ = tx.send(ControlCommand::Stop { mode: mode.clone() }); } info!(job_id = %job_id, mode = ?mode, "Job stop signal dispatched."); @@ -303,7 +499,10 @@ impl JobManager { StreamingJobRollupStatus::Reconciling } } - fn extract_control_senders(&self, job_id: &str) -> Result>> { + fn extract_control_senders( + &self, + job_id: &str, + ) -> Result>> { let jobs_guard = self .active_jobs .read() @@ -320,13 +519,18 @@ impl JobManager { .collect()) } + #[allow(clippy::too_many_arguments)] fn build_and_spawn_pipeline( &self, job_id: String, pipeline_id: u32, operators: &[ChainedOperator], + declared_parallelism: u32, edge_manager: &mut EdgeManager, - ) -> Result { + job_state_dir: &Path, + job_master_tx: mpsc::Sender, + recovery_epoch: u64, + ) -> Result<(PhysicalPipeline, bool)> { let (raw_inboxes, raw_outboxes) = edge_manager.take_endpoints(pipeline_id).with_context(|| { format!( @@ -352,6 +556,8 @@ impl JobManager { ) })?; + let is_source = chain.source.is_some(); + ensure!( chain.source.is_some() || !physical_inboxes.is_empty(), "Topology Error: Pipeline '{}' contains no source and has no upstream inputs (Dead end).", @@ -363,18 +569,45 @@ impl JobManager { pipeline_id ); - let (control_tx, control_rx) = mpsc::channel(64); + let (control_tx, control_rx) = mpsc::unbounded_channel(); let status = Arc::new(RwLock::new(PipelineStatus::Initializing)); let subtask_index = 0; - let parallelism = 1; + let parallelism = if declared_parallelism > 0 { + declared_parallelism + } else { + self.state_config.pipeline_parallelism + } + .max(1); + + let per_op = self.state_config.per_operator_memory_bytes; + let n_state_ops = pipeline_state_store_operator_count(operators); + let pipeline_state_memory_block = if n_state_ops > 0 { + let bytes = per_op + .checked_mul(n_state_ops as u64) + .ok_or_else(|| anyhow!("pipeline state memory byte size overflow"))?; + Some( + global_memory_pool() + .try_request_block(bytes) + .map_err(|e| anyhow!("pipeline state memory reservation failed: {e}"))?, + ) + } else { + None + }; + let ctx = TaskContext::new( job_id.clone(), pipeline_id, subtask_index, parallelism, physical_outboxes, - Arc::clone(&self.memory_pool), + Arc::clone(&global_memory_pool()), + self.io_manager_client.clone(), + job_state_dir.to_path_buf(), + pipeline_state_memory_block, + per_op, + recovery_epoch, + Some(job_master_tx.clone()), ); let runner = if let Some(source) = chain.source { @@ -388,16 +621,15 @@ impl JobManager { ) }; - let handle = self - .spawn_worker_thread(job_id, pipeline_id, runner, Arc::clone(&status)) - .with_context(|| format!("Failed to spawn OS thread for pipeline {}", pipeline_id))?; + let handle = self.spawn_worker_task(job_id, pipeline_id, runner, Arc::clone(&status)); - Ok(PhysicalPipeline { + let pipeline = PhysicalPipeline { pipeline_id, handle: Some(handle), status, control_tx, - }) + }; + Ok((pipeline, is_source)) } fn build_operator_chain(&self, operator_configs: &[ChainedOperator]) -> Result { @@ -431,55 +663,39 @@ impl JobManager { }) } - fn spawn_worker_thread( + fn spawn_worker_task( &self, job_id: String, pipeline_id: u32, runner: PipelineRunner, status: Arc>, - ) -> Result> { - let thread_name = format!("Task-{job_id}-{pipeline_id}"); - - let handle = std::thread::Builder::new() - .name(thread_name) - .spawn(move || { - if let Ok(mut st) = status.write() { - *st = PipelineStatus::Running; - } + ) -> TokioJoinHandle<()> { + self.data_rt.spawn(async move { + if let Ok(mut st) = status.write() { + *st = PipelineStatus::Running; + } - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to build current-thread Tokio runtime"); - - let execution_result = - std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - rt.block_on(async move { - runner - .run() - .await - .map_err(|e| anyhow!("Execution failed: {e}")) - }) - })); - - Self::handle_pipeline_exit(&job_id, pipeline_id, execution_result, &status); - })?; + let execution_result = runner + .run() + .await + .map_err(|e| anyhow!("Execution failed: {e}")); - Ok(handle) + Self::handle_pipeline_exit(&job_id, pipeline_id, execution_result, &status); + }) } fn handle_pipeline_exit( job_id: &str, pipeline_id: u32, - thread_result: std::thread::Result>, + result: Result<()>, status: &RwLock, ) { - let (final_status, is_fatal) = match thread_result { - Ok(Ok(_)) => { + let (final_status, is_fatal) = match result { + Ok(_) => { info!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline finished gracefully."); (PipelineStatus::Finished, false) } - Ok(Err(e)) => { + Err(e) => { error!(job_id = %job_id, pipeline_id = pipeline_id, error = %e, "Pipeline failed."); ( PipelineStatus::Failed { @@ -489,16 +705,6 @@ impl JobManager { true, ) } - Err(_) => { - error!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline thread panicked!"); - ( - PipelineStatus::Failed { - error: "Unexpected panic in task thread".into(), - is_panic: true, - }, - true, - ) - } }; if let Ok(mut st) = status.write() { @@ -509,4 +715,173 @@ impl JobManager { warn!(job_id = %job_id, pipeline_id = pipeline_id, "Pipeline failure detected. Job degraded."); } } + + // ======================================================================== + // Chandy-Lamport distributed snapshot barrier coordinator + // ======================================================================== + + fn spawn_checkpoint_coordinator( + &self, + cfg: CheckpointCoordinatorConfig, + ) -> TokioJoinHandle<()> { + self.control_rt.spawn(async move { + let CheckpointCoordinatorConfig { + job_id, + mut source_control_txs, + all_pipeline_control_txs, + mut job_master_rx, + expected_pipeline_ids, + interval_ms, + start_epoch, + job_state_dir, + timeout, + } = cfg; + if interval_ms == 0 { + info!(job_id = %job_id, "Checkpoint disabled for this job"); + return; + } + + let mut interval = tokio::time::interval(Duration::from_millis(interval_ms)); + interval.tick().await; + + let mut current_epoch: u64 = start_epoch; + struct PendingCheckpoint { + epoch: u64, + missing_acks: HashSet, + start_time: Instant, + source_reports: Vec, + } + let mut active_checkpoint: Option = None; + + let broadcast_cmd = |cmd: ControlCommand| { + for tx in &all_pipeline_control_txs { + let _ = tx.send(cmd.clone()); + } + }; + + loop { + tokio::select! { + biased; + + Some(event) = job_master_rx.recv() => { + match event { + JobMasterEvent::CheckpointAck { + pipeline_id, + epoch, + source_payloads, + } => { + if let Some(pending) = &mut active_checkpoint { + if pending.epoch != epoch { + continue; + } + pending.missing_acks.remove(&pipeline_id); + if !source_payloads.is_empty() { + pending.source_reports.extend(source_payloads); + } + + if pending.missing_acks.is_empty() { + info!( + job_id = %job_id, epoch = epoch, + "Checkpoint Epoch is GLOBALLY COMPLETED (phase 1); persisting metadata and notifying operators (phase 2)" + ); + + let completed = active_checkpoint.take().expect("active checkpoint exists"); + let kf = decode_kafka_checkpoints_from_source_payloads(completed.source_reports, epoch); + let epoch_u32 = u32::try_from(epoch).unwrap_or(u32::MAX); + + let mut catalog_ok = true; + if let Some(catalog) = CatalogManager::try_global() { + if let Err(e) = catalog.commit_job_checkpoint( + &job_id, + epoch, + &job_state_dir, + kf, + ) { + catalog_ok = false; + error!( + job_id = %job_id, epoch = epoch, + error = %e, + "Failed to commit checkpoint metadata to Catalog — aborting transactional sinks" + ); + } + } else { + warn!( + job_id = %job_id, epoch = epoch, + "CatalogManager not available; proceeding with operator Commit (Kafka transactional commit) only" + ); + } + + let phase2 = if catalog_ok { + ControlCommand::Commit { epoch: epoch_u32 } + } else { + ControlCommand::AbortCheckpoint { epoch: epoch_u32 } + }; + broadcast_cmd(phase2); + } + } + } + JobMasterEvent::CheckpointDecline { pipeline_id, epoch, reason } => { + if let Some(pending) = &active_checkpoint + && pending.epoch == epoch + { + error!( + job_id = %job_id, epoch = epoch, pipeline_id = pipeline_id, + reason = %reason, "Checkpoint FAILED!" + ); + broadcast_cmd(ControlCommand::AbortCheckpoint { + epoch: u32::try_from(epoch).unwrap_or(u32::MAX), + }); + active_checkpoint = None; + } + } + } + } + + _ = interval.tick() => { + if let Some(pending) = &active_checkpoint { + if pending.start_time.elapsed() > timeout { + warn!( + job_id = %job_id, + epoch = pending.epoch, + "Checkpoint timed out; aborting active epoch" + ); + broadcast_cmd(ControlCommand::AbortCheckpoint { + epoch: u32::try_from(pending.epoch).unwrap_or(u32::MAX), + }); + } else { + continue; + } + } + + source_control_txs.retain(|tx| !tx.is_closed()); + if source_control_txs.is_empty() { + info!(job_id = %job_id, "All source pipelines closed; checkpoint coordinator exiting"); + break; + } + + info!(job_id = %job_id, epoch = current_epoch, "Triggering global Checkpoint Barrier."); + + let barrier = CheckpointBarrier { + epoch: current_epoch as u32, + min_epoch: 0, + timestamp: std::time::SystemTime::now(), + then_stop: false, + }; + active_checkpoint = Some(PendingCheckpoint { + epoch: current_epoch, + missing_acks: expected_pipeline_ids.clone(), + start_time: Instant::now(), + source_reports: Vec::new(), + }); + + for tx in &source_control_txs { + let cmd = ControlCommand::trigger_checkpoint(barrier); + let _ = tx.send(cmd); + } + current_epoch += 1; + } + } + } + }) + } } diff --git a/src/runtime/streaming/job/mod.rs b/src/runtime/streaming/job/mod.rs index 02e0343c..59d5c61f 100644 --- a/src/runtime/streaming/job/mod.rs +++ b/src/runtime/streaming/job/mod.rs @@ -14,4 +14,4 @@ pub mod edge_manager; pub mod job_manager; pub mod models; -pub use job_manager::{JobManager, StreamingJobSummary}; +pub use job_manager::{JobManager, StateConfig, StreamingJobSummary}; diff --git a/src/runtime/streaming/job/models.rs b/src/runtime/streaming/job/models.rs index f4e2f280..e81649f2 100644 --- a/src/runtime/streaming/job/models.rs +++ b/src/runtime/streaming/job/models.rs @@ -13,11 +13,11 @@ use std::collections::HashMap; use std::fmt; use std::sync::{Arc, RwLock}; -use std::thread::JoinHandle; use std::time::Instant; use protocol::function_stream_graph::FsProgram; use tokio::sync::mpsc; +use tokio::task::JoinHandle; use crate::runtime::streaming::protocol::control::ControlCommand; @@ -78,7 +78,7 @@ pub struct PhysicalPipeline { pub pipeline_id: u32, pub handle: Option>, pub status: Arc>, - pub control_tx: mpsc::Sender, + pub control_tx: mpsc::UnboundedSender, } pub struct PhysicalExecutionGraph { diff --git a/src/runtime/streaming/memory/pool.rs b/src/runtime/streaming/memory/pool.rs deleted file mode 100644 index b6a06ad2..00000000 --- a/src/runtime/streaming/memory/pool.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use parking_lot::Mutex; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::Notify; -use tracing::{debug, warn}; - -use super::ticket::MemoryTicket; - -#[derive(Debug)] -pub struct MemoryPool { - max_bytes: usize, - used_bytes: AtomicUsize, - available_bytes: Mutex, - notify: Notify, -} - -impl MemoryPool { - pub fn new(max_bytes: usize) -> Arc { - Arc::new(Self { - max_bytes, - used_bytes: AtomicUsize::new(0), - available_bytes: Mutex::new(max_bytes), - notify: Notify::new(), - }) - } - - pub fn usage_metrics(&self) -> (usize, usize) { - (self.used_bytes.load(Ordering::Relaxed), self.max_bytes) - } - - pub async fn request_memory(self: &Arc, bytes: usize) -> MemoryTicket { - if bytes == 0 { - return MemoryTicket::new(0, self.clone()); - } - - if bytes > self.max_bytes { - warn!( - "Requested memory ({} B) exceeds total pool size ({} B)! \ - Permitting to avoid pipeline deadlock, but OOM risk is critical.", - bytes, self.max_bytes - ); - self.used_bytes.fetch_add(bytes, Ordering::Relaxed); - return MemoryTicket::new(bytes, self.clone()); - } - - loop { - { - let mut available = self.available_bytes.lock(); - if *available >= bytes { - *available -= bytes; - self.used_bytes.fetch_add(bytes, Ordering::Relaxed); - return MemoryTicket::new(bytes, self.clone()); - } - } - - debug!( - "Backpressure engaged: waiting for {} bytes to be freed...", - bytes - ); - self.notify.notified().await; - } - } - - pub(crate) fn release(&self, bytes: usize) { - if bytes == 0 { - return; - } - - { - let mut available = self.available_bytes.lock(); - *available += bytes; - } - - self.used_bytes.fetch_sub(bytes, Ordering::Relaxed); - self.notify.notify_waiters(); - } -} diff --git a/src/runtime/streaming/mod.rs b/src/runtime/streaming/mod.rs index 7e0ba57a..b092c85d 100644 --- a/src/runtime/streaming/mod.rs +++ b/src/runtime/streaming/mod.rs @@ -19,9 +19,9 @@ pub mod execution; pub mod factory; pub mod format; pub mod job; -pub mod memory; pub mod network; pub mod operators; pub mod protocol; +pub mod state; pub use protocol::StreamOutput; diff --git a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs index 43e0e657..a2325e7c 100644 --- a/src/runtime/streaming/operators/grouping/incremental_aggregate.rs +++ b/src/runtime/streaming/operators/grouping/incremental_aggregate.rs @@ -11,15 +11,17 @@ // limitations under the License. use crate::sql::common::constants::updating_state_field; -use anyhow::{Result, bail}; -use arrow::compute::max_array; +use anyhow::{Result, anyhow, bail}; +use arrow::compute::{concat_batches, max_array}; use arrow::row::{RowConverter, SortField}; use arrow_array::builder::{ BinaryBuilder, TimestampNanosecondBuilder, UInt32Builder, UInt64Builder, }; use arrow_array::cast::AsArray; use arrow_array::types::UInt64Type; -use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, StructArray}; +use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, RecordBatch, StructArray, UInt32Array, UInt64Array, +}; use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, TimeUnit}; use datafusion::common::{Result as DFResult, ScalarValue}; use datafusion::physical_expr::aggregate::AggregateFunctionExpr; @@ -36,14 +38,15 @@ use std::collections::HashSet; use std::sync::LazyLock; use std::time::{Duration, Instant, SystemTime}; use std::{collections::HashMap, mem, sync::Arc}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; // ========================================================================= // ========================================================================= use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::runtime::streaming::operators::{Key, UpdatingCache}; +use crate::runtime::streaming::state::OperatorStateStore; use crate::runtime::util::decode_aggregate; use crate::sql::common::{ CheckpointBarrier, FsSchema, TIMESTAMP_FIELD, UPDATING_META_FIELD, Watermark, to_nanos, @@ -213,10 +216,15 @@ pub struct IncrementalAggregatingFunc { ttl: Duration, key_converter: RowConverter, new_generation: u64, + + state_store: Option>, } static GLOBAL_KEY: LazyLock>> = LazyLock::new(|| Arc::new(Vec::new())); +const KEY_SLIDING_SNAPSHOT: &[u8] = &[0x01]; +const KEY_BATCH_SNAPSHOT: &[u8] = &[0x02]; + impl IncrementalAggregatingFunc { fn update_batch( &mut self, @@ -437,40 +445,38 @@ impl IncrementalAggregatingFunc { // ========================================================================= fn checkpoint_sliding(&mut self) -> DFResult>> { - if self.updated_keys.is_empty() { + let keys = self.accumulators.keys(); + if keys.is_empty() { return Ok(None); } let mut states = vec![vec![]; self.sliding_state_schema.schema.fields.len()]; let parser = self.key_converter.parser(); - let mut generation_builder = UInt64Builder::with_capacity(self.updated_keys.len()); - - let mut cols = self - .key_converter - .convert_rows(self.updated_keys.keys().map(|k| { - let (accumulators, generation) = - self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); - generation_builder.append_value(generation); - - for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { - let IncrementalState::Sliding { expr, accumulator } = state else { - continue; - }; - let state = accumulator.state().unwrap_or_else(|_| { - let state = accumulator.state().unwrap(); - *accumulator = expr.create_sliding_accumulator().unwrap(); - let states: Vec<_> = - state.iter().map(|s| s.to_array()).try_collect().unwrap(); - accumulator.merge_batch(&states).unwrap(); - state - }); - - for (idx, v) in agg.state_cols.iter().zip(state) { - states[*idx].push(v); - } + let mut generation_builder = UInt64Builder::with_capacity(keys.len()); + + let mut cols = self.key_converter.convert_rows(keys.iter().map(|k| { + let (accumulators, generation) = + self.accumulators.get_mut_generation(k.0.as_ref()).unwrap(); + generation_builder.append_value(generation); + + for (state, agg) in accumulators.iter_mut().zip(self.aggregates.iter()) { + let IncrementalState::Sliding { expr, accumulator } = state else { + continue; + }; + let state = accumulator.state().unwrap_or_else(|_| { + let state = accumulator.state().unwrap(); + *accumulator = expr.create_sliding_accumulator().unwrap(); + let states: Vec<_> = state.iter().map(|s| s.to_array()).try_collect().unwrap(); + accumulator.merge_batch(&states).unwrap(); + state + }); + + for (idx, v) in agg.state_cols.iter().zip(state) { + states[*idx].push(v); } - parser.parse(k.0.as_ref()) - }))?; + } + parser.parse(k.0.as_ref()) + }))?; cols.extend( states @@ -482,7 +488,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -496,12 +502,22 @@ impl IncrementalAggregatingFunc { { return Ok(None); } - if self.updated_keys.is_empty() { + + let keys = self.accumulators.keys(); + + let mut size = 0; + for k in &keys { + for state in self.accumulators.get_mut(k.0.as_ref()).unwrap().iter_mut() { + if let IncrementalState::Batch { data, .. } = state { + size += data.len(); + } + } + } + if size == 0 { return Ok(None); } - let size = self.updated_keys.len(); - let mut rows = Vec::with_capacity(size); + let mut key_bytes_for_rows = Vec::with_capacity(size); let mut accumulator_builder = UInt32Builder::with_capacity(size); let mut args_row_builder = BinaryBuilder::with_capacity(size, size * 4); let mut count_builder = UInt64Builder::with_capacity(size); @@ -509,10 +525,8 @@ impl IncrementalAggregatingFunc { let mut generation_builder = UInt64Builder::with_capacity(size); let now = to_nanos(SystemTime::now()) as i64; - let parser = self.key_converter.parser(); - for k in self.updated_keys.keys() { - let row = parser.parse(&k.0); + for k in keys { for (i, state) in self .accumulators .get_mut(k.0.as_ref()) @@ -520,29 +534,27 @@ impl IncrementalAggregatingFunc { .iter_mut() .enumerate() { - let IncrementalState::Batch { - data, - changed_values, - .. - } = state - else { + let IncrementalState::Batch { data, .. } = state else { continue; }; - for vk in changed_values.iter() { - if let Some(count) = data.get(vk) { - accumulator_builder.append_value(i as u32); - args_row_builder.append_value(&*vk.0); - count_builder.append_value(count.count); - generation_builder.append_value(count.generation); - timestamp_builder.append_value(now); - rows.push(row.to_owned()) - } + for (vk, count_data) in data.iter() { + accumulator_builder.append_value(i as u32); + args_row_builder.append_value(&*vk.0); + count_builder.append_value(count_data.count); + generation_builder.append_value(count_data.generation); + timestamp_builder.append_value(now); + key_bytes_for_rows.push(k.0.clone()); } data.retain(|_, v| v.count > 0); } } + let parser = self.key_converter.parser(); + let rows: Vec<_> = key_bytes_for_rows + .iter() + .map(|kb| parser.parse(kb).to_owned()) + .collect(); let mut cols = self.key_converter.convert_rows(rows)?; cols.push(Arc::new(accumulator_builder.finish())); cols.push(Arc::new(args_row_builder.finish())); @@ -552,7 +564,7 @@ impl IncrementalAggregatingFunc { let generations = generation_builder.finish(); self.new_generation = self .new_generation - .max(max_array::(&generations).unwrap()); + .max(max_array::(&generations).unwrap_or(0)); cols.push(Arc::new(generations)); Ok(Some(cols)) @@ -710,7 +722,152 @@ impl Operator for IncrementalAggregatingFunc { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Updating Aggregate recovering state from LSM-Tree..." + ); + + let mut sliding_batches = Vec::new(); + let mut batch_batches = Vec::new(); + + for key in active_keys { + if key == KEY_SLIDING_SNAPSHOT { + sliding_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); + } else if key == KEY_BATCH_SNAPSHOT { + batch_batches + .extend(store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?); + } + } + + let num_keys = self + .input_schema + .routing_keys() + .map(|k| k.len()) + .unwrap_or(0); + let now = Instant::now(); + + // Restore sliding (reversible) accumulator state + if !sliding_batches.is_empty() { + let combined = concat_batches(&self.sliding_state_schema.schema, &sliding_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + let aggregate_states: Vec> = self + .aggregates + .iter() + .map(|agg| { + agg.state_cols + .iter() + .map(|&idx| combined.column(idx).clone()) + .collect() + }) + .collect(); + let gen_col = combined + .column(combined.num_columns() - 1) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let generation = gen_col.value(i); + self.restore_sliding(&key, now, i, &aggregate_states, generation)?; + } + info!( + rows = combined.num_rows(), + "Restored sliding accumulator state." + ); + } + + // Restore batch (non-reversible) detail dictionaries + if !batch_batches.is_empty() { + let combined = concat_batches(&self.batch_state_schema.schema, &batch_batches)?; + let key_cols: Vec = combined.columns()[0..num_keys].to_vec(); + + let acc_idx_col = combined + .column(num_keys) + .as_any() + .downcast_ref::() + .expect("accumulator index column must be UInt32Array"); + let args_col = combined + .column(num_keys + 1) + .as_any() + .downcast_ref::() + .expect("args_row column must be BinaryArray"); + let count_col = combined + .column(num_keys + 2) + .as_any() + .downcast_ref::() + .expect("count column must be UInt64Array"); + // column num_keys+3 is timestamp, skip + let gen_col = combined + .column(num_keys + 4) + .as_any() + .downcast_ref::() + .expect("generation column must be UInt64Array"); + + let rows = self.key_converter.convert_columns(&key_cols)?; + + for i in 0..combined.num_rows() { + let key = rows.row(i).as_ref().to_vec(); + let acc_idx = acc_idx_col.value(i) as usize; + let args_row = args_col.value(i).to_vec(); + let count = count_col.value(i); + let generation = gen_col.value(i); + + if !self.accumulators.contains_key(&key) { + self.accumulators.insert( + Arc::new(key.clone()), + now, + generation, + self.make_accumulators(), + ); + } + + if let Some(accs) = self.accumulators.get_mut(&key) + && let Some(IncrementalState::Batch { + data, + changed_values, + .. + }) = accs.get_mut(acc_idx) + { + let vk = Key(Arc::new(args_row.clone())); + data.insert(vk.clone(), BatchData { count, generation }); + changed_values.insert(vk); + } + } + info!(rows = combined.num_rows(), "Restored batch detail state."); + } + + info!( + groups = self.accumulators.keys().len(), + "Updating Aggregate successfully restored active groups." + ); + } + self.initialize(ctx).await?; + self.state_store = Some(store); Ok(()) } @@ -719,33 +876,88 @@ impl Operator for IncrementalAggregatingFunc { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { if self.has_routing_keys { self.keyed_aggregate(&batch)?; } else { self.global_aggregate(&batch)?; } - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { if let Some(changelog_batch) = self.generate_changelog()? { - Ok(vec![StreamOutput::Forward(changelog_batch)]) - } else { - Ok(vec![]) + collector + .collect(StreamOutput::Forward(changelog_batch), _ctx) + .await?; } + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .clone() + .expect("State store not initialized"); + + // Tombstone previous epoch snapshots for disk space reclamation + store + .remove_batches(KEY_SLIDING_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(KEY_BATCH_SNAPSHOT.to_vec()) + .map_err(|e| anyhow!("{e}"))?; + + // Full snapshot of sliding (reversible) accumulator state + if let Some(cols) = self.checkpoint_sliding()? { + let batch = RecordBatch::try_new(self.sliding_state_schema.schema.clone(), cols)?; + store + .put(KEY_SLIDING_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Full snapshot of batch (non-reversible) detail state + if let Some(cols) = self.checkpoint_batch()? { + let batch = RecordBatch::try_new(self.batch_state_schema.schema.clone(), cols)?; + store + .put(KEY_BATCH_SNAPSHOT.to_vec(), batch) + .await + .map_err(|e| anyhow!("{e}"))?; + } + + // Flush to Parquet + store + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!( + epoch = barrier.epoch, + "Updating Aggregate snapshotted successfully." + ); + + self.updated_keys.clear(); + + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("state store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -907,6 +1119,7 @@ impl IncrementalAggregatingConstructor { sliding_state_schema, batch_state_schema, new_generation: 0, + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/grouping/updating_cache.rs b/src/runtime/streaming/operators/grouping/updating_cache.rs index 37f2ba04..34c732fc 100644 --- a/src/runtime/streaming/operators/grouping/updating_cache.rs +++ b/src/runtime/streaming/operators/grouping/updating_cache.rs @@ -64,6 +64,10 @@ impl Iterator for TTLIter<'_, T> { } impl UpdatingCache { + pub fn keys(&self) -> Vec { + self.map.keys().cloned().collect() + } + pub fn with_time_to_idle(ttl: Duration) -> Self { Self { map: HashMap::new(), diff --git a/src/runtime/streaming/operators/joins/join_instance.rs b/src/runtime/streaming/operators/joins/join_instance.rs index 75513542..098e5a73 100644 --- a/src/runtime/streaming/operators/joins/join_instance.rs +++ b/src/runtime/streaming/operators/joins/join_instance.rs @@ -11,9 +11,8 @@ // limitations under the License. use anyhow::{Result, anyhow}; -use arrow::compute::{max, min, partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, max, min, partition, sort_to_indices, take}; use arrow_array::{RecordBatch, TimestampNanosecondArray}; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,80 +20,79 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use std::time::UNIX_EPOCH; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; -use crate::sql::common::constants::mem_exec_join_side; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::runtime::streaming::state::OperatorStateStore; +use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } -impl JoinSide { - #[allow(dead_code)] - fn name(&self) -> &'static str { - match self { - JoinSide::Left => mem_exec_join_side::LEFT, - JoinSide::Right => mem_exec_join_side::RIGHT, - } - } -} +// ============================================================================ +// Lightweight state index: composite key [Side(1B)] + [Timestamp(8B BE)] +// ============================================================================ -struct JoinInstance { - left_tx: UnboundedSender, - right_tx: UnboundedSender, - result_stream: SendableRecordBatchStream, +struct InstantStateIndex { + side: JoinSide, + active_timestamps: BTreeSet, } -impl JoinInstance { - fn feed_data(&self, batch: RecordBatch, side: JoinSide) -> Result<()> { - match side { - JoinSide::Left => self - .left_tx - .send(batch) - .map_err(|e| anyhow!("Left send err: {}", e)), - JoinSide::Right => self - .right_tx - .send(batch) - .map_err(|e| anyhow!("Right send err: {}", e)), +impl InstantStateIndex { + fn new(side: JoinSide) -> Self { + Self { + side, + active_timestamps: BTreeSet::new(), } } - async fn close_and_drain(self) -> Result> { - drop(self.left_tx); - drop(self.right_tx); - - let mut outputs = Vec::new(); - let mut stream = self.result_stream; + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } - while let Some(result_batch) = stream.next().await { - outputs.push(result_batch?); + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } - - Ok(outputs) } } +// ============================================================================ +// InstantJoinOperator (persistent state refactor) +// ============================================================================ + pub struct InstantJoinOperator { left_input_schema: FsSchemaRef, right_input_schema: FsSchemaRef, - active_joins: BTreeMap, - left_receiver_hook: Arc>>>, - right_receiver_hook: Arc>>>, + left_schema: FsSchemaRef, + right_schema: FsSchemaRef, + + left_passer: Arc>>, + right_passer: Arc>>, join_exec_plan: Arc, + + left_state: InstantStateIndex, + right_state: InstantStateIndex, + state_store: Option>, } impl InstantJoinOperator { @@ -105,32 +103,26 @@ impl InstantJoinOperator { } } - fn get_or_create_join_instance(&mut self, time: SystemTime) -> Result<&mut JoinInstance> { - use std::collections::btree_map::Entry; - - if let Entry::Vacant(e) = self.active_joins.entry(time) { - let (left_tx, left_rx) = unbounded_channel(); - let (right_tx, right_rx) = unbounded_channel(); + async fn compute_pair( + &mut self, + left: RecordBatch, + right: RecordBatch, + ) -> Result> { + self.left_passer.write().unwrap().replace(left); + self.right_passer.write().unwrap().replace(right); - *self.left_receiver_hook.write().unwrap() = Some(left_rx); - *self.right_receiver_hook.write().unwrap() = Some(right_rx); + self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - self.join_exec_plan.reset().map_err(|e| anyhow!("{e}"))?; - let result_stream = self - .join_exec_plan - .execute(0, SessionContext::new().task_ctx()) - .map_err(|e| anyhow!("{e}"))?; + let mut result_stream = self + .join_exec_plan + .execute(0, SessionContext::new().task_ctx()) + .map_err(|e| anyhow!("{e}"))?; - e.insert(JoinInstance { - left_tx, - right_tx, - result_stream, - }); + let mut outputs = Vec::new(); + while let Some(batch) = result_stream.next().await { + outputs.push(batch.map_err(|e| anyhow!("{e}"))?); } - - self.active_joins - .get_mut(&time) - .ok_or_else(|| anyhow!("join instance missing after insert")) + Ok(outputs) } async fn process_side_internal( @@ -142,6 +134,10 @@ impl InstantJoinOperator { if batch.num_rows() == 0 { return Ok(()); } + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let time_column = batch .column(self.input_schema(side).timestamp_index) @@ -152,19 +148,28 @@ impl InstantJoinOperator { let min_timestamp = min(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; let max_timestamp = max(time_column).ok_or_else(|| anyhow!("empty timestamp column"))?; - if let Some(watermark) = ctx.current_watermark() - && watermark > from_nanos(min_timestamp as u128) - { - warn!("Dropped late batch from {:?} before watermark", side); - return Ok(()); + if let Some(watermark) = ctx.current_watermark() { + let watermark_nanos = watermark.duration_since(UNIX_EPOCH).unwrap().as_nanos() as i64; + if watermark_nanos > min_timestamp { + warn!("Dropped late batch from {:?} before watermark", side); + return Ok(()); + } } let unkeyed_batch = self.input_schema(side).unkeyed_batch(&batch)?; + let state_index = match side { + JoinSide::Left => &mut self.left_state, + JoinSide::Right => &mut self.right_state, + }; if max_timestamp == min_timestamp { - let time_key = from_nanos(max_timestamp as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(unkeyed_batch, side)?; + let ts_nanos = max_timestamp as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, unkeyed_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); return Ok(()); } @@ -179,16 +184,21 @@ impl InstantJoinOperator { let typed_timestamps = sorted_timestamps .as_any() .downcast_ref::() - .ok_or_else(|| anyhow!("sorted timestamps downcast failed"))?; + .unwrap(); + let ranges = partition(std::slice::from_ref(&sorted_timestamps)) .unwrap() .ranges(); for range in ranges { let sub_batch = sorted_batch.slice(range.start, range.end - range.start); - let time_key = from_nanos(typed_timestamps.value(range.start) as u128); - let join_instance = self.get_or_create_join_instance(time_key)?; - join_instance.feed_data(sub_batch, side)?; + let ts_nanos = typed_timestamps.value(range.start) as u64; + let key = InstantStateIndex::build_key(side, ts_nanos); + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + state_index.active_timestamps.insert(ts_nanos); } Ok(()) @@ -201,7 +211,44 @@ impl Operator for InstantJoinOperator { "InstantJoin" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = InstantStateIndex::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Instant Join Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } @@ -210,56 +257,128 @@ impl Operator for InstantJoinOperator { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let side = if input_idx == 0 { JoinSide::Left } else { JoinSide::Right }; self.process_side_internal(side, batch, ctx).await?; - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; - let mut emit_outputs = Vec::new(); + let store = self.state_store.clone().unwrap(); + let cutoff_nanos = current_time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let mut all_active_ts = BTreeSet::new(); + all_active_ts.extend(self.left_state.active_timestamps.iter()); + all_active_ts.extend(self.right_state.active_timestamps.iter()); + + let expired_ts: Vec = all_active_ts + .into_iter() + .filter(|&ts| ts < cutoff_nanos) + .collect(); + + if expired_ts.is_empty() { + return Ok(()); + } - let mut expired_times = Vec::new(); - for key in self.active_joins.keys() { - if *key < current_time { - expired_times.push(*key); + // Phase 1: Harvest — extract all expired timestamp data from LSM-Tree + let mut pending_pairs: Vec<(u64, RecordBatch, RecordBatch)> = + Vec::with_capacity(expired_ts.len()); + + for &ts in &expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + + let left_batches = store + .get_batches(&left_key) + .await + .map_err(|e| anyhow!("{e}"))?; + let right_batches = store + .get_batches(&right_key) + .await + .map_err(|e| anyhow!("{e}"))?; + + let left_input = if left_batches.is_empty() { + RecordBatch::new_empty(self.left_schema.schema.clone()) } else { - break; - } + concat_batches(&self.left_schema.schema, left_batches.iter())? + }; + let right_input = if right_batches.is_empty() { + RecordBatch::new_empty(self.right_schema.schema.clone()) + } else { + concat_batches(&self.right_schema.schema, right_batches.iter())? + }; + + pending_pairs.push((ts, left_input, right_input)); } - for time_key in expired_times { - if let Some(join_instance) = self.active_joins.remove(&time_key) { - let joined_batches = join_instance.close_and_drain().await?; - for batch in joined_batches { - emit_outputs.push(StreamOutput::Forward(batch)); - } + // Phase 2: Compute — all data extracted, no store reference held + for (_, left_input, right_input) in pending_pairs { + if left_input.num_rows() == 0 && right_input.num_rows() == 0 { + continue; + } + let results = self.compute_pair(left_input, right_input).await?; + for batch in results { + collector + .collect(StreamOutput::Forward(batch), _ctx) + .await?; } } - Ok(emit_outputs) + // Phase 3: Cleanup — tombstone LSM-Tree entries and update in-memory index + for ts in expired_ts { + let left_key = InstantStateIndex::build_key(JoinSide::Left, ts); + let right_key = InstantStateIndex::build_key(JoinSide::Right, ts); + store.remove_batches(left_key).map_err(|e| anyhow!("{e}"))?; + store + .remove_batches(right_key) + .map_err(|e| anyhow!("{e}"))?; + self.left_state.active_timestamps.remove(&ts); + self.right_state.active_timestamps.remove(&ts); + } + + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .unwrap() + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .unwrap() + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } } +// ============================================================================ +// Constructor +// ============================================================================ + pub struct InstantJoinConstructor; impl InstantJoinConstructor { @@ -268,21 +387,23 @@ impl InstantJoinConstructor { config: JoinOperator, registry: Arc, ) -> anyhow::Result { - let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; - let left_input_schema: Arc = Arc::new(config.left_schema.unwrap().try_into()?); let right_input_schema: Arc = Arc::new(config.right_schema.unwrap().try_into()?); - let left_receiver_hook = Arc::new(RwLock::new(None)); - let right_receiver_hook = Arc::new(RwLock::new(None)); + let left_schema = Arc::new(left_input_schema.schema_without_keys()?); + let right_schema = Arc::new(right_input_schema.schema_without_keys()?); + + let left_passer = Arc::new(RwLock::new(None)); + let right_passer = Arc::new(RwLock::new(None)); let codec = StreamingExtensionCodec { - context: StreamingDecodingContext::LockedJoinStream { - left: left_receiver_hook.clone(), - right: right_receiver_hook.clone(), + context: StreamingDecodingContext::LockedJoinPair { + left: left_passer.clone(), + right: right_passer.clone(), }, }; + let join_physical_plan_node = PhysicalPlanNode::decode(&mut config.join_plan.as_slice())?; let join_exec_plan = join_physical_plan_node.try_into_physical_plan( registry.as_ref(), &RuntimeEnvBuilder::new().build()?, @@ -292,10 +413,14 @@ impl InstantJoinConstructor { Ok(InstantJoinOperator { left_input_schema, right_input_schema, - active_joins: BTreeMap::new(), - left_receiver_hook, - right_receiver_hook, + left_schema, + right_schema, + left_passer, + right_passer, join_exec_plan, + left_state: InstantStateIndex::new(JoinSide::Left), + right_state: InstantStateIndex::new(JoinSide::Right), + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/joins/join_with_expiration.rs b/src/runtime/streaming/operators/joins/join_with_expiration.rs index 60bbe7e3..6a2a240c 100644 --- a/src/runtime/streaming/operators/joins/join_with_expiration.rs +++ b/src/runtime/streaming/operators/joins/join_with_expiration.rs @@ -19,15 +19,16 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::{physical_plan::AsExecutionPlan, protobuf::PhysicalPlanNode}; use futures::StreamExt; use prost::Message; -use std::collections::VecDeque; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; -use tracing::warn; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; @@ -35,49 +36,91 @@ use protocol::function_stream_graph::JoinOperator; #[derive(Debug, Copy, Clone, Eq, PartialEq)] enum JoinSide { - Left, - Right, + Left = 0, + Right = 1, } // ============================================================================ +// Persistent state buffer: composite key [Side(1B)] + [Timestamp(8B BE)] // ============================================================================ -struct StateBuffer { - batches: VecDeque<(SystemTime, RecordBatch)>, +struct PersistentStateBuffer { + side: JoinSide, ttl: Duration, + active_timestamps: BTreeSet, } -impl StateBuffer { - fn new(ttl: Duration) -> Self { +impl PersistentStateBuffer { + fn new(side: JoinSide, ttl: Duration) -> Self { Self { - batches: VecDeque::new(), + side, ttl, + active_timestamps: BTreeSet::new(), } } - fn insert(&mut self, batch: RecordBatch, time: SystemTime) { - self.batches.push_back((time, batch)); + fn build_key(side: JoinSide, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(side as u8); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key } - fn expire(&mut self, current_time: SystemTime) { - let cutoff = current_time - .checked_sub(self.ttl) - .unwrap_or(SystemTime::UNIX_EPOCH); - while let Some((time, _)) = self.batches.front() { - if *time < cutoff { - self.batches.pop_front(); - } else { - break; - } + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None } } - fn get_all_batches(&self) -> Vec { - self.batches.iter().map(|(_, b)| b.clone()).collect() + async fn insert( + &mut self, + batch: RecordBatch, + time: SystemTime, + store: &Arc, + ) -> Result<()> { + let ts_nanos = time.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + self.active_timestamps.insert(ts_nanos); + let key = Self::build_key(self.side, ts_nanos); + store.put(key, batch).await.map_err(|e| anyhow!("{e}")) + } + + fn expire(&mut self, current_time: SystemTime, store: &Arc) -> Result<()> { + let cutoff = current_time.checked_sub(self.ttl).unwrap_or(UNIX_EPOCH); + let cutoff_nanos = cutoff.duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64; + + let expired_ts: Vec = self + .active_timestamps + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_ts { + let key = Self::build_key(self.side, ts); + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; + self.active_timestamps.remove(&ts); + } + + Ok(()) + } + + async fn get_all_batches(&self, store: &Arc) -> Result> { + let mut all_batches = Vec::new(); + for &ts in &self.active_timestamps { + let key = Self::build_key(self.side, ts); + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + all_batches.extend(batches); + } + Ok(all_batches) } } // ============================================================================ +// JoinWithExpirationOperator // ============================================================================ pub struct JoinWithExpirationOperator { @@ -90,8 +133,9 @@ pub struct JoinWithExpirationOperator { right_passer: Arc>>, join_exec_plan: Arc, - left_state: StateBuffer, - right_state: StateBuffer, + left_state: PersistentStateBuffer, + right_state: PersistentStateBuffer, + state_store: Option>, } impl JoinWithExpirationOperator { @@ -131,24 +175,37 @@ impl JoinWithExpirationOperator { side: JoinSide, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let current_time = ctx.current_watermark().unwrap_or_else(SystemTime::now); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); - self.left_state.expire(current_time); - self.right_state.expire(current_time); + self.left_state.expire(current_time, store)?; + self.right_state.expire(current_time, store)?; match side { - JoinSide::Left => self.left_state.insert(batch.clone(), current_time), - JoinSide::Right => self.right_state.insert(batch.clone(), current_time), + JoinSide::Left => { + self.left_state + .insert(batch.clone(), current_time, store) + .await? + } + JoinSide::Right => { + self.right_state + .insert(batch.clone(), current_time, store) + .await? + } } let opposite_batches = match side { - JoinSide::Left => self.right_state.get_all_batches(), - JoinSide::Right => self.left_state.get_all_batches(), + JoinSide::Left => self.right_state.get_all_batches(store).await?, + JoinSide::Right => self.left_state.get_all_batches(store).await?, }; if opposite_batches.is_empty() { - return Ok(vec![]); + return Ok(()); } let opposite_schema = match side { @@ -168,11 +225,10 @@ impl JoinWithExpirationOperator { }; let result_batches = self.compute_pair(left_input, right_input).await?; - - Ok(result_batches - .into_iter() - .map(StreamOutput::Forward) - .collect()) + for b in result_batches { + collector.collect(StreamOutput::Forward(b), ctx).await?; + } + Ok(()) } } @@ -182,7 +238,44 @@ impl Operator for JoinWithExpirationOperator { "JoinWithExpiration" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + for key in active_keys { + if let Some(ts) = PersistentStateBuffer::extract_timestamp(&key) { + if key[0] == JoinSide::Left as u8 { + self.left_state.active_timestamps.insert(ts); + } else if key[0] == JoinSide::Right as u8 { + self.right_state.active_timestamps.insert(ts); + } + } + } + + info!( + pipeline_id = ctx.pipeline_id, + restored_left = self.left_state.active_timestamps.len(), + restored_right = self.right_state.active_timestamps.len(), + "Join Operator restored state from LSM-Tree." + ); + + self.state_store = Some(store); Ok(()) } @@ -191,28 +284,49 @@ impl Operator for JoinWithExpirationOperator { input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let side = if input_idx == 0 { JoinSide::Left } else { JoinSide::Right }; - self.process_side(side, batch, ctx).await + self.process_side(side, batch, ctx, collector).await } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!(epoch = barrier.epoch, "Join Operator snapshotted state."); + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -222,6 +336,7 @@ impl Operator for JoinWithExpirationOperator { } // ============================================================================ +// Constructor // ============================================================================ pub struct JoinWithExpirationConstructor; @@ -273,8 +388,9 @@ impl JoinWithExpirationConstructor { left_passer, right_passer, join_exec_plan, - left_state: StateBuffer::new(ttl), - right_state: StateBuffer::new(ttl), + left_state: PersistentStateBuffer::new(JoinSide::Left, ttl), + right_state: PersistentStateBuffer::new(JoinSide::Right, ttl), + state_store: None, }) } } diff --git a/src/runtime/streaming/operators/key_by.rs b/src/runtime/streaming/operators/key_by.rs index 59206688..90c55d08 100644 --- a/src/runtime/streaming/operators/key_by.rs +++ b/src/runtime/streaming/operators/key_by.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::sql::common::{CheckpointBarrier, Watermark}; use protocol::function_stream_graph::KeyPlanOperator; @@ -57,10 +57,11 @@ impl Operator for KeyByOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let num_rows = batch.num_rows(); if num_rows == 0 { - return Ok(vec![]); + return Ok(()); } let mut key_columns = Vec::with_capacity(self.key_extractors.len()); @@ -110,15 +111,22 @@ impl Operator for KeyByOperator { start_idx = end_idx; } - Ok(outputs) + for out in outputs { + collector.collect(out, _ctx).await?; + } + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/key_operator.rs b/src/runtime/streaming/operators/key_operator.rs index 1f4f48c6..7a89d2f2 100644 --- a/src/runtime/streaming/operators/key_operator.rs +++ b/src/runtime/streaming/operators/key_operator.rs @@ -17,7 +17,7 @@ use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::operators::StatelessPhysicalExecutor; use crate::sql::common::{CheckpointBarrier, Watermark}; use ahash::RandomState; @@ -67,7 +67,8 @@ impl Operator for KeyExecutionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let mut outputs = Vec::new(); let mut stream = self.executor.process_batch(batch).await?; @@ -122,15 +123,22 @@ impl Operator for KeyExecutionOperator { start_idx = end_idx; } } - Ok(outputs) + for out in outputs { + collector.collect(out, _ctx).await?; + } + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/projection.rs b/src/runtime/streaming/operators/projection.rs index 1a2ff3a1..b84d74aa 100644 --- a/src/runtime/streaming/operators/projection.rs +++ b/src/runtime/streaming/operators/projection.rs @@ -24,7 +24,7 @@ use protocol::function_stream_graph::ProjectionOperator as ProjectionOperatorPro use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::global::Registry; use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark}; use crate::sql::logical_node::logical::OperatorName; @@ -98,9 +98,10 @@ impl Operator for ProjectionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { if batch.num_rows() == 0 { - return Ok(vec![]); + return Ok(()); } let projected_columns = self @@ -114,15 +115,22 @@ impl Operator for ProjectionOperator { let out_batch = RecordBatch::try_new(self.output_schema.schema.clone(), projected_columns)?; - Ok(vec![StreamOutput::Forward(out_batch)]) + collector + .collect(StreamOutput::Forward(out_batch), _ctx) + .await?; + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/sink/kafka/mod.rs b/src/runtime/streaming/operators/sink/kafka/mod.rs index a24a098d..a9c4b50e 100644 --- a/src/runtime/streaming/operators/sink/kafka/mod.rs +++ b/src/runtime/streaming/operators/sink/kafka/mod.rs @@ -10,6 +10,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! ## Exactly-once Kafka sink and checkpoint 2PC +//! +//! - **Pre-commit (barrier / `snapshot_state`)**: flush in-flight sends, rotate to a new transactional +//! producer for post-barrier records, and stash the producer that covered this checkpoint interval. +//! - **Commit (`commit_checkpoint`)**: after the job coordinator persists checkpoint metadata (catalog), +//! it broadcasts `ControlCommand::Commit`; this operator calls `commit_transaction` on the stashed +//! producer so consumers with `isolation.level=read_committed` observe the batch. +//! - **Abort (`abort_checkpoint`)**: if metadata commit fails or the checkpoint is declined, the +//! coordinator broadcasts `AbortCheckpoint` and this operator calls `abort_transaction` on the +//! stashed producer. + use anyhow::{Result, anyhow, bail}; use arrow_array::Array; use arrow_array::RecordBatch; @@ -27,7 +38,7 @@ use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::format::DataSerializer; use crate::sql::common::constants::factory_operator_name; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark}; @@ -115,6 +126,12 @@ impl KafkaSinkOperator { if let Some(idx) = tx_index { config.set("enable.idempotence", "true"); + if config.get("acks").is_none() { + config.set("acks", "all"); + } + if config.get("transaction.timeout.ms").is_none() { + config.set("transaction.timeout.ms", "600000"); + } let transactional_id = format!( "fs-tx-{}-{}-{}-{}", ctx.job_id, self.topic, ctx.subtask_index, idx @@ -243,7 +260,8 @@ impl Operator for KafkaSinkOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let payloads = self.serializer.serialize(&batch)?; let producer = self.current_producer().clone(); @@ -281,15 +299,16 @@ impl Operator for KafkaSinkOperator { } } - Ok(vec![]) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( @@ -361,6 +380,34 @@ impl Operator for KafkaSinkOperator { Ok(()) } + async fn abort_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + if matches!(self.consistency_mode, ConsistencyMode::AtLeastOnce) { + return Ok(()); + } + + let state = self.transactional_state.as_mut().unwrap(); + let Some(stale) = state.producer_awaiting_commit.take() else { + warn!( + "AbortCheckpoint epoch {} but no stashed transactional producer (already committed or duplicate signal)", + epoch + ); + return Ok(()); + }; + + match stale.abort_transaction(Timeout::After(Duration::from_secs(30))) { + Ok(()) => info!( + "Aborted Kafka transaction for epoch {} (checkpoint metadata did not commit)", + epoch + ), + Err(e) => warn!( + "Kafka abort_transaction for epoch {} returned error (producer dropped): {}", + epoch, e + ), + } + + Ok(()) + } + async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result> { self.flush_to_broker().await?; info!("Kafka sink shut down gracefully."); diff --git a/src/runtime/streaming/operators/source/kafka/mod.rs b/src/runtime/streaming/operators/source/kafka/mod.rs index e73d18fa..9f5b84ad 100644 --- a/src/runtime/streaming/operators/source/kafka/mod.rs +++ b/src/runtime/streaming/operators/source/kafka/mod.rs @@ -10,21 +10,28 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Kafka source checkpointing: `enable.auto.commit=false`, offsets captured at the checkpoint barrier +//! and reported to the job coordinator for catalog persistence; restart rewinds from that snapshot. + use anyhow::{Context as _, Result, anyhow}; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use async_trait::async_trait; use bincode::{Decode, Encode}; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter as GovernorRateLimiter}; +use protocol::storage::{KafkaPartitionOffset, KafkaSourceSubtaskCheckpoint}; use rdkafka::consumer::{CommitMode, Consumer, StreamConsumer}; use rdkafka::{ClientConfig, Message as KMessage, Offset, TopicPartitionList}; use std::collections::HashMap; use std::num::NonZeroU32; +use std::path::PathBuf; use std::time::{Duration, Instant}; use tracing::{debug, error, info, warn}; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::source::{SourceEvent, SourceOffset, SourceOperator}; +use crate::runtime::streaming::api::source::{ + SourceCheckpointReport, SourceEvent, SourceOffset, SourceOperator, +}; use crate::runtime::streaming::format::{BadDataPolicy, DataDeserializer, Format}; use crate::sql::common::fs_schema::FieldValueType; use crate::sql::common::{CheckpointBarrier, MetadataField}; @@ -33,8 +40,74 @@ use crate::sql::common::{CheckpointBarrier, MetadataField}; #[derive(Copy, Clone, Debug, Encode, Decode, PartialEq, PartialOrd)] pub struct KafkaState { - partition: i32, - offset: i64, + pub partition: i32, + pub offset: i64, +} + +/// Last committed partition offsets for this source subtask, tied to a checkpoint epoch. +/// Materialized into a `.bin` under the job state dir from catalog before restart; see +/// [`TaskContext::latest_safe_epoch`] and `StreamingTableDefinition` in `storage.proto`. +#[derive(Debug, Encode, Decode)] +pub(crate) struct KafkaSourceSavedOffsets { + /// Same numbering as [`CheckpointBarrier::epoch`] / catalog `latest_checkpoint_epoch` (as u64). + pub(crate) epoch: u64, + pub(crate) partitions: Vec, +} + +pub(crate) fn encode_kafka_offset_snapshot(saved: &KafkaSourceSavedOffsets) -> Result> { + bincode::encode_to_vec(saved, bincode::config::standard()) + .map_err(|e| anyhow!("bincode encode Kafka offset snapshot: {e}")) +} + +pub(crate) fn decode_kafka_offset_snapshot(bytes: &[u8]) -> Result { + let (saved, _) = bincode::decode_from_slice(bytes, bincode::config::standard()) + .map_err(|e| anyhow!("bincode decode Kafka offset snapshot: {e}"))?; + Ok(saved) +} + +pub(crate) fn kafka_snapshot_path( + job_dir: &std::path::Path, + pipeline_id: u32, + subtask_index: u32, +) -> PathBuf { + job_dir.join(format!( + "kafka_source_offsets_pipe{}_sub{}.bin", + pipeline_id, subtask_index + )) +} + +fn kafka_offsets_snapshot_path(ctx: &TaskContext) -> PathBuf { + kafka_snapshot_path(&ctx.state_dir, ctx.pipeline_id, ctx.subtask_index) +} + +fn load_saved_offsets_if_recovering(ctx: &TaskContext) -> Option { + let safe = ctx.latest_safe_epoch(); + if safe == 0 { + return None; + } + let path = kafka_offsets_snapshot_path(ctx); + let bytes = std::fs::read(&path).ok()?; + let saved = match decode_kafka_offset_snapshot(&bytes) { + Ok(v) => v, + Err(e) => { + warn!( + path = %path.display(), + error = %e, + "Failed to decode Kafka offset snapshot" + ); + return None; + } + }; + if saved.epoch > safe { + warn!( + path = %path.display(), + saved_epoch = saved.epoch, + safe_epoch = safe, + "Ignoring Kafka offset snapshot newer than catalog safe epoch" + ); + return None; + } + Some(saved) } pub trait BatchDeserializer: Send + 'static { @@ -182,7 +255,11 @@ impl KafkaSourceOperator { } } - async fn init_and_assign_consumer(&mut self, ctx: &mut TaskContext) -> Result<()> { + async fn init_and_assign_consumer( + &mut self, + ctx: &mut TaskContext, + saved_offsets: Option, + ) -> Result<()> { info!("Creating kafka consumer for {}", self.bootstrap_servers); let mut client_config = ClientConfig::new(); @@ -205,8 +282,24 @@ impl KafkaSourceOperator { .set("group.id", &group_id) .create()?; - let has_state = false; - let state_map: HashMap = HashMap::new(); + let (has_state, state_map) = if let Some(saved) = saved_offsets { + info!( + job_id = %ctx.job_id, + pipeline_id = ctx.pipeline_id, + subtask = ctx.subtask_index, + epoch = saved.epoch, + safe_epoch = ctx.latest_safe_epoch(), + partitions = saved.partitions.len(), + "Restoring Kafka source offsets from materialized checkpoint snapshot" + ); + let mut m = HashMap::with_capacity(saved.partitions.len()); + for s in saved.partitions { + m.insert(s.partition, s); + } + (true, m) + } else { + (false, HashMap::new()) + }; let metadata = consumer .fetch_metadata(Some(&self.topic), Duration::from_secs(30)) @@ -224,9 +317,10 @@ impl KafkaSourceOperator { for p in partitions { if p.id().rem_euclid(pmax) == ctx.subtask_index as i32 { + // `current_offsets` / snapshot store last consumed offset; resume at next offset. let offset = state_map .get(&p.id()) - .map(|s| Offset::Offset(s.offset)) + .map(|s| Offset::Offset(s.offset.saturating_add(1))) .unwrap_or_else(|| { if has_state { Offset::Beginning @@ -264,7 +358,8 @@ impl SourceOperator for KafkaSourceOperator { } async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { - self.init_and_assign_consumer(ctx).await?; + let saved = load_saved_offsets_if_recovering(ctx); + self.init_and_assign_consumer(ctx, saved).await?; self.rate_limiter = Some(GovernorRateLimiter::direct(Quota::per_second( self.messages_per_second, ))); @@ -363,10 +458,13 @@ impl SourceOperator for KafkaSourceOperator { async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, ctx: &mut TaskContext, - ) -> Result<()> { - debug!("Source [{}] executing checkpoint", ctx.subtask_index); + ) -> Result { + debug!( + "Source [{}] executing checkpoint epoch {}", + ctx.subtask_index, barrier.epoch + ); let mut topic_partitions = TopicPartitionList::new(); for (&partition, &offset) in &self.current_offsets { @@ -381,7 +479,27 @@ impl SourceOperator for KafkaSourceOperator { warn!("Failed to commit async offset to Kafka Broker: {:?}", e); } - Ok(()) + let epoch = u64::from(barrier.epoch); + if self.current_offsets.is_empty() { + return Ok(SourceCheckpointReport::default()); + } + + let kafka_subtask = { + let mut parts: Vec<(i32, i64)> = + self.current_offsets.iter().map(|(&p, &o)| (p, o)).collect(); + parts.sort_by_key(|x| x.0); + KafkaSourceSubtaskCheckpoint { + pipeline_id: ctx.pipeline_id, + subtask_index: ctx.subtask_index, + checkpoint_epoch: epoch, + partitions: parts + .into_iter() + .map(|(partition, offset)| KafkaPartitionOffset { partition, offset }) + .collect(), + } + }; + + Ok(SourceCheckpointReport::from_kafka_checkpoint(kafka_subtask)) } async fn on_close(&mut self, _ctx: &mut TaskContext) -> Result<()> { diff --git a/src/runtime/streaming/operators/value_execution.rs b/src/runtime/streaming/operators/value_execution.rs index ff952dda..b93cd78b 100644 --- a/src/runtime/streaming/operators/value_execution.rs +++ b/src/runtime/streaming/operators/value_execution.rs @@ -17,7 +17,7 @@ use futures::StreamExt; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::operators::StatelessPhysicalExecutor; use crate::sql::common::{CheckpointBarrier, Watermark}; @@ -43,26 +43,31 @@ impl Operator for ValueExecutionOperator { _input_idx: usize, batch: RecordBatch, _ctx: &mut TaskContext, - ) -> Result> { - let mut outputs = Vec::new(); - + collector: &mut dyn Collector, + ) -> Result<()> { let mut stream = self.executor.process_batch(batch).await?; while let Some(batch_result) = stream.next().await { let out_batch = batch_result?; if out_batch.num_rows() > 0 { - outputs.push(StreamOutput::Forward(out_batch)); + collector + .collect(StreamOutput::Forward(out_batch), _ctx) + .await?; } } - Ok(outputs) + Ok(()) } async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![StreamOutput::Watermark(watermark)]) + collector: &mut dyn Collector, + ) -> Result<()> { + collector + .collect(StreamOutput::Watermark(watermark), _ctx) + .await?; + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/watermark/watermark_generator.rs b/src/runtime/streaming/operators/watermark/watermark_generator.rs index b74a92f2..497553eb 100644 --- a/src/runtime/streaming/operators/watermark/watermark_generator.rs +++ b/src/runtime/streaming/operators/watermark/watermark_generator.rs @@ -23,11 +23,11 @@ use datafusion_proto::protobuf::PhysicalExprNode; use prost::Message; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tracing::{debug, info}; +use tracing::debug; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_millis}; use async_trait::async_trait; @@ -107,10 +107,6 @@ impl Operator for WatermarkGeneratorOperator { "ExpressionWatermarkGenerator" } - fn tick_interval(&self) -> Option { - Some(Duration::from_secs(1)) - } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { self.last_event_wall = SystemTime::now(); Ok(()) @@ -121,13 +117,16 @@ impl Operator for WatermarkGeneratorOperator { _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { self.last_event_wall = SystemTime::now(); - let mut outputs = vec![StreamOutput::Forward(batch.clone())]; + collector + .collect(StreamOutput::Forward(batch.clone()), ctx) + .await?; let Some(max_batch_ts) = self.extract_max_timestamp(&batch) else { - return Ok(outputs); + return Ok(()); }; let new_watermark = self.evaluate_watermark(&batch)?; @@ -145,42 +144,27 @@ impl Operator for WatermarkGeneratorOperator { to_millis(self.state.max_watermark) ); - outputs.push(StreamOutput::Watermark(Watermark::EventTime( - self.state.max_watermark, - ))); + collector + .collect( + StreamOutput::Watermark(Watermark::EventTime(self.state.max_watermark)), + ctx, + ) + .await?; self.state.last_watermark_emitted_at = max_batch_ts; self.is_idle = false; } - Ok(outputs) + Ok(()) } async fn process_watermark( &mut self, _watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { - Ok(vec![]) - } - - async fn process_tick( - &mut self, - _tick_index: u64, - ctx: &mut TaskContext, - ) -> Result> { - if let Some(idle_timeout) = self.idle_time { - let elapsed = self.last_event_wall.elapsed().unwrap_or(Duration::ZERO); - if !self.is_idle && elapsed > idle_timeout { - info!( - "task [{}] entering Idle after {:?}", - ctx.subtask_index, idle_timeout - ); - self.is_idle = true; - return Ok(vec![StreamOutput::Watermark(Watermark::Idle)]); - } - } - Ok(vec![]) + _collector: &mut dyn Collector, + ) -> Result<()> { + Ok(()) } async fn snapshot_state( diff --git a/src/runtime/streaming/operators/windows/session_aggregating_window.rs b/src/runtime/streaming/operators/windows/session_aggregating_window.rs index 4293ea7c..2da2c285 100644 --- a/src/runtime/streaming/operators/windows/session_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/session_aggregating_window.rs @@ -30,15 +30,17 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::converter::Converter; use crate::sql::common::{ CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, @@ -170,6 +172,7 @@ impl ActiveSession { } } +#[derive(Clone)] struct SessionWindowResult { window_start: SystemTime, window_end: SystemTime, @@ -389,9 +392,39 @@ pub struct SessionWindowOperator { session_states: HashMap, KeySessionState>, pq_watermark_actions: BTreeMap>>, pq_start_times: BTreeMap>>, + + // LSM-Tree state engine and per-routing-key timestamp index + state_store: Option>, + pending_timestamps: HashMap, BTreeSet>, } impl SessionWindowOperator { + // State key: [RoutingKey bytes] + [8-byte big-endian timestamp] + fn build_state_key(routing_key: &[u8], ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(routing_key.len() + 8); + key.extend_from_slice(routing_key); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() >= 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[key.len() - 8..]); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + + fn extract_routing_key(key: &[u8]) -> Vec { + if key.len() >= 8 { + key[..key.len() - 8].to_vec() + } else { + Vec::new() + } + } + fn filter_batch_by_time( &self, batch: RecordBatch, @@ -430,6 +463,7 @@ impl SessionWindowOperator { &mut self, sorted_batch: RecordBatch, watermark: Option, + is_recovery_replay: bool, ) -> Result<()> { let partition_ranges = if !self.config.input_schema_ref.has_routing_keys() { std::iter::once(0..sorted_batch.num_rows()).collect::>() @@ -470,6 +504,32 @@ impl SessionWindowOperator { .to_vec() }; + // Write-ahead persistence: skip during recovery replay to avoid duplicate writes + if !is_recovery_replay { + let ts_col = key_batch + .column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .unwrap(); + let ts_nanos = ts_col.value(0) as u64; + + let state_key = Self::build_state_key(&row_key, ts_nanos); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + store + .put(state_key, key_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps + .entry(row_key.clone()) + .or_default() + .insert(ts_nanos); + } + let state = self .session_states .entry(row_key.clone()) @@ -529,7 +589,10 @@ impl SessionWindowOperator { Ok(()) } - async fn evaluate_watermark(&mut self, watermark: SystemTime) -> Result> { + async fn evaluate_watermark_with_meta( + &mut self, + watermark: SystemTime, + ) -> Result, Vec)>> { let mut emit_results: Vec<(Vec, Vec)> = Vec::new(); loop { @@ -588,11 +651,7 @@ impl SessionWindowOperator { } } - if emit_results.is_empty() { - return Ok(vec![]); - } - - Ok(vec![self.format_to_arrow(emit_results)?]) + Ok(emit_results) } fn format_to_arrow( @@ -666,52 +725,168 @@ impl Operator for SessionWindowOperator { "SessionWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery & event sourcing: rebuild in-memory sessions from LSM-Tree + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Session Operator recovering active state keys from LSM-Tree..." + ); + + let mut recovered_batches = Vec::new(); + + for key in active_keys { + if let Some(ts) = Self::extract_timestamp(&key) { + let row_key = Self::extract_routing_key(&key); + self.pending_timestamps + .entry(row_key) + .or_default() + .insert(ts); + } + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + recovered_batches.extend(batches); + } + + // Temporal ordering is critical: replay must preserve watermark/session merge invariants + recovered_batches.sort_by_key(|b| { + b.column(self.config.input_schema_ref.timestamp_index) + .as_any() + .downcast_ref::() + .map(|ts| ts.value(0)) + .unwrap_or(0) + }); + + for batch in recovered_batches { + self.ingest_sorted_batch(batch, None, true).await?; + } + + info!( + pipeline_id = ctx.pipeline_id, + "Session Window Operator successfully replayed events and rebuilt in-memory sessions." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory ingestion async fn process_data( &mut self, _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let watermark_time = ctx.current_watermark(); let filtered_batch = self.filter_batch_by_time(batch, watermark_time)?; if filtered_batch.num_rows() == 0 { - return Ok(vec![]); + return Ok(()); } let sorted_batch = self.sort_batch(&filtered_batch)?; - self.ingest_sorted_batch(sorted_batch, watermark_time) + self.ingest_sorted_batch(sorted_batch, watermark_time, false) .await?; - Ok(vec![]) + Ok(()) } + // Watermark-driven session closure with precise LSM-Tree garbage collection async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; - let output_batches = self.evaluate_watermark(current_time).await?; - Ok(output_batches - .into_iter() - .map(StreamOutput::Forward) - .collect()) + let completed_sessions = self.evaluate_watermark_with_meta(current_time).await?; + if completed_sessions.is_empty() { + return Ok(()); + } + + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + + // GC: tombstone expired raw data covered by closed sessions + for (row_key, session_results) in &completed_sessions { + if let Some(ts_set) = self.pending_timestamps.get_mut(row_key) { + for session_res in session_results { + let start_nanos = to_nanos(session_res.window_start) as u64; + let end_nanos = to_nanos(session_res.window_end - self.config.gap) as u64; + + let expired_ts: Vec = + ts_set.range(start_nanos..=end_nanos).copied().collect(); + + for ts in expired_ts { + let state_key = Self::build_state_key(row_key, ts); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + ts_set.remove(&ts); + } + } + } + } + + let output_batch = self.format_to_arrow(completed_sessions)?; + collector + .collect(StreamOutput::Forward(output_batch), _ctx) + .await?; + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + + info!( + epoch = barrier.epoch, + "Session Window Operator snapshotted state." + ); + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -797,6 +972,8 @@ impl SessionAggregatingWindowConstructor { pq_start_times: BTreeMap::new(), pq_watermark_actions: BTreeMap::new(), row_converter, + state_store: None, + pending_timestamps: HashMap::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs index 73ba4dc9..3516e950 100644 --- a/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/sliding_aggregating_window.rs @@ -11,7 +11,7 @@ // limitations under the License. use anyhow::{Result, anyhow, bail}; -use arrow::compute::{partition, sort_to_indices, take}; +use arrow::compute::{concat_batches, partition, sort_to_indices, take}; use arrow_array::{Array, PrimitiveArray, RecordBatch, types::TimestampNanosecondType}; use arrow_schema::SchemaRef; use datafusion::common::ScalarValue; @@ -27,20 +27,49 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::{BTreeMap, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; +use tracing::info; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; use protocol::function_stream_graph::SlidingWindowAggregateOperator; // ============================================================================ +// Dual-layer state key: [StateType(1B)] + [Timestamp(8B BE)] +// STATE_TYPE_RAW = 0 (raw input data, pending partial aggregation) +// STATE_TYPE_PARTIAL = 1 (pre-aggregated pane results) +// ============================================================================ + +const STATE_TYPE_RAW: u8 = 0; +const STATE_TYPE_PARTIAL: u8 = 1; + +fn build_state_key(state_type: u8, ts_nanos: u64) -> Vec { + let mut key = Vec::with_capacity(9); + key.push(state_type); + key.extend_from_slice(&ts_nanos.to_be_bytes()); + key +} + +fn parse_state_key(key: &[u8]) -> Option<(u8, u64)> { + if key.len() == 9 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(&key[1..9]); + Some((key[0], u64::from_be_bytes(ts_bytes))) + } else { + None + } +} + +// ============================================================================ +// RecordBatchTier & TieredRecordBatchHolder // ============================================================================ #[derive(Default, Debug)] @@ -263,6 +292,11 @@ pub struct SlidingWindowOperator { active_bins: BTreeMap, tiered_record_batches: TieredRecordBatchHolder, + + // LSM-Tree state engine with dual-layer index + state_store: Option>, + pending_raw_bins: BTreeSet, + pending_partial_bins: BTreeSet, } impl SlidingWindowOperator { @@ -309,16 +343,89 @@ impl Operator for SlidingWindowOperator { "SlidingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore dual-layer state (partial panes + raw active bins) + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + let mut raw_recovery_batches = Vec::new(); + + for key in active_keys { + if let Some((state_type, ts_nanos)) = parse_state_key(&key) { + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + if state_type == STATE_TYPE_PARTIAL { + let bin_start = from_nanos(ts_nanos as u128); + for b in batches { + self.tiered_record_batches.insert(b, bin_start)?; + } + self.pending_partial_bins.insert(ts_nanos); + } else if state_type == STATE_TYPE_RAW { + let schema = batches[0].schema(); + let combined = concat_batches(&schema, &batches)?; + raw_recovery_batches.push((ts_nanos, combined)); + } + } + } + + // Temporal ordering guarantees correct DataFusion session replay + raw_recovery_batches.sort_by_key(|(ts, _)| *ts); + + for (ts_nanos, batch) in raw_recovery_batches { + let bin_start = from_nanos(ts_nanos as u128); + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + slot.sender + .as_ref() + .unwrap() + .send(batch) + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(ts_nanos); + } + + info!( + pipeline_id = ctx.pipeline_id, + partial_bins = self.pending_partial_bins.len(), + raw_bins = self.pending_raw_bins.len(), + "Sliding Window Operator recovered state." + ); + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data (Type 0) before in-memory computation async fn process_data( &mut self, _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let bin_array = self .binning_function .evaluate(&batch)? @@ -340,6 +447,10 @@ impl Operator for SlidingWindowOperator { let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); let watermark = ctx.current_watermark(); + let store = self + .state_store + .clone() + .expect("State store not initialized"); for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -351,8 +462,16 @@ impl Operator for SlidingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + let key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store + .put(key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.insert(bin_start_nanos); + + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -368,20 +487,24 @@ impl Operator for SlidingWindowOperator { .map_err(|e| anyhow!("partial channel send: {e}"))?; } - Ok(vec![]) + Ok(()) } + // State morphing (Type 0 → Type 1) and dual-layer GC async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; let watermark_bin = self.bin_start(current_time); - - let mut final_outputs = Vec::new(); + let store = self + .state_store + .clone() + .expect("State store not initialized"); let mut expired_bins = Vec::new(); for &k in self.active_bins.keys() { @@ -398,12 +521,34 @@ impl Operator for SlidingWindowOperator { .remove(&bin_start) .ok_or_else(|| anyhow!("missing active bin"))?; let bin_end = bin_start + self.slide; + let bin_start_nanos = to_nanos(bin_start) as u64; + // Phase 1: drain partial aggregation from DataFusion bin.close_and_drain().await?; - for b in bin.finished_batches { - self.tiered_record_batches.insert(b, bin_start)?; + + // Phase 2: state morphing — persist partial result (Type 1), feed tiered holder + if !bin.finished_batches.is_empty() { + let schema = bin.finished_batches[0].schema(); + let combined_partial = concat_batches(&schema, &bin.finished_batches)?; + + let p_key = build_state_key(STATE_TYPE_PARTIAL, bin_start_nanos); + store + .put(p_key, combined_partial) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.insert(bin_start_nanos); + + for b in bin.finished_batches { + self.tiered_record_batches.insert(b, bin_start)?; + } } + // Phase 3: tombstone raw data (Type 0) — no longer needed after partial is saved + let r_key = build_state_key(STATE_TYPE_RAW, bin_start_nanos); + store.remove_batches(r_key).map_err(|e| anyhow!("{e}"))?; + self.pending_raw_bins.remove(&bin_start_nanos); + + // Phase 4: compute final sliding window result let interval_start = bin_end - self.width; let interval_end = bin_end; @@ -433,21 +578,52 @@ impl Operator for SlidingWindowOperator { .execute(0, SessionContext::new().task_ctx())?; while let Some(batch) = proj_exec.next().await { - final_outputs.push(StreamOutput::Forward(batch?)); + collector + .collect(StreamOutput::Forward(batch?), _ctx) + .await?; } - self.tiered_record_batches - .delete_before(bin_end + self.slide - self.width)?; + // Phase 5: GC expired partial bins (Type 1) that fall outside the window + let cutoff_time = bin_end + self.slide - self.width; + self.tiered_record_batches.delete_before(cutoff_time)?; + + let cutoff_nanos = to_nanos(cutoff_time) as u64; + let expired_partials: Vec = self + .pending_partial_bins + .iter() + .take_while(|&&ts| ts < cutoff_nanos) + .copied() + .collect(); + + for ts in expired_partials { + let p_key = build_state_key(STATE_TYPE_PARTIAL, ts); + store.remove_batches(p_key).map_err(|e| anyhow!("{e}"))?; + self.pending_partial_bins.remove(&ts); + } } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -531,6 +707,9 @@ impl SlidingAggregatingWindowConstructor { final_batches_passer, active_bins: BTreeMap::new(), tiered_record_batches: TieredRecordBatchHolder::new(vec![slide])?, + state_store: None, + pending_raw_bins: BTreeSet::new(), + pending_partial_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs index de576bf0..6b6b6029 100644 --- a/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs +++ b/src/runtime/streaming/operators/windows/tumbling_aggregating_window.rs @@ -27,17 +27,18 @@ use datafusion_proto::{ }; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use std::mem; use std::sync::{Arc, RwLock}; use std::time::{Duration, SystemTime}; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; use crate::sql::common::{CheckpointBarrier, FsSchema, Watermark, from_nanos, to_nanos}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; @@ -94,9 +95,28 @@ pub struct TumblingWindowOperator { final_batches_passer: Arc>>, active_bins: BTreeMap, + + // LSM-Tree state engine and pending window timestamp index + state_store: Option>, + pending_bins: BTreeSet, } impl TumblingWindowOperator { + // State key: 8-byte big-endian bin_start_nanos + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn bin_start(&self, timestamp: SystemTime) -> SystemTime { if self.width == Duration::ZERO { return timestamp; @@ -141,16 +161,79 @@ impl Operator for TumblingWindowOperator { "TumblingWindow" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: replay raw data from LSM-Tree into DataFusion sessions + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Tumbling Window Operator recovering active windows from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + let bin_start = from_nanos(ts_nanos as u128); + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + if batches.is_empty() { + continue; + } + + let slot = self.active_bins.entry(bin_start).or_default(); + Self::ensure_bin_running( + slot, + self.partial_aggregation_plan.clone(), + &self.receiver_hook, + )?; + + let sender = slot.sender.as_ref().unwrap(); + for batch in batches { + sender + .send(batch) + .map_err(|e| anyhow!("recovery channel send: {e}"))?; + } + + self.pending_bins.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Tumbling Window Operator successfully replayed events and rebuilt in-memory state." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist raw data before in-memory computation async fn process_data( &mut self, _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let bin_array = self .binning_function .evaluate(&batch)? @@ -171,6 +254,11 @@ impl Operator for TumblingWindowOperator { .ok_or_else(|| anyhow!("binning function must produce TimestampNanosecond"))?; let partition_ranges = partition(std::slice::from_ref(&sorted_bins))?.ranges(); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + for range in partition_ranges { let bin_start = from_nanos(typed_bin.value(range.start) as u128); @@ -186,8 +274,16 @@ impl Operator for TumblingWindowOperator { } let bin_batch = sorted.slice(range.start, range.end - range.start); - let slot = self.active_bins.entry(bin_start).or_default(); + let bin_start_nanos = to_nanos(bin_start) as u64; + + let state_key = Self::build_state_key(bin_start_nanos); + store + .put(state_key, bin_batch.clone()) + .await + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.insert(bin_start_nanos); + let slot = self.active_bins.entry(bin_start).or_default(); Self::ensure_bin_running( slot, self.partial_aggregation_plan.clone(), @@ -203,19 +299,23 @@ impl Operator for TumblingWindowOperator { .map_err(|e| anyhow!("partial channel send: {e}"))?; } - Ok(vec![]) + Ok(()) } + // Watermark-driven window closure with LSM-Tree GC async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; - - let mut final_outputs = Vec::new(); + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); let mut expired_bins = Vec::new(); for &k in self.active_bins.keys() { @@ -227,10 +327,8 @@ impl Operator for TumblingWindowOperator { } for bin_start in expired_bins { - let mut bin = self - .active_bins - .remove(&bin_start) - .ok_or_else(|| anyhow!("missing tumbling bin"))?; + let mut bin = self.active_bins.remove(&bin_start).unwrap(); + let bin_start_nanos = to_nanos(bin_start) as u64; bin.close_and_drain().await?; let partial_batches = mem::take(&mut bin.finished_batches); @@ -255,7 +353,9 @@ impl Operator for TumblingWindowOperator { )?; if self.final_projection.is_none() { - final_outputs.push(StreamOutput::Forward(with_timestamp)); + collector + .collect(StreamOutput::Forward(with_timestamp), _ctx) + .await?; } else { aggregate_results.push(with_timestamp); } @@ -268,19 +368,42 @@ impl Operator for TumblingWindowOperator { final_projection.execute(0, SessionContext::new().task_ctx())?; while let Some(batch) = proj_exec.next().await { - final_outputs.push(StreamOutput::Forward(batch?)); + collector + .collect(StreamOutput::Forward(batch?), _ctx) + .await?; } } + + // Tombstone the raw data — window is fully closed + let state_key = Self::build_state_key(bin_start_nanos); + store + .remove_batches(state_key) + .map_err(|e| anyhow!("{e}"))?; + self.pending_bins.remove(&bin_start_nanos); } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -367,6 +490,8 @@ impl TumblingAggregateWindowConstructor { receiver_hook, final_batches_passer, active_bins: BTreeMap::new(), + state_store: None, + pending_bins: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/operators/windows/window_function.rs b/src/runtime/streaming/operators/windows/window_function.rs index 5e340fec..1249233e 100644 --- a/src/runtime/streaming/operators/windows/window_function.rs +++ b/src/runtime/streaming/operators/windows/window_function.rs @@ -13,7 +13,6 @@ use anyhow::{Result, anyhow}; use arrow::compute::{max, min}; use arrow_array::RecordBatch; -use datafusion::execution::SendableRecordBatchStream; use datafusion::execution::context::SessionContext; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::ExecutionPlan; @@ -21,57 +20,26 @@ use datafusion_proto::physical_plan::AsExecutionPlan; use datafusion_proto::protobuf::PhysicalPlanNode; use futures::StreamExt; use prost::Message; -use std::collections::BTreeMap; +use std::collections::BTreeSet; use std::sync::{Arc, RwLock}; use std::time::SystemTime; -use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel}; -use tracing::warn; +use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel}; +use tracing::{info, warn}; use crate::runtime::streaming::StreamOutput; use crate::runtime::streaming::api::context::TaskContext; -use crate::runtime::streaming::api::operator::Operator; +use crate::runtime::streaming::api::operator::{Collector, Operator}; use crate::runtime::streaming::factory::Registry; +use crate::runtime::streaming::state::OperatorStateStore; use crate::sql::common::time_utils::print_time; -use crate::sql::common::{CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos}; +use crate::sql::common::{ + CheckpointBarrier, FsSchema, FsSchemaRef, Watermark, from_nanos, to_nanos, +}; use crate::sql::physical::{StreamingDecodingContext, StreamingExtensionCodec}; use async_trait::async_trait; // ============================================================================ -// ============================================================================ - -struct ActiveWindowExec { - sender: Option>, - result_stream: Option, -} - -impl ActiveWindowExec { - fn new( - plan: Arc, - hook: &Arc>>>, - ) -> Result { - let (tx, rx) = unbounded_channel(); - *hook.write().unwrap() = Some(rx); - plan.reset()?; - let result_stream = plan.execute(0, SessionContext::new().task_ctx())?; - Ok(Self { - sender: Some(tx), - result_stream: Some(result_stream), - }) - } - - async fn close_and_drain(&mut self) -> Result> { - self.sender.take(); - let mut results = Vec::new(); - if let Some(mut stream) = self.result_stream.take() { - while let Some(batch) = stream.next().await { - results.push(batch?); - } - } - Ok(results) - } -} - -// ============================================================================ +// WindowFunctionOperator: LSM-Tree backed lazy-compute model // ============================================================================ pub struct WindowFunctionOperator { @@ -79,10 +47,28 @@ pub struct WindowFunctionOperator { input_schema_unkeyed: FsSchemaRef, window_exec_plan: Arc, receiver_hook: Arc>>>, - active_execs: BTreeMap, + + // LSM-Tree state engine and lightweight timestamp index + state_store: Option>, + pending_timestamps: BTreeSet, } impl WindowFunctionOperator { + // State key: 8-byte big-endian timestamp (nanos) + fn build_state_key(ts_nanos: u64) -> Vec { + ts_nanos.to_be_bytes().to_vec() + } + + fn extract_timestamp(key: &[u8]) -> Option { + if key.len() == 8 { + let mut ts_bytes = [0u8; 8]; + ts_bytes.copy_from_slice(key); + Some(u64::from_be_bytes(ts_bytes)) + } else { + None + } + } + fn filter_and_split_batches( &self, batch: RecordBatch, @@ -137,18 +123,6 @@ impl WindowFunctionOperator { } Ok(batches) } - - fn get_or_create_exec(&mut self, timestamp: SystemTime) -> Result<&mut ActiveWindowExec> { - use std::collections::btree_map::Entry; - match self.active_execs.entry(timestamp) { - Entry::Vacant(v) => { - let new_exec = - ActiveWindowExec::new(self.window_exec_plan.clone(), &self.receiver_hook)?; - Ok(v.insert(new_exec)) - } - Entry::Occupied(o) => Ok(o.into_mut()), - } - } } #[async_trait] @@ -157,70 +131,155 @@ impl Operator for WindowFunctionOperator { "WindowFunction" } - async fn on_start(&mut self, _ctx: &mut TaskContext) -> Result<()> { + // Recovery: restore the lightweight timestamp index from LSM-Tree. + // Data stays on disk until process_watermark triggers on-demand compute. + async fn on_start(&mut self, ctx: &mut TaskContext) -> Result<()> { + let pipeline_block = ctx + .pipeline_state_memory_block + .as_ref() + .ok_or_else(|| anyhow!("missing pipeline state memory block"))?; + let store = OperatorStateStore::new( + ctx.pipeline_id, + ctx.state_dir.clone(), + ctx.io_manager.clone(), + Arc::clone(pipeline_block), + ctx.operator_state_memory_bytes, + ) + .map_err(|e| anyhow!("Failed to init state store: {e}"))?; + + let safe_epoch = ctx.latest_safe_epoch(); + let active_keys = store + .restore_metadata(safe_epoch) + .await + .map_err(|e| anyhow!("State recovery failed: {e}"))?; + + if !active_keys.is_empty() { + info!( + pipeline_id = ctx.pipeline_id, + key_count = active_keys.len(), + "Window Function Operator recovering active timestamps from LSM-Tree..." + ); + + for key in active_keys { + if let Some(ts_nanos) = Self::extract_timestamp(&key) { + self.pending_timestamps.insert(ts_nanos); + } + } + + info!( + pipeline_id = ctx.pipeline_id, + "Window Function Operator successfully rebuilt in-memory indices." + ); + } + + self.state_store = Some(store); Ok(()) } + // Write-ahead: persist data into LSM-Tree, defer computation to watermark async fn process_data( &mut self, _input_idx: usize, batch: RecordBatch, ctx: &mut TaskContext, - ) -> Result> { + _collector: &mut dyn Collector, + ) -> Result<()> { let current_watermark = ctx.current_watermark(); let split_batches = self.filter_and_split_batches(batch, current_watermark)?; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); for (sub_batch, timestamp) in split_batches { - let exec = self.get_or_create_exec(timestamp)?; - exec.sender - .as_ref() - .ok_or_else(|| anyhow!("window exec sender missing"))? - .send(sub_batch) - .map_err(|e| anyhow!("route batch to plan: {e}"))?; + let ts_nanos = to_nanos(timestamp) as u64; + let key = Self::build_state_key(ts_nanos); + + store + .put(key, sub_batch) + .await + .map_err(|e| anyhow!("{e}"))?; + + self.pending_timestamps.insert(ts_nanos); } - Ok(vec![]) + Ok(()) } + // On-demand compute & GC: pull data from LSM-Tree, run DataFusion, tombstone async fn process_watermark( &mut self, watermark: Watermark, _ctx: &mut TaskContext, - ) -> Result> { + collector: &mut dyn Collector, + ) -> Result<()> { let Watermark::EventTime(current_time) = watermark else { - return Ok(vec![]); + return Ok(()); }; - - let mut final_outputs = Vec::new(); - - let mut expired_timestamps = Vec::new(); - for &k in self.active_execs.keys() { - if k < current_time { - expired_timestamps.push(k); - } else { - break; + let store = self + .state_store + .as_ref() + .expect("State store not initialized"); + let current_nanos = to_nanos(current_time) as u64; + + let expired_ts: Vec = self + .pending_timestamps + .iter() + .take_while(|&&ts| ts < current_nanos) + .copied() + .collect(); + + for ts in expired_ts { + let key = Self::build_state_key(ts); + + let batches = store.get_batches(&key).await.map_err(|e| anyhow!("{e}"))?; + + if !batches.is_empty() { + let (tx, rx) = unbounded_channel(); + *self.receiver_hook.write().unwrap() = Some(rx); + + self.window_exec_plan.reset()?; + let mut stream = self + .window_exec_plan + .execute(0, SessionContext::new().task_ctx())?; + + for batch in batches { + tx.send(batch) + .map_err(|e| anyhow!("Failed to send batch to execution plan: {e}"))?; + } + drop(tx); + + while let Some(res) = stream.next().await { + collector.collect(StreamOutput::Forward(res?), _ctx).await?; + } } - } - for ts in expired_timestamps { - let mut exec = self - .active_execs - .remove(&ts) - .ok_or_else(|| anyhow!("missing window exec"))?; - let result_batches = exec.close_and_drain().await?; - for batch in result_batches { - final_outputs.push(StreamOutput::Forward(batch)); - } + store.remove_batches(key).map_err(|e| anyhow!("{e}"))?; + self.pending_timestamps.remove(&ts); } - Ok(final_outputs) + Ok(()) } async fn snapshot_state( &mut self, - _barrier: CheckpointBarrier, + barrier: CheckpointBarrier, _ctx: &mut TaskContext, ) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .prepare_checkpoint_epoch(barrier.epoch as u64) + .map_err(|e| anyhow!("Snapshot failed: {e}"))?; + Ok(()) + } + + async fn commit_checkpoint(&mut self, epoch: u32, _ctx: &mut TaskContext) -> Result<()> { + self.state_store + .as_ref() + .expect("State store not initialized") + .commit_checkpoint_epoch(epoch as u64) + .map_err(|e| anyhow!("Commit checkpoint failed: {e}"))?; Ok(()) } @@ -275,7 +334,8 @@ impl WindowFunctionConstructor { input_schema_unkeyed, window_exec_plan, receiver_hook, - active_execs: BTreeMap::new(), + state_store: None, + pending_timestamps: BTreeSet::new(), }) } } diff --git a/src/runtime/streaming/protocol/control.rs b/src/runtime/streaming/protocol/control.rs index 3b23cb09..6d0bc492 100644 --- a/src/runtime/streaming/protocol/control.rs +++ b/src/runtime/streaming/protocol/control.rs @@ -11,6 +11,7 @@ // limitations under the License. use super::event::CheckpointBarrier; +use protocol::storage::SourceCheckpointPayload; use serde::{Deserialize, Serialize}; use std::time::Duration; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -55,11 +56,24 @@ impl From for CheckpointBarrier { #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ControlCommand { Start, - Stop { mode: StopMode }, + Stop { + mode: StopMode, + }, DropState, - Commit { epoch: u32 }, - UpdateConfig { config_json: String }, - TriggerCheckpoint { barrier: CheckpointBarrierWire }, + /// Phase 2 of checkpoint 2PC: metadata durable; transactional Kafka sink should `commit_transaction`. + Commit { + epoch: u32, + }, + /// Roll back pre-committed transactional Kafka writes when checkpoint metadata commit failed or barrier declined. + AbortCheckpoint { + epoch: u32, + }, + UpdateConfig { + config_json: String, + }, + TriggerCheckpoint { + barrier: CheckpointBarrierWire, + }, } impl ControlCommand { @@ -79,3 +93,18 @@ pub enum StopMode { pub fn control_channel(capacity: usize) -> (Sender, Receiver) { mpsc::channel(capacity) } + +#[derive(Debug, Clone)] +pub enum JobMasterEvent { + CheckpointAck { + pipeline_id: u32, + epoch: u64, + /// Source protocol checkpoint payloads (enum-style oneof envelope). + source_payloads: Vec, + }, + CheckpointDecline { + pipeline_id: u32, + epoch: u64, + reason: String, + }, +} diff --git a/src/runtime/streaming/protocol/event.rs b/src/runtime/streaming/protocol/event.rs index 823035f8..21be6852 100644 --- a/src/runtime/streaming/protocol/event.rs +++ b/src/runtime/streaming/protocol/event.rs @@ -17,7 +17,7 @@ use std::time::SystemTime; use arrow_array::RecordBatch; -use crate::runtime::streaming::memory::MemoryTicket; +use crate::runtime::memory::MemoryTicket; #[derive(Debug, Copy, Clone, PartialEq, Eq, Encode, Decode, Serialize, Deserialize)] pub enum Watermark { diff --git a/src/runtime/streaming/protocol/mod.rs b/src/runtime/streaming/protocol/mod.rs index e91e8d8c..28fd85a4 100644 --- a/src/runtime/streaming/protocol/mod.rs +++ b/src/runtime/streaming/protocol/mod.rs @@ -13,4 +13,6 @@ pub mod control; pub mod event; +#[allow(unused_imports)] +pub use control::{ControlCommand, JobMasterEvent, StopMode}; pub use event::{CheckpointBarrier, StreamOutput, Watermark}; diff --git a/src/runtime/streaming/state/error.rs b/src/runtime/streaming/state/error.rs new file mode 100644 index 00000000..37bc6481 --- /dev/null +++ b/src/runtime/streaming/state/error.rs @@ -0,0 +1,51 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crossbeam_channel::TrySendError; +use thiserror::Error; + +use crate::runtime::memory::MemoryAllocationError; + +#[derive(Error, Debug)] +pub enum StateEngineError { + #[error("I/O error during state persistence: {0}")] + IoError(#[from] std::io::Error), + + #[error("Parquet serialization/deserialization failed: {0}")] + ParquetError(#[from] parquet::errors::ParquetError), + + #[error("Arrow computation failed: {0}")] + ArrowError(#[from] arrow::error::ArrowError), + + #[error("Memory hard limit exceeded and spill channel is full")] + MemoryBackpressureTimeout, + + #[error("Background I/O pool has been shut down or disconnected")] + IoPoolDisconnected, + + #[error("State metadata corrupted: {0}")] + Corruption(String), + + #[error("State memory block reservation failed: {0}")] + MemoryReservation(#[from] MemoryAllocationError), +} + +pub type Result = std::result::Result; + +impl From> for StateEngineError { + fn from(err: TrySendError) -> Self { + match err { + TrySendError::Full(_) => StateEngineError::MemoryBackpressureTimeout, + TrySendError::Disconnected(_) => StateEngineError::IoPoolDisconnected, + } + } +} diff --git a/src/runtime/streaming/state/io_manager.rs b/src/runtime/streaming/state/io_manager.rs new file mode 100644 index 00000000..9b37da1d --- /dev/null +++ b/src/runtime/streaming/state/io_manager.rs @@ -0,0 +1,151 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +#[allow(unused_imports)] +use super::error::StateEngineError; +use super::metrics::StateMetricsCollector; +use super::operator_state::{MemTable, OperatorStateStore, TombstoneMap}; +use crossbeam_channel::{Receiver, Sender, TrySendError, bounded}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; +use std::time::Instant; + +pub struct SpillJob { + pub store: Arc, + pub epoch: u64, + pub data: MemTable, + pub tombstone_snapshot: TombstoneMap, +} + +pub enum CompactJob { + Minor { store: Arc }, + Major { store: Arc }, +} + +pub struct IoPool { + spill_tx: Option>, + compact_tx: Option>, + worker_handles: Vec>, +} + +impl IoPool { + pub fn try_new( + spill_threads: usize, + compact_threads: usize, + metrics: Arc, + ) -> std::io::Result<(Self, IoManager)> { + let (spill_tx, spill_rx) = bounded::(1024); + let (compact_tx, compact_rx) = bounded::(256); + let mut worker_handles = Vec::with_capacity(spill_threads + compact_threads); + + for i in 0..spill_threads.max(1) { + let rx = spill_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-spill-worker-{i}")) + .spawn(move || spill_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + for i in 0..compact_threads.max(1) { + let rx = compact_rx.clone(); + let m = metrics.clone(); + let handle = thread::Builder::new() + .name(format!("fs-compact-worker-{i}")) + .spawn(move || compact_worker_loop(rx, m))?; + worker_handles.push(handle); + } + + let manager = IoManager { + spill_tx: spill_tx.clone(), + compact_tx: compact_tx.clone(), + }; + + Ok(( + Self { + spill_tx: Some(spill_tx), + compact_tx: Some(compact_tx), + worker_handles, + }, + manager, + )) + } + + pub fn shutdown(mut self) { + tracing::info!("Initiating graceful shutdown for IoPool..."); + self.spill_tx.take(); + self.compact_tx.take(); + for handle in self.worker_handles.drain(..) { + if let Err(e) = handle.join() { + tracing::error!("I/O Worker thread panicked during shutdown: {:?}", e); + } + } + tracing::info!("IoPool graceful shutdown completed."); + } +} + +#[derive(Clone)] +pub struct IoManager { + spill_tx: Sender, + compact_tx: Sender, +} + +impl IoManager { + pub fn try_send_spill(&self, job: SpillJob) -> Result<(), TrySendError> { + self.spill_tx.try_send(job) + } + pub fn try_send_compact(&self, job: CompactJob) -> Result<(), TrySendError> { + self.compact_tx.try_send(job) + } + pub fn pending_spills(&self) -> usize { + self.spill_tx.len() + } +} + +fn spill_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let op_id = job.store.operator_id; + let epoch = job.epoch; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + job.store + .execute_spill_sync(job.epoch, job.data, job.tombstone_snapshot, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_spill_duration(op_id, duration_ms); + + match result { + Ok(Ok(())) => tracing::debug!(op_id, epoch, duration_ms, "Spill success"), + Ok(Err(e)) => tracing::error!(op_id, epoch, duration_ms, %e, "Spill I/O Error"), + Err(_) => tracing::error!(op_id, epoch, "CRITICAL: Spill thread PANICKED! Recovered."), + } + } +} + +fn compact_worker_loop(rx: Receiver, metrics: Arc) { + while let Ok(job) = rx.recv() { + let (store, is_major) = match job { + CompactJob::Minor { store } => (store, false), + CompactJob::Major { store } => (store, true), + }; + + let op_id = store.operator_id; + let start = Instant::now(); + + let result = catch_unwind(AssertUnwindSafe(|| { + store.execute_compact_sync(is_major, &metrics) + })); + + let duration_ms = start.elapsed().as_millis(); + metrics.record_compaction_duration(op_id, is_major, duration_ms); + + match result { + Ok(Ok(())) => tracing::info!(op_id, is_major, duration_ms, "Compaction success"), + Ok(Err(e)) => tracing::error!(op_id, is_major, duration_ms, %e, "Compaction I/O Error"), + Err(_) => tracing::error!(op_id, is_major, "CRITICAL: Compact thread PANICKED!"), + } + } +} diff --git a/src/runtime/streaming/state/metrics.rs b/src/runtime/streaming/state/metrics.rs new file mode 100644 index 00000000..4a86a64f --- /dev/null +++ b/src/runtime/streaming/state/metrics.rs @@ -0,0 +1,18 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. + +pub trait StateMetricsCollector: Send + Sync + 'static { + fn record_memory_usage(&self, operator_id: u32, bytes: u64); + fn record_spill_duration(&self, operator_id: u32, duration_ms: u128); + fn record_compaction_duration(&self, operator_id: u32, is_major: bool, duration_ms: u128); + fn inc_io_errors(&self, operator_id: u32); +} + +/// Default no-op implementation. +pub struct NoopMetricsCollector; +impl StateMetricsCollector for NoopMetricsCollector { + fn record_memory_usage(&self, _: u32, _: u64) {} + fn record_spill_duration(&self, _: u32, _: u128) {} + fn record_compaction_duration(&self, _: u32, _: bool, _: u128) {} + fn inc_io_errors(&self, _: u32) {} +} diff --git a/src/runtime/streaming/memory/mod.rs b/src/runtime/streaming/state/mod.rs similarity index 59% rename from src/runtime/streaming/memory/mod.rs rename to src/runtime/streaming/state/mod.rs index 45fc3194..7d5bb3ef 100644 --- a/src/runtime/streaming/memory/mod.rs +++ b/src/runtime/streaming/state/mod.rs @@ -10,8 +10,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod pool; -pub mod ticket; +pub mod error; +mod io_manager; +pub mod metrics; +mod operator_state; -pub use pool::MemoryPool; -pub use ticket::MemoryTicket; +#[allow(unused_imports)] +pub use error::{Result, StateEngineError}; +#[allow(unused_imports)] +pub use io_manager::{CompactJob, IoManager, IoPool, SpillJob}; +#[allow(unused_imports)] +pub use metrics::{NoopMetricsCollector, StateMetricsCollector}; +#[allow(unused_imports)] +pub use operator_state::OperatorStateStore; diff --git a/src/runtime/streaming/state/operator_state.rs b/src/runtime/streaming/state/operator_state.rs new file mode 100644 index 00000000..a3514461 --- /dev/null +++ b/src/runtime/streaming/state/operator_state.rs @@ -0,0 +1,1176 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::error::{Result, StateEngineError}; +use super::io_manager::{CompactJob, IoManager, SpillJob}; +use super::metrics::StateMetricsCollector; +use crate::runtime::memory::{MemoryBlock, MemoryTicket}; +use arrow_array::builder::{BinaryBuilder, BooleanBuilder, UInt64Builder}; +use arrow_array::{Array, BinaryArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; +use crossbeam_channel::TrySendError; +use parking_lot::{Mutex, RwLock}; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::fs::{self, File}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; +use tokio::sync::Notify; +use uuid::Uuid; + +pub(crate) const PARTITION_KEY_COL: &str = "__fs_partition_key"; + +pub type PartitionKey = Vec; +pub type MemTable = HashMap>; +pub type TombstoneMap = HashMap; + +const TOMBSTONE_ENTRY_OVERHEAD: usize = 8 + 16; + +pub struct OperatorStateStore { + pub operator_id: u32, + current_epoch: AtomicU64, + + active_table: RwLock, + immutable_tables: Mutex>, + + data_files: RwLock>, + tombstone_files: RwLock>, + tombstones: RwLock, + + state_ticket: Arc, + state_used: AtomicU64, + state_quota: u64, + soft_limit: u64, + io_manager: IoManager, + + data_dir: PathBuf, + tombstone_dir: PathBuf, + + spill_notify: Arc, + is_spilling: AtomicBool, + is_compacting: AtomicBool, +} + +const DEFAULT_SOFT_LIMIT_RATIO: f64 = 0.7; + +impl OperatorStateStore { + /// `pipeline_state_memory_block` is the pipeline-wide slab reserved at job spawn; this store + /// takes one ticket of `operator_state_memory_bytes` from it. + pub fn new( + operator_id: u32, + base_dir: impl AsRef, + io_manager: IoManager, + pipeline_state_memory_block: Arc, + operator_state_memory_bytes: u64, + ) -> Result> { + let ticket = pipeline_state_memory_block + .try_allocate(operator_state_memory_bytes) + .ok_or_else(|| { + StateEngineError::Corruption( + "pipeline state memory block exhausted (operator state ticket)".into(), + ) + })?; + let state_ticket = Arc::new(ticket); + let soft_limit = (operator_state_memory_bytes as f64 * DEFAULT_SOFT_LIMIT_RATIO) as u64; + + let op_dir = base_dir.as_ref().join(format!("op_{operator_id}")); + let data_dir = op_dir.join("data"); + let tombstone_dir = op_dir.join("tombstones"); + + fs::create_dir_all(&data_dir).map_err(StateEngineError::IoError)?; + fs::create_dir_all(&tombstone_dir).map_err(StateEngineError::IoError)?; + + Ok(Arc::new(Self { + operator_id, + current_epoch: AtomicU64::new(1), + active_table: RwLock::new(HashMap::new()), + immutable_tables: Mutex::new(VecDeque::new()), + data_files: RwLock::new(Vec::new()), + tombstone_files: RwLock::new(Vec::new()), + tombstones: RwLock::new(HashMap::new()), + state_ticket, + state_used: AtomicU64::new(0), + state_quota: operator_state_memory_bytes, + soft_limit, + io_manager, + data_dir, + tombstone_dir, + spill_notify: Arc::new(Notify::new()), + is_spilling: AtomicBool::new(false), + is_compacting: AtomicBool::new(false), + })) + } + + fn state_bytes_used(&self) -> u64 { + self.state_used.load(Ordering::Relaxed) + } + + fn state_should_spill(&self) -> bool { + self.state_bytes_used() > self.soft_limit + } + + fn rebuild_state_used_from_tables(&self) { + let mut n = 0u64; + for rows in self.active_table.read().values() { + for b in rows { + n += b.get_array_memory_size() as u64; + } + } + for (_, table) in self.immutable_tables.lock().iter() { + for rows in table.values() { + for b in rows { + n += b.get_array_memory_size() as u64; + } + } + } + self.state_used.store(n, Ordering::Release); + } + + async fn wait_until_memory_available_async(self: Arc, need: u64) { + while self.state_used.load(Ordering::Relaxed).saturating_add(need) > self.state_quota { + self.trigger_spill(); + self.spill_notify.notified().await; + } + } + + fn wait_until_memory_available_blocking(self: &Arc, need: u64) -> Result<()> { + loop { + if self.state_used.load(Ordering::Relaxed).saturating_add(need) <= self.state_quota { + return Ok(()); + } + self.trigger_spill(); + let start = Instant::now(); + while self.is_spilling.load(Ordering::SeqCst) { + if start.elapsed() > Duration::from_secs(120) { + return Err(StateEngineError::Corruption( + "state memory wait for spill timed out".into(), + )); + } + std::thread::sleep(Duration::from_millis(1)); + } + } + } + + pub async fn put(self: &Arc, key: PartitionKey, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size() as u64; + self.clone().wait_until_memory_available_async(size).await; + self.state_used.fetch_add(size, Ordering::Relaxed); + self.active_table + .write() + .entry(key) + .or_default() + .push(batch); + + if self.state_should_spill() { + self.downgrade_active_table(self.current_epoch.load(Ordering::Acquire)); + self.trigger_spill(); + } + Ok(()) + } + + pub fn remove_batches(self: &Arc, key: PartitionKey) -> Result<()> { + let current_ep = self.current_epoch.load(Ordering::Acquire); + let tombstone_mem_size = (key.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64; + + { + let mut tb_guard = self.tombstones.write(); + if !tb_guard.contains_key(&key) { + self.wait_until_memory_available_blocking(tombstone_mem_size)?; + self.state_used + .fetch_add(tombstone_mem_size, Ordering::Relaxed); + tb_guard.insert(key.clone(), current_ep); + } + } + + let released_active: u64 = self + .active_table + .write() + .remove(&key) + .map(|rows| rows.iter().map(|b| b.get_array_memory_size() as u64).sum()) + .unwrap_or(0); + + let mut released_imm = 0u64; + for (_, table) in self.immutable_tables.lock().iter_mut() { + if let Some(rows) = table.remove(&key) { + released_imm += rows + .iter() + .map(|b| b.get_array_memory_size() as u64) + .sum::(); + } + } + + let released = released_active.saturating_add(released_imm); + if released > 0 { + self.state_used.fetch_sub(released, Ordering::Relaxed); + } + + Ok(()) + } + + /// Checkpoint phase 1: flush mutable in-memory state into an epoch-tagged immutable table and + /// trigger spill. This does NOT advance `current_epoch`. + pub fn prepare_checkpoint_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.downgrade_active_table(epoch); + self.trigger_spill(); + Ok(()) + } + + /// Checkpoint phase 2: once global metadata commit succeeds, advance the durable safe epoch. + pub fn commit_checkpoint_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.current_epoch + .store(epoch.saturating_add(1), Ordering::Release); + Ok(()) + } + + /// Checkpoint rollback: do not advance `current_epoch`. Any already-spilled files are kept and + /// filtered by safe epoch during restore. + pub fn abort_checkpoint_epoch(self: &Arc, _epoch: u64) -> Result<()> { + Ok(()) + } + + /// Backward-compatible helper (phase1 + phase2 in one call). + pub fn snapshot_epoch(self: &Arc, epoch: u64) -> Result<()> { + self.prepare_checkpoint_epoch(epoch)?; + self.commit_checkpoint_epoch(epoch)?; + Ok(()) + } + + pub async fn await_spill_complete(&self) { + while self.is_spilling.load(Ordering::SeqCst) { + self.spill_notify.notified().await; + } + } + + fn downgrade_active_table(&self, epoch: u64) { + let mut active_guard = self.active_table.write(); + if active_guard.is_empty() { + return; + } + let old_active = std::mem::take(&mut *active_guard); + self.immutable_tables.lock().push_back((epoch, old_active)); + } + + pub async fn get_batches(&self, key: &[u8]) -> Result> { + let deleted_epoch = self.tombstones.read().get(key).copied(); + let mut out = Vec::new(); + + if let Some(batches) = self.active_table.read().get(key) { + out.extend(batches.clone()); + } + + for (table_epoch, table) in self.immutable_tables.lock().iter().rev() { + if let Some(del_ep) = deleted_epoch + && *table_epoch <= del_ep + { + continue; + } + if let Some(batches) = table.get(key) { + out.extend(batches.clone()); + } + } + + let paths: Vec = self.data_files.read().clone(); + if paths.is_empty() { + return Ok(out); + } + + let pk = key.to_vec(); + let merged = tokio::task::spawn_blocking(move || -> Result> { + let mut acc = Vec::new(); + for path in paths { + let file_epoch = extract_epoch(&path); + if let Some(del_ep) = deleted_epoch + && file_epoch <= del_ep + { + continue; + } + + // Native Bloom Filter intercepts empty reads here + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let mut reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for maybe in reader.by_ref() { + if let Some(filtered) = filter_and_strip_partition_key(&maybe?, &pk)? { + acc.push(filtered); + } + } + } + Ok(acc) + }) + .await + .map_err(|_| StateEngineError::Corruption("Tokio task panicked".into()))??; + + out.extend(merged); + Ok(out) + } + + fn trigger_spill(self: &Arc) { + if !self.is_spilling.swap(true, Ordering::SeqCst) { + let target = self.immutable_tables.lock().pop_front(); + let Some((epoch, data)) = target else { + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return; + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let job = SpillJob { + store: self.clone(), + epoch, + data, + tombstone_snapshot, + }; + + match self.io_manager.try_send_spill(job) { + Ok(()) => {} + Err(TrySendError::Full(j)) | Err(TrySendError::Disconnected(j)) => { + self.immutable_tables.lock().push_front((j.epoch, j.data)); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + } + } + } + } + + pub fn trigger_minor_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Minor { + store: self.clone(), + }); + } + } + + pub fn trigger_major_compaction(self: &Arc) { + if !self.is_compacting.swap(true, Ordering::SeqCst) { + let _ = self.io_manager.try_send_compact(CompactJob::Major { + store: self.clone(), + }); + } + } + + pub(crate) fn execute_spill_sync( + self: &Arc, + epoch: u64, + data: MemTable, + tombstones: TombstoneMap, + metrics: &Arc, + ) -> Result<()> { + let mut batches_to_write = Vec::new(); + let mut spilled_bytes: u64 = 0; + let distinct_keys_count = data.len() as u64; + + for (key, batches) in data { + for batch in batches { + spilled_bytes += batch.get_array_memory_size() as u64; + batches_to_write.push(inject_partition_key(&batch, &key)?); + } + } + + if !batches_to_write.is_empty() { + let path = self.data_dir.join(Self::generate_data_file_name(epoch)); + if let Err(e) = + write_parquet_with_bloom_atomic(&path, &batches_to_write, distinct_keys_count) + { + metrics.inc_io_errors(self.operator_id); + let restored = restore_memtable_from_injected_batches(batches_to_write)?; + self.immutable_tables.lock().push_front((epoch, restored)); + self.rebuild_state_used_from_tables(); + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + return Err(e); + } + self.data_files.write().push(path); + } + + if !tombstones.is_empty() { + let mut key_builder = BinaryBuilder::new(); + let mut epoch_builder = UInt64Builder::new(); + let tomb_ndv = tombstones.len() as u64; + + for (key, del_epoch) in tombstones.iter() { + key_builder.append_value(key); + epoch_builder.append_value(*del_epoch); + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("deleted_key", DataType::Binary, false), + Field::new("deleted_epoch", DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(key_builder.finish()), + Arc::new(epoch_builder.finish()), + ], + )?; + + let path = self + .tombstone_dir + .join(Self::generate_tombstone_file_name(epoch)); + if let Err(e) = write_parquet_with_bloom_atomic(&path, &[batch], tomb_ndv) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + self.tombstone_files.write().push(path); + } + + if spilled_bytes > 0 { + self.state_used.fetch_sub(spilled_bytes, Ordering::Relaxed); + } + + metrics.record_memory_usage(self.operator_id, self.state_bytes_used()); + + self.is_spilling.store(false, Ordering::SeqCst); + self.spill_notify.notify_waiters(); + + if !self.immutable_tables.lock().is_empty() { + self.trigger_spill(); + } + Ok(()) + } + + pub(crate) fn execute_compact_sync( + self: &Arc, + is_major: bool, + metrics: &Arc, + ) -> Result<()> { + let result = (|| -> Result<()> { + let files_to_merge = { + let df = self.data_files.read(); + if df.len() < 2 { + return Ok(()); + } + if is_major { + df.clone() + } else { + df.iter().take(2).cloned().collect() + } + }; + + let tombstone_snapshot = self.tombstones.read().clone(); + let compacted_watermark_epoch = files_to_merge + .iter() + .map(|p| extract_epoch(p)) + .max() + .unwrap_or(0); + let new_path = self + .data_dir + .join(Self::generate_data_file_name(compacted_watermark_epoch)); + + let mut all_batches = Vec::new(); + let mut estimated_rows = 0; + + for path in &files_to_merge { + let file_epoch = extract_epoch(path); + let file = File::open(path).map_err(StateEngineError::IoError)?; + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { + let b = batch?; + if let Some(filtered) = + filter_tombstones_from_batch(&b, &tombstone_snapshot, file_epoch)? + { + estimated_rows += filtered.num_rows() as u64; + all_batches.push(filtered); + } + } + } + + if !all_batches.is_empty() { + if let Err(e) = write_parquet_with_bloom_atomic( + &new_path, + &all_batches, + estimated_rows.max(100), + ) { + metrics.inc_io_errors(self.operator_id); + return Err(e); + } + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + df.push(new_path); + } else { + let mut df = self.data_files.write(); + df.retain(|p| !files_to_merge.contains(p)); + } + + for path in &files_to_merge { + let _ = fs::remove_file(path); + } + + { + let mut tg = self.tombstones.write(); + let keys_before: Vec = tg.keys().cloned().collect(); + tg.retain(|_key, deleted_epoch| *deleted_epoch > compacted_watermark_epoch); + let mut tomb_freed = 0u64; + for k in keys_before { + if !tg.contains_key(&k) { + tomb_freed += (k.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64; + } + } + if tomb_freed > 0 { + self.state_used.fetch_sub(tomb_freed, Ordering::Relaxed); + } + metrics.record_memory_usage(self.operator_id, self.state_bytes_used()); + } + + { + let mut tf_guard = self.tombstone_files.write(); + tf_guard.retain(|p| { + if extract_epoch(p) <= compacted_watermark_epoch { + let _ = fs::remove_file(p); + return false; + } + true + }); + } + + Ok(()) + })(); + + self.is_compacting.store(false, Ordering::SeqCst); + result + } + + pub async fn restore_metadata( + self: &Arc, + safe_epoch: u64, + ) -> Result> { + self.state_used.store(0, Ordering::Release); + self.active_table.write().clear(); + self.immutable_tables + .lock() + .retain(|(e, _)| *e <= safe_epoch); + + let cleanup_future = |files: &mut Vec| { + files.retain(|path| { + if extract_epoch(path) > safe_epoch { + let _ = fs::remove_file(path); + false + } else { + true + } + }); + }; + cleanup_future(&mut self.data_files.write()); + cleanup_future(&mut self.tombstone_files.write()); + + let tomb_paths = self.tombstone_files.read().clone(); + type RawTombstones = HashMap; + let raw_tombstones = tokio::task::spawn_blocking(move || -> Result { + let mut map = RawTombstones::new(); + for path in tomb_paths { + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?; + for batch in reader { + let batch = batch?; + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let ep_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let e = ep_col.value(i); + let current_max = map.get(&k).copied().unwrap_or(0); + if e > current_max { + map.insert(k, e); + } + } + } + } + Ok(map) + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + let tomb_epoch_map = raw_tombstones.clone(); + + *self.tombstones.write() = raw_tombstones; + self.rebuild_state_used_from_tables(); + let tomb_overhead: u64 = self + .tombstones + .read() + .keys() + .map(|k| (k.len() + TOMBSTONE_ENTRY_OVERHEAD) as u64) + .sum(); + if tomb_overhead > 0 { + self.state_used.fetch_add(tomb_overhead, Ordering::Relaxed); + } + + let data_paths = self.data_files.read().clone(); + let active_keys = tokio::task::spawn_blocking(move || -> Result> { + let mut keys = HashSet::new(); + for path in data_paths { + let file_epoch = extract_epoch(&path); + let file = File::open(&path).map_err(StateEngineError::IoError)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let schema = builder.parquet_schema(); + let mask = ProjectionMask::leaves(schema, vec![schema.columns().len() - 1]); + let reader = builder.with_projection(mask).build()?; + + for batch in reader { + let batch = batch?; + let key_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..key_col.len() { + let k = key_col.value(i).to_vec(); + let is_active = match tomb_epoch_map.get(&k) { + Some(del_ep) => *del_ep < file_epoch, + None => true, + }; + if is_active { + keys.insert(k); + } + } + } + } + Ok(keys) + }) + .await + .map_err(|_| StateEngineError::Corruption("Task Panicked".into()))??; + + self.current_epoch.store(safe_epoch + 1, Ordering::Release); + Ok(active_keys) + } + + // ======================================================================== + // UUID-based file name generators + // ======================================================================== + + fn generate_data_file_name(epoch: u64) -> String { + format!("data-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } + + fn generate_tombstone_file_name(epoch: u64) -> String { + format!("tombstone-epoch-{}_uuid-{}.parquet", epoch, Uuid::now_v7()) + } +} + +// ============================================================================ +// Internal helper functions +// ============================================================================ + +fn write_parquet_with_bloom_atomic(path: &Path, batches: &[RecordBatch], ndv: u64) -> Result<()> { + if batches.is_empty() { + return Ok(()); + } + let tmp = path.with_extension("tmp"); + { + let file = File::create(&tmp).map_err(StateEngineError::IoError)?; + let props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_bloom_filter_ndv(ndv) + .build(); + + let mut writer = ArrowWriter::try_new(&file, batches[0].schema(), Some(props))?; + for b in batches { + writer.write(b)?; + } + writer.close()?; + file.sync_all().map_err(StateEngineError::IoError)?; + } + fs::rename(&tmp, path).map_err(StateEngineError::IoError)?; + Ok(()) +} + +fn extract_epoch(path: &Path) -> u64 { + let name = path + .file_name() + .unwrap_or_default() + .to_str() + .unwrap_or_default(); + if let Some(start) = name.find("-epoch-") { + let after = &name[start + 7..]; + if let Some(end) = after.find("_uuid-") { + return after[..end].parse().unwrap_or(0); + } + } + 0 +} + +fn inject_partition_key(batch: &RecordBatch, key: &[u8]) -> Result { + let mut fields = batch.schema().fields().to_vec(); + fields.push(Arc::new(Field::new( + PARTITION_KEY_COL, + DataType::Binary, + false, + ))); + let schema = Arc::new(Schema::new(fields)); + let key_array = Arc::new(BinaryArray::from_iter_values(std::iter::repeat_n( + key, + batch.num_rows(), + ))); + let mut cols = batch.columns().to_vec(); + cols.push(key_array as Arc); + Ok(RecordBatch::try_new(schema, cols)?) +} + +fn filter_tombstones_from_batch( + batch: &RecordBatch, + tombstones: &TombstoneMap, + file_epoch: u64, +) -> Result> { + if tombstones.is_empty() { + return Ok(Some(batch.clone())); + } + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + let mut has_valid = false; + + for i in 0..batch.num_rows() { + let key = key_col.value(i).to_vec(); + let keep = match tombstones.get(&key) { + Some(deleted_epoch) => *deleted_epoch < file_epoch, + None => true, + }; + mask_builder.append_value(keep); + if keep { + has_valid = true; + } + } + + if !has_valid { + return Ok(None); + } + let mask = mask_builder.finish(); + Ok(Some(arrow::compute::filter_record_batch(batch, &mask)?)) +} + +fn filter_and_strip_partition_key( + batch: &RecordBatch, + target_key: &[u8], +) -> Result> { + let Ok(idx) = batch.schema().index_of(PARTITION_KEY_COL) else { + return Ok(Some(batch.clone())); + }; + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let mut mask_builder = BooleanBuilder::with_capacity(batch.num_rows()); + for i in 0..batch.num_rows() { + mask_builder.append_value(key_col.value(i) == target_key); + } + let mask = mask_builder.finish(); + let filtered = arrow::compute::filter_record_batch(batch, &mask)?; + if filtered.num_rows() == 0 { + return Ok(None); + } + let mut proj: Vec = (0..filtered.num_columns()).collect(); + proj.retain(|&i| i != idx); + Ok(Some(filtered.project(&proj)?)) +} + +fn restore_memtable_from_injected_batches(batches: Vec) -> Result { + let mut m = MemTable::new(); + for batch in batches { + let idx = batch.schema().index_of(PARTITION_KEY_COL).unwrap(); + let key_col = batch + .column(idx) + .as_any() + .downcast_ref::() + .unwrap(); + let pk = key_col.value(0).to_vec(); + let mut proj: Vec = (0..batch.num_columns()).collect(); + proj.retain(|&i| i != idx); + let projected = batch.project(&proj)?; + m.entry(pk).or_default().push(projected); + } + Ok(m) +} + +#[cfg(test)] +mod tests { + use super::super::io_manager::IoPool; + use super::super::metrics::NoopMetricsCollector; + use super::*; + use crate::runtime::memory::{MemoryBlock, MemoryPool, global_memory_pool}; + use arrow_array::Int64Array; + use tempfile::TempDir; + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])) + } + + fn make_batch(values: &[i64]) -> RecordBatch { + RecordBatch::try_new( + test_schema(), + vec![Arc::new(Int64Array::from(values.to_vec()))], + ) + .unwrap() + } + + const TEST_OPERATOR_MEMORY: u64 = 2 * 1024 * 1024; + + fn ensure_global_memory_pool() { + use crate::runtime::memory::{init_global_memory_pool, try_global_memory_pool}; + use std::sync::Once; + static INIT: Once = Once::new(); + INIT.call_once(|| { + if try_global_memory_pool().is_err() { + init_global_memory_pool(TEST_OPERATOR_MEMORY.saturating_mul(64)) + .expect("global memory pool init"); + } + }); + } + + fn state_block(bytes: u64) -> Arc { + ensure_global_memory_pool(); + global_memory_pool() + .try_request_block(bytes) + .expect("test pipeline state memory block") + } + + fn setup() -> (TempDir, IoManager, IoPool) { + ensure_global_memory_pool(); + let tmp = TempDir::new().unwrap(); + let metrics: Arc = Arc::new(NoopMetricsCollector); + let (pool, mgr) = IoPool::try_new(1, 1, metrics).unwrap(); + (tmp, mgr, pool) + } + + #[tokio::test] + async fn test_put_and_get() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let key = b"key-a".to_vec(); + let batch = make_batch(&[10, 20, 30]); + store.put(key.clone(), batch).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 1); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[10, 20, 30]); + } + + #[tokio::test] + async fn test_multiple_puts_same_key() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let key = b"key-x".to_vec(); + store.put(key.clone(), make_batch(&[1])).await.unwrap(); + store.put(key.clone(), make_batch(&[2])).await.unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert_eq!(result.len(), 2); + } + + #[tokio::test] + async fn test_get_nonexistent_key() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let result = store.get_batches(b"no-such-key").await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_batches() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let key = b"key-del".to_vec(); + store.put(key.clone(), make_batch(&[42])).await.unwrap(); + + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_remove_does_not_affect_other_keys() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let k1 = b"key-1".to_vec(); + let k2 = b"key-2".to_vec(); + store.put(k1.clone(), make_batch(&[1])).await.unwrap(); + store.put(k2.clone(), make_batch(&[2])).await.unwrap(); + + store.remove_batches(k1.clone()).unwrap(); + + assert!(store.get_batches(&k1).await.unwrap().is_empty()); + assert_eq!(store.get_batches(&k2).await.unwrap().len(), 1); + } + + #[tokio::test] + async fn test_snapshot_epoch_advances() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + store.put(b"k".to_vec(), make_batch(&[1])).await.unwrap(); + store.snapshot_epoch(5).unwrap(); + + assert_eq!(store.current_epoch.load(Ordering::Acquire), 6); + } + + #[tokio::test] + async fn test_data_survives_snapshot_via_spill() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let key = b"persist".to_vec(); + store.put(key.clone(), make_batch(&[99])).await.unwrap(); + store.snapshot_epoch(1).unwrap(); + store.await_spill_complete().await; + + let result = store.get_batches(&key).await.unwrap(); + assert!(!result.is_empty()); + let col = result[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[99]); + } + + #[tokio::test] + async fn test_tombstone_hides_immutable_data() { + let (tmp, mgr, _pool) = setup(); + let store = OperatorStateStore::new( + 1, + tmp.path(), + mgr, + state_block(TEST_OPERATOR_MEMORY), + TEST_OPERATOR_MEMORY, + ) + .unwrap(); + + let key = b"will-die".to_vec(); + store.put(key.clone(), make_batch(&[7])).await.unwrap(); + + // Move to immutable via snapshot + store.snapshot_epoch(1).unwrap(); + + // Tombstone at epoch 2 (> immutable epoch 1) + store.current_epoch.store(2, Ordering::Release); + store.remove_batches(key.clone()).unwrap(); + + let result = store.get_batches(&key).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_state_block_tracking() { + let mem = MemoryPool::new(2048); + assert_eq!(mem.usage_metrics().0, 0); + + mem.force_reserve(100); + assert_eq!(mem.usage_metrics().0, 100); + + mem.force_release(40); + assert_eq!(mem.usage_metrics().0, 60); + + let soft_limit = 1000u64; + assert!(mem.usage_metrics().0 <= soft_limit); + mem.force_reserve(1000); + assert!(mem.usage_metrics().0 > soft_limit); + } + + #[tokio::test] + async fn test_state_block_hard_limit() { + let mem = MemoryPool::new(1024); + assert!(mem.usage_metrics().0 + 500 <= mem.usage_metrics().1); + assert!(mem.usage_metrics().0 + 1025 > mem.usage_metrics().1); + + mem.force_reserve(800); + assert!(mem.usage_metrics().0 + 300 > mem.usage_metrics().1); + assert!(mem.usage_metrics().0 + 200 <= mem.usage_metrics().1); + } + + #[test] + fn test_extract_epoch() { + let path = PathBuf::from("/tmp/data-epoch-42_uuid-abc123.parquet"); + assert_eq!(extract_epoch(&path), 42); + + let path2 = PathBuf::from("/tmp/tombstone-epoch-100_uuid-def456.parquet"); + assert_eq!(extract_epoch(&path2), 100); + + let path3 = PathBuf::from("/tmp/random-file.parquet"); + assert_eq!(extract_epoch(&path3), 0); + } + + #[test] + fn test_inject_and_strip_partition_key() { + let batch = make_batch(&[1, 2, 3]); + let key = b"pk-test"; + + let injected = inject_partition_key(&batch, key).unwrap(); + assert_eq!(injected.num_columns(), 2); + assert!(injected.schema().index_of(PARTITION_KEY_COL).is_ok()); + + let stripped = filter_and_strip_partition_key(&injected, key) + .unwrap() + .unwrap(); + assert_eq!(stripped.num_columns(), 1); + let col = stripped + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[1, 2, 3]); + } + + #[test] + fn test_filter_partition_key_mismatch() { + let batch = make_batch(&[1, 2]); + let injected = inject_partition_key(&batch, b"pk-a").unwrap(); + + let result = filter_and_strip_partition_key(&injected, b"pk-b").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_restore_memtable_roundtrip() { + let batch1 = inject_partition_key(&make_batch(&[10]), b"k1").unwrap(); + let batch2 = inject_partition_key(&make_batch(&[20]), b"k2").unwrap(); + let batch3 = inject_partition_key(&make_batch(&[30]), b"k1").unwrap(); + + let restored = + restore_memtable_from_injected_batches(vec![batch1, batch2, batch3]).unwrap(); + + assert_eq!(restored.len(), 2); + assert_eq!(restored[b"k1".as_ref()].len(), 2); + assert_eq!(restored[b"k2".as_ref()].len(), 1); + } + + #[test] + fn test_write_and_read_parquet() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("test.parquet"); + + let batch = make_batch(&[100, 200, 300]); + write_parquet_with_bloom_atomic(&path, std::slice::from_ref(&batch), 1).unwrap(); + + let file = File::open(&path).unwrap(); + let reader = ParquetRecordBatchReaderBuilder::try_new(file) + .unwrap() + .build() + .unwrap(); + + let read_batches: Vec = reader.map(|r| r.unwrap()).collect(); + assert_eq!(read_batches.len(), 1); + let col = read_batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.values(), &[100, 200, 300]); + } + + #[test] + fn test_filter_tombstones_from_batch() { + let batch = make_batch(&[1, 2, 3]); + let key = b"victim"; + let injected = inject_partition_key(&batch, key).unwrap(); + + let mut tombstones: TombstoneMap = HashMap::new(); + tombstones.insert(key.to_vec(), 10); + + // file_epoch <= tombstone epoch => fully filtered + let result = filter_tombstones_from_batch(&injected, &tombstones, 5).unwrap(); + assert!(result.is_none()); + + // file_epoch > tombstone epoch => data survives + let result = filter_tombstones_from_batch(&injected, &tombstones, 15).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_write_empty_batches_is_noop() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("empty.parquet"); + + write_parquet_with_bloom_atomic(&path, &[], 0).unwrap(); + assert!(!path.exists()); + } +} diff --git a/src/runtime/buffer_and_event/buffer_or_event.rs b/src/runtime/wasm/buffer_and_event/buffer_or_event.rs similarity index 100% rename from src/runtime/buffer_and_event/buffer_or_event.rs rename to src/runtime/wasm/buffer_and_event/buffer_or_event.rs diff --git a/src/runtime/buffer_and_event/mod.rs b/src/runtime/wasm/buffer_and_event/mod.rs similarity index 100% rename from src/runtime/buffer_and_event/mod.rs rename to src/runtime/wasm/buffer_and_event/mod.rs diff --git a/src/runtime/buffer_and_event/stream_element/mod.rs b/src/runtime/wasm/buffer_and_event/stream_element/mod.rs similarity index 100% rename from src/runtime/buffer_and_event/stream_element/mod.rs rename to src/runtime/wasm/buffer_and_event/stream_element/mod.rs diff --git a/src/runtime/buffer_and_event/stream_element/stream_element.rs b/src/runtime/wasm/buffer_and_event/stream_element/stream_element.rs similarity index 100% rename from src/runtime/buffer_and_event/stream_element/stream_element.rs rename to src/runtime/wasm/buffer_and_event/stream_element/stream_element.rs diff --git a/src/runtime/wasm/input/input_protocol.rs b/src/runtime/wasm/input/input_protocol.rs index 69fae972..50294201 100644 --- a/src/runtime/wasm/input/input_protocol.rs +++ b/src/runtime/wasm/input/input_protocol.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use std::time::Duration; pub trait InputProtocol: Send + Sync + 'static { diff --git a/src/runtime/wasm/input/input_provider.rs b/src/runtime/wasm/input/input_provider.rs index 3f6606cd..8eee649d 100644 --- a/src/runtime/wasm/input/input_provider.rs +++ b/src/runtime/wasm/input/input_provider.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::input::Input; -use crate::runtime::task::InputConfig; +use crate::runtime::wasm::task::InputConfig; pub struct InputProvider; diff --git a/src/runtime/wasm/input/input_runner.rs b/src/runtime/wasm/input/input_runner.rs index 854e4de8..ece85e3d 100644 --- a/src/runtime/wasm/input/input_runner.rs +++ b/src/runtime/wasm/input/input_runner.rs @@ -10,13 +10,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::TaskCompletionFlag; use crate::runtime::input::input_protocol::InputProtocol; use crate::runtime::input::{Input, InputState}; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ControlMailBox; -use crate::runtime::task::InputRuntimeConfig; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ControlMailBox; +use crate::runtime::wasm::task::InputRuntimeConfig; use crossbeam_channel::{Receiver, Sender, bounded, unbounded}; use std::sync::{Arc, Mutex}; use std::thread; @@ -250,7 +250,7 @@ impl InputRunner

{ impl Input for InputRunner

{ fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if !matches!(*self.state.lock().unwrap(), InputState::Uninitialized) { return Ok(()); diff --git a/src/runtime/wasm/input/interface.rs b/src/runtime/wasm/input/interface.rs index dd89ba77..06da4923 100644 --- a/src/runtime/wasm/input/interface.rs +++ b/src/runtime/wasm/input/interface.rs @@ -10,8 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::taskexecutor::InitContext; pub use crate::runtime::common::ComponentState as InputState; diff --git a/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs b/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs index 85336c53..1fb487a6 100644 --- a/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs +++ b/src/runtime/wasm/input/protocol/kafka/kafka_protocol.rs @@ -11,8 +11,8 @@ // limitations under the License. use super::config::KafkaConfig; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::input::input_protocol::InputProtocol; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use rdkafka::Message; use rdkafka::TopicPartitionList; use rdkafka::config::ClientConfig; diff --git a/src/runtime/wasm/mod.rs b/src/runtime/wasm/mod.rs index b1c82f4c..78be72e2 100644 --- a/src/runtime/wasm/mod.rs +++ b/src/runtime/wasm/mod.rs @@ -13,6 +13,9 @@ //! WebAssembly runtime integration. +pub mod buffer_and_event; pub mod input; pub mod output; pub mod processor; +pub mod task; +pub mod taskexecutor; diff --git a/src/runtime/wasm/output/interface.rs b/src/runtime/wasm/output/interface.rs index e7c3b903..21c3055d 100644 --- a/src/runtime/wasm/output/interface.rs +++ b/src/runtime/wasm/output/interface.rs @@ -10,8 +10,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::taskexecutor::InitContext; pub trait Output: Send + Sync { fn init_with_context( diff --git a/src/runtime/wasm/output/output_protocol.rs b/src/runtime/wasm/output/output_protocol.rs index dd502ca6..6140d3eb 100644 --- a/src/runtime/wasm/output/output_protocol.rs +++ b/src/runtime/wasm/output/output_protocol.rs @@ -10,7 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; pub trait OutputProtocol: Send + Sync + 'static { fn name(&self) -> String; diff --git a/src/runtime/wasm/output/output_provider.rs b/src/runtime/wasm/output/output_provider.rs index c6d01fef..25ca8431 100644 --- a/src/runtime/wasm/output/output_provider.rs +++ b/src/runtime/wasm/output/output_provider.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::output::Output; -use crate::runtime::task::OutputConfig; +use crate::runtime::wasm::task::OutputConfig; pub struct OutputProvider; diff --git a/src/runtime/wasm/output/output_runner.rs b/src/runtime/wasm/output/output_runner.rs index 85ba99b4..ca6d780c 100644 --- a/src/runtime/wasm/output/output_runner.rs +++ b/src/runtime/wasm/output/output_runner.rs @@ -10,13 +10,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::{ComponentState, TaskCompletionFlag}; use crate::runtime::output::Output; use crate::runtime::output::output_protocol::OutputProtocol; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ControlMailBox; -use crate::runtime::task::OutputRuntimeConfig; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ControlMailBox; +use crate::runtime::wasm::task::OutputRuntimeConfig; use crossbeam_channel::{Receiver, Sender, bounded, unbounded}; use std::sync::{Arc, Mutex}; use std::thread; @@ -288,7 +288,7 @@ impl OutputRunner

{ impl Output for OutputRunner

{ fn init_with_context( &mut self, - ctx: &crate::runtime::taskexecutor::InitContext, + ctx: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if !matches!(*self.state.lock().unwrap(), ComponentState::Uninitialized) { return Ok(()); diff --git a/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs b/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs index 2083294d..d9e6db4d 100644 --- a/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs +++ b/src/runtime/wasm/output/protocol/kafka/kafka_protocol.rs @@ -11,8 +11,8 @@ // limitations under the License. use super::producer_config::KafkaProducerConfig; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::output::output_protocol::OutputProtocol; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use rdkafka::producer::{BaseRecord, DefaultProducerContext, Producer, ThreadedProducer}; use std::sync::Mutex; use std::time::Duration; diff --git a/src/runtime/wasm/processor/wasm/wasm_host.rs b/src/runtime/wasm/processor/wasm/wasm_host.rs index 009dd6b4..2bf7d4f0 100644 --- a/src/runtime/wasm/processor/wasm/wasm_host.rs +++ b/src/runtime/wasm/processor/wasm/wasm_host.rs @@ -10,9 +10,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::output::Output; use crate::runtime::processor::wasm::wasm_cache; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; use crate::storage::state_backend::{StateStore, StateStoreFactory}; use std::sync::{Arc, OnceLock}; use wasmtime::component::{Component, HasData, Linker, Resource, bindgen}; @@ -449,7 +449,7 @@ pub fn create_wasm_host_with_component( engine: &Engine, component: &Component, outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> anyhow::Result<(Processor, Store)> { @@ -495,7 +495,7 @@ pub fn create_wasm_host_with_component( pub fn create_wasm_host( wasm_bytes: &[u8], outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> anyhow::Result<(Processor, Store)> { diff --git a/src/runtime/wasm/processor/wasm/wasm_processor.rs b/src/runtime/wasm/processor/wasm/wasm_processor.rs index 1afc9dcf..52234bfe 100644 --- a/src/runtime/wasm/processor/wasm/wasm_processor.rs +++ b/src/runtime/wasm/processor/wasm/wasm_processor.rs @@ -134,7 +134,7 @@ impl WasmProcessorImpl { impl WasmProcessor for WasmProcessorImpl { fn init_with_context( &mut self, - _init_context: &crate::runtime::taskexecutor::InitContext, + _init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { if self.initialized { log::warn!("WasmProcessor '{}' already initialized", self.name); @@ -405,7 +405,7 @@ impl WasmProcessor for WasmProcessorImpl { fn init_wasm_host( &mut self, outputs: Vec>, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, task_name: String, create_time: u64, ) -> Result<(), Box> { diff --git a/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs b/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs index 23a9f703..fb2c17fb 100644 --- a/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs +++ b/src/runtime/wasm/processor/wasm/wasm_processor_trait.rs @@ -11,7 +11,7 @@ // limitations under the License. use crate::runtime::output::Output; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::taskexecutor::InitContext; pub trait WasmProcessor: Send + Sync { fn process( diff --git a/src/runtime/wasm/processor/wasm/wasm_task.rs b/src/runtime/wasm/processor/wasm/wasm_task.rs index c61f385f..4330aaaf 100644 --- a/src/runtime/wasm/processor/wasm/wasm_task.rs +++ b/src/runtime/wasm/processor/wasm/wasm_task.rs @@ -13,13 +13,13 @@ use super::input_strategy::{InputStrategy, RoundRobinStrategy, from_selector_name}; use super::thread_pool::ThreadGroup; use super::wasm_processor_trait::WasmProcessor; -use crate::runtime::buffer_and_event::BufferOrEvent; use crate::runtime::common::{ComponentState, TaskCompletionFlag}; use crate::runtime::input::Input; use crate::runtime::output::Output; use crate::runtime::processor::function_error::FunctionErrorReport; -use crate::runtime::task::ProcessorRuntimeConfig; -use crate::runtime::task::{ControlMailBox, TaskControlSignal, TaskLifecycle}; +use crate::runtime::wasm::buffer_and_event::BufferOrEvent; +use crate::runtime::wasm::task::ProcessorRuntimeConfig; +use crate::runtime::wasm::task::{ControlMailBox, TaskControlSignal, TaskLifecycle}; use crate::storage::task::FunctionInfo; use crossbeam_channel::{Receiver, after, select, unbounded}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -120,7 +120,7 @@ impl WasmTask { pub fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { let mut inputs = self.inputs.take().ok_or_else(|| { Box::new(std::io::Error::other("inputs already moved to thread")) @@ -262,7 +262,7 @@ impl WasmTask { shared_state: Arc>, failure_cause: Arc>>, execution_state: Arc>, - _init_context: crate::runtime::taskexecutor::InitContext, + _init_context: crate::runtime::wasm::taskexecutor::InitContext, ) { let mut state = TaskState::Initialized; let mut last_idx: usize = 0; @@ -729,7 +729,7 @@ impl WasmTask { impl TaskLifecycle for WasmTask { fn init_with_context( &mut self, - init_context: &crate::runtime::taskexecutor::InitContext, + init_context: &crate::runtime::wasm::taskexecutor::InitContext, ) -> Result<(), Box> { ::init_with_context(self, init_context) } diff --git a/src/runtime/task/builder/mod.rs b/src/runtime/wasm/task/builder/mod.rs similarity index 100% rename from src/runtime/task/builder/mod.rs rename to src/runtime/wasm/task/builder/mod.rs diff --git a/src/runtime/task/builder/processor/mod.rs b/src/runtime/wasm/task/builder/processor/mod.rs similarity index 97% rename from src/runtime/task/builder/processor/mod.rs rename to src/runtime/wasm/task/builder/processor/mod.rs index 418271dd..c1306924 100644 --- a/src/runtime/task/builder/processor/mod.rs +++ b/src/runtime/wasm/task/builder/processor/mod.rs @@ -19,8 +19,8 @@ use crate::runtime::output::{Output, OutputProvider}; use crate::runtime::processor::wasm::wasm_processor::WasmProcessorImpl; use crate::runtime::processor::wasm::wasm_processor_trait::WasmProcessor; use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; -use crate::runtime::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/python/mod.rs b/src/runtime/wasm/task/builder/python/mod.rs similarity index 95% rename from src/runtime/task/builder/python/mod.rs rename to src/runtime/wasm/task/builder/python/mod.rs index 03f6ca0f..1b31d2e5 100644 --- a/src/runtime/task/builder/python/mod.rs +++ b/src/runtime/wasm/task/builder/python/mod.rs @@ -20,8 +20,8 @@ use crate::runtime::processor::python::get_python_engine_and_component; use crate::runtime::processor::wasm::wasm_processor::WasmProcessorImpl; use crate::runtime::processor::wasm::wasm_processor_trait::WasmProcessor; use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; -use crate::runtime::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::{InputConfig, OutputConfig, ProcessorConfig, WasmTaskConfig}; use serde_yaml::Value; use std::sync::Arc; @@ -33,7 +33,7 @@ impl PythonBuilder { yaml_value: &Value, modules: &[(String, Vec)], create_time: u64, - ) -> Result, Box> + ) -> Result, Box> { let config_type = yaml_value .get(TYPE) diff --git a/src/runtime/task/builder/sink/mod.rs b/src/runtime/wasm/task/builder/sink/mod.rs similarity index 97% rename from src/runtime/task/builder/sink/mod.rs rename to src/runtime/wasm/task/builder/sink/mod.rs index f1babbd6..65e8bc95 100644 --- a/src/runtime/task/builder/sink/mod.rs +++ b/src/runtime/wasm/task/builder/sink/mod.rs @@ -15,7 +15,7 @@ // Specifically handles building logic for Sink type configuration (future support) use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/source/mod.rs b/src/runtime/wasm/task/builder/source/mod.rs similarity index 97% rename from src/runtime/task/builder/source/mod.rs rename to src/runtime/wasm/task/builder/source/mod.rs index d766ebbe..fc81bea9 100644 --- a/src/runtime/task/builder/source/mod.rs +++ b/src/runtime/wasm/task/builder/source/mod.rs @@ -15,7 +15,7 @@ // Specifically handles building logic for Source type configuration (future support) use crate::runtime::processor::wasm::wasm_task::WasmTask; -use crate::runtime::task::yaml_keys::{TYPE, type_values}; +use crate::runtime::wasm::task::yaml_keys::{TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/builder/task_builder.rs b/src/runtime/wasm/task/builder/task_builder.rs similarity index 94% rename from src/runtime/task/builder/task_builder.rs rename to src/runtime/wasm/task/builder/task_builder.rs index 9f89dbba..2246d6d8 100644 --- a/src/runtime/task/builder/task_builder.rs +++ b/src/runtime/wasm/task/builder/task_builder.rs @@ -15,13 +15,13 @@ //! Provides unified factory methods to create TaskLifecycle instances from YAML config. //! Dispatches to specific builders (Processor, Source, Sink, Python) based on task type. -use crate::runtime::task::TaskLifecycle; -use crate::runtime::task::builder::processor::ProcessorBuilder; +use crate::runtime::wasm::task::TaskLifecycle; +use crate::runtime::wasm::task::builder::processor::ProcessorBuilder; #[cfg(feature = "python")] -use crate::runtime::task::builder::python::PythonBuilder; -use crate::runtime::task::builder::sink::SinkBuilder; -use crate::runtime::task::builder::source::SourceBuilder; -use crate::runtime::task::yaml_keys::{NAME, TYPE, type_values}; +use crate::runtime::wasm::task::builder::python::PythonBuilder; +use crate::runtime::wasm::task::builder::sink::SinkBuilder; +use crate::runtime::wasm::task::builder::source::SourceBuilder; +use crate::runtime::wasm::task::yaml_keys::{NAME, TYPE, type_values}; use serde_yaml::Value; use std::sync::Arc; diff --git a/src/runtime/task/control_mailbox.rs b/src/runtime/wasm/task/control_mailbox.rs similarity index 100% rename from src/runtime/task/control_mailbox.rs rename to src/runtime/wasm/task/control_mailbox.rs diff --git a/src/runtime/task/lifecycle.rs b/src/runtime/wasm/task/lifecycle.rs similarity index 97% rename from src/runtime/task/lifecycle.rs rename to src/runtime/wasm/task/lifecycle.rs index 2b857f81..ea00f7c2 100644 --- a/src/runtime/task/lifecycle.rs +++ b/src/runtime/wasm/task/lifecycle.rs @@ -15,8 +15,8 @@ // Defines the complete lifecycle management interface for Task, including initialization, start, stop, checkpoint and close use crate::runtime::common::ComponentState; -use crate::runtime::task::control_mailbox::ControlMailBox; -use crate::runtime::taskexecutor::InitContext; +use crate::runtime::wasm::task::control_mailbox::ControlMailBox; +use crate::runtime::wasm::taskexecutor::InitContext; use crate::storage::task::FunctionInfo; use std::sync::Arc; diff --git a/src/runtime/task/mod.rs b/src/runtime/wasm/task/mod.rs similarity index 100% rename from src/runtime/task/mod.rs rename to src/runtime/wasm/task/mod.rs diff --git a/src/runtime/task/processor_config.rs b/src/runtime/wasm/task/processor_config.rs similarity index 99% rename from src/runtime/task/processor_config.rs rename to src/runtime/wasm/task/processor_config.rs index fe515647..a3069adc 100644 --- a/src/runtime/task/processor_config.rs +++ b/src/runtime/wasm/task/processor_config.rs @@ -608,7 +608,7 @@ impl WasmTaskConfig { task_name: String, value: &Value, ) -> Result> { - use crate::runtime::task::yaml_keys::{INPUT_GROUPS, INPUTS, NAME, OUTPUTS}; + use crate::runtime::wasm::task::yaml_keys::{INPUT_GROUPS, INPUTS, NAME, OUTPUTS}; // 1. Get name from config (if exists), otherwise use the passed task_name let config_name = value diff --git a/src/runtime/task/yaml_keys.rs b/src/runtime/wasm/task/yaml_keys.rs similarity index 100% rename from src/runtime/task/yaml_keys.rs rename to src/runtime/wasm/task/yaml_keys.rs diff --git a/src/runtime/taskexecutor/init_context.rs b/src/runtime/wasm/taskexecutor/init_context.rs similarity index 97% rename from src/runtime/taskexecutor/init_context.rs rename to src/runtime/wasm/taskexecutor/init_context.rs index 13ad5c81..fca44a32 100644 --- a/src/runtime/taskexecutor/init_context.rs +++ b/src/runtime/wasm/taskexecutor/init_context.rs @@ -15,7 +15,7 @@ // Provides various resources needed for task initialization, including state storage, task storage, thread pool, etc. use crate::runtime::processor::wasm::thread_pool::{TaskThreadPool, ThreadGroup}; -use crate::runtime::task::ControlMailBox; +use crate::runtime::wasm::task::ControlMailBox; use crate::storage::state_backend::StateStorageServer; use crate::storage::task::TaskStorage; use std::sync::{Arc, Mutex}; diff --git a/src/runtime/taskexecutor/mod.rs b/src/runtime/wasm/taskexecutor/mod.rs similarity index 100% rename from src/runtime/taskexecutor/mod.rs rename to src/runtime/wasm/taskexecutor/mod.rs diff --git a/src/runtime/taskexecutor/task_manager.rs b/src/runtime/wasm/taskexecutor/task_manager.rs similarity index 98% rename from src/runtime/taskexecutor/task_manager.rs rename to src/runtime/wasm/taskexecutor/task_manager.rs index f11997d5..897e0a3d 100644 --- a/src/runtime/taskexecutor/task_manager.rs +++ b/src/runtime/wasm/taskexecutor/task_manager.rs @@ -13,8 +13,8 @@ use crate::config::GlobalConfig; use crate::runtime::common::ComponentState; use crate::runtime::processor::wasm::thread_pool::{GlobalTaskThreadPool, TaskThreadPool}; -use crate::runtime::task::{TaskBuilder, TaskLifecycle}; -use crate::runtime::taskexecutor::init_context::InitContext; +use crate::runtime::wasm::task::{TaskBuilder, TaskLifecycle}; +use crate::runtime::wasm::taskexecutor::init_context::InitContext; use crate::storage::state_backend::StateStorageServer; use crate::storage::task::{ FunctionInfo, StoredTaskInfo, TaskModuleBytes, TaskStorage, TaskStorageFactory, diff --git a/src/server/initializer.rs b/src/server/initializer.rs index 785321b8..8a04608e 100644 --- a/src/server/initializer.rs +++ b/src/server/initializer.rs @@ -19,6 +19,12 @@ use crate::config::GlobalConfig; pub type InitializerFn = fn(&GlobalConfig) -> Result<()>; +fn initialize_streaming_sql_planning(config: &GlobalConfig) -> Result<()> { + let job = config.streaming.resolved_job(); + crate::sql::planning_runtime::install_sql_planning_from_streaming_job(&job); + Ok(()) +} + #[derive(Clone)] pub struct Component { pub name: &'static str, @@ -94,8 +100,10 @@ impl ComponentRegistry { pub fn build_core_registry() -> ComponentRegistry { let builder = { let b = ComponentRegistryBuilder::new() + .register("StreamingSqlPlanning", initialize_streaming_sql_planning) .register("WasmCache", initialize_wasm_cache) .register("TaskManager", initialize_task_manager) + .register("MemoryService", initialize_memory_service) .register("JobManager", initialize_job_manager); #[cfg(feature = "python")] let b = b.register("PythonService", initialize_python_service); @@ -143,7 +151,7 @@ fn initialize_wasm_cache(config: &GlobalConfig) -> Result<()> { } fn initialize_task_manager(config: &GlobalConfig) -> Result<()> { - crate::runtime::taskexecutor::TaskManager::init(config) + crate::runtime::wasm::taskexecutor::TaskManager::init(config) .context("TaskManager service failed to start")?; Ok(()) } @@ -155,28 +163,48 @@ fn initialize_python_service(config: &GlobalConfig) -> Result<()> { Ok(()) } +fn initialize_memory_service(config: &GlobalConfig) -> Result<()> { + crate::server::memory_service::MemoryService::initialize(config) +} + fn initialize_job_manager(config: &GlobalConfig) -> Result<()> { use crate::runtime::streaming::factory::OperatorFactory; use crate::runtime::streaming::factory::Registry; - use crate::runtime::streaming::job::JobManager; + use crate::runtime::streaming::job::{JobManager, StateConfig}; use std::sync::Arc; + let per_operator_memory_bytes = config + .streaming + .operator_state_store_memory_bytes + .unwrap_or(crate::config::DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES); + let job = config.streaming.resolved_job(); + let registry = Arc::new(Registry::new()); let factory = Arc::new(OperatorFactory::new(registry)); - let max_memory_bytes = config - .streaming - .max_memory_bytes - .unwrap_or(256 * 1024 * 1024); - JobManager::init(factory, max_memory_bytes).context("JobManager service failed to start")?; + let state_base_dir = std::env::temp_dir().join("function-stream").join("state"); + let state_config = StateConfig { + checkpoint_interval_ms: job.checkpoint_interval_ms, + pipeline_parallelism: job.pipeline_parallelism, + job_manager_control_plane_threads: job.job_manager_control_plane_threads, + job_manager_data_plane_threads: job.job_manager_data_plane_threads, + per_operator_memory_bytes, + ..StateConfig::default() + }; + + JobManager::init(factory, state_base_dir, state_config) + .context("JobManager service failed to start")?; Ok(()) } fn initialize_coordinator(_config: &GlobalConfig) -> Result<()> { - crate::runtime::taskexecutor::TaskManager::get() + crate::runtime::wasm::taskexecutor::TaskManager::get() .context("Dependency violation: Coordinator requires TaskManager")?; + crate::runtime::memory::try_global_memory_pool() + .context("Dependency violation: Coordinator requires MemoryService")?; + crate::storage::stream_catalog::CatalogManager::global() .context("Dependency violation: Coordinator requires StreamCatalog")?; diff --git a/src/server/memory_service.rs b/src/server/memory_service.rs new file mode 100644 index 00000000..2ba24eee --- /dev/null +++ b/src/server/memory_service.rs @@ -0,0 +1,61 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::{Context, Result}; +use tracing::info; + +use crate::config::{ + DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES, DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES, GlobalConfig, +}; + +pub struct MemoryService; + +impl MemoryService { + pub fn initialize(config: &GlobalConfig) -> Result<()> { + use crate::config::system::system_memory_info; + + let mem_info = system_memory_info().ok(); + let total_physical = mem_info.as_ref().map(|m| m.total_physical).unwrap_or(0); + let avail_physical = mem_info.as_ref().map(|m| m.available_physical).unwrap_or(0); + let total_virtual = mem_info.as_ref().map(|m| m.total_virtual).unwrap_or(0); + let avail_virtual = mem_info.as_ref().map(|m| m.available_virtual).unwrap_or(0); + + let streaming_runtime_memory_bytes = config + .streaming + .streaming_runtime_memory_bytes + .unwrap_or(DEFAULT_STREAMING_RUNTIME_MEMORY_BYTES); + + let operator_state_store_memory_bytes = config + .streaming + .operator_state_store_memory_bytes + .unwrap_or(DEFAULT_OPERATOR_STATE_STORE_MEMORY_BYTES); + + info!( + total_physical_mb = total_physical / (1024 * 1024), + available_physical_mb = avail_physical / (1024 * 1024), + total_virtual_mb = total_virtual / (1024 * 1024), + available_virtual_mb = avail_virtual / (1024 * 1024), + streaming_runtime_memory_mb = streaming_runtime_memory_bytes / (1024 * 1024), + operator_state_store_memory_mb = operator_state_store_memory_bytes / (1024 * 1024), + "MemoryService: global streaming + operator state pools" + ); + + let total_pool_bytes = + streaming_runtime_memory_bytes.saturating_add(operator_state_store_memory_bytes); + crate::runtime::memory::init_global_memory_pool(total_pool_bytes) + .context("Global memory pool initialization failed")?; + + info!("MemoryService initialized"); + Ok(()) + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index cb7a4a85..def6ac9e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -14,6 +14,7 @@ mod handler; mod initializer; +pub mod memory_service; mod service; pub use handler::FunctionStreamServiceImpl; diff --git a/src/sql/analysis/aggregate_rewriter.rs b/src/sql/analysis/aggregate_rewriter.rs index d7be0db8..ddcb0294 100644 --- a/src/sql/analysis/aggregate_rewriter.rs +++ b/src/sql/analysis/aggregate_rewriter.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use crate::sql::analysis::streaming_window_analzer::StreamingWindowAnalzer; use crate::sql::logical_node::aggregate::StreamWindowAggregateNode; use crate::sql::logical_node::key_calculation::{KeyExtractionNode, KeyExtractionStrategy}; +use crate::sql::logical_node::updating_aggregate::ContinuousAggregateNode; use crate::sql::schema::StreamSchemaProvider; use crate::sql::types::{ QualifiedField, TIMESTAMP_FIELD, WindowBehavior, WindowType, build_df_schema_with_metadata, @@ -70,10 +71,10 @@ impl TreeNodeRewriter for AggregateRewriter<'_> { }) .collect(); - // 3. Dispatch to Updating Aggregate if no windowing is detected. + // 3. Dispatch to ContinuousAggregateNode (UpdatingAggregate) if no windowing is detected. let input_window = StreamingWindowAnalzer::get_window(&agg.input)?; if window_exprs.is_empty() && input_window.is_none() { - return self.rewrite_as_updating_aggregate( + return self.rewrite_as_continuous_updating_aggregate( agg.input, key_fields, agg.group_expr, @@ -174,9 +175,9 @@ impl<'a> AggregateRewriter<'a> { })) } - /// [Strategy] Rewrites standard GROUP BY into a non-windowed updating aggregate. + /// [Strategy] Rewrites standard GROUP BY into a ContinuousAggregateNode with retraction semantics. /// Injected max(_timestamp) ensures the streaming pulse (Watermark) continues to propagate. - fn rewrite_as_updating_aggregate( + fn rewrite_as_continuous_updating_aggregate( &self, input: Arc, key_fields: Vec, @@ -184,6 +185,7 @@ impl<'a> AggregateRewriter<'a> { mut aggr_expr: Vec, schema: Arc, ) -> Result> { + let key_count = key_fields.len(); let keyed_input = self.build_keyed_input(input, &group_expr, &key_fields)?; // Ensure the updating stream maintains time awareness. @@ -207,14 +209,23 @@ impl<'a> AggregateRewriter<'a> { schema.metadata().clone(), )?); - let aggregate = Aggregate::try_new_with_schema( + let base_aggregate = Aggregate::try_new_with_schema( Arc::new(keyed_input), group_expr, aggr_expr, output_schema, )?; - Ok(Transformed::yes(LogicalPlan::Aggregate(aggregate))) + let continuous_node = ContinuousAggregateNode::try_new( + LogicalPlan::Aggregate(base_aggregate), + (0..key_count).collect(), + None, + self.schema_provider.planning_options.ttl, + )?; + + Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(continuous_node), + }))) } /// [Strategy] Reconciles window definitions between the input stream and the current GROUP BY. @@ -232,24 +243,16 @@ impl<'a> AggregateRewriter<'a> { let has_group_window = !window_expr_info.is_empty(); match (input_window, has_group_window) { - // Re-aggregation or subquery with an existing window. (Some(i_win), true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); if i_win != g_win { - return plan_err!( - "Inconsistent windowing: input is {:?}, but group by is {:?}", - i_win, - g_win - ); + return plan_err!("Inconsistent windowing detected"); } if let Some(field) = visitor.fields.iter().next() { group_expr[idx] = Expr::Column(field.qualified_column()); Ok(WindowBehavior::InData) } else { - if matches!(i_win, WindowType::Session { .. }) { - return plan_err!("Nested session windows are not supported"); - } group_expr.remove(idx); Ok(WindowBehavior::FromOperator { window: i_win, @@ -259,7 +262,6 @@ impl<'a> AggregateRewriter<'a> { }) } } - // First-time windowing defined in this aggregate. (None, true) => { let (idx, g_win) = window_expr_info.pop().unwrap(); group_expr.remove(idx); @@ -270,9 +272,8 @@ impl<'a> AggregateRewriter<'a> { is_nested: false, }) } - // Passthrough: input is already windowed, no new window in group by. (Some(_), false) => Ok(WindowBehavior::InData), - _ => unreachable!("Dispatched to non-windowed path previously"), + _ => unreachable!("Handled by updating path"), } } } diff --git a/src/sql/common/constants.rs b/src/sql/common/constants.rs index 40642cd7..19fdbcb3 100644 --- a/src/sql/common/constants.rs +++ b/src/sql/common/constants.rs @@ -107,7 +107,11 @@ pub mod sql_field { } pub mod sql_planning_default { - pub const DEFAULT_PARALLELISM: usize = 4; + pub const DEFAULT_PARALLELISM: usize = 1; + /// Default physical parallelism for `KeyBy` / key-extraction pipelines (configurable via YAML). + pub const DEFAULT_KEY_BY_PARALLELISM: usize = 1; + /// Parallelism for aggregations that run after `KeyBy` / shuffle on non-empty routing keys. + pub const KEYED_AGGREGATE_DEFAULT_PARALLELISM: usize = 8; pub const PLANNING_TTL_SECS: u64 = 24 * 60 * 60; } diff --git a/src/sql/common/operator_config.rs b/src/sql/common/operator_config.rs index b5360cd7..209bee48 100644 --- a/src/sql/common/operator_config.rs +++ b/src/sql/common/operator_config.rs @@ -1,5 +1,14 @@ // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. use serde::{Deserialize, Serialize}; diff --git a/src/sql/logical_node/aggregate.rs b/src/sql/logical_node/aggregate.rs index d9833c50..1e288ab5 100644 --- a/src/sql/logical_node/aggregate.rs +++ b/src/sql/logical_node/aggregate.rs @@ -64,6 +64,13 @@ multifield_partial_ord!( ); impl StreamWindowAggregateNode { + /// This node is only emitted after `KeyExtractionNode` in streaming rewrites; `partition_keys` + /// may be empty when GROUP BY is only a window call (window column stripped from key list), + /// but the pipeline still consumes a shuffle — use keyed aggregate parallelism. + fn parallelism_after_keyed_shuffle(&self, planner: &Planner) -> usize { + planner.keyed_aggregate_parallelism() + } + /// Safely constructs a new node, computing the final projection without panicking. pub fn try_new( window_spec: WindowBehavior, @@ -126,7 +133,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), format!("TumblingWindow<{}>", operator_config.name), - 1, + self.parallelism_after_keyed_shuffle(planner), )) } @@ -176,7 +183,7 @@ impl StreamWindowAggregateNode { OperatorName::SlidingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::SLIDING_WINDOW_LABEL.to_string(), - 1, + self.parallelism_after_keyed_shuffle(planner), )) } @@ -243,7 +250,7 @@ impl StreamWindowAggregateNode { OperatorName::SessionWindowAggregate, operator_config.encode_to_vec(), operator_config.name.clone(), - 1, + self.parallelism_after_keyed_shuffle(planner), )) } @@ -299,7 +306,7 @@ impl StreamWindowAggregateNode { OperatorName::TumblingWindowAggregate, operator_config.encode_to_vec(), proto_operator_name::INSTANT_WINDOW_LABEL.to_string(), - 1, + self.parallelism_after_keyed_shuffle(planner), )) } } diff --git a/src/sql/logical_node/async_udf.rs b/src/sql/logical_node/async_udf.rs index 6cd2da7b..1c35398e 100644 --- a/src/sql/logical_node/async_udf.rs +++ b/src/sql/logical_node/async_udf.rs @@ -160,7 +160,7 @@ impl StreamingOperatorBlueprint for AsyncFunctionExecutionNode { OperatorName::AsyncUdf, operator_config.encode_to_vec(), format!("AsyncUdf<{}>", self.operator_name), - 1, + planner.default_parallelism(), ); let upstream_schema = input_schemas.remove(0); diff --git a/src/sql/logical_node/join.rs b/src/sql/logical_node/join.rs index ea142d0a..15631f1f 100644 --- a/src/sql/logical_node/join.rs +++ b/src/sql/logical_node/join.rs @@ -191,7 +191,7 @@ impl StreamingOperatorBlueprint for StreamingJoinNode { self.determine_operator_type(), operator_config.encode_to_vec(), runtime_operator_kind::STREAMING_JOIN.to_string(), - 1, + planner.default_parallelism(), ); let left_edge = diff --git a/src/sql/logical_node/key_calculation.rs b/src/sql/logical_node/key_calculation.rs index 6bcad784..ec83e108 100644 --- a/src/sql/logical_node/key_calculation.rs +++ b/src/sql/logical_node/key_calculation.rs @@ -238,7 +238,7 @@ impl StreamingOperatorBlueprint for KeyExtractionNode { engine_operator_name, protobuf_payload, format!("Key<{}>", self.operator_label.as_deref().unwrap_or("_")), - 1, + planner.key_by_parallelism(), ); let data_edge = diff --git a/src/sql/logical_node/logical/operator_chain.rs b/src/sql/logical_node/logical/operator_chain.rs index 34a01a5c..2aecddd6 100644 --- a/src/sql/logical_node/logical/operator_chain.rs +++ b/src/sql/logical_node/logical/operator_chain.rs @@ -128,4 +128,15 @@ impl OperatorChain { pub fn is_sink(&self) -> bool { self.operators[0].operator_name == OperatorName::ConnectorSink } + + /// Operators safe to run at a higher upstream `TaskContext::parallelism` when fused after a + /// stateful node (e.g. window aggregate @ 8 → projection @ 1). + pub fn is_parallelism_upstream_expandable(&self) -> bool { + self.operators.iter().all(|op| { + matches!( + op.operator_name, + OperatorName::Projection | OperatorName::Value | OperatorName::ExpressionWatermark + ) + }) + } } diff --git a/src/sql/logical_node/lookup.rs b/src/sql/logical_node/lookup.rs index 00f624a7..c060ba82 100644 --- a/src/sql/logical_node/lookup.rs +++ b/src/sql/logical_node/lookup.rs @@ -199,7 +199,7 @@ impl StreamingOperatorBlueprint for StreamReferenceJoinNode { "DictionaryJoin<{}>", self.external_dictionary.table_identifier ), - 1, + planner.default_parallelism(), ); let incoming_edge = diff --git a/src/sql/logical_node/projection.rs b/src/sql/logical_node/projection.rs index df55e575..3c5cfccb 100644 --- a/src/sql/logical_node/projection.rs +++ b/src/sql/logical_node/projection.rs @@ -170,7 +170,7 @@ impl StreamingOperatorBlueprint for StreamProjectionNode { OperatorName::Projection, operator_config.encode_to_vec(), label, - 1, + planner.default_parallelism(), ); let routing_strategy = if self.requires_shuffle { diff --git a/src/sql/logical_node/remote_table.rs b/src/sql/logical_node/remote_table.rs index d43a87e0..bde1d47f 100644 --- a/src/sql/logical_node/remote_table.rs +++ b/src/sql/logical_node/remote_table.rs @@ -119,7 +119,7 @@ impl StreamingOperatorBlueprint for RemoteTableBoundaryNode { OperatorName::Value, operator_payload, self.table_identifier.to_string(), - 1, + planner.default_parallelism(), ); let routing_edges: Vec = input_schemas diff --git a/src/sql/logical_node/sink.rs b/src/sql/logical_node/sink.rs index dbfcaa55..2edf8f27 100644 --- a/src/sql/logical_node/sink.rs +++ b/src/sql/logical_node/sink.rs @@ -149,7 +149,7 @@ impl StreamingOperatorBlueprint for StreamEgressNode { fn compile_to_graph_node( &self, - _planner: &Planner, + planner: &Planner, node_index: usize, input_schemas: Vec, ) -> Result { @@ -167,7 +167,7 @@ impl StreamingOperatorBlueprint for StreamEgressNode { OperatorName::ConnectorSink, operator_payload, operator_description, - 1, + planner.default_parallelism(), ); let routing_edges: Vec = input_schemas diff --git a/src/sql/logical_node/table_source.rs b/src/sql/logical_node/table_source.rs index 65f4459f..b1c6bfdd 100644 --- a/src/sql/logical_node/table_source.rs +++ b/src/sql/logical_node/table_source.rs @@ -147,7 +147,7 @@ impl StreamingOperatorBlueprint for StreamIngestionNode { fn compile_to_graph_node( &self, - _compiler_context: &Planner, + compiler_context: &Planner, node_id_sequence: usize, upstream_schemas: Vec, ) -> Result { @@ -167,7 +167,7 @@ impl StreamingOperatorBlueprint for StreamIngestionNode { OperatorName::ConnectorSource, connector_payload, operator_description, - 1, + compiler_context.default_parallelism(), ); Ok(CompiledTopologyNode::new(execution_unit, vec![])) diff --git a/src/sql/logical_node/updating_aggregate.rs b/src/sql/logical_node/updating_aggregate.rs index 598d20eb..0ddb2b28 100644 --- a/src/sql/logical_node/updating_aggregate.rs +++ b/src/sql/logical_node/updating_aggregate.rs @@ -218,13 +218,15 @@ impl StreamingOperatorBlueprint for ContinuousAggregateNode { let operator_config = self.compile_operator_config(planner, &upstream_schema)?; + let parallelism = planner.keyed_aggregate_parallelism(); + let logical_node = LogicalNode::single( node_index as u32, format!("updating_aggregate_{node_index}"), OperatorName::UpdatingAggregate, operator_config.encode_to_vec(), proto_operator_name::UPDATING_AGGREGATE.to_string(), - 1, + parallelism, ); let shuffle_edge = diff --git a/src/sql/logical_node/watermark_node.rs b/src/sql/logical_node/watermark_node.rs index 7c83c429..9a8fc9d6 100644 --- a/src/sql/logical_node/watermark_node.rs +++ b/src/sql/logical_node/watermark_node.rs @@ -209,7 +209,7 @@ impl StreamingOperatorBlueprint for EventTimeWatermarkNode { OperatorName::ExpressionWatermark, operator_config.encode_to_vec(), runtime_operator_kind::WATERMARK_GENERATOR.to_string(), - 1, + planner.default_parallelism(), ); let incoming_edge = LogicalEdge::project_all( diff --git a/src/sql/logical_node/windows_function.rs b/src/sql/logical_node/windows_function.rs index a79ceff3..9be37382 100644 --- a/src/sql/logical_node/windows_function.rs +++ b/src/sql/logical_node/windows_function.rs @@ -163,13 +163,15 @@ impl StreamingOperatorBlueprint for StreamingWindowFunctionNode { window_function_plan: evaluation_plan_payload, }; + let parallelism = planner.keyed_aggregate_parallelism(); + let logical_node = LogicalNode::single( node_index as u32, format!("window_function_{node_index}"), OperatorName::WindowFunction, operator_config.encode_to_vec(), runtime_operator_kind::STREAMING_WINDOW_EVALUATOR.to_string(), - 1, + parallelism, ); let routing_edge = diff --git a/src/sql/logical_planner/optimizers/chaining.rs b/src/sql/logical_planner/optimizers/chaining.rs index 8260df19..ea7bd885 100644 --- a/src/sql/logical_planner/optimizers/chaining.rs +++ b/src/sql/logical_planner/optimizers/chaining.rs @@ -45,9 +45,14 @@ impl Optimizer for ChainingOptimizer { let source_node = plan.node_weight(node_idx).expect("Source node missing"); let target_node = plan.node_weight(target_idx).expect("Target node missing"); + let parallelism_ok = source_node.parallelism == target_node.parallelism + || target_node + .operator_chain + .is_parallelism_upstream_expandable(); + if source_node.operator_chain.is_source() || target_node.operator_chain.is_sink() - || source_node.parallelism != target_node.parallelism + || !parallelism_ok { continue; } @@ -93,6 +98,8 @@ impl Optimizer for ChainingOptimizer { source_node.description = format!("{} -> {}", source_node.description, target_node.description); + source_node.parallelism = source_node.parallelism.max(target_node.parallelism); + source_node .operator_chain .operators @@ -150,6 +157,31 @@ mod tests { ) } + /// Window aggregate at higher default parallelism may forward into projection @ 1: still fuse + /// so each branch does not reserve a separate global state-memory block for the same sub-chain. + #[test] + fn fusion_stateful_high_parallelism_into_expandable_low() { + let mut g = LogicalGraph::new(); + let n0 = g.add_node(source_node()); + let n1 = g.add_node(proj_node(1, "tumble")); + let n2 = g.add_node(proj_node(2, "proj")); + let n1w = g.node_weight_mut(n1).unwrap(); + n1w.parallelism = 8; + let e = forward_edge(); + g.add_edge(n0, n1, e.clone()); + g.add_edge(n1, n2, e); + + let changed = ChainingOptimizer {}.optimize_once(&mut g); + assert!(changed); + assert_eq!(g.node_count(), 2); + let fused = g + .node_weights() + .find(|n| n.description.contains("->")) + .unwrap(); + assert_eq!(fused.parallelism, 8); + assert_eq!(fused.operator_chain.len(), 2); + } + /// Regression: upstream at last `NodeIndex` + remove non-last downstream swaps indices. #[test] fn fusion_remaps_when_upstream_was_last_node_index() { diff --git a/src/sql/logical_planner/streaming_planner.rs b/src/sql/logical_planner/streaming_planner.rs index 4619fb3f..1e999c2a 100644 --- a/src/sql/logical_planner/streaming_planner.rs +++ b/src/sql/logical_planner/streaming_planner.rs @@ -42,6 +42,7 @@ use datafusion_common::TableReference; use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec; use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use crate::sql::common::constants::sql_planning_default; use crate::sql::common::{FsSchema, FsSchemaRef}; use crate::sql::logical_node::debezium::{ PACK_NODE_NAME, UNROLL_NODE_NAME, UnrollDebeziumPayloadNode, @@ -96,6 +97,22 @@ pub(crate) struct Planner<'a> { } impl<'a> Planner<'a> { + #[inline] + pub(crate) fn default_parallelism(&self) -> usize { + self.schema_provider.default_parallelism() + } + + #[inline] + pub(crate) fn key_by_parallelism(&self) -> usize { + self.schema_provider.key_by_parallelism() + } + + /// Parallelism for operators that consume a keyed shuffle (non-empty partition keys). + #[inline] + pub(crate) fn keyed_aggregate_parallelism(&self) -> usize { + sql_planning_default::KEYED_AGGREGATE_DEFAULT_PARALLELISM + } + pub(crate) fn new( schema_provider: &'a StreamSchemaProvider, session_state: &'a SessionState, diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 71dd4dd1..529c7a2d 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -19,6 +19,7 @@ pub mod logical_node; pub mod logical_planner; pub mod parse; pub mod physical; +pub(crate) mod planning_runtime; pub mod schema; pub mod types; diff --git a/src/sql/planning_runtime.rs b/src/sql/planning_runtime.rs new file mode 100644 index 00000000..dc4749ad --- /dev/null +++ b/src/sql/planning_runtime.rs @@ -0,0 +1,35 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Runtime-installed SQL planning defaults (from `GlobalConfig` / `conf/config.yaml`). + +use std::sync::OnceLock; + +use crate::config::streaming_job::ResolvedStreamingJobConfig; +use crate::sql::common::constants::sql_planning_default; +use crate::sql::types::SqlConfig; + +static SQL_PLANNING: OnceLock = OnceLock::new(); + +/// Installs [`SqlConfig`] derived from resolved streaming job YAML (KeyBy parallelism, etc.). +/// Safe to call once at bootstrap; later calls are ignored if already set. +pub fn install_sql_planning_from_streaming_job(job: &ResolvedStreamingJobConfig) { + let cfg = SqlConfig { + default_parallelism: sql_planning_default::DEFAULT_PARALLELISM, + key_by_parallelism: job.key_by_parallelism as usize, + }; + let _ = SQL_PLANNING.set(cfg).ok(); +} + +pub(crate) fn sql_planning_snapshot() -> SqlConfig { + SQL_PLANNING.get().cloned().unwrap_or_default() +} diff --git a/src/sql/schema/kafka_operator_config.rs b/src/sql/schema/kafka_operator_config.rs index d9251310..d87dda8f 100644 --- a/src/sql/schema/kafka_operator_config.rs +++ b/src/sql/schema/kafka_operator_config.rs @@ -24,6 +24,9 @@ use crate::sql::common::formats::{ use crate::sql::common::with_option_keys as opt; use crate::sql::schema::table_role::TableRole; +const STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL: &str = "checkpoint.interval"; +const STREAMING_JOB_OPTION_PARALLELISM: &str = "parallelism"; + fn sql_format_to_proto(fmt: &SqlFormat) -> DFResult { match fmt { SqlFormat::Json(j) => Ok(FormatConfig { @@ -194,7 +197,10 @@ pub fn build_kafka_proto_config( }; let group_id_prefix = options.pull_opt_str(opt::KAFKA_GROUP_ID_PREFIX)?; - let client_configs = options.drain_remaining_string_values()?; + let mut client_configs = options.drain_remaining_string_values()?; + // Streaming job-level options are parsed by planner/coordinator, not Kafka client. + client_configs.remove(STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL); + client_configs.remove(STREAMING_JOB_OPTION_PARALLELISM); Ok(ProtoConfig::KafkaSource(KafkaSourceConfig { topic, @@ -242,7 +248,10 @@ pub fn build_kafka_proto_config( None => options.pull_opt_str(opt::KAFKA_TIMESTAMP_FIELD_LEGACY)?, }; - let client_configs = options.drain_remaining_string_values()?; + let mut client_configs = options.drain_remaining_string_values()?; + // Streaming job-level options are parsed by planner/coordinator, not Kafka client. + client_configs.remove(STREAMING_JOB_OPTION_CHECKPOINT_INTERVAL); + client_configs.remove(STREAMING_JOB_OPTION_PARALLELISM); Ok(ProtoConfig::KafkaSink(KafkaSinkConfig { topic, diff --git a/src/sql/schema/schema_provider.rs b/src/sql/schema/schema_provider.rs index d5405dd2..26fd43e8 100644 --- a/src/sql/schema/schema_provider.rs +++ b/src/sql/schema/schema_provider.rs @@ -29,7 +29,7 @@ use crate::sql::common::constants::{planning_placeholder_udf, window_fn}; use crate::sql::logical_node::logical::{DylibUdfConfig, LogicalProgram}; use crate::sql::schema::table::Table as CatalogTable; use crate::sql::schema::utils::window_arrow_struct; -use crate::sql::types::{PlanningOptions, PlanningPlaceholderUdf}; +use crate::sql::types::{PlanningOptions, PlanningPlaceholderUdf, SqlConfig}; pub type ObjectName = UniCase; @@ -129,6 +129,7 @@ pub struct StreamPlanningContext { pub config_options: datafusion::config::ConfigOptions, pub planning_options: PlanningOptions, pub analyzer: Analyzer, + pub sql_config: SqlConfig, } /// Back-compat name for [`StreamPlanningContext`]. @@ -139,14 +140,26 @@ impl StreamPlanningContext { StreamPlanningContextBuilder::default() } + #[inline] + pub fn default_parallelism(&self) -> usize { + self.sql_config.default_parallelism + } + + #[inline] + pub fn key_by_parallelism(&self) -> usize { + self.sql_config.key_by_parallelism + } + /// Same registration order as the historical `StreamSchemaProvider::new` (placeholders, then DataFusion defaults). pub fn new() -> Self { - Self::builder() + let mut ctx = Self::builder() .with_streaming_extensions() .expect("streaming extensions") .with_default_functions() .expect("default functions") - .build() + .build(); + ctx.sql_config = crate::sql::planning_runtime::sql_planning_snapshot(); + ctx } pub fn register_stream_table(&mut self, table: StreamTable) { diff --git a/src/sql/types/mod.rs b/src/sql/types/mod.rs index c9d80681..d5124bcc 100644 --- a/src/sql/types/mod.rs +++ b/src/sql/types/mod.rs @@ -38,12 +38,15 @@ pub enum ProcessingMode { #[derive(Clone, Debug)] pub struct SqlConfig { pub default_parallelism: usize, + /// Physical pipeline parallelism for [`KeyExtractionNode`](crate::sql::logical_node::key_calculation::KeyExtractionNode) / KeyBy. + pub key_by_parallelism: usize, } impl Default for SqlConfig { fn default() -> Self { Self { default_parallelism: sql_planning_default::DEFAULT_PARALLELISM, + key_by_parallelism: sql_planning_default::DEFAULT_KEY_BY_PARALLELISM, } } } diff --git a/src/storage/stream_catalog/manager.rs b/src/storage/stream_catalog/manager.rs index 3804a95a..3c9d561e 100644 --- a/src/storage/stream_catalog/manager.rs +++ b/src/storage/stream_catalog/manager.rs @@ -10,6 +10,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::path::Path; use std::sync::{Arc, OnceLock}; use anyhow::{Context, anyhow, bail}; @@ -17,7 +18,9 @@ use datafusion::common::{Result as DFResult, internal_err, plan_err}; use prost::Message; use protocol::function_stream_graph::FsProgram; use protocol::storage::{self as pb, table_definition}; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; + +use crate::runtime::streaming::operators::source::kafka as kafka_snap; use unicase::UniCase; use crate::sql::common::constants::sql_field; @@ -33,6 +36,93 @@ use super::meta_store::MetaStore; const CATALOG_KEY_PREFIX: &str = "catalog:stream_table:"; const STREAMING_JOB_KEY_PREFIX: &str = "streaming_job:"; +/// One persisted streaming job row from catalog (program + checkpoint metadata + Kafka offsets). +#[derive(Debug, Clone)] +pub struct StoredStreamingJob { + pub table_name: String, + pub program: FsProgram, + pub checkpoint_interval_ms: u64, + pub latest_checkpoint_epoch: u64, + pub kafka_source_checkpoints: Vec, +} + +fn parse_kafka_offset_snapshot_filename(name: &str) -> Option<(u32, u32)> { + const PREFIX: &str = "kafka_source_offsets_pipe"; + const SUFFIX: &str = ".bin"; + if !name.starts_with(PREFIX) || !name.ends_with(SUFFIX) { + return None; + } + let mid = name.strip_prefix(PREFIX)?.strip_suffix(SUFFIX)?; + let (pipe, sub_part) = mid.split_once("_sub")?; + Some((pipe.parse().ok()?, sub_part.parse().ok()?)) +} + +/// Removes on-disk staging snapshots once their payload is committed into catalog (same epoch). +fn cleanup_kafka_offset_snapshots_for_epoch(job_dir: &Path, epoch: u64) { + let Ok(rd) = std::fs::read_dir(job_dir) else { + return; + }; + for ent in rd.flatten() { + let path = ent.path(); + let name = ent.file_name().to_string_lossy().into_owned(); + if parse_kafka_offset_snapshot_filename(&name).is_none() { + continue; + } + let Ok(bytes) = std::fs::read(&path) else { + continue; + }; + let Ok(saved) = kafka_snap::decode_kafka_offset_snapshot(&bytes) else { + continue; + }; + if saved.epoch == epoch && std::fs::remove_file(&path).is_err() { + debug!(path = %path.display(), "Could not remove staged Kafka offset snapshot (non-fatal)"); + } + } +} + +/// Writes catalog-stored Kafka checkpoints back to the job state dir before `submit_job` resumes sources. +pub fn materialize_kafka_source_checkpoints_from_catalog( + job_dir: &Path, + checkpoints: &[pb::KafkaSourceSubtaskCheckpoint], +) -> DFResult<()> { + if checkpoints.is_empty() { + return Ok(()); + } + std::fs::create_dir_all(job_dir).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "create job state dir {}: {e}", + job_dir.display() + )) + })?; + for c in checkpoints { + let saved = kafka_snap::KafkaSourceSavedOffsets { + epoch: c.checkpoint_epoch, + partitions: c + .partitions + .iter() + .map(|p| kafka_snap::KafkaState { + partition: p.partition, + offset: p.offset, + }) + .collect(), + }; + let path = kafka_snap::kafka_snapshot_path(job_dir, c.pipeline_id, c.subtask_index); + let bytes = kafka_snap::encode_kafka_offset_snapshot(&saved).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "encode kafka snapshot for {}: {e}", + path.display() + )) + })?; + std::fs::write(&path, &bytes).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "write kafka snapshot {}: {e}", + path.display() + )) + })?; + } + Ok(()) +} + pub struct CatalogManager { store: Arc, } @@ -88,6 +178,7 @@ impl CatalogManager { table_name: &str, fs_program: &FsProgram, comment: &str, + checkpoint_interval_ms: u64, ) -> DFResult<()> { let program_bytes = fs_program.encode_to_vec(); let def = pb::StreamingTableDefinition { @@ -95,11 +186,14 @@ impl CatalogManager { created_at_millis: chrono::Utc::now().timestamp_millis(), fs_program_bytes: program_bytes, comment: comment.to_string(), + checkpoint_interval_ms, + latest_checkpoint_epoch: 0, + kafka_source_checkpoints: vec![], }; let payload = def.encode_to_vec(); let key = Self::build_streaming_job_key(table_name); self.store.put(&key, payload)?; - info!(table = %table_name, "Streaming job definition persisted"); + info!(table = %table_name, interval_ms = checkpoint_interval_ms, "Streaming job definition persisted"); Ok(()) } @@ -110,7 +204,56 @@ impl CatalogManager { Ok(()) } - pub fn load_streaming_job_definitions(&self) -> DFResult> { + /// Persist the globally-completed checkpoint epoch after all operators ACK. + /// Only advances forward; stale epochs are silently ignored. + /// + /// `kafka_source_checkpoints` is assembled by the job coordinator from source pipeline checkpoint + /// ACKs (in-memory); it is stored next to `latest_checkpoint_epoch` in the catalog. + /// + /// `job_state_dir` is only used to remove legacy on-disk staging snapshots for this epoch, if present. + pub fn commit_job_checkpoint( + &self, + table_name: &str, + epoch: u64, + job_state_dir: &Path, + kafka_source_checkpoints: Vec, + ) -> DFResult<()> { + let key = Self::build_streaming_job_key(table_name); + + let current_payload = self.store.get(&key)?.ok_or_else(|| { + datafusion::common::DataFusionError::Plan(format!( + "Cannot commit checkpoint: Streaming job '{}' not found in catalog", + table_name + )) + })?; + + let mut def = + pb::StreamingTableDefinition::decode(current_payload.as_slice()).map_err(|e| { + datafusion::common::DataFusionError::Execution(format!( + "Protobuf decode error: {}", + e + )) + })?; + + if epoch > def.latest_checkpoint_epoch { + def.latest_checkpoint_epoch = epoch; + def.kafka_source_checkpoints = kafka_source_checkpoints; + let new_payload = def.encode_to_vec(); + self.store.put(&key, new_payload)?; + debug!( + table = %table_name, + epoch = epoch, + kafka_subtasks = def.kafka_source_checkpoints.len(), + "Checkpoint metadata committed to Catalog" + ); + cleanup_kafka_offset_snapshots_for_epoch(job_state_dir, epoch); + } + + Ok(()) + } + + /// Load all persisted streaming jobs (including Kafka offset checkpoints for restore). + pub fn load_streaming_job_definitions(&self) -> DFResult> { let records = self.store.scan_prefix(STREAMING_JOB_KEY_PREFIX)?; let mut out = Vec::with_capacity(records.len()); for (key, payload) in records { @@ -136,7 +279,13 @@ impl CatalogManager { continue; } }; - out.push((def.table_name, program)); + out.push(StoredStreamingJob { + table_name: def.table_name, + program, + checkpoint_interval_ms: def.checkpoint_interval_ms, + latest_checkpoint_epoch: def.latest_checkpoint_epoch, + kafka_source_checkpoints: def.kafka_source_checkpoints, + }); } Ok(out) } @@ -522,12 +671,45 @@ pub fn restore_streaming_jobs_from_store() { let mut restored = 0usize; let mut failed = 0usize; - for (table_name, fs_program) in definitions { + for job in definitions { + let StoredStreamingJob { + table_name, + program, + checkpoint_interval_ms: interval_ms, + latest_checkpoint_epoch: latest_epoch, + kafka_source_checkpoints, + } = job; let jm = job_manager.clone(); let name = table_name.clone(); - match rt.block_on(jm.submit_job(name.clone(), fs_program)) { + + let job_dir = jm.job_state_directory(&table_name); + if let Err(e) = + materialize_kafka_source_checkpoints_from_catalog(&job_dir, &kafka_source_checkpoints) + { + warn!( + table = %table_name, + error = %e, + "Failed to materialize Kafka checkpoints from catalog before job restore" + ); + } + + let custom_interval = if interval_ms > 0 { + Some(interval_ms) + } else { + None + }; + let recovery_epoch = if latest_epoch > 0 { + Some(latest_epoch) + } else { + None + }; + + match rt.block_on(jm.submit_job(name.clone(), program, custom_interval, recovery_epoch)) { Ok(job_id) => { - info!(table = %table_name, job_id = %job_id, "Streaming job restored"); + info!( + table = %table_name, job_id = %job_id, + epoch = latest_epoch, "Streaming job restored" + ); restored += 1; } Err(e) => { diff --git a/src/storage/stream_catalog/mod.rs b/src/storage/stream_catalog/mod.rs index 6f31317a..ef176c40 100644 --- a/src/storage/stream_catalog/mod.rs +++ b/src/storage/stream_catalog/mod.rs @@ -17,8 +17,10 @@ mod manager; mod meta_store; mod rocksdb_meta_store; +#[allow(unused_imports)] pub use manager::{ - CatalogManager, initialize_stream_catalog, restore_global_catalog_from_store, + CatalogManager, StoredStreamingJob, initialize_stream_catalog, + materialize_kafka_source_checkpoints_from_catalog, restore_global_catalog_from_store, restore_streaming_jobs_from_store, }; pub use meta_store::{InMemoryMetaStore, MetaStore}; diff --git a/tests/integration/framework/kafka_manager.py b/tests/integration/framework/kafka_manager.py index e495f638..3898fc7d 100644 --- a/tests/integration/framework/kafka_manager.py +++ b/tests/integration/framework/kafka_manager.py @@ -19,7 +19,7 @@ import logging import time -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, TypeVar import docker @@ -28,6 +28,8 @@ from confluent_kafka import TopicPartition as _new_topic_partition from confluent_kafka.admin import AdminClient, NewTopic +from .utils import find_free_port + logger = logging.getLogger(__name__) T = TypeVar("T") @@ -97,7 +99,10 @@ def __init__( config: Optional[KafkaConfig] = None, docker_client: Optional[docker.DockerClient] = None, ) -> None: - self.config = config or KafkaConfig() + if config is None: + host_port = find_free_port() + config = KafkaConfig(bootstrap_servers=f"127.0.0.1:{host_port}") + self.config = config # Dependency Injection: Allow passing an existing client, or create a lazy one. self._docker_client = docker_client @@ -173,7 +178,11 @@ def _ensure_container(self) -> None: self.docker_client.containers.run( image=self.config.image, name=self.config.container_name, - ports={f"{self.config.internal_port}/tcp": self.config.internal_port}, + ports={ + f"{self.config.internal_port}/tcp": int( + self.config.bootstrap_servers.rsplit(":", 1)[1] + ) + }, environment=self.config.environment_vars, detach=True, remove=True, # Auto-remove on stop diff --git a/tests/integration/test/wasm/python_sdk/conftest.py b/tests/integration/test/wasm/python_sdk/conftest.py index aa5d60c6..e0acd26e 100644 --- a/tests/integration/test/wasm/python_sdk/conftest.py +++ b/tests/integration/test/wasm/python_sdk/conftest.py @@ -54,10 +54,25 @@ def kafka_topics(kafka: KafkaDockerManager) -> str: return kafka.config.bootstrap_servers -def _sanitize_node_id(nodeid: str) -> str: - """Converts a pytest nodeid into a safe directory name.""" - clean_name = re.sub(r"[^\w\-]+", "-", nodeid) - return clean_name.strip("-") +def _sanitize_segment(segment: str) -> str: + clean = re.sub(r"[^\w\-]+", "_", segment).strip("_") + return clean or "unknown" + + +def _nodeid_to_workspace_path(nodeid: str) -> str: + """ + Convert pytest nodeid into a readable nested path under target/. + + Example: + test/wasm/python_sdk/test_data_flow.py::TestDataFlow::test_single_word_counting + -> + test/wasm/python_sdk/test_data_flow/TestDataFlow/test_single_word_counting + """ + parts = nodeid.split("::") + file_part = Path(parts[0]).with_suffix("") + file_segments = [_sanitize_segment(seg) for seg in file_part.parts] + extra_segments = [_sanitize_segment(seg) for seg in parts[1:]] + return str(Path(*file_segments, *extra_segments)) @pytest.fixture @@ -66,7 +81,7 @@ def fs_server(request: pytest.FixtureRequest) -> Generator[FunctionStreamInstanc Function-scoped FunctionStream instance. Uses Context Manager to ensure SIGKILL and workspace cleanup. """ - test_name = _sanitize_node_id(request.node.nodeid) + test_name = _nodeid_to_workspace_path(request.node.nodeid) with FunctionStreamInstance(test_name=test_name) as instance: yield instance diff --git a/tests/integration/test/wasm/python_sdk/test_data_flow.py b/tests/integration/test/wasm/python_sdk/test_data_flow.py index 9e9532a2..7fc89d7f 100644 --- a/tests/integration/test/wasm/python_sdk/test_data_flow.py +++ b/tests/integration/test/wasm/python_sdk/test_data_flow.py @@ -74,6 +74,13 @@ def consume_messages( deadline = time.time() + timeout try: + logger.info( + "Start consuming topic=%s expected_count=%d timeout=%.1fs bootstrap=%s", + topic, + expected_count, + timeout, + bootstrap, + ) while len(collected) < expected_count and time.time() < deadline: msg = consumer.poll(timeout=POLL_INTERVAL_S) if msg is None: @@ -85,11 +92,15 @@ def consume_messages( payload = msg.value().decode("utf-8") collected.append(json.loads(payload)) + logger.info("Consumed topic=%s count=%d payload=%s", topic, len(collected), payload) finally: consumer.close() if len(collected) < expected_count: - raise TimeoutError(f"Expected {expected_count} messages, received {len(collected)}") + raise TimeoutError( + f"Expected {expected_count} messages, received {len(collected)}. " + f"topic={topic}, collected={collected}" + ) return collected diff --git a/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py b/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py new file mode 100644 index 00000000..dcf5df18 --- /dev/null +++ b/tests/integration/test/wasm/python_sdk/test_streaming_sql_kafka.py @@ -0,0 +1,256 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import json +import time +import uuid +from typing import Any, Dict, List + +from .test_data_flow import consume_messages, produce_messages + + +def _uid(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:8]}" + + +def _sql_ok(fs_server: Any, sql: str) -> Any: + resp = fs_server.execute_sql(sql) + assert resp.status_code == 200, f"SQL failed: {sql}\nstatus={resp.status_code}\nmsg={resp.message}" + return resp + + +class TestStreamingSqlKafka: + @staticmethod + def _create_impression_source(fs_server: Any, source_name: str, in_topic: str, bootstrap: str) -> None: + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + impression_id VARCHAR, + ad_id BIGINT, + campaign_id BIGINT, + user_id VARCHAR, + impression_time TIMESTAMP NOT NULL, + WATERMARK FOR impression_time AS impression_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{bootstrap}' + ); + """, + ) + + @staticmethod + def _create_click_source(fs_server: Any, source_name: str, in_topic: str, bootstrap: str) -> None: + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + click_id VARCHAR, + impression_id VARCHAR, + ad_id BIGINT, + click_time TIMESTAMP NOT NULL, + WATERMARK FOR click_time AS click_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{bootstrap}' + ); + """, + ) + + def test_tumble_window_with_kafka_produce_consume( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_tumble_impressions") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + + kafka.create_topics_if_not_exist([in_topic, out_topic]) + + _sql_ok( + fs_server, + f""" + CREATE TABLE {source_name} ( + impression_id VARCHAR, + campaign_id BIGINT, + impression_time TIMESTAMP NOT NULL, + WATERMARK FOR impression_time AS impression_time - INTERVAL '1' SECOND + ) WITH ( + 'connector' = 'kafka', + 'topic' = '{in_topic}', + 'format' = 'json', + 'scan.startup.mode' = 'earliest', + 'bootstrap.servers' = '{kafka_topics}' + ); + """, + ) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + TUMBLE(INTERVAL '2' SECOND) AS time_window, + campaign_id, + COUNT(*) AS total_impressions + FROM {source_name} + GROUP BY 1, campaign_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=8) + old_window_msgs: List[Dict[str, Any]] = [ + { + "impression_id": "i-1", + "campaign_id": 1001, + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat(), + }, + { + "impression_id": "i-2", + "campaign_id": 1001, + "impression_time": (base + dt.timedelta(milliseconds=500)).isoformat(), + }, + { + "impression_id": "i-3", + "campaign_id": 1002, + "impression_time": (base + dt.timedelta(milliseconds=900)).isoformat(), + }, + ] + advance_wm = { + "impression_id": "i-4", + "campaign_id": 9999, + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat(), + } + + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in old_window_msgs + [advance_wm]]) + time.sleep(1.0) + + records = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(int(r["campaign_id"]), int(r["total_impressions"])) for r in records} + assert got == {(1001, 2), (1002, 1)} + + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") + + def test_hop_window_with_where_filter( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_hop_uv") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + kafka.create_topics_if_not_exist([in_topic, out_topic]) + self._create_impression_source(fs_server, source_name, in_topic, kafka_topics) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + HOP(INTERVAL '1' SECOND, INTERVAL '4' SECOND) AS time_window, + ad_id, + COUNT(*) AS kept_rows + FROM {source_name} + WHERE campaign_id = 2001 + GROUP BY 1, ad_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=8) + msgs = [ + {"impression_id": "h1", "ad_id": 11, "campaign_id": 2001, "user_id": "u1", + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat()}, + {"impression_id": "h2", "ad_id": 11, "campaign_id": 2002, "user_id": "u2", + "impression_time": (base + dt.timedelta(milliseconds=300)).isoformat()}, + {"impression_id": "h3", "ad_id": 12, "campaign_id": 2001, "user_id": "u3", + "impression_time": (base + dt.timedelta(milliseconds=600)).isoformat()}, + {"impression_id": "h4", "ad_id": 999, "campaign_id": 9999, "user_id": "wm", + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat()}, + ] + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in msgs]) + rows = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(int(r["ad_id"]), int(r["kept_rows"])) for r in rows} + assert got == {(11, 1), (12, 1)} + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") + + def test_session_window_user_activity( + self, + fs_server: Any, + kafka: Any, + kafka_topics: str, + ) -> None: + source_name = _uid("ad_impressions_src") + stream_name = _uid("metric_session_impr") + in_topic = _uid("topic_in") + out_topic = _uid("topic_out") + kafka.create_topics_if_not_exist([in_topic, out_topic]) + self._create_impression_source(fs_server, source_name, in_topic, kafka_topics) + + _sql_ok( + fs_server, + f""" + CREATE STREAMING TABLE {stream_name} WITH ( + 'connector' = 'kafka', + 'topic' = '{out_topic}', + 'format' = 'json', + 'bootstrap.servers' = '{kafka_topics}' + ) AS + SELECT + SESSION(INTERVAL '2' SECOND) AS time_window, + user_id, + COUNT(*) AS impressions_in_session + FROM {source_name} + GROUP BY 1, user_id; + """, + ) + + now = dt.datetime.now(dt.timezone.utc) + base = now.replace(microsecond=0) - dt.timedelta(seconds=10) + msgs = [ + {"impression_id": "s1", "ad_id": 1, "campaign_id": 1, "user_id": "uA", + "impression_time": (base + dt.timedelta(milliseconds=100)).isoformat()}, + {"impression_id": "s2", "ad_id": 1, "campaign_id": 1, "user_id": "uA", + "impression_time": (base + dt.timedelta(milliseconds=900)).isoformat()}, + {"impression_id": "s3", "ad_id": 2, "campaign_id": 1, "user_id": "uB", + "impression_time": (base + dt.timedelta(milliseconds=1200)).isoformat()}, + {"impression_id": "s4", "ad_id": 999, "campaign_id": 9999, "user_id": "wm", + "impression_time": dt.datetime.now(dt.timezone.utc).isoformat()}, + ] + produce_messages(kafka_topics, in_topic, [json.dumps(x) for x in msgs]) + rows = consume_messages(kafka_topics, out_topic, expected_count=2, timeout=15.0) + got = {(r["user_id"], int(r["impressions_in_session"])) for r in rows} + assert got == {("uA", 2), ("uB", 1)} + _sql_ok(fs_server, f"DROP STREAMING TABLE {stream_name};") \ No newline at end of file