From 95ea4db2b551ac8d9c3d1a49bd39564b75d3f347 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 13 May 2022 04:34:00 -0700 Subject: [PATCH 1/7] unbox scalars --- datafusion/common/src/scalar.rs | 94 +++++------ datafusion/core/src/scalar.rs | 156 +++++++----------- datafusion/core/src/sql/planner.rs | 10 +- .../physical-expr/src/aggregate/array_agg.rs | 77 +++------ .../src/aggregate/array_agg_distinct.rs | 79 +++------ .../src/aggregate/count_distinct.rs | 4 +- .../physical-expr/src/aggregate/tdigest.rs | 4 +- datafusion/proto/src/from_proto.rs | 16 +- datafusion/proto/src/lib.rs | 76 ++++----- datafusion/proto/src/to_proto.rs | 12 +- 10 files changed, 195 insertions(+), 333 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 03a59ff6d3db..b5cfa5aee735 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -72,8 +72,7 @@ pub enum ScalarValue { /// large binary LargeBinary(Option>), /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - List(Option>>, Box), + List(Option>, DataType), /// Date stored as a signed 32bit int Date32(Option), /// Date stored as a signed 64bit int @@ -93,8 +92,7 @@ pub enum ScalarValue { /// Interval with MonthDayNano unit IntervalMonthDayNano(Option), /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) - #[allow(clippy::box_collection)] - Struct(Option>>, Box>), + Struct(Option>, Vec), } // manual implementation of `PartialEq` that uses OrderedFloat to @@ -392,7 +390,7 @@ macro_rules! build_list { ) } Some(values) => { - build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE) + build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE) } } }}; @@ -412,37 +410,34 @@ macro_rules! build_timestamp_list { $SIZE, ) } - Some(values) => { - let values = values.as_ref(); - match $TIME_UNIT { - TimeUnit::Second => { - build_values_list_tz!( - TimestampSecondBuilder, - TimestampSecond, - values, - $SIZE - ) - } - TimeUnit::Microsecond => build_values_list_tz!( - TimestampMillisecondBuilder, - TimestampMillisecond, - values, - $SIZE - ), - TimeUnit::Millisecond => build_values_list_tz!( - TimestampMicrosecondBuilder, - TimestampMicrosecond, + Some(values) => match $TIME_UNIT { + TimeUnit::Second => { + build_values_list_tz!( + TimestampSecondBuilder, + TimestampSecond, values, $SIZE - ), - TimeUnit::Nanosecond => build_values_list_tz!( - TimestampNanosecondBuilder, - TimestampNanosecond, - values, - $SIZE - ), + ) } - } + TimeUnit::Microsecond => build_values_list_tz!( + TimestampMillisecondBuilder, + TimestampMillisecond, + values, + $SIZE + ), + TimeUnit::Millisecond => build_values_list_tz!( + TimestampMicrosecondBuilder, + TimestampMicrosecond, + values, + $SIZE + ), + TimeUnit::Nanosecond => build_values_list_tz!( + TimestampNanosecondBuilder, + TimestampNanosecond, + values, + $SIZE + ), + }, } }}; } @@ -579,11 +574,9 @@ impl ScalarValue { ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, ScalarValue::Binary(_) => DataType::Binary, ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( - "item", - data_type.as_ref().clone(), - true, - ))), + ScalarValue::List(_, data_type) => { + DataType::List(Box::new(Field::new("item", data_type.clone(), true))) + } ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::IntervalYearMonth(_) => { @@ -593,7 +586,7 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { DataType::Interval(IntervalUnit::MonthDayNano) } - ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), + ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), } } @@ -794,7 +787,6 @@ impl ScalarValue { for scalar in scalars.into_iter() { match scalar { ScalarValue::List(Some(xs), _) => { - let xs = *xs; for s in xs { match s { ScalarValue::$SCALAR_TY(Some(val)) => { @@ -1000,7 +992,7 @@ impl ScalarValue { if let ScalarValue::List(values, _) = scalar { match values { Some(values) => { - let element_array = ScalarValue::iter_to_array(*values)?; + let element_array = ScalarValue::iter_to_array(values)?; // Add new offset index flat_len += element_array.len() as i32; @@ -1160,7 +1152,7 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() { + ScalarValue::List(values, data_type) => Arc::new(match data_type { DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), DataType::Int8 => build_list!(Int8Builder, Int8, values, size), DataType::Int16 => build_list!(Int16Builder, Int16, values, size), @@ -1183,7 +1175,7 @@ impl ScalarValue { repeat(self.clone()).take(size), &DataType::List(Box::new(Field::new( "item", - data_type.as_ref().clone(), + data_type.clone(), true, ))), ) @@ -1303,8 +1295,7 @@ impl ScalarValue { Some(scalar_vec) } }; - let value = value.map(Box::new); - let data_type = Box::new(nested_type.data_type().clone()); + let data_type = nested_type.data_type().clone(); ScalarValue::List(value, data_type) } DataType::Date32 => { @@ -1389,7 +1380,7 @@ impl ScalarValue { let col_scalar = ScalarValue::try_from_array(col_array, index)?; field_values.push(col_scalar); } - Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone())) + Self::Struct(Some(field_values), fields.clone()) } DataType::FixedSizeList(nested_type, _len) => { let list_array = @@ -1404,8 +1395,7 @@ impl ScalarValue { Some(scalar_vec) } }; - let value = value.map(Box::new); - let data_type = Box::new(nested_type.data_type().clone()); + let data_type = nested_type.data_type().clone(); ScalarValue::List(value, data_type) } other => { @@ -1610,7 +1600,7 @@ impl From> for ScalarValue { }) .unzip(); - Self::Struct(Some(Box::new(scalars)), Box::new(fields)) + Self::Struct(Some(scalars), fields) } } @@ -1738,11 +1728,9 @@ impl TryFrom<&DataType> for ScalarValue { value_type.as_ref().try_into()? } DataType::List(ref nested_type) => { - ScalarValue::List(None, Box::new(nested_type.data_type().clone())) - } - DataType::Struct(fields) => { - ScalarValue::Struct(None, Box::new(fields.clone())) + ScalarValue::List(None, nested_type.data_type().clone()) } + DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", diff --git a/datafusion/core/src/scalar.rs b/datafusion/core/src/scalar.rs index 774b8ebf86dc..9798dcb0a1d9 100644 --- a/datafusion/core/src/scalar.rs +++ b/datafusion/core/src/scalar.rs @@ -158,8 +158,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = - ScalarValue::List(None, Box::new(DataType::UInt64)).to_array(); + let list_array_ref = ScalarValue::List(None, DataType::UInt64).to_array(); let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); assert!(list_array.is_null(0)); @@ -170,12 +169,12 @@ mod tests { #[test] fn scalar_list_to_array() { let list_array_ref = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::UInt64(Some(100)), ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), - ])), - Box::new(DataType::UInt64), + ]), + DataType::UInt64, ) .to_array(); @@ -605,51 +604,39 @@ mod tests { assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); assert_eq!( - List( - Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32), - ) - .partial_cmp(&List( - Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32), - )), + List(Some(vec![Int32(Some(1)), Int32(Some(5))]), DataType::Int32,) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + DataType::Int32, + )), Some(Ordering::Equal) ); assert_eq!( - List( - Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), - Box::new(DataType::Int32), - ) - .partial_cmp(&List( - Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32), - )), + List(Some(vec![Int32(Some(10)), Int32(Some(5))]), DataType::Int32,) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + DataType::Int32, + )), Some(Ordering::Greater) ); assert_eq!( - List( - Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32), - ) - .partial_cmp(&List( - Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), - Box::new(DataType::Int32), - )), + List(Some(vec![Int32(Some(1)), Int32(Some(5))]), DataType::Int32,) + .partial_cmp(&List( + Some(vec![Int32(Some(10)), Int32(Some(5))]), + DataType::Int32, + )), Some(Ordering::Less) ); // For different data type, `partial_cmp` returns None. assert_eq!( - List( - Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])), - Box::new(DataType::Int64), - ) - .partial_cmp(&List( - Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32), - )), + List(Some(vec![Int64(Some(1)), Int64(Some(5))]), DataType::Int64,) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + DataType::Int32, + )), None ); @@ -694,7 +681,7 @@ mod tests { ); let scalar = ScalarValue::Struct( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Int32(Some(23)), ScalarValue::Boolean(Some(false)), ScalarValue::Utf8(Some("Hello".to_string())), @@ -702,13 +689,13 @@ mod tests { ("e", ScalarValue::from(2i16)), ("f", ScalarValue::from(3i64)), ]), - ])), - Box::new(vec![ + ]), + vec![ field_a.clone(), field_b.clone(), field_c.clone(), field_d.clone(), - ]), + ], ); // Check Display @@ -866,26 +853,20 @@ mod tests { // Define primitive list scalars let l0 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ); let l1 = ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(4i32), - ScalarValue::from(5i32), - ])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + DataType::Int32, ); - let l2 = ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(6i32)])), - Box::new(DataType::Int32), - ); + let l2 = ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32); // Define struct scalars let s0 = ScalarValue::from(vec![ @@ -927,16 +908,12 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = ScalarValue::List( - Some(Box::new(vec![s0.clone(), s1.clone()])), - Box::new(s0.get_datatype()), - ); + let nl0 = + ScalarValue::List(Some(vec![s0.clone(), s1.clone()]), s0.get_datatype()); - let nl1 = - ScalarValue::List(Some(Box::new(vec![s2])), Box::new(s0.get_datatype())); + let nl1 = ScalarValue::List(Some(vec![s2]), s0.get_datatype()); - let nl2 = - ScalarValue::List(Some(Box::new(vec![s1])), Box::new(s0.get_datatype())); + let nl2 = ScalarValue::List(Some(vec![s1]), s0.get_datatype()); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -1080,61 +1057,40 @@ mod tests { fn test_nested_lists() { // Define inner list scalars let l1 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ), ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(4i32), - ScalarValue::from(5i32), - ])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + DataType::Int32, ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let l2 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ + ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(6i32)])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), + DataType::Int32, ), - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(7i32), - ScalarValue::from(8i32), - ])), - Box::new(DataType::Int32), - ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let l3 = ScalarValue::List( - Some(Box::new(vec![ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(9i32)])), - Box::new(DataType::Int32), - )])), - Box::new(DataType::List(Box::new(Field::new( - "item", + Some(vec![ScalarValue::List( + Some(vec![ScalarValue::from(9i32)]), DataType::Int32, - true, - )))), + )]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index ca8cb23bc5e9..16e7ebbef450 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -2329,10 +2329,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { values.iter().map(|e| e.get_datatype()).collect(); if data_types.is_empty() { - Ok(Expr::Literal(ScalarValue::List( - None, - Box::new(DataType::Utf8), - ))) + Ok(Expr::Literal(ScalarValue::List(None, DataType::Utf8))) } else if data_types.len() > 1 { Err(DataFusionError::NotImplemented(format!( "Arrays with different types are not supported: {:?}", @@ -2341,10 +2338,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let data_type = values[0].get_datatype(); - Ok(Expr::Literal(ScalarValue::List( - Some(Box::new(values)), - Box::new(data_type), - ))) + Ok(Expr::Literal(ScalarValue::List(Some(values), data_type))) } } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 4f2bc3ece1ea..57bbfc46dd1b 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -133,7 +133,7 @@ impl Accumulator for ArrayAggAccumulator { (0..arr.len()).try_for_each(|index| { let scalar = ScalarValue::try_from_array(arr, index)?; if let ScalarValue::List(Some(values), _) = scalar { - self.values.extend(*values); + self.values.extend(values); Ok(()) } else { Err(DataFusionError::Internal( @@ -149,8 +149,8 @@ impl Accumulator for ArrayAggAccumulator { fn evaluate(&self) -> Result { Ok(ScalarValue::List( - Some(Box::new(self.values.clone())), - Box::new(self.datatype.clone()), + Some(self.values.clone()), + self.datatype.clone(), )) } } @@ -172,14 +172,14 @@ mod tests { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); let list = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2)), ScalarValue::Int32(Some(3)), ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ); generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) @@ -188,70 +188,45 @@ mod tests { #[test] fn array_agg_nested() -> Result<()> { let l1 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ), ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(4i32), - ScalarValue::from(5i32), - ])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + DataType::Int32, ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let l2 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ + ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(6i32)])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), + DataType::Int32, ), - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(7i32), - ScalarValue::from(8i32), - ])), - Box::new(DataType::Int32), - ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let l3 = ScalarValue::List( - Some(Box::new(vec![ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(9i32)])), - Box::new(DataType::Int32), - )])), - Box::new(DataType::List(Box::new(Field::new( - "item", + Some(vec![ScalarValue::List( + Some(vec![ScalarValue::from(9i32)]), DataType::Int32, - true, - )))), + )]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let list = ScalarValue::List( - Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + Some(vec![l1.clone(), l2.clone(), l3.clone()]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index dfe19e0eb4c2..8807d5ce35de 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -121,8 +121,8 @@ impl DistinctArrayAggAccumulator { impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { Ok(vec![ScalarValue::List( - Some(Box::new(self.values.clone().into_iter().collect())), - Box::new(self.datatype.clone()), + Some(self.values.clone().into_iter().collect()), + self.datatype.clone(), )]) } @@ -152,8 +152,8 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&self) -> Result { Ok(ScalarValue::List( - Some(Box::new(self.values.clone().into_iter().collect())), - Box::new(self.datatype.clone()), + Some(self.values.clone().into_iter().collect()), + self.datatype.clone(), )) } } @@ -207,14 +207,14 @@ mod tests { let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); let out = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(2)), ScalarValue::Int32(Some(7)), ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ); check_distinct_array_agg(col, out, DataType::Int32) @@ -224,72 +224,47 @@ mod tests { fn distinct_array_agg_nested() -> Result<()> { // [[1, 2, 3], [4, 5]] let l1 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::from(1i32), ScalarValue::from(2i32), ScalarValue::from(3i32), - ])), - Box::new(DataType::Int32), + ]), + DataType::Int32, ), ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(4i32), - ScalarValue::from(5i32), - ])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), + DataType::Int32, ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // [[6], [7, 8]] let l2 = ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ + ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(6i32)])), - Box::new(DataType::Int32), + Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), + DataType::Int32, ), - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(7i32), - ScalarValue::from(8i32), - ])), - Box::new(DataType::Int32), - ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + ]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // [[9]] let l3 = ScalarValue::List( - Some(Box::new(vec![ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(9i32)])), - Box::new(DataType::Int32), - )])), - Box::new(DataType::List(Box::new(Field::new( - "item", + Some(vec![ScalarValue::List( + Some(vec![ScalarValue::from(9i32)]), DataType::Int32, - true, - )))), + )]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); let list = ScalarValue::List( - Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), + Some(vec![l1.clone(), l2.clone(), l3.clone()]), + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), ); // Duplicate l1 in the input array and check that it is deduped in the output. diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index cb32dcd4969b..5603fe9cf193 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -190,7 +190,7 @@ impl Accumulator for DistinctCountAccumulator { .map(|state_data_type| { let values = Box::new(Vec::new()); let data_type = Box::new(state_data_type.clone()); - ScalarValue::List(Some(values), data_type) + ScalarValue::List(Some(*values), *data_type) }) .collect::>(); @@ -238,7 +238,7 @@ mod tests { macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ match $LIST { - ScalarValue::List(_, data_type) => match data_type.as_ref() { + ScalarValue::List(_, data_type) => match data_type { &DataType::$DATA_TYPE => (), _ => panic!("Unexpected DataType for list"), }, diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 401fd5a239f6..4f306e3ffa03 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -626,7 +626,7 @@ impl TDigest { ScalarValue::Float64(Some(self.count.into_inner())), ScalarValue::Float64(Some(self.max.into_inner())), ScalarValue::Float64(Some(self.min.into_inner())), - ScalarValue::List(Some(Box::new(centroids)), Box::new(DataType::Float64)), + ScalarValue::List(Some(centroids), DataType::Float64), ] } @@ -647,7 +647,7 @@ impl TDigest { }; let centroids: Vec<_> = match &state[5] { - ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c + ScalarValue::List(Some(c), d) if *d == DataType::Float64 => c .chunks(2) .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) .collect(), diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 37466dae207d..13ddb8e1b064 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -599,7 +599,7 @@ impl TryFrom<&protobuf::scalar_value::Value> for ScalarValue { Value::Float64Value(v) => ScalarValue::Float64(Some(*v)), Value::Date32Value(v) => ScalarValue::Date32(Some(*v)), Value::ListValue(v) => v.try_into()?, - Value::NullListValue(v) => ScalarValue::List(None, Box::new(v.try_into()?)), + Value::NullListValue(v) => ScalarValue::List(None, v.try_into()?), Value::NullValue(null_enum) => { let primitive = PrimitiveScalarType::try_from(null_enum)?; (&primitive).try_into()? @@ -672,10 +672,7 @@ impl TryFrom<&protobuf::ScalarListValue> for ScalarValue { typechecked_scalar_value_conversion(value, leaf_scalar_type) }) .collect::, _>>()?; - ScalarValue::List( - Some(Box::new(typechecked_values)), - Box::new(leaf_scalar_type.into()), - ) + ScalarValue::List(Some(typechecked_values), leaf_scalar_type.into()) } Datatype::List(list_type) => { let protobuf::ScalarListType { @@ -708,9 +705,9 @@ impl TryFrom<&protobuf::ScalarListValue> for ScalarValue { ScalarValue::List( match typechecked_values.len() { 0 => None, - _ => Some(Box::new(typechecked_values)), + _ => Some(typechecked_values), }, - Box::new((list_type).try_into()?), + (list_type).try_into()?, ) } }; @@ -852,18 +849,17 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } = &scalar_list; let scalar_type = opt_scalar_type.as_ref().required("datatype")?; - let scalar_type = Box::new(scalar_type); let typechecked_values: Vec = values .iter() .map(|val| val.try_into()) .collect::, _>>()?; - Self::List(Some(Box::new(typechecked_values)), scalar_type) + Self::List(Some(typechecked_values), scalar_type) } Value::NullListValue(v) => { let datatype = v.datatype.as_ref().required("datatype")?; - Self::List(None, Box::new(datatype)) + Self::List(None, datatype) } Value::NullValue(v) => { let null_type_enum = protobuf::PrimitiveScalarType::try_from(v)?; diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index dcc18373c848..d329085d8037 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -67,57 +67,49 @@ mod roundtrip_tests { let should_fail_on_seralize: Vec = vec![ // Should fail due to inconsistent types ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Int16(None), ScalarValue::Float32(Some(32.0)), - ])), - Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), + ]), + DataType::List(new_box_field("item", DataType::Int16, true)), ), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Float32(None), ScalarValue::Float32(Some(32.0)), - ])), - Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), + ]), + DataType::List(new_box_field("item", DataType::Int16, true)), ), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::List( None, - Box::new(DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - ))), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ])), - Box::new(DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - ))), + ]), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), ScalarValue::List( None, - Box::new(DataType::List(new_box_field( + DataType::List(new_box_field( "lists are typed inconsistently", DataType::Int16, true, - ))), + )), ), - ])), - Box::new(DataType::List(new_box_field( + ]), + DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - ))), + )), ), ]; @@ -150,7 +142,7 @@ mod roundtrip_tests { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(None, Box::new(DataType::Boolean)), + ScalarValue::List(None, DataType::Boolean), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -207,49 +199,37 @@ mod roundtrip_tests { ScalarValue::TimestampSecond(Some(0), Some("UTC".to_string())), ScalarValue::TimestampSecond(None, None), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ])), - Box::new(DataType::List(new_box_field( - "level1", - DataType::Float32, - true, - ))), + ]), + DataType::List(new_box_field("level1", DataType::Float32, true)), ), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::List( None, - Box::new(DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - ))), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), ScalarValue::List( - Some(Box::new(vec![ + Some(vec![ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), ScalarValue::Float32(Some(5.5)), ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), - ])), - Box::new(DataType::List(new_box_field( - "level2", - DataType::Float32, - true, - ))), + ]), + DataType::List(new_box_field("level2", DataType::Float32, true)), ), - ])), - Box::new(DataType::List(new_box_field( + ]), + DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - ))), + )), ), ]; diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 03a9f6b10432..68b83e5cafa6 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -799,13 +799,13 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.as_ref().try_into()?), + datatype: Some(datatype.try_into()?), values: Vec::new(), }, )), } } else { - let scalar_type = match datatype.as_ref() { + let scalar_type = match datatype { DataType::List(field) => field.as_ref().data_type(), _ => todo!("Proper error handling"), }; @@ -817,9 +817,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { scalar::ScalarValue::List(_, list_type), DataType::List(field), ) => { - if let DataType::List(list_field) = - list_type.as_ref() - { + if let DataType::List(list_field) = list_type { let scalar_datatype = field.data_type(); let list_datatype = list_field.data_type(); if std::mem::discriminant(list_datatype) @@ -893,7 +891,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.as_ref().try_into()?), + datatype: Some(datatype.try_into()?), values: type_checked_values, }, )), @@ -902,7 +900,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } None => protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::NullListValue( - datatype.as_ref().try_into()?, + datatype.try_into()?, )), }, } From 5661512b5de0951fb058405aa72f6394731d3d5a Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 13 May 2022 05:20:13 -0700 Subject: [PATCH 2/7] fix conflicts --- datafusion/physical-expr/src/aggregate/sum_distinct.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 238722726547..862f634463c2 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -132,12 +132,14 @@ impl Accumulator for DistinctSumAccumulator { // 1. Stores aggregate state in `ScalarValue::List` // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { - let mut distinct_values = Box::new(Vec::new()); - let data_type = Box::new(self.data_type.clone()); + let mut distinct_values = Vec::new(); self.hash_values .iter() .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![ScalarValue::List(Some(distinct_values), data_type)] + vec![ScalarValue::List( + Some(distinct_values), + self.data_type.clone(), + )] }; Ok(state_out) } From 77b35578db7702c32ed6813e1ebad44b5fc7d117 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 13 May 2022 07:03:03 -0700 Subject: [PATCH 3/7] fix clippy --- datafusion/common/src/scalar.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 29533929ed48..76395b0c45a9 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -926,16 +926,20 @@ impl ScalarValue { match values { Some(values) => { // Push value for each field - for c in 0..columns.len() { - let column = columns.get_mut(c).unwrap(); - column.push(values[c].clone()); + for (i, v) in + values.iter().enumerate().take(columns.len()) + { + let column = columns.get_mut(i).unwrap(); + column.push(v.clone()); } } None => { // Push NULL of the appropriate type for each field - for c in 0..columns.len() { - let dtype = fields[c].data_type(); - let column = columns.get_mut(c).unwrap(); + for (i, f) in + fields.iter().enumerate().take(columns.len()) + { + let dtype = f.data_type(); + let column = columns.get_mut(i).unwrap(); column.push(ScalarValue::try_from(dtype)?); } } From 1d883ac659ff02c2302d73be64bb0b7c07aa5f5a Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 15 May 2022 01:21:20 -0700 Subject: [PATCH 4/7] Update datafusion/common/src/scalar.rs Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 76395b0c45a9..7d6e318af61b 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -73,7 +73,7 @@ pub enum ScalarValue { Binary(Option>), /// large binary LargeBinary(Option>), - /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + /// list of nested ScalarValue List(Option>, DataType), /// Date stored as a signed 32bit int Date32(Option), From d64f2247e9927acd1e52fa9b3820f4e4563dfa33 Mon Sep 17 00:00:00 2001 From: comphead Date: Sun, 15 May 2022 01:21:25 -0700 Subject: [PATCH 5/7] Update datafusion/common/src/scalar.rs Co-authored-by: Andrew Lamb --- datafusion/common/src/scalar.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 7d6e318af61b..b05909c47341 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -93,7 +93,7 @@ pub enum ScalarValue { IntervalDayTime(Option), /// Interval with MonthDayNano unit IntervalMonthDayNano(Option), - /// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue)) + /// struct of nested ScalarValue Struct(Option>, Vec), } From f957fc3d4ddd3609d9e5c7e6e35abe8cdf628188 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 16 May 2022 12:04:11 -0700 Subject: [PATCH 6/7] boxing datatype --- datafusion/common/src/scalar.rs | 32 ++--- datafusion/core/src/scalar.rs | 111 ++++++++++++------ datafusion/core/src/sql/planner.rs | 10 +- .../physical-expr/src/aggregate/array_agg.rs | 41 +++++-- .../src/aggregate/array_agg_distinct.rs | 43 +++++-- .../src/aggregate/count_distinct.rs | 4 +- .../src/aggregate/sum_distinct.rs | 6 +- .../physical-expr/src/aggregate/tdigest.rs | 4 +- datafusion/proto/src/from_proto.rs | 12 +- datafusion/proto/src/lib.rs | 48 +++++--- datafusion/proto/src/to_proto.rs | 12 +- 11 files changed, 214 insertions(+), 109 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index b05909c47341..6577a76ba9fd 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -74,7 +74,7 @@ pub enum ScalarValue { /// large binary LargeBinary(Option>), /// list of nested ScalarValue - List(Option>, DataType), + List(Option>, Box), /// Date stored as a signed 32bit int Date32(Option), /// Date stored as a signed 64bit int @@ -94,7 +94,7 @@ pub enum ScalarValue { /// Interval with MonthDayNano unit IntervalMonthDayNano(Option), /// struct of nested ScalarValue - Struct(Option>, Vec), + Struct(Option>, Box>), } // manual implementation of `PartialEq` that uses OrderedFloat to @@ -582,9 +582,11 @@ impl ScalarValue { ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, ScalarValue::Binary(_) => DataType::Binary, ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(_, data_type) => { - DataType::List(Box::new(Field::new("item", data_type.clone(), true))) - } + ScalarValue::List(_, data_type) => DataType::List(Box::new(Field::new( + "item", + data_type.as_ref().clone(), + true, + ))), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::IntervalYearMonth(_) => { @@ -594,7 +596,7 @@ impl ScalarValue { ScalarValue::IntervalMonthDayNano(_) => { DataType::Interval(IntervalUnit::MonthDayNano) } - ScalarValue::Struct(_, fields) => DataType::Struct(fields.clone()), + ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()), ScalarValue::Null => DataType::Null, } } @@ -1178,7 +1180,7 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(values, data_type) => Arc::new(match data_type { + ScalarValue::List(values, data_type) => Arc::new(match data_type.as_ref() { DataType::Boolean => build_list!(BooleanBuilder, Boolean, values, size), DataType::Int8 => build_list!(Int8Builder, Int8, values, size), DataType::Int16 => build_list!(Int16Builder, Int16, values, size), @@ -1201,7 +1203,7 @@ impl ScalarValue { repeat(self.clone()).take(size), &DataType::List(Box::new(Field::new( "item", - data_type.clone(), + data_type.as_ref().clone(), true, ))), ) @@ -1324,7 +1326,7 @@ impl ScalarValue { } }; let data_type = nested_type.data_type().clone(); - ScalarValue::List(value, data_type) + ScalarValue::List(value, Box::new(data_type)) } DataType::Date32 => { typed_cast!(array, index, Date32Array, Date32) @@ -1408,7 +1410,7 @@ impl ScalarValue { let col_scalar = ScalarValue::try_from_array(col_array, index)?; field_values.push(col_scalar); } - Self::Struct(Some(field_values), fields.clone()) + Self::Struct(Some(field_values), Box::new(fields.clone())) } DataType::FixedSizeList(nested_type, _len) => { let list_array = @@ -1424,7 +1426,7 @@ impl ScalarValue { } }; let data_type = nested_type.data_type().clone(); - ScalarValue::List(value, data_type) + ScalarValue::List(value, Box::new(data_type)) } other => { return Err(DataFusionError::NotImplemented(format!( @@ -1629,7 +1631,7 @@ impl From> for ScalarValue { }) .unzip(); - Self::Struct(Some(scalars), fields) + Self::Struct(Some(scalars), Box::new(fields)) } } @@ -1757,9 +1759,11 @@ impl TryFrom<&DataType> for ScalarValue { value_type.as_ref().try_into()? } DataType::List(ref nested_type) => { - ScalarValue::List(None, nested_type.data_type().clone()) + ScalarValue::List(None, Box::new(nested_type.data_type().clone())) + } + DataType::Struct(fields) => { + ScalarValue::Struct(None, Box::new(fields.clone())) } - DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { return Err(DataFusionError::NotImplemented(format!( diff --git a/datafusion/core/src/scalar.rs b/datafusion/core/src/scalar.rs index 9798dcb0a1d9..fb4d7030c1b9 100644 --- a/datafusion/core/src/scalar.rs +++ b/datafusion/core/src/scalar.rs @@ -158,7 +158,8 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array_ref = ScalarValue::List(None, DataType::UInt64).to_array(); + let list_array_ref = + ScalarValue::List(None, Box::new(DataType::UInt64)).to_array(); let list_array = list_array_ref.as_any().downcast_ref::().unwrap(); assert!(list_array.is_null(0)); @@ -174,7 +175,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]), - DataType::UInt64, + Box::new(DataType::UInt64), ) .to_array(); @@ -604,39 +605,51 @@ mod tests { assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None); assert_eq!( - List(Some(vec![Int32(Some(1)), Int32(Some(5))]), DataType::Int32,) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - DataType::Int32, - )), + List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Box::new(DataType::Int32), + )), Some(Ordering::Equal) ); assert_eq!( - List(Some(vec![Int32(Some(10)), Int32(Some(5))]), DataType::Int32,) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - DataType::Int32, - )), + List( + Some(vec![Int32(Some(10)), Int32(Some(5))]), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Box::new(DataType::Int32), + )), Some(Ordering::Greater) ); assert_eq!( - List(Some(vec![Int32(Some(1)), Int32(Some(5))]), DataType::Int32,) - .partial_cmp(&List( - Some(vec![Int32(Some(10)), Int32(Some(5))]), - DataType::Int32, - )), + List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(10)), Int32(Some(5))]), + Box::new(DataType::Int32), + )), Some(Ordering::Less) ); // For different data type, `partial_cmp` returns None. assert_eq!( - List(Some(vec![Int64(Some(1)), Int64(Some(5))]), DataType::Int64,) - .partial_cmp(&List( - Some(vec![Int32(Some(1)), Int32(Some(5))]), - DataType::Int32, - )), + List( + Some(vec![Int64(Some(1)), Int64(Some(5))]), + Box::new(DataType::Int32), + ) + .partial_cmp(&List( + Some(vec![Int32(Some(1)), Int32(Some(5))]), + Box::new(DataType::Int32), + )), None ); @@ -690,12 +703,12 @@ mod tests { ("f", ScalarValue::from(3i64)), ]), ]), - vec![ + Box::new(vec![ field_a.clone(), field_b.clone(), field_c.clone(), field_d.clone(), - ], + ]), ); // Check Display @@ -858,15 +871,18 @@ mod tests { ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - DataType::Int32, + Box::new(DataType::Int32), ); let l1 = ScalarValue::List( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, + Box::new(DataType::Int32), ); - let l2 = ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32); + let l2 = ScalarValue::List( + Some(vec![ScalarValue::from(6i32)]), + Box::new(DataType::Int32), + ); // Define struct scalars let s0 = ScalarValue::from(vec![ @@ -908,12 +924,14 @@ mod tests { assert_eq!(array, &expected); // Define list-of-structs scalars - let nl0 = - ScalarValue::List(Some(vec![s0.clone(), s1.clone()]), s0.get_datatype()); + let nl0 = ScalarValue::List( + Some(vec![s0.clone(), s1.clone()]), + Box::new(s0.get_datatype()), + ); - let nl1 = ScalarValue::List(Some(vec![s2]), s0.get_datatype()); + let nl1 = ScalarValue::List(Some(vec![s2]), Box::new(s0.get_datatype())); - let nl2 = ScalarValue::List(Some(vec![s1]), s0.get_datatype()); + let nl2 = ScalarValue::List(Some(vec![s1]), Box::new(s0.get_datatype())); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -1064,33 +1082,48 @@ mod tests { ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - DataType::Int32, + Box::new(DataType::Int32), ), ScalarValue::List( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let l2 = ScalarValue::List( Some(vec![ - ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), + ScalarValue::List( + Some(vec![ScalarValue::from(6i32)]), + Box::new(DataType::Int32), + ), ScalarValue::List( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let l3 = ScalarValue::List( Some(vec![ScalarValue::List( Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, + Box::new(DataType::Int32), )]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 22604a793900..cb02c3ed07ad 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -2431,7 +2431,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { values.iter().map(|e| e.get_datatype()).collect(); if data_types.is_empty() { - Ok(Expr::Literal(ScalarValue::List(None, DataType::Utf8))) + Ok(Expr::Literal(ScalarValue::List( + None, + Box::new(DataType::Utf8), + ))) } else if data_types.len() > 1 { Err(DataFusionError::NotImplemented(format!( "Arrays with different types are not supported: {:?}", @@ -2440,7 +2443,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let data_type = values[0].get_datatype(); - Ok(Expr::Literal(ScalarValue::List(Some(values), data_type))) + Ok(Expr::Literal(ScalarValue::List( + Some(values), + Box::new(data_type), + ))) } } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index eb8931d52af0..2a40c8bad819 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -150,7 +150,7 @@ impl Accumulator for ArrayAggAccumulator { fn evaluate(&self) -> Result { Ok(ScalarValue::List( Some(self.values.clone()), - self.datatype.clone(), + Box::new(self.datatype.clone()), )) } } @@ -179,7 +179,7 @@ mod tests { ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), ]), - DataType::Int32, + Box::new(DataType::Int32), ); generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) @@ -195,38 +195,57 @@ mod tests { ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - DataType::Int32, + Box::new(DataType::Int32), ), ScalarValue::List( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let l2 = ScalarValue::List( Some(vec![ - ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), + ScalarValue::List( + Some(vec![ScalarValue::from(6i32)]), + Box::new(DataType::Int32), + ), ScalarValue::List( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let l3 = ScalarValue::List( Some(vec![ScalarValue::List( Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, + Box::new(DataType::Int32), )]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let list = ScalarValue::List( Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index c78f6fff4347..9448683c0d39 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -122,7 +122,7 @@ impl Accumulator for DistinctArrayAggAccumulator { fn state(&self) -> Result> { Ok(vec![ScalarValue::List( Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), + Box::new(self.datatype.clone()), )]) } @@ -153,7 +153,7 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&self) -> Result { Ok(ScalarValue::List( Some(self.values.clone().into_iter().collect()), - self.datatype.clone(), + Box::new(self.datatype.clone()), )) } } @@ -214,7 +214,7 @@ mod tests { ScalarValue::Int32(Some(4)), ScalarValue::Int32(Some(5)), ]), - DataType::Int32, + Box::new(DataType::Int32), ); check_distinct_array_agg(col, out, DataType::Int32) @@ -231,40 +231,59 @@ mod tests { ScalarValue::from(2i32), ScalarValue::from(3i32), ]), - DataType::Int32, + Box::new(DataType::Int32), ), ScalarValue::List( Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); // [[6], [7, 8]] let l2 = ScalarValue::List( Some(vec![ - ScalarValue::List(Some(vec![ScalarValue::from(6i32)]), DataType::Int32), + ScalarValue::List( + Some(vec![ScalarValue::from(6i32)]), + Box::new(DataType::Int32), + ), ScalarValue::List( Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]), - DataType::Int32, + Box::new(DataType::Int32), ), ]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); // [[9]] let l3 = ScalarValue::List( Some(vec![ScalarValue::List( Some(vec![ScalarValue::from(9i32)]), - DataType::Int32, + Box::new(DataType::Int32), )]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); let list = ScalarValue::List( Some(vec![l1.clone(), l2.clone(), l3.clone()]), - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), ); // Duplicate l1 in the input array and check that it is deduped in the output. diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 61ad533d0de7..f1e3afe6b041 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -193,7 +193,7 @@ impl Accumulator for DistinctCountAccumulator { .map(|state_data_type| { let values = Box::new(Vec::new()); let data_type = Box::new(state_data_type.clone()); - ScalarValue::List(Some(*values), *data_type) + ScalarValue::List(Some(*values), data_type) }) .collect::>(); @@ -241,7 +241,7 @@ mod tests { macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ match $LIST { - ScalarValue::List(_, data_type) => match data_type { + ScalarValue::List(_, data_type) => match data_type.as_ref() { &DataType::$DATA_TYPE => (), _ => panic!("Unexpected DataType for list"), }, diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 862f634463c2..2b887c1fe584 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -133,13 +133,11 @@ impl Accumulator for DistinctSumAccumulator { // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set let state_out = { let mut distinct_values = Vec::new(); + let data_type = Box::new(self.data_type.clone()); self.hash_values .iter() .for_each(|distinct_value| distinct_values.push(distinct_value.clone())); - vec![ScalarValue::List( - Some(distinct_values), - self.data_type.clone(), - )] + vec![ScalarValue::List(Some(distinct_values), data_type)] }; Ok(state_out) } diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr/src/aggregate/tdigest.rs index 4f306e3ffa03..14d73b5bc018 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr/src/aggregate/tdigest.rs @@ -626,7 +626,7 @@ impl TDigest { ScalarValue::Float64(Some(self.count.into_inner())), ScalarValue::Float64(Some(self.max.into_inner())), ScalarValue::Float64(Some(self.min.into_inner())), - ScalarValue::List(Some(centroids), DataType::Float64), + ScalarValue::List(Some(centroids), Box::new(DataType::Float64)), ] } @@ -647,7 +647,7 @@ impl TDigest { }; let centroids: Vec<_> = match &state[5] { - ScalarValue::List(Some(c), d) if *d == DataType::Float64 => c + ScalarValue::List(Some(c), d) if **d == DataType::Float64 => c .chunks(2) .map(|v| Centroid::new(cast_scalar_f64!(v[0]), cast_scalar_f64!(v[1]))) .collect(), diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 2e5141893712..0bb767a347cd 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -601,7 +601,7 @@ impl TryFrom<&protobuf::scalar_value::Value> for ScalarValue { Value::Float64Value(v) => ScalarValue::Float64(Some(*v)), Value::Date32Value(v) => ScalarValue::Date32(Some(*v)), Value::ListValue(v) => v.try_into()?, - Value::NullListValue(v) => ScalarValue::List(None, v.try_into()?), + Value::NullListValue(v) => ScalarValue::List(None, Box::new(v.try_into()?)), Value::NullValue(null_enum) => { let primitive = PrimitiveScalarType::try_from(null_enum)?; (&primitive).try_into()? @@ -674,7 +674,10 @@ impl TryFrom<&protobuf::ScalarListValue> for ScalarValue { typechecked_scalar_value_conversion(value, leaf_scalar_type) }) .collect::, _>>()?; - ScalarValue::List(Some(typechecked_values), leaf_scalar_type.into()) + ScalarValue::List( + Some(typechecked_values), + Box::new(leaf_scalar_type.into()), + ) } Datatype::List(list_type) => { let protobuf::ScalarListType { @@ -709,7 +712,7 @@ impl TryFrom<&protobuf::ScalarListValue> for ScalarValue { 0 => None, _ => Some(typechecked_values), }, - (list_type).try_into()?, + Box::new((list_type).try_into()?), ) } }; @@ -851,6 +854,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } = &scalar_list; let scalar_type = opt_scalar_type.as_ref().required("datatype")?; + let scalar_type = Box::new(scalar_type); let typechecked_values: Vec = values .iter() @@ -861,7 +865,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } Value::NullListValue(v) => { let datatype = v.datatype.as_ref().required("datatype")?; - Self::List(None, datatype) + Self::List(None, Box::new(datatype)) } Value::NullValue(v) => { let null_type_enum = protobuf::PrimitiveScalarType::try_from(v)?; diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index d329085d8037..24809b9aa6f4 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -71,20 +71,24 @@ mod roundtrip_tests { ScalarValue::Int16(None), ScalarValue::Float32(Some(32.0)), ]), - DataType::List(new_box_field("item", DataType::Int16, true)), + Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), ), ScalarValue::List( Some(vec![ ScalarValue::Float32(None), ScalarValue::Float32(Some(32.0)), ]), - DataType::List(new_box_field("item", DataType::Int16, true)), + Box::new(DataType::List(new_box_field("item", DataType::Int16, true))), ), ScalarValue::List( Some(vec![ ScalarValue::List( None, - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( Some(vec![ @@ -94,22 +98,26 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( None, - DataType::List(new_box_field( + Box::new(DataType::List(new_box_field( "lists are typed inconsistently", DataType::Int16, true, - )), + ))), ), ]), - DataType::List(new_box_field( + Box::new(DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - )), + ))), ), ]; @@ -142,7 +150,7 @@ mod roundtrip_tests { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(None, DataType::Boolean), + ScalarValue::List(None, Box::new(DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -206,13 +214,21 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - DataType::List(new_box_field("level1", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level1", + DataType::Float32, + true, + ))), ), ScalarValue::List( Some(vec![ ScalarValue::List( None, - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ScalarValue::List( Some(vec![ @@ -222,14 +238,18 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), + Box::new(DataType::List(new_box_field( + "level2", + DataType::Float32, + true, + ))), ), ]), - DataType::List(new_box_field( + Box::new(DataType::List(new_box_field( "level1", DataType::List(new_box_field("level2", DataType::Float32, true)), true, - )), + ))), ), ]; diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 0cf8d7e7763e..5970a3c30a5c 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -801,13 +801,13 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.try_into()?), + datatype: Some(datatype.as_ref().try_into()?), values: Vec::new(), }, )), } } else { - let scalar_type = match datatype { + let scalar_type = match datatype.as_ref() { DataType::List(field) => field.as_ref().data_type(), _ => todo!("Proper error handling"), }; @@ -819,7 +819,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { scalar::ScalarValue::List(_, list_type), DataType::List(field), ) => { - if let DataType::List(list_field) = list_type { + if let DataType::List(list_field) = + list_type.as_ref() + { let scalar_datatype = field.data_type(); let list_datatype = list_field.data_type(); if std::mem::discriminant(list_datatype) @@ -893,7 +895,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::ListValue( protobuf::ScalarListValue { - datatype: Some(datatype.try_into()?), + datatype: Some(datatype.as_ref().try_into()?), values: type_checked_values, }, )), @@ -902,7 +904,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } None => protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::NullListValue( - datatype.try_into()?, + datatype.as_ref().try_into()?, )), }, } From b9ea1a5b7bef5386acaaa641f918cd839428bd2f Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 17 May 2022 09:52:16 -0700 Subject: [PATCH 7/7] fixing test --- datafusion/common/src/scalar.rs | 16 ++++++---------- datafusion/core/src/scalar.rs | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 6577a76ba9fd..758cc18bb1ca 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -928,21 +928,17 @@ impl ScalarValue { match values { Some(values) => { // Push value for each field - for (i, v) in - values.iter().enumerate().take(columns.len()) - { - let column = columns.get_mut(i).unwrap(); - column.push(v.clone()); + for (column, value) in columns.iter_mut().zip(values) { + column.push(value.clone()); } } None => { // Push NULL of the appropriate type for each field - for (i, f) in - fields.iter().enumerate().take(columns.len()) + for (column, field) in + columns.iter_mut().zip(fields.as_ref()) { - let dtype = f.data_type(); - let column = columns.get_mut(i).unwrap(); - column.push(ScalarValue::try_from(dtype)?); + column + .push(ScalarValue::try_from(field.data_type())?); } } }; diff --git a/datafusion/core/src/scalar.rs b/datafusion/core/src/scalar.rs index fb4d7030c1b9..d3094597bbb8 100644 --- a/datafusion/core/src/scalar.rs +++ b/datafusion/core/src/scalar.rs @@ -644,7 +644,7 @@ mod tests { assert_eq!( List( Some(vec![Int64(Some(1)), Int64(Some(5))]), - Box::new(DataType::Int32), + Box::new(DataType::Int64), ) .partial_cmp(&List( Some(vec![Int32(Some(1)), Int32(Some(5))]),