From 670b9409b9d8d36d3a3f5a034c070de4bca2339c Mon Sep 17 00:00:00 2001 From: stringhandler Date: Wed, 11 Mar 2026 15:31:57 +0200 Subject: [PATCH] feat: add infix operators and warnings --- .gitignore | 2 + examples/array_fold_2n.simf | 5 +- examples/infix_operators.simf | 86 +++ src/ast.rs | 1174 +++++++++++++++++++++++++++------ src/compile/mod.rs | 83 +++ src/error.rs | 53 +- src/lexer.rs | 266 +++++++- src/lib.rs | 325 ++++++++- src/main.rs | 57 +- src/parse.rs | 242 ++++++- src/tracker.rs | 4 +- src/unstable_flags.rs | 122 ++++ src/value.rs | 11 +- src/witness.rs | 9 +- 14 files changed, 2196 insertions(+), 243 deletions(-) create mode 100644 examples/infix_operators.simf create mode 100644 src/unstable_flags.rs diff --git a/.gitignore b/.gitignore index 4b40e9ab..f5961c43 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ node_modules/ # macOS .DS_Store + +.claude \ No newline at end of file diff --git a/examples/array_fold_2n.simf b/examples/array_fold_2n.simf index 9d86325c..0bba8da5 100644 --- a/examples/array_fold_2n.simf +++ b/examples/array_fold_2n.simf @@ -1,8 +1,9 @@ // From https://github.com/BlockstreamResearch/SimplicityHL/issues/153 fn sum(elt: u32, acc: u32) -> u32 { - let (_, acc): (bool, u32) = jet::add_32(elt, acc); - acc + // let (_, acc): (bool, u32) = jet::add_32(elt, acc); + elt + acc + // acc } fn main() { diff --git a/examples/infix_operators.simf b/examples/infix_operators.simf new file mode 100644 index 00000000..d655b62e --- /dev/null +++ b/examples/infix_operators.simf @@ -0,0 +1,86 @@ +/* + * INFIX OPERATORS + * + * Demonstrates all infix operators: +, -, *, /, %, &, |, ^, &&, ||, ==, !=, <, <=, >, >= + * + * Addition and subtraction panic at runtime on overflow/underflow. + * Multiplication returns a type twice the width of the operands (no overflow). + * Division and modulo return the same type as the operands. + * Bitwise operators return the same type as the operands. + * Logical operators short-circuit and return bool. + * Comparison operators return bool. + * + * Arithmetic operators require: simc -Z infix_arithmetic_operators + */ +fn main() { + let a: u8 = 20; + let b: u8 = 6; + + // run `simc` with `-Z infix_arithmetic_operators` to allow these. + // Addition: u8 + u8 → u8, panics on overflow + // let sum: u8 = a + b; + // assert!(jet::eq_8(sum, 26)); + + // Subtraction: u8 - u8 → u8, panics on underflow + // let diff: u8 = a - b; + // assert!(jet::eq_8(diff, 14)); + + // Multiplication: u8 * u8 → u16 (full precision, no overflow possible) + // let product: u16 = a * b; + // assert!(jet::eq_16(product, 120)); + + // Division: u8 / u8 → u8, panics if divisor is zero + // let quotient: u8 = a / b; + // assert!(jet::eq_8(quotient, 3)); + + // Modulo: u8 % u8 → u8, panics if divisor is zero + // let remainder: u8 = a % b; + // assert!(jet::eq_8(remainder, 2)); + + // Bitwise AND: u8 & u8 → u8 + let and: u8 = a & b; + assert!(jet::eq_8(and, 4)); // 0b00010100 & 0b00000110 = 0b00000100 + + // Bitwise OR: u8 | u8 → u8 + let or: u8 = a | b; + assert!(jet::eq_8(or, 22)); // 0b00010100 | 0b00000110 = 0b00010110 + + // Bitwise XOR: u8 ^ u8 → u8 + let xor: u8 = a ^ b; + assert!(jet::eq_8(xor, 18)); // 0b00010100 ^ 0b00000110 = 0b00010010 + + // Logical AND: bool && bool → bool, short-circuits (rhs not evaluated if lhs is false) + let a_eq_a: bool = a == a; // true + let b_eq_b: bool = b == b; // true + let a_eq_b: bool = a == b; // false + assert!(a_eq_a && b_eq_b); // true && true = true + assert!(true && true); + // false && _ short-circuits to false: + let and_false: bool = a_eq_b && b_eq_b; + assert!(match and_false { false => true, true => false, }); + + // Logical OR: bool || bool → bool, short-circuits (rhs not evaluated if lhs is true) + assert!(a_eq_a || a_eq_b); // true || _ = true (rhs not evaluated) + assert!(false || true); + // false || false = false: + let or_false: bool = a_eq_b || a_eq_b; + assert!(match or_false { false => true, true => false, }); + + // Equality: u8 == u8 → bool + assert!(a == a); // 20 == 20 is true + assert!(a != b); // 20 != 6 is true + + // Less than: u8 < u8 → bool + assert!(b < a); // 6 < 20 + + // Greater than: u8 > u8 → bool + assert!(a > b); // 20 > 6 + + // Less or equal: u8 <= u8 → bool + assert!(b <= a); // 6 <= 20 + assert!(b <= b); // 6 <= 6 + + // Greater or equal: u8 >= u8 → bool + assert!(a >= b); // 20 >= 6 + assert!(a >= a); // 20 >= 20 +} diff --git a/src/ast.rs b/src/ast.rs index 3c59f2f7..bc4ef200 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -1,5 +1,6 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::fmt; use std::num::NonZeroUsize; use std::str::FromStr; use std::sync::Arc; @@ -239,6 +240,16 @@ pub enum SingleExpressionInner { Call(Call), /// Match expression. Match(Match), + /// Infix operator expression: calls an arithmetic or comparison jet. + BinaryOp { + jet: Elements, + lhs: Arc, + rhs: Arc, + assert_no_carry: bool, + swap_args: bool, + negate_result: bool, + check_nonzero_divisor: bool, + }, } /// Call of a user-defined or of a builtin function. @@ -502,6 +513,9 @@ impl TreeLike for ExprTree<'_> { } S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::BinaryOp { lhs, rhs, .. } => { + Tree::Nary(Arc::new([Self::Expression(lhs), Self::Expression(rhs)])) + } }, Self::Call(call) => Tree::Nary(call.args().iter().map(Self::Expression).collect()), Self::Match(match_) => Tree::Nary(Arc::new([ @@ -704,6 +718,60 @@ impl Scope { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum WarningName { + ModuleItemIgnored, + ArithmeticOperationCouldOverflow, + DivisionCouldPanicOnZero, +} + +impl fmt::Display for WarningName { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + WarningName::ModuleItemIgnored => write!(f, "ModuleItem was ignored"), + WarningName::ArithmeticOperationCouldOverflow => write!(f, "This operator panics if the result overflows. To handle overflow, use the jet directly and destructure the (bool, uN) result."), + WarningName::DivisionCouldPanicOnZero => write!(f, "This operator panics if the divisor is zero. To handle division by zero, use a jet and check the divisor before using it."), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Warning { + /// Canonical name used for allowing and denying specific warnings. + pub canonical_name: WarningName, + /// Span in which this warning occured. + pub span: Span, +} + +impl Warning { + fn module_item_ignored>(span: S) -> Self { + Warning { + canonical_name: WarningName::ModuleItemIgnored, + span: span.into(), + } + } + + fn arthimetic_operation_could_overflow>(span: S) -> Self { + Warning { + canonical_name: WarningName::ArithmeticOperationCouldOverflow, + span: span.into(), + } + } + + fn division_could_panic_on_zero>(span: S) -> Self { + Warning { + canonical_name: WarningName::DivisionCouldPanicOnZero, + span: span.into(), + } + } +} + +impl From for RichError { + fn from(value: Warning) -> Self { + RichError::new(Error::DeniedWarning(value.canonical_name), value.span) + } +} + /// Part of the abstract syntax tree that can be generated from a precursor in the parse tree. trait AbstractSyntaxTree: Sized { /// Component of the parse tree. @@ -714,41 +782,61 @@ trait AbstractSyntaxTree: Sized { /// /// Check if the analyzed expression is of the expected type. /// Statements return no values so their expected type is always unit. - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result; + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError>; } impl Program { - pub fn analyze(from: &parse::Program) -> Result { + pub fn analyze(from: &parse::Program) -> Result<(Self, Vec), RichError> { let unit = ResolvedType::unit(); let mut scope = Scope::default(); - let items = from + let items: Vec<(Item, Vec)> = from .items() .iter() .map(|s| Item::analyze(s, &unit, &mut scope)) - .collect::, RichError>>()?; + .collect::>()?; debug_assert!(scope.is_topmost()); let (parameters, witness_types, call_tracker) = scope.destruct(); - let mut iter = items.into_iter().filter_map(|item| match item { - Item::Function(Function::Main(expr)) => Some(expr), - _ => None, - }); - let main = iter.next().ok_or(Error::MainRequired).with_span(from)?; - if iter.next().is_some() { - return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from); + let mut all_warnings: Vec = vec![]; + let mut main_expr = None; + let mut main_seen = false; + for (item, mut warnings) in items { + all_warnings.append(&mut warnings); + match item { + Item::Function(Function::Main(expr)) => { + if main_seen { + return Err(Error::FunctionRedefined(FunctionName::main())).with_span(from); + } + main_expr = Some(expr); + main_seen = true; + } + _ => {} + } } - Ok(Self { - main, - parameters, - witness_types, - call_tracker: Arc::new(call_tracker), - }) + let main = main_expr.ok_or(Error::MainRequired).with_span(from)?; + Ok(( + Self { + main, + parameters, + witness_types, + call_tracker: Arc::new(call_tracker), + }, + all_warnings, + )) } } impl AbstractSyntaxTree for Item { type From = parse::Item; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Items cannot return anything"); assert!(scope.is_topmost(), "Items live in the topmost scope only"); @@ -757,12 +845,11 @@ impl AbstractSyntaxTree for Item { scope .insert_alias(alias.name().clone(), alias.ty().clone()) .with_span(alias)?; - Ok(Self::TypeAlias) + Ok((Self::TypeAlias, vec![])) } - parse::Item::Function(function) => { - Function::analyze(function, ty, scope).map(Self::Function) - } - parse::Item::Module => Ok(Self::Module), + parse::Item::Function(function) => Function::analyze(function, ty, scope) + .map(|(f, warnings)| (Self::Function(f), warnings)), + parse::Item::Module => Ok((Self::Module, vec![])), } } } @@ -770,7 +857,11 @@ impl AbstractSyntaxTree for Item { impl AbstractSyntaxTree for Function { type From = parse::Function; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Function definitions cannot return anything"); assert!(scope.is_topmost(), "Items live in the topmost scope only"); @@ -795,15 +886,18 @@ impl AbstractSyntaxTree for Function { for param in params.iter() { scope.insert_variable(param.identifier().clone(), param.ty().clone()); } - let body = Expression::analyze(from.body(), &ret, scope).map(Arc::new)?; + let (body, warnings) = Expression::analyze(from.body(), &ret, scope)?; scope.pop_scope(); debug_assert!(scope.is_topmost()); - let function = CustomFunction { params, body }; + let function = CustomFunction { + params, + body: Arc::new(body), + }; scope .insert_function(from.name().clone(), function) .with_span(from)?; - return Ok(Self::Custom); + return Ok((Self::Custom, warnings)); } if !from.params().is_empty() { @@ -817,24 +911,26 @@ impl AbstractSyntaxTree for Function { } scope.push_main_scope(); - let body = Expression::analyze(from.body(), ty, scope)?; + let (body, warnings) = Expression::analyze(from.body(), ty, scope)?; scope.pop_main_scope(); - Ok(Self::Main(body)) + Ok((Self::Main(body), warnings)) } } impl AbstractSyntaxTree for Statement { type From = parse::Statement; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Statements cannot return anything"); match from { - parse::Statement::Assignment(assignment) => { - Assignment::analyze(assignment, ty, scope).map(Self::Assignment) - } - parse::Statement::Expression(expression) => { - Expression::analyze(expression, ty, scope).map(Self::Expression) - } + parse::Statement::Assignment(assignment) => Assignment::analyze(assignment, ty, scope) + .map(|(a, warnings)| (Self::Assignment(a), warnings)), + parse::Statement::Expression(expression) => Expression::analyze(expression, ty, scope) + .map(|(e, warnings)| (Self::Expression(e), warnings)), } } } @@ -842,24 +938,31 @@ impl AbstractSyntaxTree for Statement { impl AbstractSyntaxTree for Assignment { type From = parse::Assignment; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Assignments cannot return anything"); // The assignment is a statement that returns nothing. // // However, the expression evaluated in the assignment does have a type, // namely the type specified in the assignment. let ty_expr = scope.resolve(from.ty()).with_span(from)?; - let expression = Expression::analyze(from.expression(), &ty_expr, scope)?; + let (expression, warnings) = Expression::analyze(from.expression(), &ty_expr, scope)?; let typed_variables = from.pattern().is_of_type(&ty_expr).with_span(from)?; for (identifier, ty) in typed_variables { scope.insert_variable(identifier, ty); } - Ok(Self { - pattern: from.pattern().clone(), - expression, - span: *from.as_ref(), - }) + Ok(( + Self { + pattern: from.pattern().clone(), + expression, + span: *from.as_ref(), + }, + warnings, + )) } } @@ -872,7 +975,10 @@ impl Expression { /// /// The returned expression might not be evaluable at compile time. /// The details depend on the current state of the SimplicityHL compiler. - pub fn analyze_const(from: &parse::Expression, ty: &ResolvedType) -> Result { + pub fn analyze_const( + from: &parse::Expression, + ty: &ResolvedType, + ) -> Result<(Self, Vec), RichError> { let mut empty_scope = Scope::default(); Self::analyze(from, ty, &mut empty_scope) } @@ -881,27 +987,33 @@ impl Expression { impl AbstractSyntaxTree for Expression { type From = parse::Expression; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { match from.inner() { parse::ExpressionInner::Single(single) => { - let ast_single = SingleExpression::analyze(single, ty, scope)?; - Ok(Self { - ty: ty.clone(), - inner: ExpressionInner::Single(ast_single), - span: *from.as_ref(), - }) + let (ast_single, warnings) = SingleExpression::analyze(single, ty, scope)?; + Ok(( + Self { + ty: ty.clone(), + inner: ExpressionInner::Single(ast_single), + span: *from.as_ref(), + }, + warnings, + )) } parse::ExpressionInner::Block(statements, expression) => { scope.push_scope(); - let ast_statements = statements + let ast_statements: Vec<(Statement, Vec)> = statements .iter() .map(|s| Statement::analyze(s, &ResolvedType::unit(), scope)) - .collect::, RichError>>()?; - let ast_expression = match expression { + .collect::>()?; + let (ast_expression, mut ast_warnings) = match expression { Some(expression) => Expression::analyze(expression, ty, scope) - .map(Arc::new) - .map(Some), - None if ty.is_unit() => Ok(None), + .map(|(e, warnings)| (Some(Arc::new(e)), warnings)), + None if ty.is_unit() => Ok((None, vec![])), None => Err(Error::ExpressionTypeMismatch( ty.clone(), ResolvedType::unit(), @@ -910,21 +1022,227 @@ impl AbstractSyntaxTree for Expression { }?; scope.pop_scope(); - Ok(Self { - ty: ty.clone(), - inner: ExpressionInner::Block(ast_statements, ast_expression), - span: *from.as_ref(), - }) + let mut all_warnings = vec![]; + let mut all_statements = vec![]; + for (statement, mut warnings) in ast_statements { + all_warnings.append(&mut warnings); + all_statements.push(statement); + } + all_warnings.append(&mut ast_warnings); + + Ok(( + Self { + ty: ty.clone(), + inner: ExpressionInner::Block(all_statements.into(), ast_expression), + span: *from.as_ref(), + }, + all_warnings, + )) } } } } +/// Tries to infer the type of a parse expression using the current scope, +/// without performing a full analysis. Returns `None` if the type cannot be determined. +fn peek_expression_type(expr: &parse::Expression, scope: &Scope) -> Option { + match expr.inner() { + parse::ExpressionInner::Single(single) => peek_single_expression_type(single, scope), + parse::ExpressionInner::Block(_, Some(end_expr)) => peek_expression_type(end_expr, scope), + parse::ExpressionInner::Block(_, None) => Some(ResolvedType::unit()), + } +} + +fn peek_single_expression_type( + expr: &parse::SingleExpression, + scope: &Scope, +) -> Option { + match expr.inner() { + parse::SingleExpressionInner::Variable(id) => scope.get_variable(id).cloned(), + parse::SingleExpressionInner::Expression(inner) => peek_expression_type(inner, scope), + _ => None, + } +} + +/// Maps a comparison infix operator and operand type to: +/// - the Simplicity jet +/// - whether arguments should be swapped (for `>` and `>=`) Note: there are no GreaterThan jets, so the +/// args must be swapped. +/// - whether the result should be negated (for `!=`) +/// +/// Returns `Err` if there is no jet for the given combination. +fn determine_comparison_op_jet( + op: &parse::InfixOp, + arg_ty: &ResolvedType, +) -> Result<(Elements, bool, bool), Error> { + use parse::InfixOp::*; + use UIntType::*; + + let uint_ty = arg_ty + .as_integer() + .ok_or_else(|| Error::ExpressionUnexpectedType(arg_ty.clone()))?; + + match op { + Eq | Ne => { + let jet = match uint_ty { + U1 => Elements::Eq1, + U8 => Elements::Eq8, + U16 => Elements::Eq16, + U32 => Elements::Eq32, + U64 => Elements::Eq64, + _ => return Err(Error::ExpressionUnexpectedType(arg_ty.clone())), + }; + Ok((jet, false, matches!(op, Ne))) + } + Lt | Gt => { + let jet = match uint_ty { + U8 => Elements::Lt8, + U16 => Elements::Lt16, + U32 => Elements::Lt32, + U64 => Elements::Lt64, + _ => return Err(Error::ExpressionUnexpectedType(arg_ty.clone())), + }; + // Gt(a, b) = Lt(b, a) → swap_args=true for Gt + Ok((jet, matches!(op, Gt), false)) + } + Le | Ge => { + let jet = match uint_ty { + U8 => Elements::Le8, + U16 => Elements::Le16, + U32 => Elements::Le32, + U64 => Elements::Le64, + _ => return Err(Error::ExpressionUnexpectedType(arg_ty.clone())), + }; + // Ge(a, b) = Le(b, a) → swap_args=true for Ge + Ok((jet, matches!(op, Ge), false)) + } + _ => Err(Error::ExpressionUnexpectedType(arg_ty.clone())), + } +} + +/// Maps an infix operator and the expected output type to the corresponding Simplicity jet, +/// the expected input argument type, and whether the jet's raw output is `(bool, uN)` +/// and requires a carry/overflow assertion. +/// +/// Returns `Err` if there is no jet for the given operator + output type combination. +/// +/// | Operator | Output type | Jet | Input type | Assert no carry | +/// |----------|-------------|-------------|------------|-----------------| +/// | `+` | `uN` | `AddN` | `uN` | yes | +/// | `-` | `uN` | `SubtractN` | `uN` | yes | +/// | `*` | `u(2N)` | `MultiplyN` | `uN` | no | +/// | `/` | `uN` | `DivideN` | `uN` | no | +/// | `%` | `uN` | `ModuloN` | `uN` | no | +/// Maps a bitwise infix operator and operand type to the corresponding Simplicity jet. +/// +/// | Operator | Output type | Jet | Input type | +/// |----------|-------------|--------|------------| +/// | `&` | `uN` | `AndN` | `uN` | +/// | `\|` | `uN` | `OrN` | `uN` | +/// | `^` | `uN` | `XorN` | `uN` | +fn determine_infix_bitwise_op_jet( + op: &parse::InfixOp, + ty: &ResolvedType, +) -> Result { + use parse::InfixOp::*; + use UIntType::*; + + let uint_ty = ty + .as_integer() + .ok_or_else(|| Error::ExpressionUnexpectedType(ty.clone()))?; + + match (op, uint_ty) { + (BitAnd, U1) => Ok(Elements::And1), + (BitAnd, U8) => Ok(Elements::And8), + (BitAnd, U16) => Ok(Elements::And16), + (BitAnd, U32) => Ok(Elements::And32), + (BitAnd, U64) => Ok(Elements::And64), + (BitOr, U1) => Ok(Elements::Or1), + (BitOr, U8) => Ok(Elements::Or8), + (BitOr, U16) => Ok(Elements::Or16), + (BitOr, U32) => Ok(Elements::Or32), + (BitOr, U64) => Ok(Elements::Or64), + (BitXor, U1) => Ok(Elements::Xor1), + (BitXor, U8) => Ok(Elements::Xor8), + (BitXor, U16) => Ok(Elements::Xor16), + (BitXor, U32) => Ok(Elements::Xor32), + (BitXor, U64) => Ok(Elements::Xor64), + _ => Err(Error::ExpressionUnexpectedType(ty.clone())), + } +} + +fn determine_infix_arith_op_jet( + op: &parse::InfixOp, + ty: &ResolvedType, +) -> Result<(Elements, ResolvedType, bool), Error> { + use parse::InfixOp::*; + use UIntType::*; + + match op { + // Add and Sub: jet produces (bool, uN); assert no carry, then return uN + Add | Sub => { + let uint_ty = ty + .as_integer() + .ok_or_else(|| Error::ExpressionUnexpectedType(ty.clone()))?; + + let jet = match (op, uint_ty) { + (Add, U8) => Elements::Add8, + (Add, U16) => Elements::Add16, + (Add, U32) => Elements::Add32, + (Add, U64) => Elements::Add64, + (Sub, U8) => Elements::Subtract8, + (Sub, U16) => Elements::Subtract16, + (Sub, U32) => Elements::Subtract32, + (Sub, U64) => Elements::Subtract64, + _ => return Err(Error::ExpressionUnexpectedType(ty.clone())), + }; + + Ok((jet, ResolvedType::from(uint_ty), true)) + } + // Mul: jet takes (uN, uN) and produces u(2N), no overflow possible + Mul => { + let (jet, arg_uint) = match ty.as_integer() { + Some(U16) => (Elements::Multiply8, U8), + Some(U32) => (Elements::Multiply16, U16), + Some(U64) => (Elements::Multiply32, U32), + Some(U128) => (Elements::Multiply64, U64), + _ => return Err(Error::ExpressionUnexpectedType(ty.clone())), + }; + Ok((jet, ResolvedType::from(arg_uint), false)) + } + // Div and Rem: jet takes (uN, uN) and produces uN, no overflow flag + Div | Rem => { + let uint_ty = ty + .as_integer() + .ok_or_else(|| Error::ExpressionUnexpectedType(ty.clone()))?; + + let jet = match (op, uint_ty) { + (Div, U8) => Elements::Divide8, + (Div, U16) => Elements::Divide16, + (Div, U32) => Elements::Divide32, + (Div, U64) => Elements::Divide64, + (Rem, U8) => Elements::Modulo8, + (Rem, U16) => Elements::Modulo16, + (Rem, U32) => Elements::Modulo32, + (Rem, U64) => Elements::Modulo64, + _ => return Err(Error::ExpressionUnexpectedType(ty.clone())), + }; + Ok((jet, ResolvedType::from(uint_ty), false)) + } + // Comparison ops are handled by comparison_op_jet, not here + _ => Err(Error::ExpressionUnexpectedType(ty.clone())), + } +} + impl AbstractSyntaxTree for SingleExpression { type From = parse::SingleExpression; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { - let inner = match from.inner() { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { + let (inner, warnings) = match from.inner() { parse::SingleExpressionInner::Boolean(bit) => { if !ty.is_boolean() { return Err(Error::ExpressionTypeMismatch( @@ -933,17 +1251,20 @@ impl AbstractSyntaxTree for SingleExpression { )) .with_span(from); } - SingleExpressionInner::Constant(Value::from(*bit)) + (SingleExpressionInner::Constant(Value::from(*bit)), vec![]) } parse::SingleExpressionInner::Decimal(decimal) => { let ty = ty .as_integer() .ok_or(Error::ExpressionUnexpectedType(ty.clone())) .with_span(from)?; - UIntValue::parse_decimal(decimal, ty) - .with_span(from) - .map(Value::from) - .map(SingleExpressionInner::Constant)? + ( + UIntValue::parse_decimal(decimal, ty) + .with_span(from) + .map(Value::from) + .map(SingleExpressionInner::Constant)?, + vec![], + ) } parse::SingleExpressionInner::Binary(bits) => { let ty = ty @@ -951,23 +1272,26 @@ impl AbstractSyntaxTree for SingleExpression { .ok_or(Error::ExpressionUnexpectedType(ty.clone())) .with_span(from)?; let value = UIntValue::parse_binary(bits, ty).with_span(from)?; - SingleExpressionInner::Constant(Value::from(value)) + (SingleExpressionInner::Constant(Value::from(value)), vec![]) } parse::SingleExpressionInner::Hexadecimal(bytes) => { let value = Value::parse_hexadecimal(bytes, ty).with_span(from)?; - SingleExpressionInner::Constant(value) + (SingleExpressionInner::Constant(value), vec![]) } parse::SingleExpressionInner::Witness(name) => { scope .insert_witness(name.clone(), ty.clone()) .with_span(from)?; - SingleExpressionInner::Witness(name.clone()) + (SingleExpressionInner::Witness(name.clone()), vec![]) } parse::SingleExpressionInner::Parameter(name) => { scope .insert_parameter(name.shallow_clone(), ty.clone()) .with_span(from)?; - SingleExpressionInner::Parameter(name.shallow_clone()) + ( + SingleExpressionInner::Parameter(name.shallow_clone()), + vec![], + ) } parse::SingleExpressionInner::Variable(identifier) => { let bound_ty = scope @@ -979,12 +1303,12 @@ impl AbstractSyntaxTree for SingleExpression { .with_span(from); } scope.insert_variable(identifier.clone(), ty.clone()); - SingleExpressionInner::Variable(identifier.clone()) + (SingleExpressionInner::Variable(identifier.clone()), vec![]) } parse::SingleExpressionInner::Expression(parse) => { - Expression::analyze(parse, ty, scope) - .map(Arc::new) - .map(SingleExpressionInner::Expression)? + Expression::analyze(parse, ty, scope).map(|(e, warnings)| { + (SingleExpressionInner::Expression(Arc::new(e)), warnings) + })? } parse::SingleExpressionInner::Tuple(tuple) => { let types = ty @@ -994,12 +1318,23 @@ impl AbstractSyntaxTree for SingleExpression { if tuple.len() != types.len() { return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from); } - tuple + let inner: Vec<(Expression, Vec)> = tuple .iter() .zip(types.iter()) .map(|(el_parse, el_ty)| Expression::analyze(el_parse, el_ty, scope)) - .collect::, RichError>>() - .map(SingleExpressionInner::Tuple)? + .collect::>()?; + + let mut all_warnings = vec![]; + let mut all_expressions: Vec = vec![]; + for i in inner { + all_warnings.extend_from_slice(&i.1); + all_expressions.push(i.0); + } + + ( + SingleExpressionInner::Tuple(all_expressions.into()), + all_warnings, + ) } parse::SingleExpressionInner::Array(array) => { let (el_ty, size) = ty @@ -1009,11 +1344,22 @@ impl AbstractSyntaxTree for SingleExpression { if array.len() != size { return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from); } - array + let inner: Vec<(Expression, Vec)> = array .iter() .map(|el_parse| Expression::analyze(el_parse, el_ty, scope)) - .collect::, RichError>>() - .map(SingleExpressionInner::Array)? + .collect::>()?; + + let mut all_warnings = vec![]; + let mut all_expressions: Vec = vec![]; + for i in inner { + all_warnings.extend_from_slice(&i.1); + all_expressions.push(i.0); + } + + ( + SingleExpressionInner::Array(all_expressions.into()), + all_warnings, + ) } parse::SingleExpressionInner::List(list) => { let (el_ty, bound) = ty @@ -1023,10 +1369,22 @@ impl AbstractSyntaxTree for SingleExpression { if bound.get() <= list.len() { return Err(Error::ExpressionUnexpectedType(ty.clone())).with_span(from); } - list.iter() - .map(|e| Expression::analyze(e, el_ty, scope)) - .collect::, RichError>>() - .map(SingleExpressionInner::List)? + let inner: Vec<(Expression, Vec)> = list + .iter() + .map(|el_parse| Expression::analyze(el_parse, el_ty, scope)) + .collect::>()?; + + let mut all_warnings = vec![]; + let mut all_expressions: Vec = vec![]; + for i in inner { + all_warnings.extend_from_slice(&i.1); + all_expressions.push(i.0); + } + + ( + SingleExpressionInner::List(all_expressions.into()), + all_warnings, + ) } parse::SingleExpressionInner::Either(either) => { let (ty_l, ty_r) = ty @@ -1035,13 +1393,11 @@ impl AbstractSyntaxTree for SingleExpression { .with_span(from)?; match either { Either::Left(parse_l) => Expression::analyze(parse_l, ty_l, scope) - .map(Arc::new) - .map(Either::Left), + .map(|(l, warnings)| (Either::Left(Arc::new(l)), warnings)), Either::Right(parse_r) => Expression::analyze(parse_r, ty_r, scope) - .map(Arc::new) - .map(Either::Right), + .map(|(r, warnings)| (Either::Right(Arc::new(r)), warnings)), } - .map(SingleExpressionInner::Either)? + .map(|(e, warnings)| (SingleExpressionInner::Either(e), warnings))? } parse::SingleExpressionInner::Option(maybe_parse) => { let ty = ty @@ -1049,33 +1405,180 @@ impl AbstractSyntaxTree for SingleExpression { .ok_or(Error::ExpressionUnexpectedType(ty.clone())) .with_span(from)?; match maybe_parse { - Some(parse) => { - Some(Expression::analyze(parse, ty, scope).map(Arc::new)).transpose() - } - None => Ok(None), + Some(parse) => Expression::analyze(parse, ty, scope) + .map(|(e, warnings)| (Some(Arc::new(e)), warnings)), + None => Ok((None, vec![])), } - .map(SingleExpressionInner::Option)? - } - parse::SingleExpressionInner::Call(call) => { - Call::analyze(call, ty, scope).map(SingleExpressionInner::Call)? + .map(|(o, warnings)| (SingleExpressionInner::Option(o), warnings))? } - parse::SingleExpressionInner::Match(match_) => { - Match::analyze(match_, ty, scope).map(SingleExpressionInner::Match)? + parse::SingleExpressionInner::Call(call) => Call::analyze(call, ty, scope) + .map(|(c, warnings)| (SingleExpressionInner::Call(c), warnings))?, + parse::SingleExpressionInner::Match(match_) => Match::analyze(match_, ty, scope) + .map(|(m, warnings)| (SingleExpressionInner::Match(m), warnings))?, + parse::SingleExpressionInner::BinaryOp(binary) => { + use parse::InfixOp::*; + match binary.op() { + Eq | Ne | Lt | Le | Gt | Ge => { + // Comparison operators: output type must be bool + if !ty.is_boolean() { + return Err(Error::ExpressionUnexpectedType(ty.clone())) + .with_span(from); + } + // Infer operand type from lhs expression + let arg_ty = peek_expression_type(binary.lhs(), scope) + .ok_or(Error::ExpressionUnexpectedType(ty.clone())) + .with_span(from)?; + let (jet, swap_args, negate_result) = + determine_comparison_op_jet(binary.op(), &arg_ty).with_span(from)?; + let (lhs, mut lhs_warnings) = + Expression::analyze(binary.lhs(), &arg_ty, scope)?; + let (rhs, mut rhs_warnings) = + Expression::analyze(binary.rhs(), &arg_ty, scope)?; + lhs_warnings.append(&mut rhs_warnings); + scope.track_call(from, TrackedCallName::Jet); + ( + SingleExpressionInner::BinaryOp { + jet, + lhs: Arc::new(lhs), + rhs: Arc::new(rhs), + assert_no_carry: false, + swap_args, + negate_result, + check_nonzero_divisor: false, + }, + lhs_warnings, + ) + } + LogicalAnd | LogicalOr => { + // Desugar to a match so the Simplicity `case` combinator + // provides natural short-circuit evaluation — only the + // taken branch is evaluated. + // + // a && b => match a { false => false, true => b } + // a || b => match a { false => b, true => true } + if !ty.is_boolean() { + return Err(Error::ExpressionUnexpectedType(ty.clone())) + .with_span(from); + } + let bool_ty = ResolvedType::boolean(); + let span = *from.as_ref(); + let (lhs, mut lhs_warnings) = + Expression::analyze(binary.lhs(), &bool_ty, scope)?; + let (rhs, mut rhs_warnings) = + Expression::analyze(binary.rhs(), &bool_ty, scope)?; + lhs_warnings.append(&mut rhs_warnings); + + let make_bool = |b: bool| -> Expression { + Expression { + inner: ExpressionInner::Single(SingleExpression { + inner: SingleExpressionInner::Constant(Value::from(b)), + ty: bool_ty.clone(), + span, + }), + ty: bool_ty.clone(), + span, + } + }; + + let (left_expr, right_expr) = match binary.op() { + LogicalAnd => (make_bool(false), rhs), + LogicalOr => (rhs, make_bool(true)), + _ => unreachable!(), + }; + ( + SingleExpressionInner::Match(Match { + scrutinee: Arc::new(lhs), + left: MatchArm { + pattern: MatchPattern::False, + expression: Arc::new(left_expr), + }, + right: MatchArm { + pattern: MatchPattern::True, + expression: Arc::new(right_expr), + }, + span, + }), + lhs_warnings, + ) + } + BitAnd | BitOr | BitXor => { + // Bitwise operators: same input and output type, no carry or overflow + let jet = + determine_infix_bitwise_op_jet(binary.op(), ty).with_span(from)?; + let (lhs, mut lhs_warnings) = + Expression::analyze(binary.lhs(), ty, scope)?; + let (rhs, mut rhs_warnings) = + Expression::analyze(binary.rhs(), ty, scope)?; + lhs_warnings.append(&mut rhs_warnings); + scope.track_call(from, TrackedCallName::Jet); + ( + SingleExpressionInner::BinaryOp { + jet, + lhs: Arc::new(lhs), + rhs: Arc::new(rhs), + assert_no_carry: false, + swap_args: false, + negate_result: false, + check_nonzero_divisor: false, + }, + lhs_warnings, + ) + } + Add | Sub | Mul | Div | Rem => { + // Arithmetic operators + let (jet, arg_ty, assert_no_carry) = + determine_infix_arith_op_jet(binary.op(), ty).with_span(from)?; + let (lhs, mut lhs_warnings) = + Expression::analyze(binary.lhs(), &arg_ty, scope)?; + let (rhs, mut rhs_warnings) = + Expression::analyze(binary.rhs(), &arg_ty, scope)?; + lhs_warnings.append(&mut rhs_warnings); + if assert_no_carry { + lhs_warnings.push(Warning::arthimetic_operation_could_overflow(&from)) + } + if matches!(binary.op(), parse::InfixOp::Div | parse::InfixOp::Rem) { + lhs_warnings.push(Warning::division_could_panic_on_zero(&from)) + } + scope.track_call(from, TrackedCallName::Jet); + ( + SingleExpressionInner::BinaryOp { + jet, + lhs: Arc::new(lhs), + rhs: Arc::new(rhs), + assert_no_carry, + swap_args: false, + negate_result: false, + check_nonzero_divisor: matches!( + binary.op(), + parse::InfixOp::Div | parse::InfixOp::Rem + ), + }, + lhs_warnings, + ) + } + } } }; - Ok(Self { - inner, - ty: ty.clone(), - span: *from.as_ref(), - }) + Ok(( + Self { + inner, + ty: ty.clone(), + span: *from.as_ref(), + }, + warnings, + )) } } impl AbstractSyntaxTree for Call { type From = parse::Call; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { fn check_argument_types( parse_args: &[parse::Expression], expected_tys: &[ResolvedType], @@ -1108,17 +1611,23 @@ impl AbstractSyntaxTree for Call { parse_args: &[parse::Expression], args_tys: &[ResolvedType], scope: &mut Scope, - ) -> Result, RichError> { - let args = parse_args + ) -> Result<(Arc<[Expression]>, Vec), RichError> { + let args: Vec<(Expression, Vec)> = parse_args .iter() .zip(args_tys.iter()) .map(|(arg_parse, arg_ty)| Expression::analyze(arg_parse, arg_ty, scope)) - .collect::, RichError>>()?; - Ok(args) + .collect::>()?; + let mut all_warns = vec![]; + let mut all_args = vec![]; + for (e, mut w) in args { + all_warns.append(&mut w); + all_args.push(e); + } + Ok((all_args.into(), all_warns)) } - let name = CallName::analyze(from, ty, scope)?; - let args = match name.clone() { + let (name, name_warnings) = CallName::analyze(from, ty, scope)?; + let (args, mut args_warnings) = match name.clone() { CallName::Jet(jet) => { let args_tys = crate::jet::source_type(jet) .iter() @@ -1275,11 +1784,16 @@ impl AbstractSyntaxTree for Call { } }; - Ok(Self { - name, - args, - span: *from.as_ref(), - }) + let mut all_warnings = name_warnings; + all_warnings.append(&mut args_warnings); + Ok(( + Self { + name, + args, + span: *from.as_ref(), + }, + all_warnings, + )) } } @@ -1291,36 +1805,38 @@ impl AbstractSyntaxTree for CallName { from: &Self::From, _ty: &ResolvedType, scope: &mut Scope, - ) -> Result { + ) -> Result<(Self, Vec), RichError> { match from.name() { parse::CallName::Jet(name) => match Elements::from_str(name.as_inner()) { Ok(Elements::CheckSigVerify | Elements::Verify) | Err(_) => { Err(Error::JetDoesNotExist(name.clone())).with_span(from) } - Ok(jet) => Ok(Self::Jet(jet)), + Ok(jet) => Ok((Self::Jet(jet), vec![])), }, parse::CallName::UnwrapLeft(right_ty) => scope .resolve(right_ty) - .map(Self::UnwrapLeft) + .map(|c| (Self::UnwrapLeft(c), vec![])) .with_span(from), parse::CallName::UnwrapRight(left_ty) => scope .resolve(left_ty) - .map(Self::UnwrapRight) + .map(|c| (Self::UnwrapRight(c), vec![])) + .with_span(from), + parse::CallName::IsNone(some_ty) => scope + .resolve(some_ty) + .map(|c| (Self::IsNone(c), vec![])) + .with_span(from), + parse::CallName::Unwrap => Ok((Self::Unwrap, vec![])), + parse::CallName::Assert => Ok((Self::Assert, vec![])), + parse::CallName::Panic => Ok((Self::Panic, vec![])), + parse::CallName::Debug => Ok((Self::Debug, vec![])), + parse::CallName::TypeCast(target) => scope + .resolve(target) + .map(|c| (Self::TypeCast(c), vec![])) .with_span(from), - parse::CallName::IsNone(some_ty) => { - scope.resolve(some_ty).map(Self::IsNone).with_span(from) - } - parse::CallName::Unwrap => Ok(Self::Unwrap), - parse::CallName::Assert => Ok(Self::Assert), - parse::CallName::Panic => Ok(Self::Panic), - parse::CallName::Debug => Ok(Self::Debug), - parse::CallName::TypeCast(target) => { - scope.resolve(target).map(Self::TypeCast).with_span(from) - } parse::CallName::Custom(name) => scope .get_function(name) .cloned() - .map(Self::Custom) + .map(|c| (Self::Custom(c), vec![])) .ok_or(Error::FunctionUndefined(name.clone())) .with_span(from), parse::CallName::ArrayFold(name, size) => { @@ -1335,7 +1851,7 @@ impl AbstractSyntaxTree for CallName { { Err(Error::FunctionNotFoldable(name.clone())).with_span(from) } else { - Ok(Self::ArrayFold(function, *size)) + Ok((Self::ArrayFold(function, *size), vec![])) } } parse::CallName::Fold(name, bound) => { @@ -1350,7 +1866,7 @@ impl AbstractSyntaxTree for CallName { { Err(Error::FunctionNotFoldable(name.clone())).with_span(from) } else { - Ok(Self::Fold(function, *bound)) + Ok((Self::Fold(function, *bound), vec![])) } } parse::CallName::ForWhile(name) => { @@ -1382,7 +1898,7 @@ impl AbstractSyntaxTree for CallName { | UIntType::U4 | UIntType::U8 | UIntType::U16), - ) => Ok(Self::ForWhile(function, int_ty.bit_width())), + ) => Ok((Self::ForWhile(function, int_ty.bit_width()), vec![])), _ => Err(Error::FunctionNotLoopable(name.clone())).with_span(from), } } @@ -1393,60 +1909,73 @@ impl AbstractSyntaxTree for CallName { impl AbstractSyntaxTree for Match { type From = parse::Match; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { let scrutinee_ty = from.scrutinee_type(); let scrutinee_ty = scope.resolve(&scrutinee_ty).with_span(from)?; - let scrutinee = - Expression::analyze(from.scrutinee(), &scrutinee_ty, scope).map(Arc::new)?; + let (scrutinee, scrutinee_warnings) = + Expression::analyze(from.scrutinee(), &scrutinee_ty, scope)?; scope.push_scope(); if let Some((id_l, ty_l)) = from.left().pattern().as_typed_variable() { let ty_l = scope.resolve(ty_l).with_span(from)?; scope.insert_variable(id_l.clone(), ty_l); } - let ast_l = Expression::analyze(from.left().expression(), ty, scope).map(Arc::new)?; + let (ast_l, mut ast_l_warnings) = Expression::analyze(from.left().expression(), ty, scope)?; scope.pop_scope(); scope.push_scope(); if let Some((id_r, ty_r)) = from.right().pattern().as_typed_variable() { let ty_r = scope.resolve(ty_r).with_span(from)?; scope.insert_variable(id_r.clone(), ty_r); } - let ast_r = Expression::analyze(from.right().expression(), ty, scope).map(Arc::new)?; + let (ast_r, mut ast_r_warnings) = + Expression::analyze(from.right().expression(), ty, scope)?; scope.pop_scope(); - Ok(Self { - scrutinee, - left: MatchArm { - pattern: from.left().pattern().clone(), - expression: ast_l, + let mut all_warnings = scrutinee_warnings; + all_warnings.append(&mut ast_l_warnings); + all_warnings.append(&mut ast_r_warnings); + + Ok(( + Self { + scrutinee: Arc::new(scrutinee), + left: MatchArm { + pattern: from.left().pattern().clone(), + expression: Arc::new(ast_l), + }, + right: MatchArm { + pattern: from.right().pattern().clone(), + expression: Arc::new(ast_r), + }, + span: *from.as_ref(), }, - right: MatchArm { - pattern: from.right().pattern().clone(), - expression: ast_r, - }, - span: *from.as_ref(), - }) + all_warnings, + )) } } fn analyze_named_module( name: ModuleName, from: &parse::ModuleProgram, -) -> Result, RichError> { +) -> Result<(HashMap, Vec), RichError> { let unit = ResolvedType::unit(); let mut scope = Scope::default(); - let items = from + let items: Vec<(ModuleItem, Vec)> = from .items() .iter() .map(|s| ModuleItem::analyze(s, &unit, &mut scope)) - .collect::, RichError>>()?; + .collect::>()?; + debug_assert!(scope.is_topmost()); - let mut iter = items.into_iter().filter_map(|item| match item { - ModuleItem::Module(module) if module.name == name => Some(module), + let mut iter = items.into_iter().filter_map(|(item, warnings)| match item { + ModuleItem::Module(module) if module.name == name => Some((module, warnings)), _ => None, }); - let Some(witness_module) = iter.next() else { - return Ok(HashMap::new()); // "not present" is equivalent to empty + let Some((witness_module, warnings)) = iter.next() else { + return Ok((HashMap::new(), vec![])); // "not present" is equivalent to empty }; if iter.next().is_some() { return Err(Error::ModuleRedefined(name)).with_span(from); @@ -1462,32 +1991,43 @@ fn analyze_named_module( assignment.value().clone(), ); } - Ok(map) + Ok((map, warnings)) } impl WitnessValues { - pub fn analyze(from: &parse::ModuleProgram) -> Result { - analyze_named_module(ModuleName::witness(), from).map(Self::from) + pub fn analyze(from: &parse::ModuleProgram) -> Result<(Self, Vec), RichError> { + analyze_named_module(ModuleName::witness(), from) + .map(|(i, warnings)| (Self::from(i), warnings)) } } impl crate::witness::Arguments { - pub fn analyze(from: &parse::ModuleProgram) -> Result { - analyze_named_module(ModuleName::param(), from).map(Self::from) + pub fn analyze(from: &parse::ModuleProgram) -> Result<(Self, Vec), RichError> { + analyze_named_module(ModuleName::param(), from).map(|(i, warning)| (Self::from(i), warning)) } } impl AbstractSyntaxTree for ModuleItem { type From = parse::ModuleItem; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Items cannot return anything"); assert!(scope.is_topmost(), "Items live in the topmost scope only"); match from { - parse::ModuleItem::Ignored => Ok(Self::Ignored), - parse::ModuleItem::Module(witness_module) => { - Module::analyze(witness_module, ty, scope).map(Self::Module) + parse::ModuleItem::Ignored => { + // TODO: confirm if this is a warning. + // TODO: find the correct span + Ok(( + Self::Ignored, + vec![Warning::module_item_ignored(Span::new(0, 0))], + )) } + parse::ModuleItem::Module(witness_module) => Module::analyze(witness_module, ty, scope) + .map(|(m, warning)| (Self::Module(m), warning)), } } } @@ -1495,40 +2035,61 @@ impl AbstractSyntaxTree for ModuleItem { impl AbstractSyntaxTree for Module { type From = parse::Module; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Modules cannot return anything"); assert!(scope.is_topmost(), "Modules live in the topmost scope only"); - let assignments = from + let assignments: Vec<(ModuleAssignment, Vec)> = from .assignments() .iter() .map(|s| ModuleAssignment::analyze(s, ty, scope)) - .collect::, RichError>>()?; + .collect::>()?; debug_assert!(scope.is_topmost()); - Ok(Self { - name: from.name().shallow_clone(), - span: *from.as_ref(), - assignments, - }) + let mut all_warnings = vec![]; + let mut all_assignments = vec![]; + for (a, mut warnings) in assignments { + all_warnings.append(&mut warnings); + all_assignments.push(a); + } + + Ok(( + Self { + name: from.name().shallow_clone(), + span: *from.as_ref(), + assignments: all_assignments.into(), + }, + all_warnings, + )) } } impl AbstractSyntaxTree for ModuleAssignment { type From = parse::ModuleAssignment; - fn analyze(from: &Self::From, ty: &ResolvedType, scope: &mut Scope) -> Result { + fn analyze( + from: &Self::From, + ty: &ResolvedType, + scope: &mut Scope, + ) -> Result<(Self, Vec), RichError> { assert!(ty.is_unit(), "Assignments cannot return anything"); let ty_expr = scope.resolve(from.ty()).with_span(from)?; - let expression = Expression::analyze(from.expression(), &ty_expr, scope)?; + let (expression, warnings) = Expression::analyze(from.expression(), &ty_expr, scope)?; let value = Value::from_const_expr(&expression) .ok_or(Error::ExpressionUnexpectedType(ty_expr.clone())) .with_span(from.expression())?; - Ok(Self { - name: from.name().clone(), - value, - span: *from.as_ref(), - }) + Ok(( + Self { + name: from.name().clone(), + value, + span: *from.as_ref(), + }, + warnings, + )) } } @@ -1573,3 +2134,248 @@ impl AsRef for ModuleAssignment { &self.span } } + +#[cfg(test)] +mod tests { + use simplicity::jet::Elements; + + use crate::parse::InfixOp; + use crate::types::{ResolvedType, TypeConstructible}; + + use super::determine_infix_arith_op_jet; + + // --- infix_op_jet: Add --- + + #[test] + fn infix_op_jet_add_u8() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Add, &ResolvedType::u8()).unwrap(); + assert_eq!(jet, Elements::Add8); + assert_eq!(arg, ResolvedType::u8()); + assert!(carry); + } + + #[test] + fn infix_op_jet_add_u16() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Add, &ResolvedType::u16()).unwrap(); + assert_eq!(jet, Elements::Add16); + assert_eq!(arg, ResolvedType::u16()); + assert!(carry); + } + + #[test] + fn infix_op_jet_add_u32() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Add, &ResolvedType::u32()).unwrap(); + assert_eq!(jet, Elements::Add32); + assert_eq!(arg, ResolvedType::u32()); + assert!(carry); + } + + #[test] + fn infix_op_jet_add_u64() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Add, &ResolvedType::u64()).unwrap(); + assert_eq!(jet, Elements::Add64); + assert_eq!(arg, ResolvedType::u64()); + assert!(carry); + } + + // --- infix_op_jet: Sub --- + + #[test] + fn infix_op_jet_sub_u8() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Sub, &ResolvedType::u8()).unwrap(); + assert_eq!(jet, Elements::Subtract8); + assert_eq!(arg, ResolvedType::u8()); + assert!(carry); + } + + #[test] + fn infix_op_jet_sub_u16() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Sub, &ResolvedType::u16()).unwrap(); + assert_eq!(jet, Elements::Subtract16); + assert_eq!(arg, ResolvedType::u16()); + assert!(carry); + } + + #[test] + fn infix_op_jet_sub_u32() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Sub, &ResolvedType::u32()).unwrap(); + assert_eq!(jet, Elements::Subtract32); + assert_eq!(arg, ResolvedType::u32()); + assert!(carry); + } + + #[test] + fn infix_op_jet_sub_u64() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Sub, &ResolvedType::u64()).unwrap(); + assert_eq!(jet, Elements::Subtract64); + assert_eq!(arg, ResolvedType::u64()); + assert!(carry); + } + + // --- infix_op_jet: Mul --- + + #[test] + fn infix_op_jet_mul_u16() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Mul, &ResolvedType::u16()).unwrap(); + assert_eq!(jet, Elements::Multiply8); + assert_eq!(arg, ResolvedType::u8()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_mul_u32() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Mul, &ResolvedType::u32()).unwrap(); + assert_eq!(jet, Elements::Multiply16); + assert_eq!(arg, ResolvedType::u16()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_mul_u64() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Mul, &ResolvedType::u64()).unwrap(); + assert_eq!(jet, Elements::Multiply32); + assert_eq!(arg, ResolvedType::u32()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_mul_u128() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Mul, &ResolvedType::u128()).unwrap(); + assert_eq!(jet, Elements::Multiply64); + assert_eq!(arg, ResolvedType::u64()); + assert!(!carry); + } + + // --- infix_op_jet: Div --- + + #[test] + fn infix_op_jet_div_u8() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Div, &ResolvedType::u8()).unwrap(); + assert_eq!(jet, Elements::Divide8); + assert_eq!(arg, ResolvedType::u8()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_div_u16() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Div, &ResolvedType::u16()).unwrap(); + assert_eq!(jet, Elements::Divide16); + assert_eq!(arg, ResolvedType::u16()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_div_u32() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Div, &ResolvedType::u32()).unwrap(); + assert_eq!(jet, Elements::Divide32); + assert_eq!(arg, ResolvedType::u32()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_div_u64() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Div, &ResolvedType::u64()).unwrap(); + assert_eq!(jet, Elements::Divide64); + assert_eq!(arg, ResolvedType::u64()); + assert!(!carry); + } + + // --- infix_op_jet: Rem --- + + #[test] + fn infix_op_jet_rem_u8() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Rem, &ResolvedType::u8()).unwrap(); + assert_eq!(jet, Elements::Modulo8); + assert_eq!(arg, ResolvedType::u8()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_rem_u16() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Rem, &ResolvedType::u16()).unwrap(); + assert_eq!(jet, Elements::Modulo16); + assert_eq!(arg, ResolvedType::u16()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_rem_u32() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Rem, &ResolvedType::u32()).unwrap(); + assert_eq!(jet, Elements::Modulo32); + assert_eq!(arg, ResolvedType::u32()); + assert!(!carry); + } + + #[test] + fn infix_op_jet_rem_u64() { + let (jet, arg, carry) = + determine_infix_arith_op_jet(&InfixOp::Rem, &ResolvedType::u64()).unwrap(); + assert_eq!(jet, Elements::Modulo64); + assert_eq!(arg, ResolvedType::u64()); + assert!(!carry); + } + + // --- infix_op_jet: error cases --- + + #[test] + fn infix_op_jet_add_wrong_type_unit() { + let result = determine_infix_arith_op_jet(&InfixOp::Add, &ResolvedType::unit()); + assert!( + result.is_err(), + "Expected Err for Add with unit output type" + ); + } + + #[test] + fn infix_op_jet_sub_wrong_type_unit() { + let result = determine_infix_arith_op_jet(&InfixOp::Sub, &ResolvedType::unit()); + assert!( + result.is_err(), + "Expected Err for Sub with unit output type" + ); + } + + #[test] + fn infix_op_jet_mul_wrong_type_u8() { + // `*` on u8 inputs would produce u16, so u8 as output type is wrong + let result = determine_infix_arith_op_jet(&InfixOp::Mul, &ResolvedType::u8()); + assert!(result.is_err(), "Expected Err for Mul with u8 output type"); + } + + #[test] + fn infix_op_jet_div_wrong_type_bool() { + let result = determine_infix_arith_op_jet(&InfixOp::Div, &ResolvedType::boolean()); + assert!( + result.is_err(), + "Expected Err for Div with bool output type" + ); + } + + #[test] + fn infix_op_jet_rem_wrong_type_bool() { + let result = determine_infix_arith_op_jet(&InfixOp::Rem, &ResolvedType::boolean()); + assert!( + result.is_err(), + "Expected Err for Rem with bool output type" + ); + } +} diff --git a/src/compile/mod.rs b/src/compile/mod.rs index 2af17e6f..5f1fc930 100644 --- a/src/compile/mod.rs +++ b/src/compile/mod.rs @@ -222,6 +222,18 @@ impl<'brand> Scope<'brand> { } } +/// Returns the equality jet and a zero `simplicity::Value` for the operand type +/// of a divide or modulo jet. Used to generate the divisor-zero check. +fn divisor_check_jet_and_zero_for_binary_op(jet: Elements) -> (Elements, simplicity::Value) { + match jet { + Elements::Divide8 | Elements::Modulo8 => (Elements::Eq8, simplicity::Value::u8(0)), + Elements::Divide16 | Elements::Modulo16 => (Elements::Eq16, simplicity::Value::u16(0)), + Elements::Divide32 | Elements::Modulo32 => (Elements::Eq32, simplicity::Value::u32(0)), + Elements::Divide64 | Elements::Modulo64 => (Elements::Eq64, simplicity::Value::u64(0)), + _ => unreachable!("divisor_check_jet_and_zero called on non-divide/modulo jet"), + } +} + fn compile_blk<'brand>( stmts: &[Statement], scope: &mut Scope<'brand>, @@ -355,6 +367,77 @@ impl SingleExpression { } SingleExpressionInner::Call(call) => call.compile(scope)?, SingleExpressionInner::Match(match_) => match_.compile(scope)?, + SingleExpressionInner::BinaryOp { + jet, + lhs, + rhs, + assert_no_carry, + swap_args, + negate_result, + check_nonzero_divisor, + } => { + let args = if *swap_args { + rhs.compile(scope)?.pair(lhs.compile(scope)?) + } else { + lhs.compile(scope)?.pair(rhs.compile(scope)?) + }; + let jet_node = ProgNode::jet(scope.ctx(), *jet); + let args = if *check_nonzero_divisor { + // Emit a divisor-zero check before the divide/modulo jet. + // If rhs == 0 the program panics; otherwise the args pass through unchanged. + // + // args : Input → (lhs, rhs) + // + // Step 1: check_rhs_eq_zero : (lhs, rhs) → (1+1) + // = pair(drop(iden), unit_scribe(0)) >>> EqN + let (eq_jet_elem, zero_sv) = divisor_check_jet_and_zero_for_binary_op(*jet); + // drop(iden) : (lhs, rhs) → rhs + let rhs_from_pair = ProgNode::drop_(&ProgNode::iden(scope.ctx())); + // unit_scribe(0) : (lhs, rhs) → 0 + let zero_scribe = ProgNode::unit_scribe(scope.ctx(), &zero_sv); + // pair(rhs_from_pair, zero_scribe) : (lhs, rhs) → (rhs, 0) + let rhs_zero_pair = ProgNode::pair(&rhs_from_pair, &zero_scribe).unwrap(); + // >>> EqN : (rhs, 0) → (1+1) [right = equal, left = not-equal] + let eq_node = ProgNode::jet(scope.ctx(), eq_jet_elem); + let check_rhs_eq_zero = ProgNode::comp(&rhs_zero_pair, &eq_node).unwrap(); + + // Step 2: pair(check, iden) : (lhs, rhs) → ((1+1), (lhs, rhs)) + let check_and_iden = + ProgNode::pair(&check_rhs_eq_zero, &ProgNode::iden(scope.ctx())).unwrap(); + + // Step 3: assertl(drop(iden), fail) : ((1+1), (lhs, rhs)) → (lhs, rhs) + // asserts eq result is left (false = not-equal, i.e. divisor != 0) + // passes through (lhs, rhs) via drop(iden) + let assert_nonzero = + ProgNode::iden(scope.ctx()).assertl_drop(Cmr::fail(FailEntropy::ZERO)); + + // Compose: (lhs, rhs) → (lhs, rhs) [panics if rhs == 0] + let check_passthrough = + ProgNode::comp(&check_and_iden, &assert_nonzero).unwrap(); + + // args >>> check_passthrough : Input → (lhs, rhs) [with zero check] + args.comp(&check_passthrough).with_span(self)? + } else { + args + }; + let jet_result = args.comp(&jet_node).with_span(self)?; + if *assert_no_carry { + // jet returns (bool, uN): assert carry==false (panic on overflow), return uN + // assertl(drop_(iden), cmr_fail): ((unit+unit), uN) → uN + let assert_and_extract = + ProgNode::iden(scope.ctx()).assertl_drop(Cmr::fail(FailEntropy::ZERO)); + jet_result.comp(&assert_and_extract).with_span(self)? + } else if *negate_result { + // bool NOT: pair(iden, unit) >>> case_true_false: (1+1) → (1+1) + let input_and_unit = + PairBuilder::iden(scope.ctx()).pair(PairBuilder::unit(scope.ctx())); + let bool_not_node = ProgNode::case_true_false(scope.ctx()); + let bool_not = input_and_unit.comp(&bool_not_node).with_span(self)?; + jet_result.comp(&bool_not).with_span(self)? + } else { + jet_result + } + } }; scope diff --git a/src/error.rs b/src/error.rs index 1ded6fdd..69f369fb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,7 @@ use chumsky::DefaultExpected; use itertools::Itertools; use simplicity::elements; +use crate::ast::WarningName; use crate::lexer::Token; use crate::parse::MatchPattern; use crate::str::{AliasName, FunctionName, Identifier, JetName, ModuleName, WitnessName}; @@ -189,32 +190,37 @@ impl RichError { } } -impl fmt::Display for RichError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fn get_line_col(file: &str, offset: usize) -> (usize, usize) { - let mut line = 1; - let mut last_newline_offset = 0; +pub(crate) fn get_line_col(file: &str, offset: usize) -> (usize, usize) { + let mut line = 1; + let mut last_newline_offset = 0; - let slice = file.get(0..offset).unwrap_or_default(); + let slice = file.get(0..offset).unwrap_or_default(); - for (i, byte) in slice.bytes().enumerate() { - if byte == b'\n' { - line += 1; - last_newline_offset = i; - } - } - - let col = (offset - last_newline_offset) + 1; - (line, col) + for (i, byte) in slice.bytes().enumerate() { + if byte == b'\n' { + line += 1; + last_newline_offset = i; } + } + + let col = (offset - last_newline_offset) + 1; + (line, col) +} + +impl fmt::Display for RichError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + const RED_BOLD: &str = "\x1b[1;31m"; + const RESET: &str = "\x1b[0m"; match self.file { Some(ref file) if !file.is_empty() => { let (start_line, start_col) = get_line_col(file, self.span.start); let (end_line, end_col) = get_line_col(file, self.span.end); - let start_line_index = start_line - 1; + writeln!(f, "{RED_BOLD}error{RESET}: {}", self.error)?; + writeln!(f, " --> {start_line}:{start_col}")?; + let start_line_index = start_line - 1; let n_spanned_lines = end_line - start_line_index; let line_num_width = end_line.to_string().len(); @@ -229,18 +235,21 @@ impl fmt::Display for RichError { } let is_multiline = end_line > start_line; - let (underline_start, underline_length) = match is_multiline { true => (0, start_line_len), false => (start_col, end_col - start_col), }; write!(f, "{:width$} |", " ", width = line_num_width)?; write!(f, "{:width$}", " ", width = underline_start)?; - write!(f, "{:^ { - write!(f, "{}", self.error) + write!(f, "{RED_BOLD}error{RESET}: {}", self.error) } } } @@ -438,6 +447,7 @@ pub enum Error { ModuleRedefined(ModuleName), ArgumentMissing(WitnessName), ArgumentTypeMismatch(WitnessName, ResolvedType, ResolvedType), + DeniedWarning(WarningName), } #[rustfmt::skip] @@ -587,6 +597,9 @@ impl fmt::Display for Error { f, "Parameter `{name}` was declared with type `{declared}` but its assigned argument is of type `{assigned}`" ), + Error::DeniedWarning(warning) => write!( + f, "Warning treated as error: {warning}" + ) } } } diff --git a/src/lexer.rs b/src/lexer.rs index fef18927..59d2bb00 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -19,21 +19,65 @@ pub enum Token<'src> { Match, // Control symbols + /// `->` Arrow, + /// `:` Colon, + /// `;` Semi, + /// `,` Comma, + /// `=` Eq, + /// `=>` FatArrow, + /// `(` LParen, + /// `)` RParen, + /// `[` LBracket, + /// `]` RBracket, + /// `{` LBrace, + /// `}` RBrace, + /// `<` LAngle, + /// `>` RAngle, + // Infix operators + /// `+` + Plus, + /// `-` + Minus, + /// `*` + Star, + /// `/` + Slash, + /// `%` + Percent, + /// `&&` + AmpAmp, + /// `||` + PipePipe, + /// `&` + Ampersand, + /// `|` + Pipe, + /// `^` + Caret, + /// `==` + EqEq, + /// `!=` + BangEq, + /// `<=` + LtEq, + /// `>=` + GtEq, + // Number literals DecLiteral(Decimal), HexLiteral(Hexadecimal), @@ -85,6 +129,21 @@ impl<'src> fmt::Display for Token<'src> { Token::LAngle => write!(f, "<"), Token::RAngle => write!(f, ">"), + Token::Plus => write!(f, "+"), + Token::Minus => write!(f, "-"), + Token::Star => write!(f, "*"), + Token::Slash => write!(f, "/"), + Token::Percent => write!(f, "%"), + Token::AmpAmp => write!(f, "&&"), + Token::PipePipe => write!(f, "||"), + Token::Ampersand => write!(f, "&"), + Token::Pipe => write!(f, "|"), + Token::Caret => write!(f, "^"), + Token::EqEq => write!(f, "=="), + Token::BangEq => write!(f, "!="), + Token::LtEq => write!(f, "<="), + Token::GtEq => write!(f, ">="), + Token::DecLiteral(s) => write!(f, "{}", s), Token::HexLiteral(s) => write!(f, "0x{}", s), Token::BinLiteral(s) => write!(f, "0b{}", s), @@ -159,20 +218,38 @@ pub fn lexer<'src>( .map(Token::Param); let op = choice(( - just("->").to(Token::Arrow), - just("=>").to(Token::FatArrow), - just("=").to(Token::Eq), - just(":").to(Token::Colon), - just(";").to(Token::Semi), - just(",").to(Token::Comma), - just("(").to(Token::LParen), - just(")").to(Token::RParen), - just("[").to(Token::LBracket), - just("]").to(Token::RBracket), - just("{").to(Token::LBrace), - just("}").to(Token::RBrace), - just("<").to(Token::LAngle), - just(">").to(Token::RAngle), + choice(( + just("->").to(Token::Arrow), + just("=>").to(Token::FatArrow), + just("==").to(Token::EqEq), + just("!=").to(Token::BangEq), + just("=").to(Token::Eq), + just(":").to(Token::Colon), + just(";").to(Token::Semi), + just(",").to(Token::Comma), + just("(").to(Token::LParen), + just(")").to(Token::RParen), + just("[").to(Token::LBracket), + just("]").to(Token::RBracket), + just("{").to(Token::LBrace), + just("}").to(Token::RBrace), + )), + choice(( + just("<=").to(Token::LtEq), + just("<").to(Token::LAngle), + just(">=").to(Token::GtEq), + just(">").to(Token::RAngle), + just("+").to(Token::Plus), + just("-").to(Token::Minus), + just("*").to(Token::Star), + just("/").to(Token::Slash), + just("%").to(Token::Percent), + just("&&").to(Token::AmpAmp), + just("||").to(Token::PipePipe), + just("&").to(Token::Ampersand), + just("|").to(Token::Pipe), + just("^").to(Token::Caret), + )), )); let comment = just("//") @@ -260,6 +337,167 @@ mod tests { (tokens, errors) } + /// Helper function to get the variant name of a token + fn variant_name(token: &Token) -> &'static str { + match token { + Token::Fn => "Fn", + Token::Let => "Let", + Token::Type => "Type", + Token::Mod => "Mod", + Token::Const => "Const", + Token::Match => "Match", + Token::Arrow => "Arrow", + Token::Colon => "Colon", + Token::Semi => "Semi", + Token::Comma => "Comma", + Token::Eq => "Eq", + Token::FatArrow => "FatArrow", + Token::LParen => "LParen", + Token::RParen => "RParen", + Token::LBracket => "LBracket", + Token::RBracket => "RBracket", + Token::LBrace => "LBrace", + Token::RBrace => "RBrace", + Token::LAngle => "LAngle", + Token::RAngle => "RAngle", + Token::Plus => "Plus", + Token::Minus => "Minus", + Token::Star => "Star", + Token::Slash => "Slash", + Token::Percent => "Percent", + Token::AmpAmp => "AmpAmp", + Token::PipePipe => "PipePipe", + Token::Ampersand => "Ampersand", + Token::Pipe => "Pipe", + Token::Caret => "Caret", + Token::EqEq => "EqEq", + Token::BangEq => "BangEq", + Token::LtEq => "LtEq", + Token::GtEq => "GtEq", + Token::DecLiteral(_) => "DecLiteral", + Token::HexLiteral(_) => "HexLiteral", + Token::BinLiteral(_) => "BinLiteral", + Token::Bool(_) => "Bool", + Token::Ident(_) => "Ident", + Token::Jet(_) => "Jet", + Token::Witness(_) => "Witness", + Token::Param(_) => "Param", + Token::Macro(_) => "Macro", + Token::Comment => "Comment", + Token::BlockComment => "BlockComment", + } + } + + /// Macro to assert that a sequence of tokens matches the expected variant types + macro_rules! assert_tokens_match { + ($tokens:expr, $($expected:ident),* $(,)?) => { + { + let tokens = $tokens.as_ref().expect("Expected Some tokens"); + let expected_variants = vec![$( stringify!($expected) ),*]; + + assert_eq!( + tokens.len(), + expected_variants.len(), + "Expected {} tokens, got {}.\nTokens: {:?}", + expected_variants.len(), + tokens.len(), + tokens + ); + + for (idx, (token, expected_variant)) in tokens.iter().zip(expected_variants.iter()).enumerate() { + let actual_variant = variant_name(token); + assert_eq!( + actual_variant, + *expected_variant, + "Token at index {} does not match: expected {}, got {} (token: {:?})", + idx, + expected_variant, + actual_variant, + token + ); + } + } + }; + } + + #[test] + fn test_infix_add() { + let input = "b1 + b2"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Plus, Ident); + } + + #[test] + fn test_infix_subtract() { + let input = "b1 - b2"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Minus, Ident); + } + + #[test] + fn test_infix_multiply() { + let input = "b1 * b2"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Star, Ident); + } + + #[test] + fn test_infix_divide() { + let input = "b1 / b2"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Slash, Ident); + } + + #[test] + fn test_infix_modulo() { + let input = "b1 % b2"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Percent, Ident); + } + + #[test] + fn test_infix_in_let_binding() { + let input = "let b3 : u8 = b1 + b2;"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Let, Ident, Colon, Ident, Eq, Ident, Plus, Ident, Semi); + } + + #[test] + fn test_slash_not_confused_with_comments() { + // `/` alone must lex as Slash, not the start of `//` or `/*` + let input = "a / b"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Slash, Ident); + + // `//` must lex as a comment, not two Slash tokens + let input = "a // b"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, Comment); + + // `/*` must lex as a block comment, not Slash followed by Star + let input = "a /* b */"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Ident, BlockComment); + } + + #[test] + fn test_arrow_not_confused_with_minus() { + // `->` must lex as Arrow, not Minus followed by RAngle + let input = "fn foo() -> u8"; + let (tokens, errors) = lex(input); + assert!(errors.is_empty(), "Expected no errors, found: {:?}", errors); + assert_tokens_match!(tokens, Fn, Ident, LParen, RParen, Arrow, Ident); + } + #[test] fn test_block_comment_simple() { let input = "/* hello world */"; diff --git a/src/lib.rs b/src/lib.rs index b4a1032a..ef0fbe9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ mod serde; pub mod str; pub mod tracker; pub mod types; +pub mod unstable_flags; pub mod value; mod witness; @@ -29,10 +30,12 @@ pub extern crate either; pub extern crate simplicity; pub use simplicity::elements; +use crate::ast::Warning; use crate::debug::DebugSymbols; use crate::error::{ErrorCollector, WithFile}; use crate::parse::ParseFromStrWithErrors; pub use crate::types::ResolvedType; +pub use crate::unstable_flags::{UnstableFlags, with_flags}; pub use crate::value::Value; pub use crate::witness::{Arguments, Parameters, WitnessTypes, WitnessValues}; @@ -43,6 +46,7 @@ pub use crate::witness::{Arguments, Parameters, WitnessTypes, WitnessValues}; pub struct TemplateProgram { simfony: ast::Program, file: Arc, + warnings: Vec, } impl TemplateProgram { @@ -51,16 +55,31 @@ impl TemplateProgram { /// ## Errors /// /// The string is not a valid SimplicityHL program. - pub fn new>>(s: Str) -> Result { + pub fn new>>( + s: Str, + deny_all_warnings: bool, + flags: UnstableFlags, + ) -> Result { let file = s.into(); let mut error_handler = ErrorCollector::new(Arc::clone(&file)); - let parse_program = parse::Program::parse_from_str_with_errors(&file, &mut error_handler); + let parse_program = + with_flags(flags, || { + parse::Program::parse_from_str_with_errors(&file, &mut error_handler) + }); if let Some(program) = parse_program { - let ast_program = ast::Program::analyze(&program).with_file(Arc::clone(&file))?; - Ok(Self { - simfony: ast_program, - file, - }) + let (ast_program, warnings) = + ast::Program::analyze(&program).with_file(Arc::clone(&file))?; + + if !warnings.is_empty() && deny_all_warnings { + error_handler.update(warnings.into_iter().map(|w| w.into())); + Err(ErrorCollector::to_string(&error_handler))? + } else { + Ok(Self { + simfony: ast_program, + file, + warnings, + }) + } } else { Err(ErrorCollector::to_string(&error_handler))? } @@ -76,6 +95,62 @@ impl TemplateProgram { self.simfony.witness_types() } + /// Access any warnings produced during compilation. + pub fn warnings(&self) -> &[Warning] { + &self.warnings + } + + /// Format warnings for display in rustc style, with source location and yellow color. + pub fn format_warnings(&self, file_path: &str) -> String { + use crate::error::get_line_col; + use std::fmt::Write as _; + + const YELLOW_BOLD: &str = "\x1b[1;33m"; + const RESET: &str = "\x1b[0m"; + + let mut out = String::new(); + for warning in &self.warnings { + let message = warning.canonical_name.to_string(); + let _ = writeln!(out, "{YELLOW_BOLD}warning{RESET}: {message}"); + + if !self.file.is_empty() { + let (start_line, start_col) = get_line_col(&self.file, warning.span.start); + let (end_line, end_col) = get_line_col(&self.file, warning.span.end); + let _ = writeln!(out, " --> {file_path}:{start_line}:{start_col}"); + + let start_line_index = start_line - 1; + let n_spanned_lines = end_line - start_line_index; + let line_num_width = end_line.to_string().len(); + + let _ = writeln!(out, "{:width$} |", " ", width = line_num_width); + let mut lines = self.file.lines().skip(start_line_index).peekable(); + let start_line_len = lines.peek().map_or(0, |l| l.len()); + + for (i, line_str) in lines.take(n_spanned_lines).enumerate() { + let line_num = start_line_index + i + 1; + let _ = writeln!(out, "{line_num:line_num_width$} | {line_str}"); + } + + let is_multiline = end_line > start_line; + let (underline_start, underline_len) = if is_multiline { + (0, start_line_len) + } else { + (start_col, end_col - start_col) + }; + let _ = write!(out, "{:width$} |", " ", width = line_num_width); + let _ = write!(out, "{:width$}", " ", width = underline_start); + let _ = writeln!( + out, + "{YELLOW_BOLD}{:^ Result { - TemplateProgram::new(s) + TemplateProgram::new(s, deny_all_warnings, flags) .and_then(|template| template.instantiate(arguments, include_debug_symbols)) } @@ -217,8 +294,11 @@ impl SatisfiedProgram { arguments: Arguments, witness_values: WitnessValues, include_debug_symbols: bool, + deny_all_warnings: bool, + flags: UnstableFlags, ) -> Result { - let compiled = CompiledProgram::new(s, arguments, include_debug_symbols)?; + let compiled = + CompiledProgram::new(s, arguments, include_debug_symbols, deny_all_warnings, flags)?; compiled.satisfy(witness_values) } @@ -321,7 +401,11 @@ pub(crate) mod tests { } pub fn template_text(program_text: Cow) -> Self { - let program = match TemplateProgram::new(program_text.as_ref()) { + Self::template_text_with_flags(program_text, UnstableFlags::new()) + } + + pub fn template_text_with_flags(program_text: Cow, flags: UnstableFlags) -> Self { + let program = match TemplateProgram::new(program_text.as_ref(), false, flags) { Ok(x) => x, Err(error) => panic!("{error}"), }; @@ -662,6 +746,8 @@ fn main() { Arguments::default(), WitnessValues::default(), false, + false, + UnstableFlags::new(), ) { Ok(_) => panic!("Accepted faulty program"), Err(error) => { @@ -711,6 +797,225 @@ fn main() { .assert_run_success(); } + // --- infix operator integration tests (require -Z infix_arithmetic_operators) --- + + fn arith_flags() -> UnstableFlags { + let mut f = UnstableFlags::new(); + f.enable(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators); + f + } + + #[test] + fn infix_op_add_u8() { + let prog = r#"fn main() { + let a: u8 = 9; + let b: u8 = 10; + let (_, sum): (bool, u8) = a + b; + assert!(jet::eq_8(sum, 19)); +}"#; + TestCase::template_text_with_flags(Cow::Borrowed(prog), arith_flags()) + .with_arguments(Arguments::default()) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_sub_u8() { + let prog = r#"fn main() { + let a: u8 = 20; + let b: u8 = 7; + let (_, diff): (bool, u8) = a - b; + assert!(jet::eq_8(diff, 13)); +}"#; + TestCase::template_text_with_flags(Cow::Borrowed(prog), arith_flags()) + .with_arguments(Arguments::default()) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_mul_u8() { + let prog = r#"fn main() { + let a: u8 = 6; + let b: u8 = 7; + let product: u16 = a * b; + assert!(jet::eq_16(product, 42)); +}"#; + TestCase::template_text_with_flags(Cow::Borrowed(prog), arith_flags()) + .with_arguments(Arguments::default()) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_div_u8() { + let prog = r#"fn main() { + let a: u8 = 20; + let b: u8 = 4; + let quotient: u8 = a / b; + assert!(jet::eq_8(quotient, 5)); +}"#; + TestCase::template_text_with_flags(Cow::Borrowed(prog), arith_flags()) + .with_arguments(Arguments::default()) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_rem_u8() { + let prog = r#"fn main() { + let a: u8 = 17; + let b: u8 = 5; + let remainder: u8 = a % b; + assert!(jet::eq_8(remainder, 2)); +}"#; + TestCase::template_text_with_flags(Cow::Borrowed(prog), arith_flags()) + .with_arguments(Arguments::default()) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + // --- logical operator tests --- + + #[test] + fn infix_op_logical_and_true() { + // true && true = true + let prog = r#"fn main() { + assert!(true && true); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_and_false_lhs() { + // false && true = false (short-circuits, rhs not evaluated) + let prog = r#"fn main() { + let result: bool = false && true; + assert!(match result { false => true, true => false, }); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_and_false_rhs() { + // true && false = false + let prog = r#"fn main() { + let result: bool = true && false; + assert!(match result { false => true, true => false, }); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_or_true_lhs() { + // true || false = true (short-circuits, rhs not evaluated) + let prog = r#"fn main() { + assert!(true || false); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_or_true_rhs() { + // false || true = true + let prog = r#"fn main() { + assert!(false || true); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_or_false() { + // false || false = false + let prog = r#"fn main() { + let result: bool = false || false; + assert!(match result { false => true, true => false, }); +}"#; + TestCase::program_text(Cow::Borrowed(prog)) + .with_witness_values(WitnessValues::default()) + .assert_run_success(); + } + + #[test] + fn infix_op_logical_and_wrong_type_error() { + // `&&` requires both operands to be bool + let prog = r#"fn main() { + let a: u8 = 1; + let result: bool = true && a; +}"#; + match SatisfiedProgram::new( + prog, + Arguments::default(), + WitnessValues::default(), + false, + false, + UnstableFlags::new(), + ) { + Ok(_) => panic!("Expected type error for `&&` with non-bool operand"), + Err(error) => assert!( + error.contains("Expected expression of type"), + "Unexpected error message: {error}", + ), + } + } + + #[test] + fn infix_op_logical_and_wrong_output_type_error() { + // `&&` result must be bool + let prog = r#"fn main() { + let result: u8 = true && false; +}"#; + match SatisfiedProgram::new( + prog, + Arguments::default(), + WitnessValues::default(), + false, + false, + UnstableFlags::new(), + ) { + Ok(_) => panic!("Expected type error for `&&` assigned to non-bool"), + Err(error) => assert!( + error.contains("Expected expression of type"), + "Unexpected error message: {error}", + ), + } + } + + #[test] + fn infix_op_add_wrong_output_type_error() { + // `+` returns (bool, u8), binding it to plain u8 must fail + let prog = r#"fn main() { + let a: u8 = 1; + let b: u8 = 2; + let sum: u8 = a + b; + assert!(jet::eq_8(sum, 3)); +}"#; + match SatisfiedProgram::new( + prog, + Arguments::default(), + WitnessValues::default(), + false, + false, + arith_flags(), + ) { + Ok(_) => panic!("Expected type error for `+` with plain u8 output"), + Err(error) => assert!( + error.contains("Expected expression of type"), + "Unexpected error message: {error}", + ), + } + } + #[cfg(feature = "serde")] mod regression { use super::TestCase; diff --git a/src/main.rs b/src/main.rs index 5cb7b571..c875bf0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,8 @@ use base64::display::Base64Display; use base64::engine::general_purpose::STANDARD; use clap::{Arg, ArgAction, Command}; -use simplicityhl::{AbiMeta, CompiledProgram}; +use simplicityhl::{AbiMeta, TemplateProgram, UnstableFlags}; +use simplicityhl::unstable_flags::UnstableFlag; use std::{env, fmt}; #[cfg_attr(feature = "serde", derive(serde::Serialize))] @@ -83,6 +84,22 @@ fn main() -> Result<(), Box> { .action(ArgAction::SetTrue) .help("Additional ABI .simf contract types"), ) + .arg( + Arg::new("deny_warnings") + .long("deny-warnings") + .action(ArgAction::SetTrue) + .help("Treat warnings as errors"), + ) + .arg( + Arg::new("unstable_flags") + .short('Z') + .value_name("FLAG") + .action(ArgAction::Append) + .help( + "Enable an unstable feature flag (can be passed multiple times). \ + Known flags: infix_arithmetic_operators", + ), + ) }; let matches = command.get_matches(); @@ -113,7 +130,43 @@ fn main() -> Result<(), Box> { simplicityhl::Arguments::default() }; - let compiled = match CompiledProgram::new(prog_text, args_opt, include_debug_symbols) { + let deny_warnings = matches.get_flag("deny_warnings"); + + let mut flags = UnstableFlags::new(); + for flag_str in matches + .get_many::("unstable_flags") + .unwrap_or_default() + { + match flag_str.parse::() { + Ok(flag) => flags.enable(flag), + Err(e) => { + eprintln!("error: {e}"); + std::process::exit(1); + } + } + } + + let template = match TemplateProgram::new(prog_text, deny_warnings, flags) { + Ok(t) => t, + Err(e) => { + eprintln!("{}", e); + std::process::exit(1); + } + }; + let n_warnings = template.warnings().len(); + if n_warnings > 0 { + eprint!("{}", template.format_warnings(prog_file)); + let word = if n_warnings == 1 { + "warning" + } else { + "warnings" + }; + eprintln!( + "\x1b[1;33mwarning\x1b[0m: `{}` generated {} {}", + prog_file, n_warnings, word, + ); + } + let compiled = match template.instantiate(args_opt, include_debug_symbols) { Ok(program) => program, Err(e) => { eprintln!("{}", e); diff --git a/src/parse.rs b/src/parse.rs index c42c6f3d..26c42bae 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -303,6 +303,106 @@ pub enum ExpressionInner { Block(Arc<[Statement]>, Option>), } +/// An infix binary operator. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum InfixOp { + /// `+` + Add, + /// `-` + Sub, + /// `*` + Mul, + /// `/` + Div, + /// `%` + Rem, + /// `&&` — short-circuit logical AND + LogicalAnd, + /// `||` — short-circuit logical OR + LogicalOr, + /// `&` + BitAnd, + /// `|` + BitOr, + /// `^` + BitXor, + /// `==` + Eq, + /// `!=` + Ne, + /// `<` + Lt, + /// `<=` + Le, + /// `>` + Gt, + /// `>=` + Ge, +} + +impl fmt::Display for InfixOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + InfixOp::Add => write!(f, "+"), + InfixOp::Sub => write!(f, "-"), + InfixOp::Mul => write!(f, "*"), + InfixOp::Div => write!(f, "/"), + InfixOp::Rem => write!(f, "%"), + InfixOp::LogicalAnd => write!(f, "&&"), + InfixOp::LogicalOr => write!(f, "||"), + InfixOp::BitAnd => write!(f, "&"), + InfixOp::BitOr => write!(f, "|"), + InfixOp::BitXor => write!(f, "^"), + InfixOp::Eq => write!(f, "=="), + InfixOp::Ne => write!(f, "!="), + InfixOp::Lt => write!(f, "<"), + InfixOp::Le => write!(f, "<="), + InfixOp::Gt => write!(f, ">"), + InfixOp::Ge => write!(f, ">="), + } + } +} + +/// A binary infix operator expression: `lhs op rhs`. +#[derive(Clone, Debug)] +pub struct BinaryExpression { + lhs: Arc, + op: InfixOp, + rhs: Arc, + span: Span, +} + +impl BinaryExpression { + /// Access the left-hand side expression. + pub fn lhs(&self) -> &Expression { + &self.lhs + } + + /// Access the infix operator. + pub fn op(&self) -> &InfixOp { + &self.op + } + + /// Access the right-hand side expression. + pub fn rhs(&self) -> &Expression { + &self.rhs + } +} + +impl_eq_hash!(BinaryExpression; lhs, op, rhs); + +impl fmt::Display for BinaryExpression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", ExprTree::BinaryOp(self)) + } +} + +impl AsRef for BinaryExpression { + fn as_ref(&self) -> &Span { + &self.span + } +} + /// A single expression directly returns a value. #[derive(Clone, Debug)] pub struct SingleExpression { @@ -351,6 +451,8 @@ pub enum SingleExpressionInner { Expression(Arc), /// Match expression over a sum type Match(Match), + /// Binary infix operator expression, e.g. `a + b`. + BinaryOp(BinaryExpression), /// Tuple wrapper expression Tuple(Arc<[Expression]>), /// Array wrapper expression @@ -602,6 +704,7 @@ pub enum ExprTree<'a> { Single(&'a SingleExpression), Call(&'a Call), Match(&'a Match), + BinaryOp(&'a BinaryExpression), } impl TreeLike for ExprTree<'_> { @@ -642,6 +745,7 @@ impl TreeLike for ExprTree<'_> { | S::Expression(l) => Tree::Unary(Self::Expression(l)), S::Call(call) => Tree::Unary(Self::Call(call)), S::Match(match_) => Tree::Unary(Self::Match(match_)), + S::BinaryOp(binary) => Tree::Unary(Self::BinaryOp(binary)), S::Tuple(elements) | S::Array(elements) | S::List(elements) => { Tree::Nary(elements.iter().map(Self::Expression).collect()) } @@ -652,6 +756,10 @@ impl TreeLike for ExprTree<'_> { Self::Expression(match_.left().expression()), Self::Expression(match_.right().expression()), ])), + Self::BinaryOp(binary) => Tree::Nary(Arc::new([ + Self::Expression(binary.lhs()), + Self::Expression(binary.rhs()), + ])), } } } @@ -715,7 +823,7 @@ impl fmt::Display for ExprTree<'_> { write!(f, ")")?; } }, - S::Call(..) | S::Match(..) => {} + S::Call(..) | S::Match(..) | S::BinaryOp(..) => {} S::Tuple(tuple) => { if data.n_children_yielded == 0 { write!(f, "(")?; @@ -766,6 +874,11 @@ impl fmt::Display for ExprTree<'_> { write!(f, ",\n}}")?; } }, + Self::BinaryOp(binary) => match data.n_children_yielded { + 0 => {} + 1 => write!(f, " {} ", binary.op())?, + n => debug_assert_eq!(n, 2), + }, } } @@ -1607,14 +1720,56 @@ impl SingleExpression { .delimited_by(just(Token::LParen), just(Token::RParen)) .map(|es| SingleExpressionInner::Expression(Arc::from(es))); - choice(( + let op = select! { + Token::Plus if crate::unstable_flags::is_enabled(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators) => InfixOp::Add, + Token::Minus if crate::unstable_flags::is_enabled(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators) => InfixOp::Sub, + Token::Star if crate::unstable_flags::is_enabled(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators) => InfixOp::Mul, + Token::Slash if crate::unstable_flags::is_enabled(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators) => InfixOp::Div, + Token::Percent if crate::unstable_flags::is_enabled(crate::unstable_flags::UnstableFlag::InfixArithmeticOperators) => InfixOp::Rem, + Token::AmpAmp => InfixOp::LogicalAnd, + Token::PipePipe => InfixOp::LogicalOr, + Token::Ampersand => InfixOp::BitAnd, + Token::Pipe => InfixOp::BitOr, + Token::Caret => InfixOp::BitXor, + Token::EqEq => InfixOp::Eq, + Token::BangEq => InfixOp::Ne, + Token::LAngle => InfixOp::Lt, + Token::LtEq => InfixOp::Le, + Token::RAngle => InfixOp::Gt, + Token::GtEq => InfixOp::Ge, + }; + + let primary = choice(( left, right, some, none, boolean, match_expr, expression, list, array, tuple, call, literal, variable, )) .map_with(|inner, e| Self { inner, span: e.span(), - }) + }); + + primary + .clone() + .foldl_with(op.then(primary).repeated(), |lhs, (op, rhs), e| { + let span = e.span(); + let lhs_span = *lhs.span(); + let rhs_span = *rhs.span(); + Self { + inner: SingleExpressionInner::BinaryOp(BinaryExpression { + lhs: Arc::new(Expression { + inner: ExpressionInner::Single(lhs), + span: lhs_span, + }), + op, + rhs: Arc::new(Expression { + inner: ExpressionInner::Single(rhs), + span: rhs_span, + }), + span, + }), + span, + } + }) } } @@ -2163,3 +2318,84 @@ impl crate::ArbitraryRec for Match { }) } } + +#[cfg(test)] +mod test { + use super::*; + + /// Helper: extract a `BinaryExpression` from the outermost `Expression`. + fn unwrap_binary(expr: &Expression) -> &BinaryExpression { + match expr.inner() { + ExpressionInner::Single(single) => match single.inner() { + SingleExpressionInner::BinaryOp(binary) => binary, + _ => panic!("Expected SingleExpressionInner::BinaryOp"), + }, + _ => panic!("Expected ExpressionInner::Single"), + } + } + + #[test] + fn test_binary_op_add() { + let expr = Expression::parse_from_str("b1 + b2").expect("Failed to parse"); + assert_eq!(unwrap_binary(&expr).op(), &InfixOp::Add); + } + + #[test] + fn test_binary_op_all_operators() { + for (input, expected_op) in [ + ("a + b", InfixOp::Add), + ("a - b", InfixOp::Sub), + ("a * b", InfixOp::Mul), + ("a / b", InfixOp::Div), + ("a % b", InfixOp::Rem), + ("a == b", InfixOp::Eq), + ("a != b", InfixOp::Ne), + ("a < b", InfixOp::Lt), + ("a <= b", InfixOp::Le), + ("a > b", InfixOp::Gt), + ("a >= b", InfixOp::Ge), + ] { + let expr = Expression::parse_from_str(input).expect("Failed to parse"); + assert_eq!( + unwrap_binary(&expr).op(), + &expected_op, + "Wrong op for input: {input}" + ); + } + } + + #[test] + fn test_binary_op_left_associative() { + // `a + b + c` should parse as `(a + b) + c` + let expr = Expression::parse_from_str("a + b + c").expect("Failed to parse"); + let outer = unwrap_binary(&expr); + assert_eq!(outer.op(), &InfixOp::Add); + // LHS should wrap another BinaryOp: (a + b) + let inner = unwrap_binary(outer.lhs()); + assert_eq!(inner.op(), &InfixOp::Add); + } + + #[test] + fn test_binary_op_in_let_binding() { + let input = "{ let b3: u8 = b1 + b2; b3 }"; + let expr = Expression::parse_from_str(input).expect("Failed to parse"); + match expr.inner() { + ExpressionInner::Block(stmts, _) => { + assert_eq!(stmts.len(), 1); + match &stmts[0] { + Statement::Assignment(assign) => { + assert_eq!(unwrap_binary(assign.expression()).op(), &InfixOp::Add); + } + _ => panic!("Expected assignment statement"), + } + } + _ => panic!("Expected block expression"), + } + } + + #[test] + fn test_binary_op_display() { + let expr = Expression::parse_from_str("a + b").expect("Failed to parse"); + assert_eq!(expr.to_string(), "a + b"); + } +} diff --git a/src/tracker.rs b/src/tracker.rs index 4a6f693a..63a10850 100644 --- a/src/tracker.rs +++ b/src/tracker.rs @@ -472,7 +472,7 @@ mod tests { #[test] fn test_debug_and_jet_tracing() { - let program = TemplateProgram::new(TEST_PROGRAM).unwrap(); + let program = TemplateProgram::new(TEST_PROGRAM, false, crate::UnstableFlags::new()).unwrap(); let program = program.instantiate(Arguments::default(), true).unwrap(); let satisfied = program.satisfy(WitnessValues::default()).unwrap(); @@ -541,7 +541,7 @@ mod tests { fn test_arith_jet_trace_regression() { let env = create_test_env(); - let program = TemplateProgram::new(TEST_ARITHMETIC_JETS).unwrap(); + let program = TemplateProgram::new(TEST_ARITHMETIC_JETS, false, crate::UnstableFlags::new()).unwrap(); let program = program.instantiate(Arguments::default(), true).unwrap(); let satisfied = program.satisfy(WitnessValues::default()).unwrap(); diff --git a/src/unstable_flags.rs b/src/unstable_flags.rs new file mode 100644 index 00000000..46a94fe1 --- /dev/null +++ b/src/unstable_flags.rs @@ -0,0 +1,122 @@ +//! Unstable (experimental) compiler feature flags. +//! +//! These are analogous to `rustc`'s `-Z` flags and must be explicitly opted +//! into on the command line with `-Z `. Unstable features may +//! change or be removed without notice. + +use std::cell::Cell; +use std::fmt; +use std::str::FromStr; + +/// A single unstable compiler feature flag. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum UnstableFlag { + /// Enable infix arithmetic operators: `+`, `-`, `*`, `/`, `%`. + /// + /// These operators have side effects (overflow/underflow panics for `+`/`-`, + /// division-by-zero panics for `/`/`%`) and so are opt-in. + InfixArithmeticOperators, + // ^^^ Add new flags here. Each flag needs: + // 1. A variant in this enum + // 2. A bit position in `bit()` + // 3. A name string in `name()` / `from_str()` + // 4. An entry in `ALL` +} + +impl UnstableFlag { + const fn bit(self) -> u64 { + match self { + Self::InfixArithmeticOperators => 1 << 0, + } + } + + pub const fn name(self) -> &'static str { + match self { + Self::InfixArithmeticOperators => "infix_arithmetic_operators", + } + } + + /// All known flags, used for error messages. + pub const ALL: &'static [Self] = &[Self::InfixArithmeticOperators]; +} + +impl fmt::Display for UnstableFlag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.name()) + } +} + +impl FromStr for UnstableFlag { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "infix_arithmetic_operators" => Ok(Self::InfixArithmeticOperators), + _ => { + let known = Self::ALL + .iter() + .map(|f| f.name()) + .collect::>() + .join(", "); + Err(format!( + "unknown unstable flag `{s}`. Known flags: {known}" + )) + } + } + } +} + +/// A set of enabled unstable compiler flags. +/// +/// Stored as a bitmask; adding a new flag only requires extending +/// [`UnstableFlag`] — no changes to this type are needed. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct UnstableFlags(u64); + +impl UnstableFlags { + /// Returns an empty flag set (all features disabled). + pub const fn new() -> Self { + Self(0) + } + + /// Enable `flag` in this set. + pub fn enable(&mut self, flag: UnstableFlag) { + self.0 |= flag.bit(); + } + + /// Returns `true` if `flag` is enabled. + pub fn is_enabled(self, flag: UnstableFlag) -> bool { + self.0 & flag.bit() != 0 + } +} + +// --------------------------------------------------------------------------- +// Thread-local parse context +// --------------------------------------------------------------------------- + +thread_local! { + static ACTIVE: Cell = const { Cell::new(0) }; +} + +/// Returns `true` if `flag` is currently active in this thread's parse context. +/// +/// This is intended to be called from inside parser combinators, where the +/// active flags are set by [`with_flags`]. +pub fn is_enabled(flag: UnstableFlag) -> bool { + ACTIVE.with(|cell| cell.get() & flag.bit() != 0) +} + +/// Run `f` with `flags` active on the current thread. +/// +/// The previous flag state is restored when `f` returns, making this safe to +/// nest. +pub fn with_flags(flags: UnstableFlags, f: impl FnOnce() -> R) -> R { + let prev = ACTIVE.with(|cell| { + let old = cell.get(); + cell.set(flags.0); + old + }); + let result = f(); + ACTIVE.with(|cell| cell.set(prev)); + result +} diff --git a/src/value.rs b/src/value.rs index 3ca7fdca..02462e98 100644 --- a/src/value.rs +++ b/src/value.rs @@ -680,7 +680,8 @@ impl Value { | S::Parameter(..) | S::Variable(..) | S::Call(..) - | S::Match(..) => return None, // not const + | S::Match(..) + | S::BinaryOp { .. } => return None, // not const S::Expression(..) => continue, // skip S::Tuple(..) => { let elements = output.split_off(output.len() - size); @@ -784,7 +785,11 @@ impl Value { /// Parse a value of the given type from a string. pub fn parse_from_str(s: &str, ty: &ResolvedType) -> Result { let parse_expr = parse::Expression::parse_from_str(s)?; - let ast_expr = ast::Expression::analyze_const(&parse_expr, ty)?; + let (ast_expr, _warnings) = ast::Expression::analyze_const(&parse_expr, ty)?; + // TODO: confirm what to do with warnings here. There shouldn't be any. + // if !warnings.is_empty() {} + // Throw away warnings for now + Self::from_const_expr(&ast_expr) .ok_or(Error::ExpressionUnexpectedType(ty.clone())) .with_span(s) @@ -1317,7 +1322,7 @@ mod tests { for (string, ty, expected_value) in string_ty_value { let parse_expr = parse::Expression::parse_from_str(string).unwrap(); - let ast_expr = ast::Expression::analyze_const(&parse_expr, &ty).unwrap(); + let (ast_expr, _warnings) = ast::Expression::analyze_const(&parse_expr, &ty).unwrap(); let parsed_value = Value::from_const_expr(&ast_expr).unwrap(); assert_eq!(parsed_value, expected_value); assert!(parsed_value.is_of_type(&ty)); diff --git a/src/witness.rs b/src/witness.rs index 6d6ffefc..31e41fca 100644 --- a/src/witness.rs +++ b/src/witness.rs @@ -69,7 +69,10 @@ macro_rules! impl_name_value_map { impl ParseFromStr for $wrapper { fn parse_from_str(s: &str) -> Result { - parse::ModuleProgram::parse_from_str(s).and_then(|x| Self::analyze(&x)) + // TODO: Note warnings are dropped here. + parse::ModuleProgram::parse_from_str(s) + .and_then(|x| Self::analyze(&x)) + .map(|(s, _warnings)| s) } } @@ -247,7 +250,7 @@ mod tests { WitnessName::from_str_unchecked("A"), Value::u16(42), )])); - match SatisfiedProgram::new(s, Arguments::default(), witness, false) { + match SatisfiedProgram::new(s, Arguments::default(), witness, false, false, crate::UnstableFlags::new()) { Ok(_) => panic!("Ill-typed witness assignment was falsely accepted"), Err(error) => assert_eq!( "Witness `A` was declared with type `u32` but its assigned value is of type `u16`", @@ -266,7 +269,7 @@ fn main() { assert!(jet::is_zero_32(f())); }"#; - match CompiledProgram::new(s, Arguments::default(), false) { + match CompiledProgram::new(s, Arguments::default(), false, false, crate::UnstableFlags::new()) { Ok(_) => panic!("Witness outside main was falsely accepted"), Err(error) => { assert!(error