Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 44 additions & 52 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ pub enum ScalarValue {
Binary(Option<Vec<u8>>),
/// large binary
LargeBinary(Option<Vec<u8>>),
/// list of nested ScalarValue (boxed to reduce size_of(ScalarValue))
#[allow(clippy::box_collection)]
List(Option<Box<Vec<ScalarValue>>>, Box<DataType>),
/// list of nested ScalarValue
List(Option<Vec<ScalarValue>>, Box<DataType>),
/// Date stored as a signed 32bit int
Date32(Option<i32>),
/// Date stored as a signed 64bit int
Expand All @@ -94,9 +93,8 @@ pub enum ScalarValue {
IntervalDayTime(Option<i64>),
/// Interval with MonthDayNano unit
IntervalMonthDayNano(Option<i128>),
/// struct of nested ScalarValue (boxed to reduce size_of(ScalarValue))
#[allow(clippy::box_collection)]
Struct(Option<Box<Vec<ScalarValue>>>, Box<Vec<Field>>),
/// struct of nested ScalarValue
Struct(Option<Vec<ScalarValue>>, Box<Vec<Field>>),
}

// manual implementation of `PartialEq` that uses OrderedFloat to
Expand Down Expand Up @@ -400,7 +398,7 @@ macro_rules! build_list {
)
}
Some(values) => {
build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values.as_ref(), $SIZE)
build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, $SIZE)
}
}
}};
Expand All @@ -420,37 +418,34 @@ macro_rules! build_timestamp_list {
$SIZE,
)
}
Some(values) => {
let values = values.as_ref();
match $TIME_UNIT {
TimeUnit::Second => {
build_values_list_tz!(
TimestampSecondBuilder,
TimestampSecond,
values,
$SIZE
)
}
TimeUnit::Microsecond => build_values_list_tz!(
TimestampMillisecondBuilder,
TimestampMillisecond,
values,
$SIZE
),
TimeUnit::Millisecond => build_values_list_tz!(
TimestampMicrosecondBuilder,
TimestampMicrosecond,
values,
$SIZE
),
TimeUnit::Nanosecond => build_values_list_tz!(
TimestampNanosecondBuilder,
TimestampNanosecond,
Some(values) => match $TIME_UNIT {
TimeUnit::Second => {
build_values_list_tz!(
TimestampSecondBuilder,
TimestampSecond,
values,
$SIZE
),
)
}
}
TimeUnit::Microsecond => build_values_list_tz!(
TimestampMillisecondBuilder,
TimestampMillisecond,
values,
$SIZE
),
TimeUnit::Millisecond => build_values_list_tz!(
TimestampMicrosecondBuilder,
TimestampMicrosecond,
values,
$SIZE
),
TimeUnit::Nanosecond => build_values_list_tz!(
TimestampNanosecondBuilder,
TimestampNanosecond,
values,
$SIZE
),
},
}
}};
}
Expand Down Expand Up @@ -804,7 +799,6 @@ impl ScalarValue {
for scalar in scalars.into_iter() {
match scalar {
ScalarValue::List(Some(xs), _) => {
let xs = *xs;
for s in xs {
match s {
ScalarValue::$SCALAR_TY(Some(val)) => {
Expand Down Expand Up @@ -934,17 +928,17 @@ impl ScalarValue {
match values {
Some(values) => {
// Push value for each field
for c in 0..columns.len() {
let column = columns.get_mut(c).unwrap();
column.push(values[c].clone());
for (column, value) in columns.iter_mut().zip(values) {
column.push(value.clone());
}
}
None => {
// Push NULL of the appropriate type for each field
for c in 0..columns.len() {
let dtype = fields[c].data_type();
let column = columns.get_mut(c).unwrap();
column.push(ScalarValue::try_from(dtype)?);
for (column, field) in
columns.iter_mut().zip(fields.as_ref())
{
column
.push(ScalarValue::try_from(field.data_type())?);
}
}
};
Expand Down Expand Up @@ -1022,7 +1016,7 @@ impl ScalarValue {
if let ScalarValue::List(values, _) = scalar {
match values {
Some(values) => {
let element_array = ScalarValue::iter_to_array(*values)?;
let element_array = ScalarValue::iter_to_array(values)?;

// Add new offset index
flat_len += element_array.len() as i32;
Expand Down Expand Up @@ -1327,9 +1321,8 @@ impl ScalarValue {
Some(scalar_vec)
}
};
let value = value.map(Box::new);
let data_type = Box::new(nested_type.data_type().clone());
ScalarValue::List(value, data_type)
let data_type = nested_type.data_type().clone();
ScalarValue::List(value, Box::new(data_type))
}
DataType::Date32 => {
typed_cast!(array, index, Date32Array, Date32)
Expand Down Expand Up @@ -1413,7 +1406,7 @@ impl ScalarValue {
let col_scalar = ScalarValue::try_from_array(col_array, index)?;
field_values.push(col_scalar);
}
Self::Struct(Some(Box::new(field_values)), Box::new(fields.clone()))
Self::Struct(Some(field_values), Box::new(fields.clone()))
}
DataType::FixedSizeList(nested_type, _len) => {
let list_array =
Expand All @@ -1428,9 +1421,8 @@ impl ScalarValue {
Some(scalar_vec)
}
};
let value = value.map(Box::new);
let data_type = Box::new(nested_type.data_type().clone());
ScalarValue::List(value, data_type)
let data_type = nested_type.data_type().clone();
ScalarValue::List(value, Box::new(data_type))
}
other => {
return Err(DataFusionError::NotImplemented(format!(
Expand Down Expand Up @@ -1635,7 +1627,7 @@ impl From<Vec<(&str, ScalarValue)>> for ScalarValue {
})
.unzip();

Self::Struct(Some(Box::new(scalars)), Box::new(fields))
Self::Struct(Some(scalars), Box::new(fields))
}
}

Expand Down
73 changes: 31 additions & 42 deletions datafusion/core/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ mod tests {
#[test]
fn scalar_list_to_array() {
let list_array_ref = ScalarValue::List(
Some(Box::new(vec![
Some(vec![
ScalarValue::UInt64(Some(100)),
ScalarValue::UInt64(None),
ScalarValue::UInt64(Some(101)),
])),
]),
Box::new(DataType::UInt64),
)
.to_array();
Expand Down Expand Up @@ -606,35 +606,35 @@ mod tests {

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Some(vec![Int32(Some(1)), Int32(Some(5))]),
Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Some(vec![Int32(Some(1)), Int32(Some(5))]),
Box::new(DataType::Int32),
)),
Some(Ordering::Equal)
);

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
Some(vec![Int32(Some(10)), Int32(Some(5))]),
Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Some(vec![Int32(Some(1)), Int32(Some(5))]),
Box::new(DataType::Int32),
)),
Some(Ordering::Greater)
);

assert_eq!(
List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Some(vec![Int32(Some(1)), Int32(Some(5))]),
Box::new(DataType::Int32),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
Some(vec![Int32(Some(10)), Int32(Some(5))]),
Box::new(DataType::Int32),
)),
Some(Ordering::Less)
Expand All @@ -643,11 +643,11 @@ mod tests {
// For different data type, `partial_cmp` returns None.
assert_eq!(
List(
Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])),
Some(vec![Int64(Some(1)), Int64(Some(5))]),
Box::new(DataType::Int64),
)
.partial_cmp(&List(
Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
Some(vec![Int32(Some(1)), Int32(Some(5))]),
Box::new(DataType::Int32),
)),
None
Expand Down Expand Up @@ -694,15 +694,15 @@ mod tests {
);

let scalar = ScalarValue::Struct(
Some(Box::new(vec![
Some(vec![
ScalarValue::Int32(Some(23)),
ScalarValue::Boolean(Some(false)),
ScalarValue::Utf8(Some("Hello".to_string())),
ScalarValue::from(vec![
("e", ScalarValue::from(2i16)),
("f", ScalarValue::from(3i64)),
]),
])),
]),
Box::new(vec![
field_a.clone(),
field_b.clone(),
Expand Down Expand Up @@ -866,24 +866,21 @@ mod tests {

// Define primitive list scalars
let l0 = ScalarValue::List(
Some(Box::new(vec![
Some(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
])),
]),
Box::new(DataType::Int32),
);

let l1 = ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(4i32),
ScalarValue::from(5i32),
])),
Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]),
Box::new(DataType::Int32),
);

let l2 = ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(6i32)])),
Some(vec![ScalarValue::from(6i32)]),
Box::new(DataType::Int32),
);

Expand Down Expand Up @@ -928,15 +925,13 @@ mod tests {

// Define list-of-structs scalars
let nl0 = ScalarValue::List(
Some(Box::new(vec![s0.clone(), s1.clone()])),
Some(vec![s0.clone(), s1.clone()]),
Box::new(s0.get_datatype()),
);

let nl1 =
ScalarValue::List(Some(Box::new(vec![s2])), Box::new(s0.get_datatype()));
let nl1 = ScalarValue::List(Some(vec![s2]), Box::new(s0.get_datatype()));

let nl2 =
ScalarValue::List(Some(Box::new(vec![s1])), Box::new(s0.get_datatype()));
let nl2 = ScalarValue::List(Some(vec![s1]), Box::new(s0.get_datatype()));

// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
Expand Down Expand Up @@ -1080,23 +1075,20 @@ mod tests {
fn test_nested_lists() {
// Define inner list scalars
let l1 = ScalarValue::List(
Some(Box::new(vec![
Some(vec![
ScalarValue::List(
Some(Box::new(vec![
Some(vec![
ScalarValue::from(1i32),
ScalarValue::from(2i32),
ScalarValue::from(3i32),
])),
]),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(4i32),
ScalarValue::from(5i32),
])),
Some(vec![ScalarValue::from(4i32), ScalarValue::from(5i32)]),
Box::new(DataType::Int32),
),
])),
]),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
Expand All @@ -1105,19 +1097,16 @@ mod tests {
);

let l2 = ScalarValue::List(
Some(Box::new(vec![
Some(vec![
ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(6i32)])),
Some(vec![ScalarValue::from(6i32)]),
Box::new(DataType::Int32),
),
ScalarValue::List(
Some(Box::new(vec![
ScalarValue::from(7i32),
ScalarValue::from(8i32),
])),
Some(vec![ScalarValue::from(7i32), ScalarValue::from(8i32)]),
Box::new(DataType::Int32),
),
])),
]),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
Expand All @@ -1126,10 +1115,10 @@ mod tests {
);

let l3 = ScalarValue::List(
Some(Box::new(vec![ScalarValue::List(
Some(Box::new(vec![ScalarValue::from(9i32)])),
Some(vec![ScalarValue::List(
Some(vec![ScalarValue::from(9i32)]),
Box::new(DataType::Int32),
)])),
)]),
Box::new(DataType::List(Box::new(Field::new(
"item",
DataType::Int32,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2444,7 +2444,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let data_type = values[0].get_datatype();

Ok(Expr::Literal(ScalarValue::List(
Some(Box::new(values)),
Some(values),
Box::new(data_type),
)))
}
Expand Down
Loading