diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 08ca1a176e57..6656f333df55 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2328,7 +2328,8 @@ impl DataFrame { }) } - /// Cache DataFrame as a memory table. + /// Cache DataFrame as a memory table by default, or use + /// a [`crate::execution::session_state::CacheFactory`] if configured via [`SessionState`]. /// /// ``` /// # use datafusion::prelude::*; @@ -2344,14 +2345,23 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::new_with_state((*self.session_state).clone()); - // The schema is consistent with the output - let plan = self.clone().create_physical_plan().await?; - let schema = plan.schema(); - let task_ctx = Arc::new(self.task_ctx()); - let partitions = collect_partitioned(plan, task_ctx).await?; - let mem_table = MemTable::try_new(schema, partitions)?; - context.read_table(Arc::new(mem_table)) + if let Some(cache_factory) = self.session_state.cache_factory() { + let new_plan = cache_factory(self.plan, self.session_state.as_ref())?; + Ok(Self { + session_state: self.session_state, + plan: new_plan, + projection_requires_validation: self.projection_requires_validation, + }) + } else { + let context = SessionContext::new_with_state((*self.session_state).clone()); + // The schema is consistent with the output + let plan = self.clone().create_physical_plan().await?; + let schema = plan.schema(); + let task_ctx = Arc::new(self.task_ctx()); + let partitions = collect_partitioned(plan, task_ctx).await?; + let mem_table = MemTable::try_new(schema, partitions)?; + context.read_table(Arc::new(mem_table)) + } } /// Apply an alias to the DataFrame. diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index d7a66db28ac4..6e300f80582d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -90,6 +90,13 @@ use sqlparser::{ use url::Url; use uuid::Uuid; +/// A [`CacheFactory`] can be registered via [`SessionState`] +/// to create a custom logical plan for caching. +/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`] +/// may need to be implemented to handle such plans. +pub type CacheFactory = + fn(LogicalPlan, &SessionState) -> datafusion_common::Result; + /// `SessionState` contains all the necessary state to plan and execute queries, /// such as configuration, functions, and runtime environment. Please see the /// documentation on [`SessionContext`] for more information. @@ -185,6 +192,7 @@ pub struct SessionState { /// It will be invoked on `CREATE FUNCTION` statements. /// thus, changing dialect o PostgreSql is required function_factory: Option>, + cache_factory: Option, /// Cache logical plans of prepared statements for later execution. /// Key is the prepared statement name. prepared_plans: HashMap>, @@ -206,6 +214,7 @@ impl Debug for SessionState { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] @@ -355,6 +364,16 @@ impl SessionState { self.function_factory.as_ref() } + /// Register a [`CacheFactory`] for custom caching strategy + pub fn set_cache_factory(&mut self, cache_factory: CacheFactory) { + self.cache_factory = Some(cache_factory); + } + + /// Get the cache factory + pub fn cache_factory(&self) -> Option<&CacheFactory> { + self.cache_factory.as_ref() + } + /// Get the table factories pub fn table_factories(&self) -> &HashMap> { &self.table_factories @@ -941,6 +960,7 @@ pub struct SessionStateBuilder { table_factories: Option>>, runtime_env: Option>, function_factory: Option>, + cache_factory: Option, // fields to support convenience functions analyzer_rules: Option>>, optimizer_rules: Option>>, @@ -978,6 +998,7 @@ impl SessionStateBuilder { table_factories: None, runtime_env: None, function_factory: None, + cache_factory: None, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1030,7 +1051,7 @@ impl SessionStateBuilder { table_factories: Some(existing.table_factories), runtime_env: Some(existing.runtime_env), function_factory: existing.function_factory, - + cache_factory: existing.cache_factory, // fields to support convenience functions analyzer_rules: None, optimizer_rules: None, @@ -1319,6 +1340,12 @@ impl SessionStateBuilder { self } + /// Set a [`CacheFactory`] for custom caching strategy + pub fn with_cache_factory(mut self, cache_factory: Option) -> Self { + self.cache_factory = cache_factory; + self + } + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] /// for more details. /// @@ -1382,6 +1409,7 @@ impl SessionStateBuilder { table_factories, runtime_env, function_factory, + cache_factory, analyzer_rules, optimizer_rules, physical_optimizer_rules, @@ -1418,6 +1446,7 @@ impl SessionStateBuilder { table_factories: table_factories.unwrap_or_default(), runtime_env, function_factory, + cache_factory, prepared_plans: HashMap::new(), }; @@ -1621,6 +1650,11 @@ impl SessionStateBuilder { &mut self.function_factory } + /// Returns the cache factory + pub fn cache_factory(&mut self) -> &mut Option { + &mut self.cache_factory + } + /// Returns the current analyzer_rules value pub fn analyzer_rules( &mut self, @@ -1659,6 +1693,7 @@ impl Debug for SessionStateBuilder { .field("table_options", &self.table_options) .field("table_factories", &self.table_factories) .field("function_factory", &self.function_factory) + .field("cache_factory", &self.cache_factory) .field("expr_planners", &self.expr_planners); #[cfg(feature = "sql")] let ret = ret.field("type_planner", &self.type_planner); diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 7149c5b0bd8c..eed5897e635a 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -25,6 +25,7 @@ pub mod csv; use futures::Stream; use std::any::Any; use std::collections::HashMap; +use std::fmt::Formatter; use std::fs::File; use std::io::Write; use std::path::Path; @@ -40,12 +41,15 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; -use crate::execution::SendableRecordBatchStream; +use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; -use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_common::{DFSchemaRef, TableReference}; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType, + UserDefinedLogicalNodeCore, +}; use std::pin::Pin; use async_trait::async_trait; @@ -282,3 +286,61 @@ impl RecordBatchStream for BoundedStream { self.record_batch.schema() } } + +#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)] +struct CacheNode { + input: LogicalPlan, +} + +impl UserDefinedLogicalNodeCore for CacheNode { + fn name(&self) -> &str { + "CacheNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CacheNode") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + Ok(Self { + input: inputs[0].clone(), + }) + } +} + +fn cache_factory( + plan: LogicalPlan, + _session_state: &SessionState, +) -> Result { + Ok(LogicalPlan::Extension(datafusion_expr::Extension { + node: Arc::new(CacheNode { input: plan }), + })) +} + +/// Create a test table registered to a session context with an associated cache factory +pub async fn test_table_with_cache_factory() -> Result { + let session_state = SessionStateBuilder::new() + .with_cache_factory(Some(cache_factory)) + .build(); + let ctx = SessionContext::new_with_state(session_state); + let name = "aggregate_test_100"; + register_aggregate_csv(&ctx, name).await?; + ctx.table(name).await +} diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index e2f3ece1e4ca..ec163c32f267 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -61,7 +61,7 @@ use datafusion::prelude::{ }; use datafusion::test_util::{ parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table, - test_table_with_name, + test_table_with_cache_factory, test_table_with_name, }; use datafusion_catalog::TableProvider; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; @@ -2335,6 +2335,29 @@ async fn cache_test() -> Result<()> { Ok(()) } +#[tokio::test] +async fn cache_producer_test() -> Result<()> { + let df = test_table_with_cache_factory() + .await? + .select_columns(&["c2", "c3"])? + .limit(0, Some(1))? + .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + + let cached_df = df.clone().cache().await?; + + assert_snapshot!( + cached_df.clone().into_optimized_plan().unwrap(), + @r###" + CacheNode + Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum + Projection: aggregate_test_100.c2, aggregate_test_100.c3 + Limit: skip=0, fetch=1 + TableScan: aggregate_test_100, fetch=1 + "### + ); + Ok(()) +} + #[tokio::test] async fn partition_aware_union() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?;