Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Box ScalarValue:Lists, reduce size by half size #788

Merged
merged 2 commits into from
Jul 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value
}
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
protobuf::scalar_value::Value::NullListValue(v) => {
ScalarValue::List(None, v.try_into()?)
ScalarValue::List(None, Box::new(v.try_into()?))
}
protobuf::scalar_value::Value::NullValue(null_enum) => {
PrimitiveScalarType::from_i32(*null_enum)
Expand Down Expand Up @@ -581,8 +581,8 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarListValue {
})
.collect::<Result<Vec<_>, _>>()?;
datafusion::scalar::ScalarValue::List(
Some(typechecked_values),
leaf_scalar_type.into(),
Some(Box::new(typechecked_values)),
Box::new(leaf_scalar_type.into()),
)
}
Datatype::List(list_type) => {
Expand Down Expand Up @@ -626,9 +626,9 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarListValue {
datafusion::scalar::ScalarValue::List(
match typechecked_values.len() {
0 => None,
_ => Some(typechecked_values),
_ => Some(Box::new(typechecked_values)),
},
list_type.try_into()?,
Box::new(list_type.try_into()?),
)
}
};
Expand Down Expand Up @@ -766,14 +766,16 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
.map(|val| val.try_into())
.collect::<Result<Vec<_>, _>>()?;
let scalar_type: DataType = pb_scalar_type.try_into()?;
ScalarValue::List(Some(typechecked_values), scalar_type)
let scalar_type = Box::new(scalar_type);
ScalarValue::List(Some(Box::new(typechecked_values)), scalar_type)
}
protobuf::scalar_value::Value::NullListValue(v) => {
let pb_datatype = v
.datatype
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: NullListValue message missing required field 'datatyp'"))?;
ScalarValue::List(None, pb_datatype.try_into()?)
let pb_datatype = Box::new(pb_datatype.try_into()?);
ScalarValue::List(None, pb_datatype)
}
protobuf::scalar_value::Value::NullValue(v) => {
let null_type_enum = protobuf::PrimitiveScalarType::from_i32(*v)
Expand Down
76 changes: 48 additions & 28 deletions ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,49 +126,57 @@ mod roundtrip_tests {
let should_fail_on_seralize: Vec<ScalarValue> = vec![
//Should fail due to inconsistent types
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::Int16(None),
ScalarValue::Float32(Some(32.0)),
]),
DataType::List(new_box_field("item", DataType::Int16, true)),
])),
Box::new(DataType::List(new_box_field("item", DataType::Int16, true))),
),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::Float32(None),
ScalarValue::Float32(Some(32.0)),
]),
DataType::List(new_box_field("item", DataType::Int16, true)),
])),
Box::new(DataType::List(new_box_field("item", DataType::Int16, true))),
),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::List(
None,
DataType::List(new_box_field("level2", DataType::Float32, true)),
Box::new(DataType::List(new_box_field(
"level2",
DataType::Float32,
true,
))),
),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::Float32(Some(-213.1)),
ScalarValue::Float32(None),
ScalarValue::Float32(Some(5.5)),
ScalarValue::Float32(Some(2.0)),
ScalarValue::Float32(Some(1.0)),
]),
DataType::List(new_box_field("level2", DataType::Float32, true)),
])),
Box::new(DataType::List(new_box_field(
"level2",
DataType::Float32,
true,
))),
),
ScalarValue::List(
None,
DataType::List(new_box_field(
Box::new(DataType::List(new_box_field(
"lists are typed inconsistently",
DataType::Int16,
true,
)),
))),
),
]),
DataType::List(new_box_field(
])),
Box::new(DataType::List(new_box_field(
"level1",
DataType::List(new_box_field("level2", DataType::Float32, true)),
true,
)),
))),
),
];

Expand Down Expand Up @@ -200,7 +208,7 @@ mod roundtrip_tests {
ScalarValue::UInt64(None),
ScalarValue::Utf8(None),
ScalarValue::LargeUtf8(None),
ScalarValue::List(None, DataType::Boolean),
ScalarValue::List(None, Box::new(DataType::Boolean)),
ScalarValue::Date32(None),
ScalarValue::TimestampMicrosecond(None),
ScalarValue::TimestampNanosecond(None),
Expand Down Expand Up @@ -248,37 +256,49 @@ mod roundtrip_tests {
ScalarValue::TimestampMicrosecond(Some(i64::MAX)),
ScalarValue::TimestampMicrosecond(None),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::Float32(Some(-213.1)),
ScalarValue::Float32(None),
ScalarValue::Float32(Some(5.5)),
ScalarValue::Float32(Some(2.0)),
ScalarValue::Float32(Some(1.0)),
]),
DataType::List(new_box_field("level1", DataType::Float32, true)),
])),
Box::new(DataType::List(new_box_field(
"level1",
DataType::Float32,
true,
))),
),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::List(
None,
DataType::List(new_box_field("level2", DataType::Float32, true)),
Box::new(DataType::List(new_box_field(
"level2",
DataType::Float32,
true,
))),
),
ScalarValue::List(
Some(vec![
Some(Box::new(vec![
ScalarValue::Float32(Some(-213.1)),
ScalarValue::Float32(None),
ScalarValue::Float32(Some(5.5)),
ScalarValue::Float32(Some(2.0)),
ScalarValue::Float32(Some(1.0)),
]),
DataType::List(new_box_field("level2", DataType::Float32, true)),
])),
Box::new(DataType::List(new_box_field(
"level2",
DataType::Float32,
true,
))),
),
]),
DataType::List(new_box_field(
])),
Box::new(DataType::List(new_box_field(
"level1",
DataType::List(new_box_field("level2", DataType::Float32, true)),
true,
)),
))),
),
];

Expand Down
33 changes: 20 additions & 13 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,30 +565,37 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.try_into()?),
datatype: Some(datatype.as_ref().try_into()?),
values: Vec::new(),
},
)),
}
} else {
let scalar_type = match datatype {
let scalar_type = match datatype.as_ref() {
DataType::List(field) => field.as_ref().data_type(),
_ => todo!("Proper error handling"),
};
println!("Current scalar type for list: {:?}", scalar_type);
let type_checked_values: Vec<protobuf::ScalarValue> = values
.iter()
.map(|scalar| match (scalar, scalar_type) {
(scalar::ScalarValue::List(_, DataType::List(list_field)), DataType::List(field)) => {
let scalar_datatype = field.data_type();
let list_datatype = list_field.data_type();
if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) {
return Err(proto_error(format!(
"Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list",
list_datatype, scalar_datatype
)));
(scalar::ScalarValue::List(_, list_type), DataType::List(field)) => {
if let DataType::List(list_field) = list_type.as_ref() {
let scalar_datatype = field.data_type();
let list_datatype = list_field.data_type();
if std::mem::discriminant(list_datatype) != std::mem::discriminant(scalar_datatype) {
return Err(proto_error(format!(
"Protobuf serialization error: Lists with inconsistent typing {:?} and {:?} found within list",
list_datatype, scalar_datatype
)));
}
scalar.try_into()
} else {
Err(proto_error(format!(
"Protobuf serialization error, {:?} was inconsistent with designated type {:?}",
scalar, datatype
)))
}
scalar.try_into()
}
(scalar::ScalarValue::Boolean(_), DataType::Boolean) => scalar.try_into(),
(scalar::ScalarValue::Float32(_), DataType::Float32) => scalar.try_into(),
Expand All @@ -612,7 +619,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::ListValue(
protobuf::ScalarListValue {
datatype: Some(datatype.try_into()?),
datatype: Some(datatype.as_ref().try_into()?),
values: type_checked_values,
},
)),
Expand All @@ -621,7 +628,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
}
None => protobuf::ScalarValue {
value: Some(protobuf::scalar_value::Value::NullListValue(
datatype.try_into()?,
datatype.as_ref().try_into()?,
)),
},
}
Expand Down
8 changes: 5 additions & 3 deletions datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ impl Accumulator for DistinctCountAccumulator {
.state_data_types
.iter()
.map(|state_data_type| {
ScalarValue::List(Some(Vec::new()), state_data_type.clone())
let values = Box::new(Vec::new());
let data_type = Box::new(state_data_type.clone());
ScalarValue::List(Some(values), data_type)
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -254,8 +256,8 @@ mod tests {
macro_rules! state_to_vec {
($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
match $LIST {
ScalarValue::List(_, data_type) => match data_type {
DataType::$DATA_TYPE => (),
ScalarValue::List(_, data_type) => match data_type.as_ref() {
&DataType::$DATA_TYPE => (),
_ => panic!("Unexpected DataType for list"),
},
_ => panic!("Expected a ScalarValue::List"),
Expand Down
Loading