diff --git a/crates/oq3_semantics/src/semantic_error.rs b/crates/oq3_semantics/src/semantic_error.rs index 6c6d9e8..425472f 100644 --- a/crates/oq3_semantics/src/semantic_error.rs +++ b/crates/oq3_semantics/src/semantic_error.rs @@ -131,6 +131,10 @@ impl SemanticError { self.node.text_range() } + pub fn kind(&self) -> &SemanticErrorKind { + &self.error_kind + } + pub fn message(&self) -> String { format!("{:?}", self.error_kind) } diff --git a/crates/oq3_semantics/src/syntax_to_semantics.rs b/crates/oq3_semantics/src/syntax_to_semantics.rs index 3ab7bfa..e1a4453 100644 --- a/crates/oq3_semantics/src/syntax_to_semantics.rs +++ b/crates/oq3_semantics/src/syntax_to_semantics.rs @@ -196,14 +196,20 @@ fn from_stmt(stmt: synast::Stmt, context: &mut Context) -> Option { match stmt { synast::Stmt::IfStmt(if_stmt) => { let condition = from_expr(if_stmt.condition().unwrap(), context); - let then_branch = from_block_expr(if_stmt.then_branch().unwrap(), context); - let else_branch = if_stmt.else_branch().map(|ex| from_block_expr(ex, context)); + with_scope!(context, ScopeType::Local, + let then_branch = from_block_expr(if_stmt.then_branch().unwrap(), context); + ); + with_scope!(context, ScopeType::Local, + let else_branch = if_stmt.else_branch().map(|ex| from_block_expr(ex, context)); + ); Some(asg::If::new(condition.unwrap(), then_branch, else_branch).to_stmt()) } synast::Stmt::WhileStmt(while_stmt) => { let condition = from_expr(while_stmt.condition().unwrap(), context); - let loop_body = from_block_expr(while_stmt.body().unwrap(), context); + with_scope!(context, ScopeType::Local, + let loop_body = from_block_expr(while_stmt.body().unwrap(), context); + ); Some(asg::While::new(condition.unwrap(), loop_body).to_stmt()) } diff --git a/crates/oq3_semantics/tests/from_string_tests.rs b/crates/oq3_semantics/tests/from_string_tests.rs index e6f72d5..bad57b0 100644 --- a/crates/oq3_semantics/tests/from_string_tests.rs +++ b/crates/oq3_semantics/tests/from_string_tests.rs @@ -2,7 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use oq3_semantics::asg; -use oq3_semantics::semantic_error::SemanticErrorList; +use oq3_semantics::semantic_error::{SemanticErrorKind, SemanticErrorList}; use oq3_semantics::symbols::{SymbolTable, SymbolType}; use oq3_semantics::syntax_to_semantics::parse_source_string; use oq3_semantics::types::{ArrayDims, IsConst, Type}; @@ -117,6 +117,53 @@ while (false) { assert_eq!(inner, vec![&1u128, &2u128]); } +#[test] +fn test_from_string_while_stmt_scope() { + let code = r##" +while (false) { + int x = 1; +} +x = 2; +"##; + let (program, errors, _symbol_table) = parse_string(code); + assert!(matches!( + &errors[0].kind(), + SemanticErrorKind::UndefVarError + )); + assert_eq!(errors.len(), 1); + assert_eq!(program.len(), 2); +} + +#[test] +fn test_from_string_if_stmt_scope() { + let code = r##" +if (false) { + int x = 1; +} +x = 2; +"##; + let (program, errors, _symbol_table) = parse_string(code); + assert!(matches!( + &errors[0].kind(), + SemanticErrorKind::UndefVarError + )); + assert_eq!(errors.len(), 1); + assert_eq!(program.len(), 2); +} + +#[test] +fn test_from_string_if_stmt_scope_2() { + let code = r##" +if (false) { + int x = 1; +} +int x = 2; +"##; + let (program, errors, _symbol_table) = parse_string(code); + assert_eq!(errors.len(), 0); + assert_eq!(program.len(), 2); +} + #[test] fn test_indexed_identifier() { let code = r##"