Skip to content

Commit

Permalink
feat(numeric): add PI & cot (#479)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenquan committed Jan 7, 2024
1 parent 0eb9d9e commit 7904866
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
12 changes: 12 additions & 0 deletions arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod smoke_tests;
mod tests {
use arroyo_sql_macro::single_test_codegen;
use arroyo_types;
use std::f64::consts::PI;

// Casts
single_test_codegen!(
Expand Down Expand Up @@ -1308,4 +1309,15 @@ mod tests {
},
None
);

// PI
single_test_codegen!(
"pi",
"pi()",
arroyo_sql::TestStruct {
non_nullable_f64: PI,
..Default::default()
},
PI
);
}
35 changes: 31 additions & 4 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub enum Expression {
Case(CaseExpression),
WindowUDF(WindowType),
Unnest(Box<Expression>, bool),
PiConst(PiConstant),
}

pub struct JoinedPairedStruct {
Expand Down Expand Up @@ -252,6 +253,7 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
Expression::WindowUDF(_window_type) => {
unreachable!("window functions shouldn't be computed off of a value pointer")
}
Expression::PiConst(pi) => pi.generate(input_context),
Expression::Unnest(_, taken) => {
if !taken {
panic!("unnest appeared in a non-projection context");
Expand Down Expand Up @@ -303,6 +305,7 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
Expression::WindowUDF(_window_type) => {
unreachable!("window functions shouldn't be computed off of a value pointer")
}
Expression::PiConst(p) => p.expression_type(input_context),
Expression::Unnest(t, _) => match t.expression_type(input_context) {
TypeDef::DataType(DataType::List(inner), _) => {
TypeDef::DataType(inner.data_type().clone(), false)
Expand Down Expand Up @@ -424,7 +427,8 @@ impl Expression {
| Expression::RustUdf(_)
| Expression::WrapType(_)
| Expression::Unnest(_, _)
| Expression::Case(_) => Ok(None),
| Expression::Case(_)
| Expression::PiConst(_) => Ok(None),
}
}
fn get_duration(expression: &Expr) -> Result<Duration> {
Expand Down Expand Up @@ -553,6 +557,7 @@ impl Expression {
}
},
Expression::WindowUDF(_) => {}
Expression::PiConst(_) => {}
Expression::Unnest(n, _) => {
(&mut *n).traverse_mut(context, f);
}
Expand Down Expand Up @@ -755,7 +760,8 @@ impl<'a> ExpressionContext<'a> {
| BuiltinScalarFunction::Signum
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Log2
| BuiltinScalarFunction::Exp => Ok(NumericExpression::new(
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Cot => Ok(NumericExpression::new(
fun.clone(),
Box::new(arg_expressions.remove(0)),
)?),
Expand Down Expand Up @@ -864,7 +870,7 @@ impl<'a> ExpressionContext<'a> {
BuiltinScalarFunction::Uuid => bail!("UUID unimplemented"),
BuiltinScalarFunction::Cbrt => bail!("cube root unimplemented"),
BuiltinScalarFunction::Degrees => bail!("degrees not implemented yet"),
BuiltinScalarFunction::Pi => bail!("pi not implemented yet"),
BuiltinScalarFunction::Pi => PiConstant::new(),
BuiltinScalarFunction::Radians => bail!("radians not implemented yet"),
BuiltinScalarFunction::Factorial => bail!("factorial not implemented yet"),
BuiltinScalarFunction::Gcd => bail!("gcd not implemented yet"),
Expand All @@ -879,7 +885,6 @@ impl<'a> ExpressionContext<'a> {
),
BuiltinScalarFunction::Decode => bail!("decode not implemented yet"),
BuiltinScalarFunction::Encode => bail!("encode not implemented yet"),
BuiltinScalarFunction::Cot => bail!("cot not implemented yet"),
BuiltinScalarFunction::ArrayAppend
| BuiltinScalarFunction::ArrayConcat
| BuiltinScalarFunction::ArrayDims
Expand Down Expand Up @@ -1984,6 +1989,7 @@ enum NumericFunction {
Trunc,
Log2,
Exp,
Cot,
}

impl NumericFunction {
Expand Down Expand Up @@ -2013,6 +2019,7 @@ impl NumericFunction {
NumericFunction::Round => "round",
NumericFunction::Trunc => "trunc",
NumericFunction::Signum => "signum",
NumericFunction::Cot => "cot",
};
format_ident!("{}", name)
}
Expand Down Expand Up @@ -2047,6 +2054,7 @@ impl TryFrom<BuiltinScalarFunction> for NumericFunction {
BuiltinScalarFunction::Trunc => Ok(Self::Trunc),
BuiltinScalarFunction::Log2 => Ok(Self::Log2),
BuiltinScalarFunction::Exp => Ok(Self::Exp),
BuiltinScalarFunction::Cot => Ok(Self::Cot),
_ => bail!("{:?} is not a single argument numeric function", fun),
}
}
Expand Down Expand Up @@ -4139,3 +4147,22 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for DateTimeFunction
}
}
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd)]
pub struct PiConstant;

impl PiConstant {
pub fn new() -> Result<Expression> {
Ok(Expression::PiConst(PiConstant {}))
}
}

impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for PiConstant {
fn generate(&self, _input_context: &ValuePointerContext) -> syn::Expr {
parse_quote!(std::f64::consts::PI)
}

fn expression_type(&self, _input_context: &ValuePointerContext) -> TypeDef {
TypeDef::DataType(DataType::Float64, false)
}
}

0 comments on commit 7904866

Please sign in to comment.