Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c2deee7
feat: allow custom caching via logical node
jizezhang Nov 14, 2025
e83038a
formatting
jizezhang Nov 14, 2025
d9cd725
fix import
jizezhang Nov 14, 2025
b17e5bb
remove unused test imports
jizezhang Nov 14, 2025
8ae150f
minor: Use allow->expect to explicitly suppress Clippy lint checks (#…
2010YOUY01 Nov 14, 2025
13a8a3b
chore(deps): bump taiki-e/install-action from 2.62.50 to 2.62.51 (#18…
dependabot[bot] Nov 14, 2025
bc88780
chore(deps): bump crate-ci/typos from 1.39.1 to 1.39.2 (#18694)
dependabot[bot] Nov 14, 2025
ad7bea8
Remove FilterExec from CoalesceBatches optimization rule, add fetch s…
Dandandan Nov 14, 2025
0af2832
minor: refactor with `assert_or_internal_err!()` in `datafusion/datas…
kumarUjjawal Nov 14, 2025
244456f
chore: Enforce lint rule `clippy::needless_pass_by_value` to datafusi…
AryanBagade Nov 14, 2025
13adf2a
feat: Handle edge case with `corr` with single row and `NaN` (#18677)
comphead Nov 14, 2025
ac86c20
[main] Update changelog for 51.0.0 RC2 (#18710)
alamb Nov 14, 2025
8ce7aca
Refactor Spark crc32/sha1 signatures (#18662)
Jefffrey Nov 15, 2025
aaabd1f
feat: support spark csc (#18642)
psvri Nov 15, 2025
d61021c
CI: try free up space in `Rust / cargo test (amd64)` action (#18709)
Jefffrey Nov 15, 2025
a1c04c9
chore: enforce clippy lint needless_pass_by_value to datafusion-proto…
foskey51 Nov 15, 2025
1921919
chore: enforce clippy lint needless_pass_by_value to datafusion-spark…
foskey51 Nov 15, 2025
55b8055
pass session state to cache producer
jizezhang Nov 15, 2025
376453b
Merge branch 'main' into cache-udn
jizezhang Nov 15, 2025
f95bc67
fix imports in doc
jizezhang Nov 15, 2025
da949d5
Merge branch 'main' into cache-udn
jizezhang Nov 18, 2025
cb0b43c
Merge branch 'main' into cache-udn
jizezhang Nov 18, 2025
213c6f6
refactor
jizezhang Nov 18, 2025
1c67bec
update doc and comment
jizezhang Nov 18, 2025
5d1c999
fix test name
jizezhang Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,8 @@ impl DataFrame {
})
}

/// Cache DataFrame as a memory table.
/// Cache DataFrame as a memory table by default, or use
/// a [`crate::execution::session_state::CacheFactory`] if configured via [`SessionState`].
///
/// ```
/// # use datafusion::prelude::*;
Expand All @@ -2344,14 +2345,23 @@ impl DataFrame {
/// # }
/// ```
pub async fn cache(self) -> Result<DataFrame> {
let context = SessionContext::new_with_state((*self.session_state).clone());
// The schema is consistent with the output
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
if let Some(cache_factory) = self.session_state.cache_factory() {
let new_plan = cache_factory(self.plan, self.session_state.as_ref())?;
Ok(Self {
session_state: self.session_state,
plan: new_plan,
projection_requires_validation: self.projection_requires_validation,
})
} else {
let context = SessionContext::new_with_state((*self.session_state).clone());
// The schema is consistent with the output
let plan = self.clone().create_physical_plan().await?;
let schema = plan.schema();
let task_ctx = Arc::new(self.task_ctx());
let partitions = collect_partitioned(plan, task_ctx).await?;
let mem_table = MemTable::try_new(schema, partitions)?;
context.read_table(Arc::new(mem_table))
}
}

/// Apply an alias to the DataFrame.
Expand Down
37 changes: 36 additions & 1 deletion datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ use sqlparser::{
use url::Url;
use uuid::Uuid;

/// A [`CacheFactory`] can be registered via [`SessionState`]
/// to create a custom logical plan for caching.
/// Additionally, a custom [`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`]
/// may need to be implemented to handle such plans.
pub type CacheFactory =
fn(LogicalPlan, &SessionState) -> datafusion_common::Result<LogicalPlan>;

/// `SessionState` contains all the necessary state to plan and execute queries,
/// such as configuration, functions, and runtime environment. Please see the
/// documentation on [`SessionContext`] for more information.
Expand Down Expand Up @@ -185,6 +192,7 @@ pub struct SessionState {
/// It will be invoked on `CREATE FUNCTION` statements.
/// thus, changing dialect o PostgreSql is required
function_factory: Option<Arc<dyn FunctionFactory>>,
cache_factory: Option<CacheFactory>,
/// Cache logical plans of prepared statements for later execution.
/// Key is the prepared statement name.
prepared_plans: HashMap<String, Arc<PreparedPlan>>,
Expand All @@ -206,6 +214,7 @@ impl Debug for SessionState {
.field("table_options", &self.table_options)
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("cache_factory", &self.cache_factory)
.field("expr_planners", &self.expr_planners);

#[cfg(feature = "sql")]
Expand Down Expand Up @@ -355,6 +364,16 @@ impl SessionState {
self.function_factory.as_ref()
}

/// Register a [`CacheFactory`] for custom caching strategy
pub fn set_cache_factory(&mut self, cache_factory: CacheFactory) {
self.cache_factory = Some(cache_factory);
}

/// Get the cache factory
pub fn cache_factory(&self) -> Option<&CacheFactory> {
self.cache_factory.as_ref()
}

/// Get the table factories
pub fn table_factories(&self) -> &HashMap<String, Arc<dyn TableProviderFactory>> {
&self.table_factories
Expand Down Expand Up @@ -941,6 +960,7 @@ pub struct SessionStateBuilder {
table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
runtime_env: Option<Arc<RuntimeEnv>>,
function_factory: Option<Arc<dyn FunctionFactory>>,
cache_factory: Option<CacheFactory>,
// fields to support convenience functions
analyzer_rules: Option<Vec<Arc<dyn AnalyzerRule + Send + Sync>>>,
optimizer_rules: Option<Vec<Arc<dyn OptimizerRule + Send + Sync>>>,
Expand Down Expand Up @@ -978,6 +998,7 @@ impl SessionStateBuilder {
table_factories: None,
runtime_env: None,
function_factory: None,
cache_factory: None,
// fields to support convenience functions
analyzer_rules: None,
optimizer_rules: None,
Expand Down Expand Up @@ -1030,7 +1051,7 @@ impl SessionStateBuilder {
table_factories: Some(existing.table_factories),
runtime_env: Some(existing.runtime_env),
function_factory: existing.function_factory,

cache_factory: existing.cache_factory,
// fields to support convenience functions
analyzer_rules: None,
optimizer_rules: None,
Expand Down Expand Up @@ -1319,6 +1340,12 @@ impl SessionStateBuilder {
self
}

/// Set a [`CacheFactory`] for custom caching strategy
pub fn with_cache_factory(mut self, cache_factory: Option<CacheFactory>) -> Self {
self.cache_factory = cache_factory;
self
}

/// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`]
/// for more details.
///
Expand Down Expand Up @@ -1382,6 +1409,7 @@ impl SessionStateBuilder {
table_factories,
runtime_env,
function_factory,
cache_factory,
analyzer_rules,
optimizer_rules,
physical_optimizer_rules,
Expand Down Expand Up @@ -1418,6 +1446,7 @@ impl SessionStateBuilder {
table_factories: table_factories.unwrap_or_default(),
runtime_env,
function_factory,
cache_factory,
prepared_plans: HashMap::new(),
};

Expand Down Expand Up @@ -1621,6 +1650,11 @@ impl SessionStateBuilder {
&mut self.function_factory
}

/// Returns the cache factory
pub fn cache_factory(&mut self) -> &mut Option<CacheFactory> {
&mut self.cache_factory
}

/// Returns the current analyzer_rules value
pub fn analyzer_rules(
&mut self,
Expand Down Expand Up @@ -1659,6 +1693,7 @@ impl Debug for SessionStateBuilder {
.field("table_options", &self.table_options)
.field("table_factories", &self.table_factories)
.field("function_factory", &self.function_factory)
.field("cache_factory", &self.cache_factory)
.field("expr_planners", &self.expr_planners);
#[cfg(feature = "sql")]
let ret = ret.field("type_planner", &self.type_planner);
Expand Down
68 changes: 65 additions & 3 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub mod csv;
use futures::Stream;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Formatter;
use std::fs::File;
use std::io::Write;
use std::path::Path;
Expand All @@ -40,12 +41,15 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use crate::physical_plan::ExecutionPlan;
use crate::prelude::{CsvReadOptions, SessionContext};

use crate::execution::SendableRecordBatchStream;
use crate::execution::{SendableRecordBatchStream, SessionState, SessionStateBuilder};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_catalog::Session;
use datafusion_common::TableReference;
use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
use datafusion_common::{DFSchemaRef, TableReference};
use datafusion_expr::{
CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType,
UserDefinedLogicalNodeCore,
};
use std::pin::Pin;

use async_trait::async_trait;
Expand Down Expand Up @@ -282,3 +286,61 @@ impl RecordBatchStream for BoundedStream {
self.record_batch.schema()
}
}

#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)]
struct CacheNode {
input: LogicalPlan,
}

impl UserDefinedLogicalNodeCore for CacheNode {
fn name(&self) -> &str {
"CacheNode"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

fn schema(&self) -> &DFSchemaRef {
self.input.schema()
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "CacheNode")
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
assert_eq!(inputs.len(), 1, "input size inconsistent");
Ok(Self {
input: inputs[0].clone(),
})
}
}

fn cache_factory(
plan: LogicalPlan,
_session_state: &SessionState,
) -> Result<LogicalPlan> {
Ok(LogicalPlan::Extension(datafusion_expr::Extension {
node: Arc::new(CacheNode { input: plan }),
}))
}

/// Create a test table registered to a session context with an associated cache factory
pub async fn test_table_with_cache_factory() -> Result<DataFrame> {
let session_state = SessionStateBuilder::new()
.with_cache_factory(Some(cache_factory))
.build();
let ctx = SessionContext::new_with_state(session_state);
let name = "aggregate_test_100";
register_aggregate_csv(&ctx, name).await?;
ctx.table(name).await
}
25 changes: 24 additions & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ use datafusion::prelude::{
};
use datafusion::test_util::{
parquet_test_data, populate_csv_partitions, register_aggregate_csv, test_table,
test_table_with_name,
test_table_with_cache_factory, test_table_with_name,
};
use datafusion_catalog::TableProvider;
use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
Expand Down Expand Up @@ -2335,6 +2335,29 @@ async fn cache_test() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn cache_producer_test() -> Result<()> {
let df = test_table_with_cache_factory()
.await?
.select_columns(&["c2", "c3"])?
.limit(0, Some(1))?
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;

let cached_df = df.clone().cache().await?;

assert_snapshot!(
cached_df.clone().into_optimized_plan().unwrap(),
@r###"
CacheNode
Projection: aggregate_test_100.c2, aggregate_test_100.c3, CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS Int64) AS Int64) AS sum
Projection: aggregate_test_100.c2, aggregate_test_100.c3
Limit: skip=0, fetch=1
TableScan: aggregate_test_100, fetch=1
"###
);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
);
);
let df_results = df.collect().await?;
let cached_df_results = cached_df.collect().await?;
assert_eq!(&df_results, &cached_df_results);

to test the physical plan too
Does it need a custom ExtensionPlanner too for that ?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need to cover physical plan for this case, as it would be up to user to provide it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be used as an example how to do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we can do it as a follow up in examples

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, once this PR is approved and merged, I can work on writing up an example.

Ok(())
}

#[tokio::test]
async fn partition_aware_union() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
Expand Down