From c1e57371472ab675ea3ff69fb2f041543030a18e Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 22 May 2022 20:33:52 +0800 Subject: [PATCH 1/4] pipeline expr wrapper --- datafusion/jit/src/api.rs | 15 +++ datafusion/jit/src/ast.rs | 16 ++- datafusion/jit/src/compile.rs | 180 ++++++++++++++++++++++++++++++++++ datafusion/jit/src/jit.rs | 35 +++++++ datafusion/jit/src/lib.rs | 1 + datafusion/row/src/lib.rs | 4 + 6 files changed, 246 insertions(+), 5 deletions(-) create mode 100644 datafusion/jit/src/compile.rs diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs index d95f9ccc7ac5..0bdc8e15172f 100644 --- a/datafusion/jit/src/api.rs +++ b/datafusion/jit/src/api.rs @@ -604,6 +604,21 @@ impl<'a> CodeBlock<'a> { internal_err!("No func with the name {} exist", fn_name) } } + + pub fn deref(&self, ptr: Expr, ty: JITType) -> Result { + // if ptr.get_type() != PTR { + // internal_err!("cannot dereference {}", ptr.get_type()) + // } else { + // Ok(Expr::Deref(Box::new(ptr), ty)) + // } + + Ok(Expr::Deref(Box::new(ptr), ty)) + } + + pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> { + self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr))); + Ok(()) + } } impl Display for GeneratedFunction { diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index fd10a909e783..1bfaa30bddbb 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -32,6 +32,8 @@ pub enum Stmt { Call(String, Vec), /// declare a new variable of type Declare(String, JITType), + /// store value (the first expr) to a pointer (the second expr) + Store(Box, Box), } #[derive(Copy, Clone, Debug, PartialEq)] @@ -54,6 +56,8 @@ pub enum Expr { Binary(BinaryExpr), /// call function expression Call(String, Vec, JITType), + /// dereference a pointer + Deref(Box, JITType), } impl Expr { @@ -63,6 +67,7 @@ impl Expr { Expr::Identifier(_, ty) => *ty, Expr::Binary(bin) => bin.get_type(), Expr::Call(_, _, ty) => *ty, + Expr::Deref(_, ty) => *ty, } } } @@ -272,12 +277,9 @@ pub const R64: JITType = JITType { native: ir::types::R64, code: 0x7f, }; +pub const PTR_SIZE: usize = std::mem::size_of::(); /// The pointer type to use based on our currently target. -pub const PTR: JITType = if std::mem::size_of::() == 8 { - R64 -} else { - R32 -}; +pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 }; impl Stmt { /// print the statement with indentation @@ -323,6 +325,9 @@ impl Stmt { Stmt::Declare(name, ty) => { writeln!(f, "{}let {}: {};", ident_str, name, ty) } + Stmt::Store(value, ptr) => { + writeln!(f, "{}*({}) = {}", ident_str, ptr, value) + } } } } @@ -352,6 +357,7 @@ impl Display for Expr { .join(", ") ) } + Expr::Deref(ptr, _) => write!(f, "*({})", ptr,), } } } diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs new file mode 100644 index 000000000000..bd408170014b --- /dev/null +++ b/datafusion/jit/src/compile.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convert DataFusion logical plan to JIT execution plan. + +use datafusion_common::Result; + +use crate::api::Assembler; +use crate::{ + api::GeneratedFunction, + ast::{Expr as JITExpr, I64, PTR_SIZE}, +}; + +fn build_calc_fn( + assembler: &Assembler, + jit_expr: JITExpr, + input_names: Vec, +) -> Result { + let mut builder = assembler.new_func_builder("calc_fn"); + for input in &input_names { + builder = builder.param(format!("{}_array", input), I64); + } + let mut builder = builder.param("result", I64).param("len", I64); + + let mut fn_body = builder.enter_block(); + + fn_body.declare_as("index", fn_body.lit_i(0))?; + fn_body.while_block( + |cond| cond.lt(cond.id("index")?, cond.id("len")?), + |w| { + w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?; + for input in &input_names { + w.declare_as( + format!("{}_ptr", input), + w.add(w.id(format!("{}_array", input))?, w.id("offset")?)?, + )?; + w.declare_as(input, w.deref(w.id(format!("{}_ptr", input))?, I64)?)?; + } + w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?; + w.declare_as("res", jit_expr.clone())?; + w.store(w.id("res")?, w.id("res_ptr")?)?; + + w.assign("index", w.add(w.id("index")?, w.lit_i(1))?)?; + Ok(()) + }, + )?; + + let gen_func = fn_body.build(); + Ok(gen_func) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{ + array::{Array, PrimitiveArray}, + datatypes::{DataType, Int64Type}, + }; + use datafusion_common::{DFSchema, DataFusionError}; + use datafusion_expr::Expr as DFExpr; + + use crate::ast::BinaryExpr; + + use super::*; + + fn run_df_expr( + assembler: &Assembler, + df_expr: DFExpr, + schema: Arc, + lhs: PrimitiveArray, + rhs: PrimitiveArray, + ) -> Result> { + if lhs.null_count() != 0 || rhs.null_count() != 0 { + return Err(DataFusionError::NotImplemented( + "Computing on nullable array not yet supported".to_string(), + )); + } + if lhs.len() != rhs.len() { + return Err(DataFusionError::NotImplemented( + "Computing on different length arrays not yet supported".to_string(), + )); + } + + let input_fields = schema.field_names(); + let jit_expr: JITExpr = (df_expr, schema).try_into()?; + + let len = lhs.len(); + let result: Vec = Vec::with_capacity(len); + + let gen_func = build_calc_fn(assembler, jit_expr, input_fields)?; + + println!("{}", format!("{}", &gen_func)); + + todo!() + } + + #[test] + fn mvp_driver() { + let array_a: PrimitiveArray = + PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); + let array_b: PrimitiveArray = + PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); + + let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b"); + let schema = Arc::new( + DFSchema::new_with_metadata( + vec![ + datafusion_common::DFField::new( + Some("table1"), + "a", + DataType::Int64, + false, + ), + datafusion_common::DFField::new( + Some("table1"), + "b", + DataType::Int64, + false, + ), + ], + std::collections::HashMap::new(), + ) + .unwrap(), + ); + + let assembler = Assembler::default(); + let result = run_df_expr(&assembler, df_expr, schema, array_a, array_b); + } + + #[test] + fn calc_fn_builder() { + let expr = JITExpr::Binary(BinaryExpr::Add( + Box::new(JITExpr::Identifier("table1.a".to_string(), I64)), + Box::new(JITExpr::Identifier("table1.b".to_string(), I64)), + )); + let fields = vec!["table1.a".to_string(), "table1.b".to_string()]; + + let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () { + let index: i64; + index = 0; + while index < len { + let offset: i64; + offset = index * 8; + let table1.a_ptr: i64; + table1.a_ptr = table1.a_array + offset; + let table1.a: i64; + table1.a = *(table1.a_ptr); + let table1.b_ptr: i64; + table1.b_ptr = table1.b_array + offset; + let table1.b: i64; + table1.b = *(table1.b_ptr); + let res_ptr: i64; + res_ptr = result + offset; + let res: i64; + res = table1.a + table1.b; + *(res_ptr) = res + index = index + 1; + } +}"#; + + let assembler = Assembler::default(); + let gen_func = build_calc_fn(&assembler, expr, fields).unwrap(); + assert_eq!(format!("{}", &gen_func), expected); + } +} diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs index 0460cc805d65..4bc411b1d35a 100644 --- a/datafusion/jit/src/jit.rs +++ b/datafusion/jit/src/jit.rs @@ -263,6 +263,8 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } Stmt::Declare(_, _) => Ok(()), + // Stmt::Store(value, offset, ptr) => self.translate_store(*ptr, offset, *value), + Stmt::Store(value, ptr) => self.translate_store(*ptr, *value), } } @@ -289,6 +291,8 @@ impl<'a> FunctionTranslator<'a> { } Expr::Binary(b) => self.translate_binary_expr(b), Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret), + // Expr::Deref(ptr, offset, ty) => self.translate_deref(*ptr, offset, ty), + Expr::Deref(ptr, ty) => self.translate_deref(*ptr, ty), } } @@ -462,6 +466,37 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } + fn translate_deref(&mut self, ptr: Expr, ty: JITType) -> Result { + let ptr = self.translate_expr(ptr)?; + Ok(self.builder.ins().load(ty.native, MemFlags::new(), ptr, 0)) + } + + fn translate_store(&mut self, ptr: Expr, value: Expr) -> Result<()> { + let ptr = self.translate_expr(ptr)?; + let value = self.translate_expr(value)?; + self.builder.ins().store(MemFlags::new(), value, ptr, 0); + Ok(()) + } + + // fn translate_deref(&mut self, ptr: Expr, offset: Expr, ty: JITType) -> Result { + // let ptr = self.translate_expr(ptr)?; + // let offset = self.translate_expr(offset)?; + // Ok(self + // .builder + // .ins() + // .load(ty.native, MemFlags::new(), ptr, offset)) + // } + + // fn translate_store(&mut self, ptr: Expr, offset: Expr, value: Expr) -> Result<()> { + // let ptr = self.translate_expr(ptr)?; + // let offset = self.translate_expr(offset)?; + // let value = self.translate_expr(value)?; + // self.builder + // .ins() + // .store(MemFlags::new(), value, ptr, offset); + // Ok(()) + // } + fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result { let lhs = self.translate_expr(lhs)?; let rhs = self.translate_expr(rhs)?; diff --git a/datafusion/jit/src/lib.rs b/datafusion/jit/src/lib.rs index dff27da317e4..377d32d8a37d 100644 --- a/datafusion/jit/src/lib.rs +++ b/datafusion/jit/src/lib.rs @@ -19,6 +19,7 @@ pub mod api; pub mod ast; +pub mod compile; pub mod jit; #[cfg(test)] diff --git a/datafusion/row/src/lib.rs b/datafusion/row/src/lib.rs index c05cbcd0ef1c..d77c37063e92 100644 --- a/datafusion/row/src/lib.rs +++ b/datafusion/row/src/lib.rs @@ -30,10 +30,12 @@ //! we append their actual content to the end of the var length region and //! store their offset relative to row base and their length, packed into an 8-byte word. //! +//! ```plaintext //! ┌────────────────┬──────────────────────────┬───────────────────────┐ ┌───────────────────────┬────────────┐ //! │Validity Bitmask│ Fixed Width Field │ Variable Width Field │ ... │ vardata area │ padding │ //! │ (byte aligned) │ (native type width) │(vardata offset + len) │ │ (variable length) │ bytes │ //! └────────────────┴──────────────────────────┴───────────────────────┘ └───────────────────────┴────────────┘ +//! ``` //! //! For example, given the schema (Int8, Utf8, Float32, Utf8) //! @@ -41,10 +43,12 @@ //! //! Requires 32 bytes (31 bytes payload and 1 byte padding to make each tuple 8-bytes aligned): //! +//! ```plaintext //! ┌──────────┬──────────┬──────────────────────┬──────────────┬──────────────────────┬───────────────────────┬──────────┐ //! │0b00001011│ 0x01 │0x00000016 0x00000006│ 0x00000000 │0x0000001C 0x00000003│ FooBarbaz │ 0x00 │ //! └──────────┴──────────┴──────────────────────┴──────────────┴──────────────────────┴───────────────────────┴──────────┘ //! 0 1 2 10 14 22 31 32 +//! ``` //! use arrow::array::{make_builder, ArrayBuilder, ArrayRef}; From 675092b5cae2dad66ed2d0bb81444d93c382447f Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sun, 22 May 2022 21:32:08 +0800 Subject: [PATCH 2/4] clean up Signed-off-by: Ruihang Xia --- datafusion/jit/src/api.rs | 8 ++--- datafusion/jit/src/ast.rs | 2 +- datafusion/jit/src/compile.rs | 58 +++++++++++++++++++---------------- datafusion/jit/src/jit.rs | 21 ------------- 4 files changed, 34 insertions(+), 55 deletions(-) diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs index 0bdc8e15172f..41a39952e769 100644 --- a/datafusion/jit/src/api.rs +++ b/datafusion/jit/src/api.rs @@ -153,6 +153,7 @@ impl FunctionBuilder { } /// Add one more parameter to the function. + #[must_use] pub fn param(mut self, name: impl Into, ty: JITType) -> Self { let name = name.into(); assert!(!self.fields.back().unwrap().contains_key(&name)); @@ -163,6 +164,7 @@ impl FunctionBuilder { /// Set return type for the function. Functions are of `void` type by default if /// you do not set the return type. + #[must_use] pub fn ret(mut self, name: impl Into, ty: JITType) -> Self { let name = name.into(); assert!(!self.fields.back().unwrap().contains_key(&name)); @@ -606,12 +608,6 @@ impl<'a> CodeBlock<'a> { } pub fn deref(&self, ptr: Expr, ty: JITType) -> Result { - // if ptr.get_type() != PTR { - // internal_err!("cannot dereference {}", ptr.get_type()) - // } else { - // Ok(Expr::Deref(Box::new(ptr), ty)) - // } - Ok(Expr::Deref(Box::new(ptr), ty)) } diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index 1bfaa30bddbb..afc7028c2038 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -32,7 +32,7 @@ pub enum Stmt { Call(String, Vec), /// declare a new variable of type Declare(String, JITType), - /// store value (the first expr) to a pointer (the second expr) + /// store value (the first expr) to an address (the second expr) Store(Box, Box), } diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs index bd408170014b..7394f76a32d8 100644 --- a/datafusion/jit/src/compile.rs +++ b/datafusion/jit/src/compile.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Convert DataFusion logical plan to JIT execution plan. +//! Compile DataFusion Expr to JIT'd function. use datafusion_common::Result; @@ -25,7 +25,8 @@ use crate::{ ast::{Expr as JITExpr, I64, PTR_SIZE}, }; -fn build_calc_fn( +/// Wrap JIT Expr to array compute function. +pub fn build_calc_fn( assembler: &Assembler, jit_expr: JITExpr, input_names: Vec, @@ -37,7 +38,6 @@ fn build_calc_fn( let mut builder = builder.param("result", I64).param("len", I64); let mut fn_body = builder.enter_block(); - fn_body.declare_as("index", fn_body.lit_i(0))?; fn_body.while_block( |cond| cond.lt(cond.id("index")?, cond.id("len")?), @@ -65,13 +65,13 @@ fn build_calc_fn( #[cfg(test)] mod test { - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; use arrow::{ array::{Array, PrimitiveArray}, datatypes::{DataType, Int64Type}, }; - use datafusion_common::{DFSchema, DataFusionError}; + use datafusion_common::{DFField, DFSchema, DataFusionError}; use datafusion_expr::Expr as DFExpr; use crate::ast::BinaryExpr; @@ -79,7 +79,6 @@ mod test { use super::*; fn run_df_expr( - assembler: &Assembler, df_expr: DFExpr, schema: Arc, lhs: PrimitiveArray, @@ -96,50 +95,55 @@ mod test { )); } + // translate DF Expr to JIT Expr let input_fields = schema.field_names(); let jit_expr: JITExpr = (df_expr, schema).try_into()?; + // allocate memory for calc result let len = lhs.len(); - let result: Vec = Vec::with_capacity(len); - - let gen_func = build_calc_fn(assembler, jit_expr, input_fields)?; + let result = vec![0i64; len]; - println!("{}", format!("{}", &gen_func)); + // compile and run JIT code + let assembler = Assembler::default(); + let gen_func = build_calc_fn(&assembler, jit_expr, input_fields)?; + let mut jit = assembler.create_jit(); + let code_ptr = jit.compile(gen_func)?; + let code_fn = + unsafe { core::mem::transmute::<_, fn(i64, i64, i64, i64) -> ()>(code_ptr) }; + code_fn( + lhs.values().as_ptr() as i64, + rhs.values().as_ptr() as i64, + result.as_ptr() as i64, + len as i64, + ); - todo!() + let result_array = PrimitiveArray::::from_iter(result); + Ok(result_array) } #[test] - fn mvp_driver() { + fn array_add() { let array_a: PrimitiveArray = PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); let array_b: PrimitiveArray = PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); + let expected = + arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap(); let df_expr = datafusion_expr::col("a") + datafusion_expr::col("b"); let schema = Arc::new( DFSchema::new_with_metadata( vec![ - datafusion_common::DFField::new( - Some("table1"), - "a", - DataType::Int64, - false, - ), - datafusion_common::DFField::new( - Some("table1"), - "b", - DataType::Int64, - false, - ), + DFField::new(Some("table1"), "a", DataType::Int64, false), + DFField::new(Some("table1"), "b", DataType::Int64, false), ], - std::collections::HashMap::new(), + HashMap::new(), ) .unwrap(), ); - let assembler = Assembler::default(); - let result = run_df_expr(&assembler, df_expr, schema, array_a, array_b); + let result = run_df_expr(df_expr, schema, array_a, array_b).unwrap(); + assert_eq!(result, expected); } #[test] diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs index 4bc411b1d35a..87bb020ff2d9 100644 --- a/datafusion/jit/src/jit.rs +++ b/datafusion/jit/src/jit.rs @@ -263,7 +263,6 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } Stmt::Declare(_, _) => Ok(()), - // Stmt::Store(value, offset, ptr) => self.translate_store(*ptr, offset, *value), Stmt::Store(value, ptr) => self.translate_store(*ptr, *value), } } @@ -291,7 +290,6 @@ impl<'a> FunctionTranslator<'a> { } Expr::Binary(b) => self.translate_binary_expr(b), Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret), - // Expr::Deref(ptr, offset, ty) => self.translate_deref(*ptr, offset, ty), Expr::Deref(ptr, ty) => self.translate_deref(*ptr, ty), } } @@ -478,25 +476,6 @@ impl<'a> FunctionTranslator<'a> { Ok(()) } - // fn translate_deref(&mut self, ptr: Expr, offset: Expr, ty: JITType) -> Result { - // let ptr = self.translate_expr(ptr)?; - // let offset = self.translate_expr(offset)?; - // Ok(self - // .builder - // .ins() - // .load(ty.native, MemFlags::new(), ptr, offset)) - // } - - // fn translate_store(&mut self, ptr: Expr, offset: Expr, value: Expr) -> Result<()> { - // let ptr = self.translate_expr(ptr)?; - // let offset = self.translate_expr(offset)?; - // let value = self.translate_expr(value)?; - // self.builder - // .ins() - // .store(MemFlags::new(), value, ptr, offset); - // Ok(()) - // } - fn translate_icmp(&mut self, cmp: IntCC, lhs: Expr, rhs: Expr) -> Result { let lhs = self.translate_expr(lhs)?; let rhs = self.translate_expr(rhs)?; From 5634eb9e9b660cfc8debeb148dd06893eb86b2c2 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 24 May 2022 12:53:24 +0800 Subject: [PATCH 3/4] Apply suggestions from code review Add doc for `deref()` and `store()` Co-authored-by: Andrew Lamb --- datafusion/jit/src/api.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs index 41a39952e769..e2f70c6397e2 100644 --- a/datafusion/jit/src/api.rs +++ b/datafusion/jit/src/api.rs @@ -607,10 +607,12 @@ impl<'a> CodeBlock<'a> { } } + /// Return the value pointed to by the ptr stored in `ptr` pub fn deref(&self, ptr: Expr, ty: JITType) -> Result { Ok(Expr::Deref(Box::new(ptr), ty)) } + /// Store the value in `value` to the address in `ptr` pub fn store(&mut self, value: Expr, ptr: Expr) -> Result<()> { self.stmts.push(Stmt::Store(Box::new(value), Box::new(ptr))); Ok(()) From cb6c79eaa66474d7891d892912d0bf1f26e88eee Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 24 May 2022 14:01:10 +0800 Subject: [PATCH 4/4] CR improvement: doc, naming and hardcode Signed-off-by: Ruihang Xia --- datafusion/jit/src/api.rs | 4 +-- datafusion/jit/src/ast.rs | 42 +++++++++++++++---------- datafusion/jit/src/compile.rs | 59 ++++++++++++++++++++++++----------- datafusion/jit/src/jit.rs | 2 +- 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/datafusion/jit/src/api.rs b/datafusion/jit/src/api.rs index e2f70c6397e2..7020985a733a 100644 --- a/datafusion/jit/src/api.rs +++ b/datafusion/jit/src/api.rs @@ -608,8 +608,8 @@ impl<'a> CodeBlock<'a> { } /// Return the value pointed to by the ptr stored in `ptr` - pub fn deref(&self, ptr: Expr, ty: JITType) -> Result { - Ok(Expr::Deref(Box::new(ptr), ty)) + pub fn load(&self, ptr: Expr, ty: JITType) -> Result { + Ok(Expr::Load(Box::new(ptr), ty)) } /// Store the value in `value` to the address in `ptr` diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index afc7028c2038..55731a650548 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::DataType; use cranelift::codegen::ir; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use std::fmt::{Display, Formatter}; @@ -56,8 +57,8 @@ pub enum Expr { Binary(BinaryExpr), /// call function expression Call(String, Vec, JITType), - /// dereference a pointer - Deref(Box, JITType), + /// Load a value from pointer + Load(Box, JITType), } impl Expr { @@ -67,7 +68,7 @@ impl Expr { Expr::Identifier(_, ty) => *ty, Expr::Binary(bin) => bin.get_type(), Expr::Call(_, _, ty) => *ty, - Expr::Deref(_, ty) => *ty, + Expr::Load(_, ty) => *ty, } } } @@ -179,19 +180,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr { let field = schema.field_from_column(col)?; let ty = field.data_type(); - let jit_type = match ty { - arrow::datatypes::DataType::Int64 => I64, - arrow::datatypes::DataType::Float32 => F32, - arrow::datatypes::DataType::Float64 => F64, - arrow::datatypes::DataType::Boolean => BOOL, - - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Compiling Expression with type {} not yet supported in JIT mode", - ty - ))) - } - }; + let jit_type = JITType::try_from(ty)?; Ok(Expr::Identifier(field.qualified_name(), jit_type)) } @@ -281,6 +270,25 @@ pub const PTR_SIZE: usize = std::mem::size_of::(); /// The pointer type to use based on our currently target. pub const PTR: JITType = if PTR_SIZE == 8 { R64 } else { R32 }; +impl TryFrom<&DataType> for JITType { + type Error = DataFusionError; + + /// Try to convert DataFusion's [DataType] to [JITType] + fn try_from(df_type: &DataType) -> Result { + match df_type { + DataType::Int64 => Ok(I64), + DataType::Float32 => Ok(F32), + DataType::Float64 => Ok(F64), + DataType::Boolean => Ok(BOOL), + + _ => Err(DataFusionError::NotImplemented(format!( + "Compiling Expression with type {} not yet supported in JIT mode", + df_type + ))), + } + } +} + impl Stmt { /// print the statement with indentation pub fn fmt_ident(&self, ident: usize, f: &mut Formatter) -> std::fmt::Result { @@ -357,7 +365,7 @@ impl Display for Expr { .join(", ") ) } - Expr::Deref(ptr, _) => write!(f, "*({})", ptr,), + Expr::Load(ptr, _) => write!(f, "*({})", ptr,), } } } diff --git a/datafusion/jit/src/compile.rs b/datafusion/jit/src/compile.rs index 7394f76a32d8..4e68b52104c0 100644 --- a/datafusion/jit/src/compile.rs +++ b/datafusion/jit/src/compile.rs @@ -20,6 +20,7 @@ use datafusion_common::Result; use crate::api::Assembler; +use crate::ast::{JITType, I32}; use crate::{ api::GeneratedFunction, ast::{Expr as JITExpr, I64, PTR_SIZE}, @@ -29,26 +30,36 @@ use crate::{ pub fn build_calc_fn( assembler: &Assembler, jit_expr: JITExpr, - input_names: Vec, + inputs: Vec<(String, JITType)>, + ret_type: JITType, ) -> Result { + // Alias pointer type. + // The raw pointer `R64` or `R32` is not compatible with integers. + const PTR_TYPE: JITType = if PTR_SIZE == 8 { I64 } else { I32 }; + let mut builder = assembler.new_func_builder("calc_fn"); - for input in &input_names { - builder = builder.param(format!("{}_array", input), I64); + // Declare in-param. + // Each input takes one position, following by a pointer to place result, + // and the last is the length of inputs/output arrays. + for (name, _) in &inputs { + builder = builder.param(format!("{}_array", name), PTR_TYPE); } - let mut builder = builder.param("result", I64).param("len", I64); + let mut builder = builder.param("result", ret_type).param("len", I64); + // Start build function body. + // It's loop that calculates the result one by one. let mut fn_body = builder.enter_block(); fn_body.declare_as("index", fn_body.lit_i(0))?; fn_body.while_block( |cond| cond.lt(cond.id("index")?, cond.id("len")?), |w| { w.declare_as("offset", w.mul(w.id("index")?, w.lit_i(PTR_SIZE as i64))?)?; - for input in &input_names { + for (name, ty) in &inputs { w.declare_as( - format!("{}_ptr", input), - w.add(w.id(format!("{}_array", input))?, w.id("offset")?)?, + format!("{}_ptr", name), + w.add(w.id(format!("{}_array", name))?, w.id("offset")?)?, )?; - w.declare_as(input, w.deref(w.id(format!("{}_ptr", input))?, I64)?)?; + w.declare_as(name, w.load(w.id(format!("{}_ptr", name))?, *ty)?)?; } w.declare_as("res_ptr", w.add(w.id("result")?, w.id("offset")?)?)?; w.declare_as("res", jit_expr.clone())?; @@ -96,7 +107,16 @@ mod test { } // translate DF Expr to JIT Expr - let input_fields = schema.field_names(); + let input_fields = schema + .fields() + .iter() + .map(|field| { + Ok(( + field.qualified_name(), + JITType::try_from(field.data_type())?, + )) + }) + .collect::>>()?; let jit_expr: JITExpr = (df_expr, schema).try_into()?; // allocate memory for calc result @@ -105,15 +125,18 @@ mod test { // compile and run JIT code let assembler = Assembler::default(); - let gen_func = build_calc_fn(&assembler, jit_expr, input_fields)?; + let gen_func = build_calc_fn(&assembler, jit_expr, input_fields, I64)?; let mut jit = assembler.create_jit(); let code_ptr = jit.compile(gen_func)?; - let code_fn = - unsafe { core::mem::transmute::<_, fn(i64, i64, i64, i64) -> ()>(code_ptr) }; + let code_fn = unsafe { + core::mem::transmute::<_, fn(*const i64, *const i64, *const i64, i64) -> ()>( + code_ptr, + ) + }; code_fn( - lhs.values().as_ptr() as i64, - rhs.values().as_ptr() as i64, - result.as_ptr() as i64, + lhs.values().as_ptr(), + rhs.values().as_ptr(), + result.as_ptr(), len as i64, ); @@ -126,7 +149,7 @@ mod test { let array_a: PrimitiveArray = PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); let array_b: PrimitiveArray = - PrimitiveArray::from_iter_values((0..10).map(|x| x + 1)); + PrimitiveArray::from_iter_values((10..20).map(|x| x + 1)); let expected = arrow::compute::kernels::arithmetic::add(&array_a, &array_b).unwrap(); @@ -152,7 +175,7 @@ mod test { Box::new(JITExpr::Identifier("table1.a".to_string(), I64)), Box::new(JITExpr::Identifier("table1.b".to_string(), I64)), )); - let fields = vec!["table1.a".to_string(), "table1.b".to_string()]; + let fields = vec![("table1.a".to_string(), I64), ("table1.b".to_string(), I64)]; let expected = r#"fn calc_fn_0(table1.a_array: i64, table1.b_array: i64, result: i64, len: i64) -> () { let index: i64; @@ -178,7 +201,7 @@ mod test { }"#; let assembler = Assembler::default(); - let gen_func = build_calc_fn(&assembler, expr, fields).unwrap(); + let gen_func = build_calc_fn(&assembler, expr, fields, I64).unwrap(); assert_eq!(format!("{}", &gen_func), expected); } } diff --git a/datafusion/jit/src/jit.rs b/datafusion/jit/src/jit.rs index 87bb020ff2d9..21b0d44fb0b5 100644 --- a/datafusion/jit/src/jit.rs +++ b/datafusion/jit/src/jit.rs @@ -290,7 +290,7 @@ impl<'a> FunctionTranslator<'a> { } Expr::Binary(b) => self.translate_binary_expr(b), Expr::Call(name, args, ret) => self.translate_call_expr(name, args, ret), - Expr::Deref(ptr, ty) => self.translate_deref(*ptr, ty), + Expr::Load(ptr, ty) => self.translate_deref(*ptr, ty), } }