Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,56 @@ mod tests {
Ok(ctx)
}

async fn create_context_tpch6() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations =
vec![("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/lineitem.csv")];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}
// missing context for query 7,8,9

async fn create_context_tpch10() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/customer.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/orders.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/lineitem.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch11() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/nation.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

#[tokio::test]
async fn tpch_test_1() -> Result<()> {
let ctx = create_context_tpch1().await?;
Expand Down Expand Up @@ -266,4 +316,85 @@ mod tests {
\n TableScan: REGION projection=[r_regionkey, r_name, r_comment]");
Ok(())
}

#[tokio::test]
async fn tpch_test_6() -> Result<()> {
let ctx = create_context_tpch6().await?;
let path = "tests/testdata/tpch_substrait_plans/query_6.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{:?}", plan);
assert_eq!(plan_str, "Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_0.l_extendedprice * FILENAME_PLACEHOLDER_0.l_discount) AS REVENUE]]\
\n Projection: FILENAME_PLACEHOLDER_0.l_extendedprice * FILENAME_PLACEHOLDER_0.l_discount\
\n Filter: FILENAME_PLACEHOLDER_0.l_shipdate >= CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_0.l_shipdate < CAST(Utf8(\"1995-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_0.l_discount >= Decimal128(Some(5),3,2) AND FILENAME_PLACEHOLDER_0.l_discount <= Decimal128(Some(7),3,2) AND FILENAME_PLACEHOLDER_0.l_quantity < CAST(Int32(24) AS Decimal128(19, 0))\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]");
Ok(())
}

// TODO: missing plan 7, 8, 9
#[tokio::test]
async fn tpch_test_10() -> Result<()> {
let ctx = create_context_tpch10().await?;
let path = "tests/testdata/tpch_substrait_plans/query_10.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{:?}", plan);
assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.c_custkey AS C_CUSTKEY, FILENAME_PLACEHOLDER_0.c_name AS C_NAME, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) AS REVENUE, FILENAME_PLACEHOLDER_0.c_acctbal AS C_ACCTBAL, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.c_address AS C_ADDRESS, FILENAME_PLACEHOLDER_0.c_phone AS C_PHONE, FILENAME_PLACEHOLDER_0.c_comment AS C_COMMENT\
\n Limit: skip=0, fetch=20\
\n Sort: sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount) DESC NULLS FIRST\
\n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount), FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_0.c_comment\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment]], aggr=[[sum(FILENAME_PLACEHOLDER_2.l_extendedprice * Int32(1) - FILENAME_PLACEHOLDER_2.l_discount)]]\
\n Projection: FILENAME_PLACEHOLDER_0.c_custkey, FILENAME_PLACEHOLDER_0.c_name, FILENAME_PLACEHOLDER_0.c_acctbal, FILENAME_PLACEHOLDER_0.c_phone, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.c_address, FILENAME_PLACEHOLDER_0.c_comment, FILENAME_PLACEHOLDER_2.l_extendedprice * (CAST(Int32(1) AS Decimal128(19, 0)) - FILENAME_PLACEHOLDER_2.l_discount)\
\n Filter: FILENAME_PLACEHOLDER_0.c_custkey = FILENAME_PLACEHOLDER_1.o_custkey AND FILENAME_PLACEHOLDER_2.l_orderkey = FILENAME_PLACEHOLDER_1.o_orderkey AND FILENAME_PLACEHOLDER_1.o_orderdate >= CAST(Utf8(\"1993-10-01\") AS Date32) AND FILENAME_PLACEHOLDER_1.o_orderdate < CAST(Utf8(\"1994-01-01\") AS Date32) AND FILENAME_PLACEHOLDER_2.l_returnflag = Utf8(\"R\") AND FILENAME_PLACEHOLDER_0.c_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]\
\n TableScan: FILENAME_PLACEHOLDER_1 projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]\
\n TableScan: FILENAME_PLACEHOLDER_2 projection=[l_orderkey, l_partkey, l_suppkey, l_linenumber, l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate, l_commitdate, l_receiptdate, l_shipinstruct, l_shipmode, l_comment]\
\n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]");
Ok(())
}

#[tokio::test]
async fn tpch_test_11() -> Result<()> {
let ctx = create_context_tpch11().await?;
let path = "tests/testdata/tpch_substrait_plans/query_11.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{:?}", plan);
assert_eq!(plan_str, "Projection: FILENAME_PLACEHOLDER_0.ps_partkey AS PS_PARTKEY, sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty) AS value\
\n Sort: sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty) DESC NULLS FIRST\
\n Filter: sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty) > (<subquery>)\
\n Subquery:\
\n Projection: sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty) * Decimal128(Some(1000000),11,10)\
\n Aggregate: groupBy=[[]], aggr=[[sum(FILENAME_PLACEHOLDER_3.ps_supplycost * FILENAME_PLACEHOLDER_3.ps_availqty)]]\
\n Projection: FILENAME_PLACEHOLDER_3.ps_supplycost * CAST(FILENAME_PLACEHOLDER_3.ps_availqty AS Decimal128(19, 0))\
\n Filter: FILENAME_PLACEHOLDER_3.ps_suppkey = FILENAME_PLACEHOLDER_4.s_suppkey AND FILENAME_PLACEHOLDER_4.s_nationkey = FILENAME_PLACEHOLDER_5.n_nationkey AND FILENAME_PLACEHOLDER_5.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_3 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_4 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_5 projection=[n_nationkey, n_name, n_regionkey, n_comment]\
\n Aggregate: groupBy=[[FILENAME_PLACEHOLDER_0.ps_partkey]], aggr=[[sum(FILENAME_PLACEHOLDER_0.ps_supplycost * FILENAME_PLACEHOLDER_0.ps_availqty)]]\
\n Projection: FILENAME_PLACEHOLDER_0.ps_partkey, FILENAME_PLACEHOLDER_0.ps_supplycost * CAST(FILENAME_PLACEHOLDER_0.ps_availqty AS Decimal128(19, 0))\
\n Filter: FILENAME_PLACEHOLDER_0.ps_suppkey = FILENAME_PLACEHOLDER_1.s_suppkey AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_2.n_nationkey AND FILENAME_PLACEHOLDER_2.n_name = CAST(Utf8(\"JAPAN\") AS Utf8)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_2 projection=[n_nationkey, n_name, n_regionkey, n_comment]");
Ok(())
}
}
Loading