diff --git a/docs/source/_static/images/tracing.png b/docs/source/_static/images/tracing.png index 78897512fe..46224945f7 100644 Binary files a/docs/source/_static/images/tracing.png and b/docs/source/_static/images/tracing.png differ diff --git a/docs/source/contributor-guide/tracing.md b/docs/source/contributor-guide/tracing.md index b9b4fe0dcc..88a291f421 100644 --- a/docs/source/contributor-guide/tracing.md +++ b/docs/source/contributor-guide/tracing.md @@ -40,25 +40,76 @@ Example output: { "name": "decodeShuffleBlock", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109228835 }, { "name": "decodeShuffleBlock", "cat": "PERF", "ph": "B", "pid": 1, "tid": 5, "ts": 10109245928 }, { "name": "decodeShuffleBlock", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109248843 }, -{ "name": "execute_plan", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109350935 }, -{ "name": "CometExecIterator_getNextBatch", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109367116 }, -{ "name": "CometExecIterator_getNextBatch", "cat": "PERF", "ph": "B", "pid": 1, "tid": 5, "ts": 10109479156 }, +{ "name": "executePlan", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109350935 }, +{ "name": "getNextBatch[JVM] stage=2", "cat": "PERF", "ph": "E", "pid": 1, "tid": 5, "ts": 10109367116 }, +{ "name": "getNextBatch[JVM] stage=2", "cat": "PERF", "ph": "B", "pid": 1, "tid": 5, "ts": 10109479156 }, ``` -Traces can be viewed with [Trace Viewer]. +Traces can be viewed with [Perfetto UI]. -[Trace Viewer]: https://github.com/catapult-project/catapult/blob/main/tracing/README.md +[Perfetto UI]: https://ui.perfetto.dev Example trace visualization: ![tracing](../_static/images/tracing.png) +## Analyzing Memory Usage + +The `analyze_trace` tool parses a trace log and compares jemalloc usage against the sum of per-thread +Comet memory pool reservations. This is useful for detecting untracked native memory growth where jemalloc +allocations exceed what the memory pools account for. + +Build and run: + +```shell +cd native +cargo run --bin analyze_trace -- /path/to/comet-event-trace.json +``` + +The tool reads counter events from the trace log. Because tracing logs metrics per thread, `jemalloc_allocated` +is a process-wide value (the same global allocation reported from whichever thread logs it), while +`thread_NNN_comet_memory_reserved` values are per-thread pool reservations that are summed to get the total +tracked memory. + +Sample output: + +``` +=== Comet Trace Memory Analysis === + +Counter events parsed: 193104 +Threads with memory pools: 8 +Peak jemalloc allocated: 3068.2 MB +Peak pool total: 2864.6 MB +Peak excess (jemalloc - pool): 364.6 MB + +WARNING: jemalloc exceeded pool reservation at 138 sampled points: + + Time (us) jemalloc pool_total excess +-------------------------------------------------------------- + 179578 210.8 MB 0.1 MB 210.7 MB + 429663 420.5 MB 145.1 MB 275.5 MB + 1304969 2122.5 MB 1797.2 MB 325.2 MB + 21974838 407.0 MB 42.3 MB 364.6 MB + 33543599 5.5 MB 0.1 MB 5.3 MB + +--- Final per-thread pool reservations --- + + thread_60_comet_memory_reserved: 0.0 MB + thread_95_comet_memory_reserved: 0.0 MB + thread_96_comet_memory_reserved: 0.0 MB + ... + + Total: 0.0 MB +``` + +Some excess is expected (jemalloc metadata, fragmentation, non-pool allocations like Arrow IPC buffers). +Large or growing excess may indicate memory that is not being tracked by the pool. + ## Definition of Labels -| Label | Meaning | -| --------------------- | -------------------------------------------------------------- | -| jvm_heapUsed | JVM heap memory usage of live objects for the executor process | -| jemalloc_allocated | Native memory usage for the executor process | -| task_memory_comet_NNN | Off-heap memory allocated by Comet for query execution | -| task_memory_spark_NNN | On-heap & Off-heap memory allocated by Spark | -| comet_shuffle_NNN | Off-heap memory allocated by Comet for columnar shuffle | +| Label | Meaning | +| -------------------------------- | ------------------------------------------------------------------------------------------------------------------------ | +| jvm_heap_used | JVM heap memory usage of live objects for the executor process | +| jemalloc_allocated | Native memory usage for the executor process (requires `jemalloc` feature) | +| thread_NNN_comet_memory_reserved | Memory reserved by Comet's DataFusion memory pool (summed across all contexts on the thread). NNN is the Rust thread ID. | +| thread_NNN_comet_jvm_shuffle | Off-heap memory allocated by Comet for columnar shuffle. NNN is the Rust thread ID. | diff --git a/native/common/Cargo.toml b/native/common/Cargo.toml index 3bbc44856e..5704263eb6 100644 --- a/native/common/Cargo.toml +++ b/native/common/Cargo.toml @@ -38,3 +38,7 @@ thiserror = { workspace = true } [lib] name = "datafusion_comet_common" path = "src/lib.rs" + +[[bin]] +name = "analyze_trace" +path = "src/bin/analyze_trace.rs" diff --git a/native/common/src/bin/analyze_trace.rs b/native/common/src/bin/analyze_trace.rs new file mode 100644 index 0000000000..8df83f7cea --- /dev/null +++ b/native/common/src/bin/analyze_trace.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzes a Comet chrome trace event log (`comet-event-trace.json`) and +//! compares jemalloc usage against the sum of per-thread Comet memory pool +//! reservations. Reports any points where jemalloc exceeds the total pool size. +//! +//! Usage: +//! cargo run --bin analyze_trace -- + +use serde::Deserialize; +use std::collections::HashMap; +use std::io::{BufRead, BufReader}; +use std::{env, fs::File}; + +/// A single Chrome trace event (only the fields we care about). +#[derive(Deserialize)] +struct TraceEvent { + name: String, + ph: String, + #[allow(dead_code)] + tid: u64, + ts: u64, + #[serde(default)] + args: HashMap, +} + +/// Snapshot of memory state at a given timestamp. +struct MemorySnapshot { + ts: u64, + jemalloc: u64, + pool_total: u64, +} + +fn format_bytes(bytes: u64) -> String { + const MB: f64 = 1024.0 * 1024.0; + format!("{:.1} MB", bytes as f64 / MB) +} + +fn main() { + let args: Vec = env::args().collect(); + if args.len() != 2 { + eprintln!("Usage: analyze_trace "); + std::process::exit(1); + } + + let file = File::open(&args[1]).expect("Failed to open trace file"); + let reader = BufReader::new(file); + + // Latest jemalloc value (global, not per-thread) + let mut latest_jemalloc: u64 = 0; + // Per-thread pool reservations: thread_NNN -> bytes + let mut pool_by_thread: HashMap = HashMap::new(); + // Points where jemalloc exceeded pool total + let mut violations: Vec = Vec::new(); + // Track peak values + let mut peak_jemalloc: u64 = 0; + let mut peak_pool_total: u64 = 0; + let mut peak_excess: u64 = 0; + let mut counter_events: u64 = 0; + + // Each line is one JSON event, possibly with a trailing comma. + // The file starts with "[ " on the first event line or as a prefix. + for line in reader.lines() { + let line = line.expect("Failed to read line"); + let trimmed = line.trim(); + + // Skip empty lines or bare array brackets + if trimmed.is_empty() || trimmed == "[" || trimmed == "]" { + continue; + } + + // Strip leading "[ " (first event) and trailing comma + let json_str = trimmed + .trim_start_matches("[ ") + .trim_start_matches('[') + .trim_end_matches(','); + + if json_str.is_empty() { + continue; + } + + // Only parse counter events (they contain "\"ph\": \"C\"") + if !json_str.contains("\"ph\": \"C\"") { + continue; + } + + let event: TraceEvent = match serde_json::from_str(json_str) { + Ok(e) => e, + Err(_) => continue, + }; + + if event.ph != "C" { + continue; + } + + counter_events += 1; + + if event.name == "jemalloc_allocated" { + if let Some(val) = event.args.get("jemalloc_allocated") { + latest_jemalloc = val.as_u64().unwrap_or(0); + if latest_jemalloc > peak_jemalloc { + peak_jemalloc = latest_jemalloc; + } + } + } else if event.name.contains("comet_memory_reserved") { + // Name format: thread_NNN_comet_memory_reserved + let thread_key = event.name.clone(); + if let Some(val) = event.args.get(&event.name) { + let bytes = val.as_u64().unwrap_or(0); + pool_by_thread.insert(thread_key, bytes); + } + } else { + // Skip jvm_heap_used and other counters + continue; + } + + // After each jemalloc or pool update, check the current state + let pool_total: u64 = pool_by_thread.values().sum(); + if pool_total > peak_pool_total { + peak_pool_total = pool_total; + } + + if latest_jemalloc > 0 && pool_total > 0 && latest_jemalloc > pool_total { + let excess = latest_jemalloc - pool_total; + if excess > peak_excess { + peak_excess = excess; + } + // Record violation (sample - don't record every single one) + if violations.is_empty() + || event.ts.saturating_sub(violations.last().unwrap().ts) > 1_000_000 + || excess == peak_excess + { + violations.push(MemorySnapshot { + ts: event.ts, + jemalloc: latest_jemalloc, + pool_total, + }); + } + } + } + + // Print summary + println!("=== Comet Trace Memory Analysis ===\n"); + println!("Counter events parsed: {counter_events}"); + println!("Threads with memory pools: {}", pool_by_thread.len()); + println!("Peak jemalloc allocated: {}", format_bytes(peak_jemalloc)); + println!( + "Peak pool total: {}", + format_bytes(peak_pool_total) + ); + println!( + "Peak excess (jemalloc - pool): {}", + format_bytes(peak_excess) + ); + println!(); + + if violations.is_empty() { + println!("OK: jemalloc never exceeded the total pool reservation."); + } else { + println!( + "WARNING: jemalloc exceeded pool reservation at {} sampled points:\n", + violations.len() + ); + println!( + "{:>14} {:>14} {:>14} {:>14}", + "Time (us)", "jemalloc", "pool_total", "excess" + ); + println!("{}", "-".repeat(62)); + for snap in &violations { + let excess = snap.jemalloc - snap.pool_total; + println!( + "{:>14} {:>14} {:>14} {:>14}", + snap.ts, + format_bytes(snap.jemalloc), + format_bytes(snap.pool_total), + format_bytes(excess), + ); + } + } + + // Show final per-thread pool state + println!("\n--- Final per-thread pool reservations ---\n"); + let mut threads: Vec<_> = pool_by_thread.iter().collect(); + threads.sort_by_key(|(k, _)| (*k).clone()); + for (thread, bytes) in &threads { + println!(" {thread}: {}", format_bytes(**bytes)); + } + println!("\n Total: {}", format_bytes(pool_by_thread.values().sum())); +} diff --git a/native/common/src/tracing.rs b/native/common/src/tracing.rs index 58bea64a7a..aad4e7269b 100644 --- a/native/common/src/tracing.rs +++ b/native/common/src/tracing.rs @@ -64,10 +64,9 @@ impl Recorder { } pub fn log_memory_usage(&self, name: &str, usage_bytes: u64) { - let usage_mb = (usage_bytes as f64 / 1024.0 / 1024.0) as usize; let json = format!( - "{{ \"name\": \"{name}\", \"cat\": \"PERF\", \"ph\": \"C\", \"pid\": 1, \"tid\": {}, \"ts\": {}, \"args\": {{ \"{name}\": {usage_mb} }} }},\n", - Self::get_thread_id(), + "{{ \"name\": \"{name}\", \"cat\": \"PERF\", \"ph\": \"C\", \"pid\": 1, \"tid\": {}, \"ts\": {}, \"args\": {{ \"{name}\": {usage_bytes} }} }},\n", + get_thread_id(), self.now.elapsed().as_micros() ); let mut writer = self.writer.lock().unwrap(); @@ -80,7 +79,7 @@ impl Recorder { let json = format!( "{{ \"name\": \"{}\", \"cat\": \"PERF\", \"ph\": \"{ph}\", \"pid\": 1, \"tid\": {}, \"ts\": {} }},\n", name, - Self::get_thread_id(), + get_thread_id(), self.now.elapsed().as_micros() ); let mut writer = self.writer.lock().unwrap(); @@ -88,15 +87,15 @@ impl Recorder { .write_all(json.as_bytes()) .expect("Error writing tracing"); } +} - fn get_thread_id() -> u64 { - let thread_id = std::thread::current().id(); - format!("{thread_id:?}") - .trim_start_matches("ThreadId(") - .trim_end_matches(")") - .parse() - .expect("Error parsing thread id") - } +pub fn get_thread_id() -> u64 { + let thread_id = std::thread::current().id(); + format!("{thread_id:?}") + .trim_start_matches("ThreadId(") + .trim_end_matches(")") + .parse() + .expect("Error parsing thread id") } pub fn trace_begin(name: &str) { diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0a0b46478d..93f75bae96 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -73,6 +73,7 @@ use jni::{ sys::{jboolean, jdouble, jint, jlong}, Env, EnvUnowned, }; +use parking_lot::Mutex; use std::collections::HashMap; use std::path::PathBuf; use std::time::{Duration, Instant}; @@ -87,7 +88,9 @@ use crate::execution::operators::{ScanExec, ShuffleScanExec}; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; -use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; +use crate::execution::tracing::{ + get_thread_id, log_memory_usage, trace_begin, trace_end, with_trace, +}; use crate::execution::memory_pools::logging_pool::LoggingMemoryPool; use crate::execution::spark_config::{ @@ -103,6 +106,53 @@ use tikv_jemalloc_ctl::{epoch, stats}; static TOKIO_RUNTIME: OnceLock = OnceLock::new(); +#[cfg(feature = "jemalloc")] +fn log_jemalloc_usage() { + let e = epoch::mib().unwrap(); + let allocated = stats::allocated::mib().unwrap(); + e.advance().unwrap(); + log_memory_usage("jemalloc_allocated", allocated.read().unwrap() as u64); +} + +/// Registry of active memory pools per Rust thread ID. +/// Used to sum memory reservations across all contexts on the same thread for tracing. +type ThreadPoolMap = HashMap>>; + +static THREAD_MEMORY_POOLS: OnceLock> = OnceLock::new(); + +fn get_thread_memory_pools() -> &'static Mutex { + THREAD_MEMORY_POOLS.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn register_memory_pool(thread_id: u64, context_id: i64, pool: Arc) { + get_thread_memory_pools() + .lock() + .entry(thread_id) + .or_default() + .insert(context_id, pool); +} + +/// Unregister a context's pool and return the remaining total reserved for the thread. +fn unregister_and_total(thread_id: u64, context_id: i64) -> usize { + let mut map = get_thread_memory_pools().lock(); + if let Some(pools) = map.get_mut(&thread_id) { + pools.remove(&context_id); + if pools.is_empty() { + map.remove(&thread_id); + return 0; + } + return pools.values().map(|p| p.reserved()).sum::(); + } + 0 +} + +fn total_reserved_for_thread(thread_id: u64) -> usize { + let map = get_thread_memory_pools().lock(); + map.get(&thread_id) + .map(|pools| pools.values().map(|p| p.reserved()).sum::()) + .unwrap_or(0) +} + fn parse_usize_env_var(name: &str) -> Option { std::env::var_os(name).and_then(|n| n.to_str().and_then(|s| s.parse::().ok())) } @@ -138,6 +188,52 @@ pub fn get_runtime() -> &'static Runtime { TOKIO_RUNTIME.get_or_init(|| build_runtime(None)) } +/// Returns a short name for an OpStruct variant. +fn op_name(op: &OpStruct) -> &'static str { + match op { + OpStruct::Scan(_) => "Scan", + OpStruct::Projection(_) => "Projection", + OpStruct::Filter(_) => "Filter", + OpStruct::Sort(_) => "Sort", + OpStruct::HashAgg(_) => "HashAgg", + OpStruct::Limit(_) => "Limit", + OpStruct::ShuffleWriter(_) => "ShuffleWriter", + OpStruct::Expand(_) => "Expand", + OpStruct::SortMergeJoin(_) => "SortMergeJoin", + OpStruct::HashJoin(_) => "HashJoin", + OpStruct::Window(_) => "Window", + OpStruct::NativeScan(_) => "NativeScan", + OpStruct::IcebergScan(_) => "IcebergScan", + OpStruct::ParquetWriter(_) => "ParquetWriter", + OpStruct::Explode(_) => "Explode", + OpStruct::CsvScan(_) => "CsvScan", + OpStruct::ShuffleScan(_) => "ShuffleScan", + } +} + +/// Collect distinct operator names from a plan tree and build a tracing event name. +fn build_tracing_event_name(plan: &Operator) -> String { + let mut names = std::collections::BTreeSet::new(); + collect_op_names(plan, &mut names); + if names.is_empty() { + "executePlan".to_string() + } else { + format!( + "executePlan({})", + names.into_iter().collect::>().join(",") + ) + } +} + +fn collect_op_names<'a>(op: &'a Operator, names: &mut std::collections::BTreeSet<&'a str>) { + if let Some(ref op_struct) = op.op_struct { + names.insert(op_name(op_struct)); + } + for child in &op.children { + collect_op_names(child, names); + } +} + /// Comet native execution context. Kept alive across JNI calls. struct ExecutionContext { /// The id of the execution context. @@ -180,6 +276,12 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Rust thread ID, used for aggregating tracing metrics per thread + pub rust_thread_id: u64, + /// Pre-computed metric name for tracing memory usage + pub tracing_memory_metric_name: String, + /// Pre-computed tracing event name for executePlan calls + pub tracing_event_name: String, } /// Accept serialized query plan and return the address of the native query plan. @@ -308,6 +410,25 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( ); } + let session = Arc::new(session); + + // Register this context's memory pool so we can sum all pools + // on the same thread when emitting tracing metrics. + let rust_thread_id = get_thread_id(); + if tracing_enabled { + register_memory_pool( + rust_thread_id, + id, + Arc::clone(&session.runtime_env().memory_pool), + ); + } + + let tracing_event_name = if tracing_enabled { + build_tracing_event_name(&spark_plan) + } else { + String::new() + }; + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, @@ -324,11 +445,16 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( metrics_last_update_time: Instant::now(), poll_count_since_metrics_check: 0, plan_creation_time, - session_ctx: Arc::new(session), + session_ctx: session, debug_native, explain_native, memory_pool_config, tracing_enabled, + rust_thread_id, + tracing_memory_metric_name: format!( + "thread_{rust_thread_id}_comet_memory_reserved" + ), + tracing_event_name, }); Ok(Box::into_raw(exec_context) as i64) @@ -527,23 +653,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // Retrieve the query let exec_context = get_execution_context(exec_context); - let tracing_event_name = match &exec_context.spark_plan.op_struct { - Some(OpStruct::ShuffleWriter(_)) => "executePlan(ShuffleWriter)", - _ => "executePlan", + let tracing_enabled = exec_context.tracing_enabled; + // Clone the label only when tracing is enabled. The clone is needed + // because the closure below mutably borrows exec_context. + let owned_label; + let tracing_label = if tracing_enabled { + owned_label = exec_context.tracing_event_name.clone(); + owned_label.as_str() + } else { + "" }; - if exec_context.tracing_enabled { - #[cfg(feature = "jemalloc")] - { - let e = epoch::mib().unwrap(); - let allocated = stats::allocated::mib().unwrap(); - e.advance().unwrap(); - use crate::execution::tracing::log_memory_usage; - log_memory_usage("jemalloc_allocated", allocated.read().unwrap() as u64); - } - } - - with_trace(tracing_event_name, exec_context.tracing_enabled, || { + let result = with_trace(tracing_label, tracing_enabled, || { let exec_context_id = exec_context.id; // Initialize the execution stream. @@ -651,16 +772,22 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let next_item = exec_context.stream.as_mut().unwrap().next(); let poll_output = poll!(next_item); - // Only check time every 100 polls to reduce syscall overhead - if let Some(interval) = exec_context.metrics_update_interval { - exec_context.poll_count_since_metrics_check += 1; - if exec_context.poll_count_since_metrics_check >= 100 { + // Only check time/tracing every 100 polls to reduce overhead + exec_context.poll_count_since_metrics_check += 1; + if exec_context.poll_count_since_metrics_check >= 100 { + exec_context.poll_count_since_metrics_check = 0; + if let Some(interval) = exec_context.metrics_update_interval { let now = Instant::now(); if now - exec_context.metrics_last_update_time >= interval { update_metrics(env, exec_context)?; exec_context.metrics_last_update_time = now; } - exec_context.poll_count_since_metrics_check = 0; + } + if exec_context.tracing_enabled { + log_memory_usage( + &exec_context.tracing_memory_metric_name, + total_reserved_for_thread(exec_context.rust_thread_id) as u64, + ); } } @@ -687,7 +814,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( } } }) - }) + }); + + if exec_context.tracing_enabled { + #[cfg(feature = "jemalloc")] + log_jemalloc_usage(); + log_memory_usage( + &exec_context.tracing_memory_metric_name, + total_reserved_for_thread(exec_context.rust_thread_id) as u64, + ); + } + + result }) } @@ -709,6 +847,16 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( execution_context.task_attempt_id, ); + // Unregister this context's pool and emit the remaining total for the thread + if execution_context.tracing_enabled { + let remaining = + unregister_and_total(execution_context.rust_thread_id, execution_context.id); + log_memory_usage( + &execution_context.tracing_memory_metric_name, + remaining as u64, + ); + } + let _: Box = Box::from_raw(execution_context); Ok(()) }) @@ -950,6 +1098,16 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_logMemoryUsage( }) } +#[no_mangle] +/// Returns the Rust thread ID for the current thread. +/// This allows Java code to use Rust thread IDs in tracing metric names. +pub extern "system" fn Java_org_apache_comet_Native_getRustThreadId( + _e: EnvUnowned, + _class: JClass, +) -> jlong { + get_thread_id() as jlong +} + // ============================================================================ // Native Columnar to Row Conversion // ============================================================================ diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 4ac4fc287b..8502c79624 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -38,7 +38,6 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; -use datafusion_comet_common::tracing::with_trace_async; use futures::{StreamExt, TryFutureExt, TryStreamExt}; use std::{ any::Any, @@ -207,66 +206,61 @@ async fn external_shuffle( tracing_enabled: bool, write_buffer_size: usize, ) -> Result { - with_trace_async("external_shuffle", tracing_enabled, || async { - let schema = input.schema(); - - let mut repartitioner: Box = match &partitioning { - _ if schema.fields().is_empty() => { - log::debug!("found empty schema, overriding {partitioning:?} partitioning with EmptySchemaShufflePartitioner"); - Box::new(EmptySchemaShufflePartitioner::try_new( - output_data_file, - output_index_file, - Arc::clone(&schema), - partitioning.partition_count(), - metrics, - codec, - )?) - } - any if any.partition_count() == 1 => { - Box::new(SinglePartitionShufflePartitioner::try_new( - output_data_file, - output_index_file, - Arc::clone(&schema), - metrics, - context.session_config().batch_size(), - codec, - write_buffer_size, - )?) - } - _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( - partition, + let schema = input.schema(); + + let mut repartitioner: Box = match &partitioning { + _ if schema.fields().is_empty() => { + log::debug!("found empty schema, overriding {partitioning:?} partitioning with EmptySchemaShufflePartitioner"); + Box::new(EmptySchemaShufflePartitioner::try_new( output_data_file, output_index_file, Arc::clone(&schema), - partitioning, + partitioning.partition_count(), metrics, - context.runtime_env(), - context.session_config().batch_size(), codec, - tracing_enabled, - write_buffer_size, - )?), - }; - - while let Some(batch) = input.next().await { - // Await the repartitioner to insert the batch and shuffle the rows - // into the corresponding partition buffer. - // Otherwise, pull the next batch from the input stream might overwrite the - // current batch in the repartitioner. - repartitioner - .insert_batch(batch?) - .await - .map_err(|err| exec_datafusion_err!("Error inserting batch: {err}"))?; + )?) } - + any if any.partition_count() == 1 => Box::new(SinglePartitionShufflePartitioner::try_new( + output_data_file, + output_index_file, + Arc::clone(&schema), + metrics, + context.session_config().batch_size(), + codec, + write_buffer_size, + )?), + _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( + partition, + output_data_file, + output_index_file, + Arc::clone(&schema), + partitioning, + metrics, + context.runtime_env(), + context.session_config().batch_size(), + codec, + tracing_enabled, + write_buffer_size, + )?), + }; + + while let Some(batch) = input.next().await { + // Await the repartitioner to insert the batch and shuffle the rows + // into the corresponding partition buffer. + // Otherwise, pull the next batch from the input stream might overwrite the + // current batch in the repartitioner. repartitioner - .shuffle_write() - .map_err(|err| exec_datafusion_err!("Error in shuffle write: {err}"))?; + .insert_batch(batch?) + .await + .map_err(|err| exec_datafusion_err!("Error inserting batch: {err}"))?; + } + + repartitioner + .shuffle_write() + .map_err(|err| exec_datafusion_err!("Error in shuffle write: {err}"))?; - // shuffle writer always has empty output - Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) as SendableRecordBatchStream) - }) - .await + // shuffle writer always has empty output + Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) as SendableRecordBatchStream) } #[cfg(test)] diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java index a58ec7851b..70721366c7 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java @@ -233,8 +233,9 @@ public void write(Iterator> records) throws IOException { } Native _native = new Native(); + String shuffleMemKey = "thread_" + _native.getRustThreadId() + "_comet_jvm_shuffle"; if (tracingEnabled) { - _native.logMemoryUsage("comet_shuffle_", allocator.getUsed()); + _native.logMemoryUsage(shuffleMemKey, allocator.getUsed()); } long spillRecords = 0; @@ -247,7 +248,7 @@ public void write(Iterator> records) throws IOException { } if (tracingEnabled) { - _native.logMemoryUsage("comet_shuffle_", allocator.getUsed()); + _native.logMemoryUsage(shuffleMemKey, allocator.getUsed()); } if (outputRows != spillRecords) { diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java index 736c42aafa..8930a52884 100644 --- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java +++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java @@ -212,9 +212,9 @@ public void write(scala.collection.Iterator> records) throws IOEx // generic throwables. boolean success = false; if (tracingEnabled) { - nativeLib.traceBegin("CometUnsafeShuffleWriter"); + nativeLib.traceBegin("comet_unsafe_shuffle_writer"); } - String offheapMemKey = "comet_shuffle_" + Thread.currentThread().getId(); + String offheapMemKey = "thread_" + nativeLib.getRustThreadId() + "_comet_jvm_shuffle"; try { while (records.hasNext()) { insertRecordIntoSorter(records.next()); @@ -226,7 +226,7 @@ public void write(scala.collection.Iterator> records) throws IOEx success = true; } finally { if (tracingEnabled) { - nativeLib.traceEnd("CometUnsafeShuffleWriter"); + nativeLib.traceEnd("comet_unsafe_shuffle_writer"); } if (sorter != null) { try { diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index e198ac99ff..f0c6373149 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -138,14 +138,10 @@ class CometExecIterator( private def getNextBatch: Option[ColumnarBatch] = { assert(partitionIndex >= 0 && partitionIndex < numParts) - if (tracingEnabled) { - traceMemoryUsage() - } - val ctx = TaskContext.get() try { - withTrace( + val result = withTrace( s"getNextBatch[JVM] stage=${ctx.stageId()}", tracingEnabled, { nativeUtil.getNextBatch( @@ -154,6 +150,12 @@ class CometExecIterator( nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs) }) }) + + if (tracingEnabled) { + traceMemoryUsage() + } + + result } catch { // Handle CometQueryExecutionException with JSON payload first case e: CometQueryExecutionException => @@ -252,13 +254,7 @@ class CometExecIterator( } private def traceMemoryUsage(): Unit = { - nativeLib.logMemoryUsage("jvm_heapUsed", memoryMXBean.getHeapMemoryUsage.getUsed) - val totalTaskMemory = cometTaskMemoryManager.internal.getMemoryConsumptionForThisTask - val cometTaskMemory = cometTaskMemoryManager.getUsed - val sparkTaskMemory = totalTaskMemory - cometTaskMemory - val threadId = Thread.currentThread().getId - nativeLib.logMemoryUsage(s"task_memory_comet_$threadId", cometTaskMemory) - nativeLib.logMemoryUsage(s"task_memory_spark_$threadId", sparkTaskMemory) + nativeLib.logMemoryUsage("jvm_heap_used", memoryMXBean.getHeapMemoryUsage.getUsed) } } diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index f6800626d6..c003bcd138 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -203,6 +203,11 @@ class Native extends NativeBase { */ @native def logMemoryUsage(name: String, memoryUsageBytes: Long): Unit + /** + * Returns the Rust thread ID for the current thread. + */ + @native def getRustThreadId(): Long + // Native Columnar to Row conversion methods /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index d704d3fd88..f27d021ac4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -307,6 +307,8 @@ class CometNativeShuffleWriter[K, V]( s"Partitioning $outputPartitioning is not supported.") } + shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() shuffleWriterOpBuilder .setShuffleWriter(shuffleWriterBuilder)