From 6d4b8bbad95c7e4fec0c4f1fb755ad7a1c542983 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 29 Nov 2023 21:49:38 +0000 Subject: [PATCH] Support nested schema projection (#5148) (#5149) * Support nested schema projection * Tweak doc * Review feedback --- arrow-schema/src/fields.rs | 232 ++++++++++++++++++++++++++++++++++++- 1 file changed, 231 insertions(+), 1 deletion(-) diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index f90632455fd..400f42c59c3 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::{ArrowError, Field, FieldRef, SchemaBuilder}; use std::ops::Deref; use std::sync::Arc; +use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder}; + /// A cheaply cloneable, owned slice of [`FieldRef`] /// /// Similar to `Arc>` or `Arc<[FieldRef]>` @@ -99,6 +100,108 @@ impl Fields { .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } + /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate + /// + /// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`] + /// containing no child [`FieldRef`], a leaf field, along with a count of the number + /// of such leaves encountered so far. Only [`FieldRef`] for which `filter` + /// returned `true` will be included in the result. + /// + /// This can therefore be used to select a subset of fields from nested types + /// such as [`DataType::Struct`] or [`DataType::List`]. + /// + /// ``` + /// # use arrow_schema::{DataType, Field, Fields}; + /// let fields = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), // Leaf 0 + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("c", DataType::Float32, false), // Leaf 1 + /// Field::new("d", DataType::Float64, false), // Leaf 2 + /// Field::new("e", DataType::Struct(Fields::from(vec![ + /// Field::new("f", DataType::Int32, false), // Leaf 3 + /// Field::new("g", DataType::Float16, false), // Leaf 4 + /// ])), true), + /// ])), false) + /// ]); + /// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx)); + /// let expected = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("d", DataType::Float64, false), + /// Field::new("e", DataType::Struct(Fields::from(vec![ + /// Field::new("f", DataType::Int32, false), + /// Field::new("g", DataType::Float16, false), + /// ])), true), + /// ])), false) + /// ]); + /// assert_eq!(filtered, expected); + /// ``` + pub fn filter_leaves bool>(&self, mut filter: F) -> Self { + fn filter_field bool>( + f: &FieldRef, + filter: &mut F, + ) -> Option { + use DataType::*; + + let v = match f.data_type() { + Dictionary(_, v) => v.as_ref(), // Key must be integer + RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer + d => d, + }; + let d = match v { + List(child) => List(filter_field(child, filter)?), + LargeList(child) => LargeList(filter_field(child, filter)?), + Map(child, ordered) => Map(filter_field(child, filter)?, *ordered), + FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size), + Struct(fields) => { + let filtered: Fields = fields + .iter() + .filter_map(|f| filter_field(f, filter)) + .collect(); + + if filtered.is_empty() { + return None; + } + + Struct(filtered) + } + Union(fields, mode) => { + let filtered: UnionFields = fields + .iter() + .filter_map(|(id, f)| Some((id, filter_field(f, filter)?))) + .collect(); + + if filtered.is_empty() { + return None; + } + + Union(filtered, *mode) + } + _ => return filter(f).then(|| f.clone()), + }; + let d = match f.data_type() { + Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)), + RunEndEncoded(v, f) => { + RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d))) + } + _ => d, + }; + Some(Arc::new(f.as_ref().clone().with_data_type(d))) + } + + let mut leaf_idx = 0; + let mut filter = |f: &FieldRef| { + let t = filter(leaf_idx, f); + leaf_idx += 1; + t + }; + + self.0 + .iter() + .filter_map(|f| filter_field(f, &mut filter)) + .collect() + } + /// Remove a field by index and return it. /// /// # Panic @@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields { Self(iter.into_iter().collect()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::UnionMode; + + #[test] + fn test_filter() { + let floats = Fields::from(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ]); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("floats", DataType::Struct(floats.clone()), true), + Field::new("b", DataType::Int16, true), + Field::new( + "c", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new( + "d", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Struct(floats.clone())), + ), + false, + ), + Field::new_list( + "e", + Field::new("floats", DataType::Struct(floats.clone()), true), + true, + ), + Field::new( + "f", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3), + false, + ), + Field::new_map( + "g", + "entries", + Field::new("keys", DataType::LargeUtf8, false), + Field::new("values", DataType::Int32, true), + false, + false, + ), + Field::new( + "h", + DataType::Union( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("field1", DataType::UInt8, false), + Field::new("field3", DataType::Utf8, false), + ], + ), + UnionMode::Dense, + ), + true, + ), + Field::new( + "i", + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)), + ), + false, + ), + ]); + + let floats_a = DataType::Struct(vec![floats[0].clone()].into()); + + let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1); + assert_eq!(r.len(), 2); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + + let r = fields.filter_leaves(|_, f| f.name() == "a"); + assert_eq!(r.len(), 5); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + assert_eq!( + r[2].data_type(), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone())) + ); + assert_eq!( + r[3].as_ref(), + &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true) + ); + assert_eq!( + r[4].as_ref(), + &Field::new( + "i", + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", floats_a.clone(), true)), + ), + false, + ) + ); + + let r = fields.filter_leaves(|_, f| f.name() == "floats"); + assert_eq!(r.len(), 0); + + let r = fields.filter_leaves(|idx, _| idx == 9); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[6]); + + let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[7]); + + let union = DataType::Union( + UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]), + UnionMode::Dense, + ); + + let r = fields.filter_leaves(|idx, _| idx == 12); + assert_eq!(r.len(), 1); + assert_eq!(r[0].data_type(), &union); + + let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[9]); + } +}