diff --git a/Cargo.lock b/Cargo.lock index 26d24faa6c1b..6e52b6b4167c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -712,6 +712,15 @@ dependencies = [ "anyhow", ] +[[package]] +name = "datadriven" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d7337e53bfa2d8e547c582c1141882bc3b500904bab157fca7b4fa6b6a5ff5a" +dependencies = [ + "anyhow", +] + [[package]] name = "dataflow" version = "0.1.0" @@ -1852,7 +1861,7 @@ dependencies = [ "comm", "compile-time-run", "coord", - "datadriven", + "datadriven 0.3.0", "dataflow", "dataflow-types", "fallible-iterator", @@ -2563,7 +2572,7 @@ version = "0.1.0" dependencies = [ "anyhow", "bytes", - "datadriven", + "datadriven 0.3.0", "fallible-iterator", "getopts", "postgres", @@ -3678,7 +3687,7 @@ name = "sql-parser" version = "0.1.0" dependencies = [ "anyhow", - "datadriven", + "datadriven 0.3.0", "lazy_static", "log", "matches", @@ -4404,7 +4413,7 @@ name = "transform" version = "0.1.0" dependencies = [ "anyhow", - "datadriven", + "datadriven 0.4.0", "dataflow-types", "expr", "itertools", @@ -4540,7 +4549,7 @@ name = "walkabout" version = "0.1.0" dependencies = [ "anyhow", - "datadriven", + "datadriven 0.3.0", "fstrings", "ore", "quote", diff --git a/src/expr/src/linear.rs b/src/expr/src/linear.rs index 044686dfa784..d45a8f480c5d 100644 --- a/src/expr/src/linear.rs +++ b/src/expr/src/linear.rs @@ -146,6 +146,11 @@ impl MapFilterProject { self } + pub fn apply(self, other: &MapFilterProject) -> MapFilterProject { + let (m, f, p) = other.as_map_filter_project(); + self.map(m).filter(f).project(p) + } + /// As the arguments to `Map`, `Filter`, and `Project` operators. /// /// In principle, this operator can be implemented as a sequence of @@ -161,6 +166,10 @@ impl MapFilterProject { (map, filter, project) } + pub fn arity(&self) -> usize { + self.projection.len() + } + /// Optimize the internal expression evaluation order. pub fn optimize(&mut self) { // This should probably resemble existing scalar cse. diff --git a/src/expr/src/relation/mod.rs b/src/expr/src/relation/mod.rs index ebb4ff301097..e0843e0256ca 100644 --- a/src/expr/src/relation/mod.rs +++ b/src/expr/src/relation/mod.rs @@ -507,17 +507,25 @@ impl RelationExpr { /// Retains only the columns specified by `output`. pub fn project(self, outputs: Vec) -> Self { - RelationExpr::Project { - input: Box::new(self), - outputs, + if outputs.len() == self.arity() && outputs.iter().enumerate().all(|(i, j)| i == *j) { + self + } else { + RelationExpr::Project { + input: Box::new(self), + outputs, + } } } /// Append to each row the results of applying elements of `scalar`. pub fn map(self, scalars: Vec) -> Self { - RelationExpr::Map { - input: Box::new(self), - scalars, + if scalars.is_empty() { + self + } else { + RelationExpr::Map { + input: Box::new(self), + scalars, + } } } @@ -536,9 +544,14 @@ impl RelationExpr { where I: IntoIterator, { - RelationExpr::Filter { - input: Box::new(self), - predicates: predicates.into_iter().collect(), + let preds: Vec = predicates.into_iter().collect(); + if preds.is_empty() { + self + } else { + RelationExpr::Filter { + input: Box::new(self), + predicates: preds, + } } } diff --git a/src/transform/Cargo.toml b/src/transform/Cargo.toml index cf878ad1ba35..323f8d9a0098 100644 --- a/src/transform/Cargo.toml +++ b/src/transform/Cargo.toml @@ -12,5 +12,5 @@ itertools = "0.9" repr = { path = "../repr" } [dev-dependencies] -datadriven = "0.3.0" +datadriven = "0.4.0" anyhow = "1.0.33" diff --git a/src/transform/src/lib.rs b/src/transform/src/lib.rs index c29995777b1d..6eba2f73c031 100644 --- a/src/transform/src/lib.rs +++ b/src/transform/src/lib.rs @@ -36,6 +36,7 @@ pub mod inline_let; pub mod join_elision; pub mod join_implementation; pub mod map_lifting; +pub mod mfp_pushdown; pub mod nonnull_requirements; pub mod nonnullable; pub mod predicate_pushdown; diff --git a/src/transform/src/mfp_pushdown.rs b/src/transform/src/mfp_pushdown.rs new file mode 100644 index 000000000000..62c34a5c84aa --- /dev/null +++ b/src/transform/src/mfp_pushdown.rs @@ -0,0 +1,272 @@ +// Copyright Materialize, Inc. All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! Push down MFP into various operators. TODO. + +use std::collections::{HashMap, HashSet}; + +use expr::{MapFilterProject, RelationExpr, ScalarExpr}; + +use crate::TransformArgs; + +/// Push down MFP TODO +#[derive(Debug)] +pub struct MFPPushdown; + +impl crate::Transform for MFPPushdown { + fn transform( + &self, + relation: &mut RelationExpr, + _: TransformArgs, + ) -> Result<(), crate::TransformError> { + self.action( + relation, + MapFilterProject::new(relation.typ().column_types.len()), + ); + Ok(()) + } +} + +impl MFPPushdown { + /// ... + pub fn action(&self, relation: &mut RelationExpr, mfp: MapFilterProject) { + match relation { + RelationExpr::Map { input, scalars } => { + let m = MapFilterProject::new(input.arity()).map(scalars.clone()); + self.action(input, m.apply(&mfp)); + *relation = input.take_dangerous(); + } + RelationExpr::Filter { input, predicates } => { + let f = MapFilterProject::new(input.arity()).filter(predicates.clone()); + self.action(input, f.apply(&mfp)); + *relation = input.take_dangerous(); + } + RelationExpr::Project { input, outputs } => { + let p = MapFilterProject::new(input.arity()).project(outputs.clone()); + self.action(input, p.apply(&mfp)); + *relation = input.take_dangerous(); + } + RelationExpr::Union { base, inputs } => { + for inp in inputs { + self.action(inp, mfp.clone()); + } + self.action(base, mfp); + } + RelationExpr::FlatMap { + input, + func, + exprs, + demand: _, + } => { + // TODO(justin): we might want to absorb FlatMap into MFP, since it's also linear. + let input_arity = input.arity(); + let added_cols = func.output_type().arity(); + let flatmap_arity = input_arity + added_cols; + + let (m, f, mut p) = mfp.as_map_filter_project(); + + // Columns have several different names at different points in this process. + // The final structure will look like this: + // + // Input + // 1. v + // Pushed Map + // 2. v + // Pushed Filter + // v + // Pushed Project + // 3. v + // FlatMap + // 4. v + // Residual Map + // 5. v + // Residual Filter + // v + // Residual Project + + // First, figure out which Map expressions can be pushed down. + + // pushed_exprs contains all the map expressions we were able to move beneath the + // FlatMap. + let mut pushed_exprs = vec![]; + // residual_exprs contains all the map expressions we were not able to move beneath + // the FlatMap. + let mut residual_exprs = vec![]; + // pushdown_map maps the names of columns in the original expression to where to + // find them beneath the FlatMap. This is used both for pushed down map expressions + // and pushed down filter expressions. + let mut pushdown_map = HashMap::new(); + // col_map maps the original names of the columns to where to find them at the + // current stage of the stack above. First, immediately after Input, all the input + // columns are the same place they were originally. + let mut col_map = (0..input.arity()) + .map(|c| (c, c)) + .collect::>(); + // col_map is currently at stage 1 in the diagram above. + + for (idx, map_expr) in m.iter().enumerate() { + // We can push down any expressions bound by input, as well as the expressions + // we have already pushed down. + if map_expr.support().iter().all(|c| col_map.contains_key(c)) { + pushdown_map.insert(flatmap_arity + idx, input_arity + pushed_exprs.len()); + // Update col_map to reflect the new position of this column. + col_map.insert(flatmap_arity + idx, input_arity + pushed_exprs.len()); + + // This expression might refer to previous map expressions, so we have to + // make sure we give those references their new names. + let mut remapped_expr = map_expr.clone(); + remapped_expr.visit_mut(&mut |e| { + if let ScalarExpr::Column(c) = e { + *c = *col_map.get(c).unwrap(); + } + }); + + pushed_exprs.push(remapped_expr); + } else { + // Some of these names might be wrong now, if they referenced other map + // expressions, but those will still undergo more transformation (the + // pushed projection and FlatMap) so we will wait to remap them until + // later. + residual_exprs.push(map_expr.clone()); + } + } + + // col_map is now at stage 2 in the diagram. + + // Now we have to figure out which filters we can push down. + let (mut bound, mut unbound): (Vec<_>, Vec<_>) = f + .into_iter() + .partition(|expr| expr.support().iter().all(|c| col_map.contains_key(c))); + + // The filters that get pushed down need to use the pushed-down names for the map + // expressions (but the ones from the input stay the same). + for expr in bound.iter_mut() { + expr.visit_mut(&mut |e| { + if let ScalarExpr::Column(c) = e { + *c = *col_map.get(c).unwrap(); + } + }); + } + + // Finally, we can push down a projection which strips away columns that are not + // required by any of the M/F/P or the FlatMap. This will end up doing another + // renaming of all the columns. + let mut demanded_cols = HashSet::new(); + for c in p.iter() { + demanded_cols.insert(*c); + } + for e in unbound.iter() { + demanded_cols.extend(e.support()); + } + for e in residual_exprs.iter() { + demanded_cols.extend(e.support()); + } + for e in exprs.iter() { + demanded_cols.extend(e.support()); + } + + // We don't care about columns added by residual_exprs, nor do we care about ones + // the FlatMap introduces. + demanded_cols.retain(|c| col_map.contains_key(c)); + + // Change the demanded cols to their new names. + let demanded_cols = demanded_cols + .iter() + .map(|c| col_map.get(c).unwrap()) + .collect::>(); + + // We're going to prune away a bunch of columns. then we have to remap everything + // according to that. + let pushed_projection: Vec<_> = (0..(input_arity + pushed_exprs.len())) + .filter(|i| demanded_cols.contains(&i)) + .collect(); + + // Now we need a new col map, which maps via pushed_projection. + // This col_map represents stage 3. + let mut col_map = col_map + .into_iter() + .filter_map(|(k, v)| Some((k, pushed_projection.iter().position(|c| *c == v)?))) + .collect::>(); + + // Remap all of the FlatMap expressions. + for expr in exprs.iter_mut() { + expr.visit_mut(&mut |e| { + if let ScalarExpr::Column(c) = e { + *c = *col_map.get(c).unwrap(); + } + }); + } + + // Then add each of the FlatMap added columns to col_map. + for i in 0..added_cols { + col_map.insert(input_arity + i, pushed_projection.len() + i); + } + + // col_map is now at stage 4. + + // Now get all the map cols. + let new_flatmap_arity = pushed_projection.len() + added_cols; + + // Now extend col_map to include the residual map expressions. + col_map.extend( + (flatmap_arity..(flatmap_arity + m.len())) + .filter(|c| !col_map.contains_key(&c)) + .enumerate() + .map(|(idx, col)| (col, new_flatmap_arity + idx)) + .collect::>(), + ); + + // col_map is now at stage 5. + + // Remap all the residual expressions. + for expr in residual_exprs.iter_mut() { + expr.visit_mut(&mut |e| { + if let ScalarExpr::Column(c) = e { + *c = *col_map.get(c).unwrap(); + } + }); + } + + // Remap all the residual filters. + for expr in unbound.iter_mut() { + expr.visit_mut(&mut |e| { + if let ScalarExpr::Column(c) = e { + *c = *col_map.get(c).unwrap(); + } + }); + } + + // Finally, remap the projection (essentially putting another projection before + // it). + for c in p.iter_mut() { + *c = *col_map.get(c).unwrap(); + } + + self.action( + input, + MapFilterProject::new(input_arity) + .map(pushed_exprs) + .filter(bound) + .project(pushed_projection), + ); + + *relation = input + .take_dangerous() + .flat_map(func.clone(), exprs.iter().cloned().collect()) + .map(residual_exprs) + .filter(unbound) + .project(p); + } + _ => { + let (m, f, p) = mfp.as_map_filter_project(); + *relation = relation.take_dangerous().map(m).filter(f).project(p); + } + }; + } +} diff --git a/src/transform/tests/test_runner.rs b/src/transform/tests/test_runner.rs index e5ec91c78341..e3313c122e21 100644 --- a/src/transform/tests/test_runner.rs +++ b/src/transform/tests/test_runner.rs @@ -81,6 +81,7 @@ impl SexpParser { || ch == '-' || ch == '_' || ch == '#' + || ch == '+' } fn parse(&mut self) -> Result { @@ -130,7 +131,10 @@ impl SexpParser { mod tests { use super::{Sexp, SexpParser}; use anyhow::{anyhow, bail, Error}; - use expr::{GlobalId, Id, IdHumanizer, JoinImplementation, LocalId, RelationExpr, ScalarExpr}; + use expr::{ + BinaryFunc, GlobalId, Id, IdHumanizer, JoinImplementation, LocalId, RelationExpr, + ScalarExpr, TableFunc, + }; use repr::{ColumnType, Datum, RelationType, Row, ScalarType}; use std::collections::HashMap; use std::fmt::Write; @@ -283,6 +287,13 @@ mod tests { body: Box::new(body), }) } + // (flat-map [arguments]) + "flat-map" => Ok(RelationExpr::FlatMap { + input: Box::new(build_rel(nth(&s, 1)?, catalog, scope)?), + func: get_table_func(&nth(&s, 2)?)?, + exprs: build_scalar_list(nth(&s, 3)?)?, + demand: None, + }), // (map [expressions]) "map" => Ok(RelationExpr::Map { input: Box::new(build_rel(nth(&s, 1)?, catalog, scope)?), @@ -436,7 +447,14 @@ mod tests { } } }, - s => Err(anyhow!("expected {} to be a scalar", s)), + Sexp::List(l) => match try_atom(&l[0])?.as_str() { + "+" => Ok(ScalarExpr::CallBinary { + func: BinaryFunc::AddInt32, + expr1: Box::new(build_scalar(l[1].clone())?), + expr2: Box::new(build_scalar(l[2].clone())?), + }), + _ => Err(anyhow!("couldn't parse scalar: {:?}", l)), + }, } } @@ -456,6 +474,14 @@ mod tests { Ok(RelationType::new(col_types)) } + fn get_table_func(s: &Sexp) -> Result { + // TODO(justin): can this delegate to the planner? + match try_atom(s)?.as_str() { + "generate_series_int32" => Ok(TableFunc::GenerateSeriesInt32), + _ => Err(anyhow!("unknown table func {}", s)), + } + } + fn handle_cat(s: Sexp, cat: &mut TestCatalog) -> Result<(), Error> { match try_atom(&nth(&s, 0)?)?.as_str() { "defsource" => { @@ -567,6 +593,8 @@ mod tests { // transforms? match name { "PredicatePushdown" => Ok(Box::new(transform::predicate_pushdown::PredicatePushdown)), + "MFPPushdown" => Ok(Box::new(transform::mfp_pushdown::MFPPushdown)), + "Demand" => Ok(Box::new(transform::demand::Demand)), _ => Err(anyhow!( "no transform named {} (you might have to add it to get_transform)", name diff --git a/src/transform/tests/testdata/demand b/src/transform/tests/testdata/demand new file mode 100644 index 000000000000..0f6bf743fc48 --- /dev/null +++ b/src/transform/tests/testdata/demand @@ -0,0 +1,87 @@ +# Copyright Materialize, Inc. All rights reserved. +# +# Use of this software is governed by the Business Source License +# included in the LICENSE file. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0. + +cat +(defsource x [bool bool]) +---- +ok + +build apply=Demand +(project + (join + [(get x) (get x)] + [[#0 #2]]) + [#1]) +---- +---- +%0 = +| Get x (u0) + +%1 = +| Get x (u0) +| Project (#0) +| Map dummy +| Project (#0, #1) + +%2 = +| Join %0 %1 (= #0 #2) +| | implementation = Unimplemented +| | demand = (#1) +| Project (#1) +---- +---- + +build apply=Demand +(project + (join + [(get x) (get x)] + [[#0 #3]]) + [#1]) +---- +---- +%0 = +| Get x (u0) + +%1 = +| Get x (u0) +| Project (#1) +| Map dummy +| Project (#1, #0) + +%2 = +| Join %0 %1 (= #0 #3) +| | implementation = Unimplemented +| | demand = (#1) +| Project (#1) +---- +---- + +opt +(project + (join + [(get x) (get x)] + [[#0 #3]]) + [#1]) +---- +---- +%0 = +| Get x (u0) +| ArrangeBy (#0) + +%1 = +| Get x (u0) +| Map dummy + +%2 = +| Join %0 %1 (= #0 #3) +| | implementation = Differential %1 %0.(#0) +| Map dummy +| Project (#1) +---- +---- diff --git a/src/transform/tests/testdata/pushdown b/src/transform/tests/testdata/pushdown new file mode 100644 index 000000000000..e1fa8fc030d0 --- /dev/null +++ b/src/transform/tests/testdata/pushdown @@ -0,0 +1,277 @@ +# Copyright Materialize, Inc. All rights reserved. +# +# Use of this software is governed by the Business Source License +# included in the LICENSE file. +# +# As of the Change Date specified in that file, in accordance with +# the Business Source License, use of this software will be governed +# by the Apache License, Version 2.0. + +cat +(defsource x [bool bool]) +---- +ok + +# It should consolidate MFPs above an operator. + +build apply=MFPPushdown +(project + (project + (get x) + [#0]) + [#0]) +---- +%0 = +| Get x (u0) +| Project (#0) + +# Union + +build apply=MFPPushdown +(project + (union + [(get x) + (get x)]) + [#0]) +---- +---- +%0 = +| Get x (u0) +| Project (#0) + +%1 = +| Get x (u0) +| Project (#0) + +%2 = +| Union %0 %1 +---- +---- + +# FlatMap + +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [true false #0 #1 #2])) +---- +%0 = +| Get x (u0) +| Map true, false, #0, #1 +| FlatMap generate_series(#0, #1) +| Map #6 +| Project (#0, #1, #6, #2..#5, #7) + +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#1 (+ #3 #0)])) +---- +%0 = +| Get x (u0) +| Map #1, (#2 + #0) +| FlatMap generate_series(#0, #1) +| Project (#0, #1, #4, #2, #3) + +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#1 (+ #3 #2)])) +---- +%0 = +| Get x (u0) +| Map #1 +| FlatMap generate_series(#0, #1) +| Map (#2 + #3) +| Project (#0, #1, #3, #2, #4) + +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#1 (+ #3 #2) #1 #5])) +---- +%0 = +| Get x (u0) +| Map #1, #1, #3 +| FlatMap generate_series(#0, #1) +| Map (#2 + #5) +| Project (#0, #1, #5, #2, #6, #3, #4) + +build apply=MFPPushdown +(filter + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [true false #0 #1 #2])) +---- +%0 = +| Get x (u0) +| Filter true, false, #0, #1 +| FlatMap generate_series(#0, #1) +| Filter #2 + +build apply=MFPPushdown +(project + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#1 #0])) +---- +%0 = +| Get x (u0) +| FlatMap generate_series(#0, #1) +| Project (#1, #0) + +build apply=MFPPushdown +(project + (flat-map + (get x) + generate_series_int32 + [#1 #1]) + [#1 #1])) +---- +%0 = +| Get x (u0) +| Project (#1) +| FlatMap generate_series(#0, #0) +| Project (#0, #0) + +build apply=MFPPushdown +(map + (filter + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#1 #1 #0 #2]) + [true false #1 #0]) +---- +%0 = +| Get x (u0) +| Map true, false, #1, #0 +| Filter #0, #1, #1 +| FlatMap generate_series(#0, #1) +| Filter #6 +| Project (#0, #1, #6, #2..#5) + +build apply=MFPPushdown +(filter + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#0 #1 #2]) +---- +%0 = +| Get x (u0) +| Filter #0, #1 +| FlatMap generate_series(#0, #1) +| Filter #2 + +build apply=MFPPushdown +(project + (filter + (flat-map + (get x) + generate_series_int32 + [#0 #0]) + [#0 #1 #2]) + [#0]) +---- +%0 = +| Get x (u0) +| Filter #0, #1 +| Project (#0) +| FlatMap generate_series(#0, #0) +| Filter #1 +| Project (#0) + +build apply=MFPPushdown +(map + (project + (flat-map + (get x) + generate_series_int32 + [#0 #0]) + [#2 #0 #1]) + [#0 #1 #2]) +---- +%0 = +| Get x (u0) +| Map #0, #1 +| FlatMap generate_series(#0, #0) +| Map #4 +| Project (#4, #0, #1, #5, #2, #3) + +# Map where we can only push down one thing. + +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#0 #1 #2 #5]) +---- +%0 = +| Get x (u0) +| Map #0, #1 +| FlatMap generate_series(#0, #1) +| Map #4, #5 +| Project (#0, #1, #4, #2, #3, #5, #6) + +# Map where we push down one thing but not a thing that references it. +build apply=MFPPushdown +(map + (flat-map + (get x) + generate_series_int32 + [#0 #1]) + [#0 #0 (+ #2 #3)]) +---- +%0 = +| Get x (u0) +| Map #0, #0 +| FlatMap generate_series(#0, #1) +| Map (#4 + #2) +| Project (#0, #1, #4, #2, #3, #5) + +build apply=MFPPushdown +(project + (flat-map + (get x) + generate_series_int32 + [#0 #0]) + [#2 #1 #0])) +---- +%0 = +| Get x (u0) +| FlatMap generate_series(#0, #0) +| Project (#2, #1, #0) + +# Constant + +# build apply=MFPPushdown +# (project +# (constant +# [[1 2 3] +# [4 5 6]] +# [int64 int64 int64]) +# [#0]) +# ---- +# %0 = +# | Constant (1, 2, 3) (4, 5, 6)