Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand Down
35 changes: 20 additions & 15 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("c_address", DataType::Utf8, false),
Field::new("c_nationkey", DataType::Int64, false),
Field::new("c_phone", DataType::Utf8, false),
Field::new("c_acctbal", DataType::Float64, false),
Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
Field::new("c_mktsegment", DataType::Utf8, false),
Field::new("c_comment", DataType::Utf8, false),
]),
Expand All @@ -462,7 +462,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("o_orderkey", DataType::Int64, false),
Field::new("o_custkey", DataType::Int64, false),
Field::new("o_orderstatus", DataType::Utf8, false),
Field::new("o_totalprice", DataType::Float64, false),
Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
Field::new("o_orderdate", DataType::Date32, false),
Field::new("o_orderpriority", DataType::Utf8, false),
Field::new("o_clerk", DataType::Utf8, false),
Expand All @@ -475,10 +475,10 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("l_partkey", DataType::Int64, false),
Field::new("l_suppkey", DataType::Int64, false),
Field::new("l_linenumber", DataType::Int32, false),
Field::new("l_quantity", DataType::Float64, false),
Field::new("l_extendedprice", DataType::Float64, false),
Field::new("l_discount", DataType::Float64, false),
Field::new("l_tax", DataType::Float64, false),
Field::new("l_quantity", DataType::Decimal128(15, 2), false),
Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
Field::new("l_discount", DataType::Decimal128(15, 2), false),
Field::new("l_tax", DataType::Decimal128(15, 2), false),
Field::new("l_returnflag", DataType::Utf8, false),
Field::new("l_linestatus", DataType::Utf8, false),
Field::new("l_shipdate", DataType::Date32, false),
Expand All @@ -502,15 +502,15 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("s_address", DataType::Utf8, false),
Field::new("s_nationkey", DataType::Int64, false),
Field::new("s_phone", DataType::Utf8, false),
Field::new("s_acctbal", DataType::Float64, false),
Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
Field::new("s_comment", DataType::Utf8, false),
]),

"partsupp" => Schema::new(vec![
Field::new("ps_partkey", DataType::Int64, false),
Field::new("ps_suppkey", DataType::Int64, false),
Field::new("ps_availqty", DataType::Int32, false),
Field::new("ps_supplycost", DataType::Float64, false),
Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
Field::new("ps_comment", DataType::Utf8, false),
]),

Expand All @@ -522,7 +522,7 @@ fn get_tpch_table_schema(table: &str) -> Schema {
Field::new("p_type", DataType::Utf8, false),
Field::new("p_size", DataType::Int32, false),
Field::new("p_container", DataType::Utf8, false),
Field::new("p_retailprice", DataType::Float64, false),
Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
Field::new("p_comment", DataType::Utf8, false),
]),

Expand Down Expand Up @@ -573,9 +573,9 @@ async fn register_tpch_csv_data(
DataType::Int64 => {
cols.push(Box::new(Int64Builder::with_capacity(records.len())))
}
DataType::Float64 => {
cols.push(Box::new(Float64Builder::with_capacity(records.len())))
}
DataType::Decimal128(p, s) => cols.push(Box::new(
Decimal128Builder::with_capacity(records.len(), *p, *s),
)),
_ => {
let msg = format!("Not implemented: {}", field.data_type());
Err(DataFusionError::Plan(msg))?
Expand Down Expand Up @@ -606,9 +606,14 @@ async fn register_tpch_csv_data(
let sb = col.as_any_mut().downcast_mut::<Int64Builder>().unwrap();
sb.append_value(val.trim().parse().unwrap());
}
DataType::Float64 => {
let sb = col.as_any_mut().downcast_mut::<Float64Builder>().unwrap();
sb.append_value(val.trim().parse().unwrap());
DataType::Decimal128(_, _) => {
let sb = col
.as_any_mut()
.downcast_mut::<Decimal128Builder>()
.unwrap();
let val = val.trim().replace('.', "");
let value_i128 = val.parse::<i128>().unwrap();
sb.append_value(value_i128)?;
}
_ => Err(DataFusionError::Plan(format!(
"Not implemented: {}",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,10 @@ async fn multiple_or_predicates() -> Result<()> {
let expected =vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[#part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
];
Expand Down
27 changes: 14 additions & 13 deletions datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ where c_acctbal < (
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
Projection: #customer.c_custkey
Filter: #customer.c_acctbal < #__sq_2.__value
Filter: CAST(#customer.c_acctbal AS Decimal128(25, 2)) < #__sq_2.__value
Inner Join: #customer.c_custkey = #__sq_2.o_custkey
TableScan: customer projection=[c_custkey, c_acctbal]
Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, alias=__sq_2
Aggregate: groupBy=[[#orders.o_custkey]], aggr=[[SUM(#orders.o_totalprice)]]
Filter: #orders.o_totalprice < #__sq_1.__value
Filter: CAST(#orders.o_totalprice AS Decimal128(25, 2)) < #__sq_1.__value
Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]
Projection: #lineitem.l_orderkey, #SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
Expand Down Expand Up @@ -229,6 +229,7 @@ async fn tpch_q4_correlated() -> Result<()> {
Ok(())
}

#[ignore] // https://github.com/apache/arrow-datafusion/issues/3437
#[tokio::test]
async fn tpch_q17_correlated() -> Result<()> {
let parts = r#"63700,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly ironi
Expand Down Expand Up @@ -260,15 +261,15 @@ async fn tpch_q17_correlated() -> Result<()> {
.map_err(|e| format!("{:?} at {}", e, "error"))
.unwrap();
let actual = format!("{}", plan.display_indent());
let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) AS avg_yearly
let expected = r#"Projection: CAST(#SUM(lineitem.l_extendedprice) AS Decimal128(38, 33)) / CAST(Float64(7) AS Decimal128(38, 33)) AS avg_yearly
Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
Filter: #lineitem.l_quantity < #__sq_1.__value
Filter: CAST(#lineitem.l_quantity AS Decimal128(38, 21)) < #__sq_1.__value
Inner Join: #part.p_partkey = #__sq_1.l_partkey
Inner Join: #lineitem.l_partkey = #part.p_partkey
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]
Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = Utf8("MED BOX")
TableScan: part projection=[p_partkey, p_brand, p_container]
Projection: #lineitem.l_partkey, Float64(0.2) * #AVG(lineitem.l_quantity) AS __value, alias=__sq_1
Projection: #lineitem.l_partkey, CAST(Float64(0.2) AS Decimal128(38, 21)) * CAST(#AVG(lineitem.l_quantity) AS Decimal128(38, 21)) AS __value, alias=__sq_1
Aggregate: groupBy=[[#lineitem.l_partkey]], aggr=[[AVG(#lineitem.l_quantity)]]
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]"#
.to_string();
Expand Down Expand Up @@ -328,14 +329,14 @@ order by s_name;
Filter: #nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("CANADA")]
Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value
Filter: CAST(#partsupp.ps_availqty AS Decimal128(38, 17)) > #__sq_3.__value
Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, #partsupp.ps_suppkey = #__sq_3.l_suppkey
Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
Filter: #part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
Expand Down Expand Up @@ -384,15 +385,15 @@ order by cntrycode;"#;
Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(#custsale.c_acctbal)]]
Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS cntrycode, #customer.c_acctbal, alias=custsale
Filter: #customer.c_acctbal > #__sq_1.__value
Filter: CAST(#customer.c_acctbal AS Decimal128(19, 6)) > #__sq_1.__value
CrossJoin:
Anti Join: #customer.c_custkey = #orders.o_custkey
Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]
TableScan: orders projection=[o_custkey]
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
Filter: #customer.c_acctbal > Float64(0) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);
Expand Down Expand Up @@ -443,17 +444,17 @@ order by value desc;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: #value DESC NULLS FIRST
Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > #__sq_1.__value
Filter: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) > #__sq_1.__value
CrossJoin:
Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
Aggregate: groupBy=[[#partsupp.ps_partkey]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name], partial_filters=[#nation.n_name = Utf8("GERMANY")]
Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
Expand Down