Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PI & cot #479

Merged
merged 1 commit into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod smoke_tests;
mod tests {
use arroyo_sql_macro::single_test_codegen;
use arroyo_types;
use std::f64::consts::PI;

// Casts
single_test_codegen!(
Expand Down Expand Up @@ -1308,4 +1309,15 @@ mod tests {
},
None
);

// PI
single_test_codegen!(
"pi",
"pi()",
arroyo_sql::TestStruct {
non_nullable_f64: PI,
..Default::default()
},
PI
);
}
35 changes: 31 additions & 4 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub enum Expression {
Case(CaseExpression),
WindowUDF(WindowType),
Unnest(Box<Expression>, bool),
PiConst(PiConstant),
}

pub struct JoinedPairedStruct {
Expand Down Expand Up @@ -252,6 +253,7 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
Expression::WindowUDF(_window_type) => {
unreachable!("window functions shouldn't be computed off of a value pointer")
}
Expression::PiConst(pi) => pi.generate(input_context),
Expression::Unnest(_, taken) => {
if !taken {
panic!("unnest appeared in a non-projection context");
Expand Down Expand Up @@ -303,6 +305,7 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for Expression {
Expression::WindowUDF(_window_type) => {
unreachable!("window functions shouldn't be computed off of a value pointer")
}
Expression::PiConst(p) => p.expression_type(input_context),
Expression::Unnest(t, _) => match t.expression_type(input_context) {
TypeDef::DataType(DataType::List(inner), _) => {
TypeDef::DataType(inner.data_type().clone(), false)
Expand Down Expand Up @@ -424,7 +427,8 @@ impl Expression {
| Expression::RustUdf(_)
| Expression::WrapType(_)
| Expression::Unnest(_, _)
| Expression::Case(_) => Ok(None),
| Expression::Case(_)
| Expression::PiConst(_) => Ok(None),
}
}
fn get_duration(expression: &Expr) -> Result<Duration> {
Expand Down Expand Up @@ -553,6 +557,7 @@ impl Expression {
}
},
Expression::WindowUDF(_) => {}
Expression::PiConst(_) => {}
Expression::Unnest(n, _) => {
(&mut *n).traverse_mut(context, f);
}
Expand Down Expand Up @@ -755,7 +760,8 @@ impl<'a> ExpressionContext<'a> {
| BuiltinScalarFunction::Signum
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Log2
| BuiltinScalarFunction::Exp => Ok(NumericExpression::new(
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Cot => Ok(NumericExpression::new(
fun.clone(),
Box::new(arg_expressions.remove(0)),
)?),
Expand Down Expand Up @@ -864,7 +870,7 @@ impl<'a> ExpressionContext<'a> {
BuiltinScalarFunction::Uuid => bail!("UUID unimplemented"),
BuiltinScalarFunction::Cbrt => bail!("cube root unimplemented"),
BuiltinScalarFunction::Degrees => bail!("degrees not implemented yet"),
BuiltinScalarFunction::Pi => bail!("pi not implemented yet"),
BuiltinScalarFunction::Pi => PiConstant::new(),
BuiltinScalarFunction::Radians => bail!("radians not implemented yet"),
BuiltinScalarFunction::Factorial => bail!("factorial not implemented yet"),
BuiltinScalarFunction::Gcd => bail!("gcd not implemented yet"),
Expand All @@ -879,7 +885,6 @@ impl<'a> ExpressionContext<'a> {
),
BuiltinScalarFunction::Decode => bail!("decode not implemented yet"),
BuiltinScalarFunction::Encode => bail!("encode not implemented yet"),
BuiltinScalarFunction::Cot => bail!("cot not implemented yet"),
BuiltinScalarFunction::ArrayAppend
| BuiltinScalarFunction::ArrayConcat
| BuiltinScalarFunction::ArrayDims
Expand Down Expand Up @@ -1984,6 +1989,7 @@ enum NumericFunction {
Trunc,
Log2,
Exp,
Cot,
}

impl NumericFunction {
Expand Down Expand Up @@ -2013,6 +2019,7 @@ impl NumericFunction {
NumericFunction::Round => "round",
NumericFunction::Trunc => "trunc",
NumericFunction::Signum => "signum",
NumericFunction::Cot => "cot",
};
format_ident!("{}", name)
}
Expand Down Expand Up @@ -2047,6 +2054,7 @@ impl TryFrom<BuiltinScalarFunction> for NumericFunction {
BuiltinScalarFunction::Trunc => Ok(Self::Trunc),
BuiltinScalarFunction::Log2 => Ok(Self::Log2),
BuiltinScalarFunction::Exp => Ok(Self::Exp),
BuiltinScalarFunction::Cot => Ok(Self::Cot),
_ => bail!("{:?} is not a single argument numeric function", fun),
}
}
Expand Down Expand Up @@ -4139,3 +4147,22 @@ impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for DateTimeFunction
}
}
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd)]
pub struct PiConstant;

impl PiConstant {
pub fn new() -> Result<Expression> {
Ok(Expression::PiConst(PiConstant {}))
}
}

impl CodeGenerator<ValuePointerContext, TypeDef, syn::Expr> for PiConstant {
fn generate(&self, _input_context: &ValuePointerContext) -> syn::Expr {
parse_quote!(std::f64::consts::PI)
}

fn expression_type(&self, _input_context: &ValuePointerContext) -> TypeDef {
TypeDef::DataType(DataType::Float64, false)
}
}