Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-4957: [Rust] [DataFusion] Re-implement get_supertype #7253

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/datafusion/src/execution/context.rs
Expand Up @@ -658,7 +658,7 @@ mod tests {
let mut ctx = create_ctx(&tmp_dir, partition_count)?;

let logical_plan =
ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?;
ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE CAST(c1 AS double) > 0 AND CAST(c1 AS double) < 3")?;
let logical_plan = ctx.optimize(&logical_plan)?;

let physical_plan = ctx.create_physical_plan(&logical_plan, 1024)?;
Expand Down
17 changes: 11 additions & 6 deletions rust/datafusion/src/logicalplan.rs
Expand Up @@ -288,14 +288,14 @@ impl Expr {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
Ok(self.clone())
} else if can_coerce_from(cast_to_type, &this_type) {
} else if cast_supported(cast_to_type, &this_type) {
Ok(Expr::Cast {
expr: Box::new(self.clone()),
data_type: cast_to_type.clone(),
})
} else {
Err(ExecutionError::General(format!(
"Cannot automatically convert {:?} to {:?}",
"Cannot cast from {:?} to {:?}",
this_type, cast_to_type
)))
}
Expand Down Expand Up @@ -723,23 +723,28 @@ impl fmt::Debug for LogicalPlan {
}

/// Verify a given type cast can be performed
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
pub fn cast_supported(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;

if type_from == type_into {
return true;
}

match type_into {
Int8 => match type_from {
Int8 => true,
_ => false,
},
Int16 => match type_from {
Int8 | Int16 | UInt8 => true,
Int8 | Int16 => true,
_ => false,
},
Int32 => match type_from {
Int8 | Int16 | Int32 | UInt8 | UInt16 => true,
Int8 | Int16 | Int32 => true,
_ => false,
},
Int64 => match type_from {
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 => true,
Int8 | Int16 | Int32 | Int64 => true,
_ => false,
},
UInt8 => match type_from {
Expand Down
44 changes: 28 additions & 16 deletions rust/datafusion/src/optimizer/type_coercion.rs
Expand Up @@ -187,8 +187,10 @@ mod tests {
use crate::execution::context::ExecutionContext;
use crate::execution::physical_plan::csv::CsvReadOptions;
use crate::logicalplan::Expr::*;
use crate::logicalplan::{col, Operator};
use crate::logicalplan::{cast_supported, col, Operator};
use crate::optimizer::utils::get_supertype;
use crate::test::arrow_testdata_path;
use arrow::datatypes::DataType::*;
use arrow::datatypes::{DataType, Field, Schema};

#[test]
Expand All @@ -212,6 +214,30 @@ mod tests {
Ok(())
}

#[test]
fn test_type_matrix() -> Result<()> {
let types = vec![
Boolean, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32,
Float64, Utf8,
];

for from_type in &types {
for to_type in &types {
match get_supertype(from_type, to_type) {
Ok(t) => {
// swapping from and to should result in same supertype
assert_eq!(t, get_supertype(to_type, from_type)?);
// both from and to types should be coercable to the supertype
assert!(cast_supported(&t, &from_type));
assert!(cast_supported(&t, &to_type));
}
Err(_) => assert!(get_supertype(to_type, from_type).is_err()),
}
}
}
Ok(())
}

#[test]
fn test_add_i32_i64() {
binary_cast_test(
Expand Down Expand Up @@ -254,20 +280,6 @@ mod tests {
);
}

#[test]
fn test_add_u32_i64() {
binary_cast_test(
DataType::UInt32,
DataType::Int64,
"CAST(#0 AS Int64) Plus #1",
);
binary_cast_test(
DataType::Int64,
DataType::UInt32,
"#0 Plus CAST(#1 AS Int64)",
);
}

fn binary_cast_test(left_type: DataType, right_type: DataType, expected: &str) {
let schema = Schema::new(vec![
Field::new("c0", left_type, true),
Expand All @@ -285,6 +297,6 @@ mod tests {

let expr2 = rule.rewrite_expr(&expr, &schema).unwrap();

assert_eq!(expected, format!("{:?}", expr2));
assert_eq!(format!("{:?}", expr2), expected);
}
}
174 changes: 68 additions & 106 deletions rust/datafusion/src/optimizer/utils.rs
Expand Up @@ -131,116 +131,78 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result<Vec<Fi

/// Given two datatypes, determine the supertype that both types can safely be cast to
pub fn get_supertype(l: &DataType, r: &DataType) -> Result<DataType> {
match _get_supertype(l, r) {
Some(dt) => Ok(dt),
None => _get_supertype(r, l).ok_or_else(|| {
ExecutionError::InternalError(format!(
"Failed to determine supertype of {:?} and {:?}",
l, r
))
}),
}
}

/// Given two datatypes, determine the supertype that both types can safely be cast to
fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (l, r) {
(UInt8, Int8) => Some(Int8),
(UInt8, Int16) => Some(Int16),
(UInt8, Int32) => Some(Int32),
(UInt8, Int64) => Some(Int64),

(UInt16, Int16) => Some(Int16),
(UInt16, Int32) => Some(Int32),
(UInt16, Int64) => Some(Int64),

(UInt32, Int32) => Some(Int32),
(UInt32, Int64) => Some(Int64),

(UInt64, Int64) => Some(Int64),

(Int8, UInt8) => Some(Int8),

(Int16, UInt8) => Some(Int16),
(Int16, UInt16) => Some(Int16),

(Int32, UInt8) => Some(Int32),
(Int32, UInt16) => Some(Int32),
(Int32, UInt32) => Some(Int32),

(Int64, UInt8) => Some(Int64),
(Int64, UInt16) => Some(Int64),
(Int64, UInt32) => Some(Int64),
(Int64, UInt64) => Some(Int64),

(UInt8, UInt8) => Some(UInt8),
(UInt8, UInt16) => Some(UInt16),
(UInt8, UInt32) => Some(UInt32),
(UInt8, UInt64) => Some(UInt64),
(UInt8, Float32) => Some(Float32),
(UInt8, Float64) => Some(Float64),

(UInt16, UInt8) => Some(UInt16),
(UInt16, UInt16) => Some(UInt16),
(UInt16, UInt32) => Some(UInt32),
(UInt16, UInt64) => Some(UInt64),
(UInt16, Float32) => Some(Float32),
(UInt16, Float64) => Some(Float64),

(UInt32, UInt8) => Some(UInt32),
(UInt32, UInt16) => Some(UInt32),
(UInt32, UInt32) => Some(UInt32),
(UInt32, UInt64) => Some(UInt64),
(UInt32, Float32) => Some(Float32),
(UInt32, Float64) => Some(Float64),

(UInt64, UInt8) => Some(UInt64),
(UInt64, UInt16) => Some(UInt64),
(UInt64, UInt32) => Some(UInt64),
(UInt64, UInt64) => Some(UInt64),
(UInt64, Float32) => Some(Float32),
(UInt64, Float64) => Some(Float64),

(Int8, Int8) => Some(Int8),
(Int8, Int16) => Some(Int16),
(Int8, Int32) => Some(Int32),
(Int8, Int64) => Some(Int64),
(Int8, Float32) => Some(Float32),
(Int8, Float64) => Some(Float64),

(Int16, Int8) => Some(Int16),
(Int16, Int16) => Some(Int16),
(Int16, Int32) => Some(Int32),
(Int16, Int64) => Some(Int64),
(Int16, Float32) => Some(Float32),
(Int16, Float64) => Some(Float64),

(Int32, Int8) => Some(Int32),
(Int32, Int16) => Some(Int32),
(Int32, Int32) => Some(Int32),
(Int32, Int64) => Some(Int64),
(Int32, Float32) => Some(Float32),
(Int32, Float64) => Some(Float64),

(Int64, Int8) => Some(Int64),
(Int64, Int16) => Some(Int64),
(Int64, Int32) => Some(Int64),
(Int64, Int64) => Some(Int64),
(Int64, Float32) => Some(Float32),
(Int64, Float64) => Some(Float64),

(Float32, Float32) => Some(Float32),
(Float32, Float64) => Some(Float64),
(Float64, Float32) => Some(Float64),
(Float64, Float64) => Some(Float64),

(Utf8, _) => Some(Utf8),
(_, Utf8) => Some(Utf8),

(Boolean, Boolean) => Some(Boolean),
if l == r {
return Ok(l.clone());
}

let super_type = match l {
UInt8 => match r {
UInt16 | UInt32 | UInt64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
UInt16 => match r {
UInt8 => Some(l.clone()),
UInt32 | UInt64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
UInt32 => match r {
UInt8 | UInt16 => Some(l.clone()),
UInt64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
UInt64 => match r {
UInt8 | UInt16 | UInt32 => Some(l.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
Int8 => match r {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we support signed and unsigned with the same width as well? for example (int8, uint8) -> int16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woudln't (int8, uint8) -> int8 instead of int16? I think signed & unsigned of same length should return signed, so I agree with your question

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with regards to int8 vs int16, I was thinking 255 from uint8 doesn't fit into range of int8, so increase of width to 16 is required.

Int16 | Int32 | Int64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
Int16 => match r {
Int8 => Some(l.clone()),
Int32 | Int64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
Int32 => match r {
Int8 | Int16 => Some(l.clone()),
Int64 => Some(r.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
Int64 => match r {
Int8 | Int16 | Int32 => Some(l.clone()),
Float32 | Float64 => Some(r.clone()),
_ => None,
},
Float32 => match r {
Int8 | Int16 | Int32 | Int64 => Some(Float32),
UInt8 | UInt16 | UInt32 | UInt64 => Some(Float32),
Float64 => Some(Float64),
_ => None,
},
Float64 => match r {
Int8 | Int16 | Int32 | Int64 => Some(Float64),
UInt8 | UInt16 | UInt32 | UInt64 => Some(Float64),
Float32 | Float64 => Some(Float64),
_ => None,
},
_ => None,
};

match super_type {
Some(dt) => Ok(dt),
None => Err(ExecutionError::InternalError(format!(
"Failed to determine supertype of {:?} and {:?}",
l, r
))),
}
}

Expand Down