Skip to content

Commit

Permalink
typechecker: Add is variant bindings to return/match pattern expressions
Browse files Browse the repository at this point in the history
Currently `expand_context_for_bindings` is only used with if/ guard
conditions.

It is also useful to have though in expressions where there is no block
but the binding can be used as part of the expression.

There are probably other places where it would make sense in future.
  • Loading branch information
robryanx authored and awesomekling committed Oct 6, 2022
1 parent 9743d97 commit abab871
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
39 changes: 23 additions & 16 deletions selfhost/typechecker.jakt
Original file line number Diff line number Diff line change
Expand Up @@ -3469,7 +3469,7 @@ struct Typechecker {
.error("Condition must be a boolean expression", new_condition.span())
}

let checked_block = .typecheck_block(new_then_block, parent_scope_id: scope_id, safety_mode)
let checked_block = .typecheck_block(new_then_block!, parent_scope_id: scope_id, safety_mode)
mut checked_else: CheckedStatement? = None
if new_else_statement.has_value() {
checked_else = .typecheck_statement(new_else_statement!, scope_id, safety_mode)
Expand Down Expand Up @@ -3662,7 +3662,7 @@ struct Typechecker {
return .typecheck_statement(rewritten_statement, scope_id, safety_mode)
}

function expand_context_for_bindings(mut this, condition: ParsedExpression, acc: ParsedExpression?, then_block: ParsedBlock, else_statement: ParsedStatement?, span: Span) throws -> (ParsedExpression, ParsedBlock, ParsedStatement?) {
function expand_context_for_bindings(mut this, condition: ParsedExpression, acc: ParsedExpression?, then_block: ParsedBlock?, else_statement: ParsedStatement?, span: Span) throws -> (ParsedExpression, ParsedBlock?, ParsedStatement?) {
match condition {
BinaryOp(lhs, op, rhs) => {
if op is LogicalAnd {
Expand All @@ -3688,18 +3688,20 @@ struct Typechecker {
outer_if_stmts.push(ParsedStatement::VarDecl(var, init: enum_variant_arg, span))
}
mut inner_condition = condition
mut new_then_block = then_block
mut new_else_statement = else_statement
if acc.has_value() {
inner_condition = acc!
outer_if_stmts.push(ParsedStatement::If(condition: inner_condition, then_block, else_statement, span))
} else {
for stmt in then_block.stmts.iterator() {
outer_if_stmts.push(stmt)
if then_block.has_value() {
if acc.has_value() {
inner_condition = acc!
outer_if_stmts.push(ParsedStatement::If(condition: inner_condition, then_block: then_block!, else_statement, span))
} else {
for stmt in then_block!.stmts.iterator() {
outer_if_stmts.push(stmt)
}
}
}
new_then_block = ParsedBlock(stmts: outer_if_stmts)
return .expand_context_for_bindings(condition: unary_op_single_condition, acc: None, then_block: new_then_block, else_statement: new_else_statement, span)

let new_then_block = ParsedBlock(stmts: outer_if_stmts)
return .expand_context_for_bindings(condition: unary_op_single_condition, acc: None, then_block: new_then_block, else_statement, span)

}
else => {}
}
Expand All @@ -3716,13 +3718,14 @@ struct Typechecker {
function typecheck_if(mut this, condition: ParsedExpression, then_block: ParsedBlock, else_statement: ParsedStatement?, scope_id: ScopeId, safety_mode: SafetyMode, span: Span) throws -> CheckedStatement {
let (new_condition, new_then_block, new_else_statement) = .expand_context_for_bindings(condition, acc: None, then_block, else_statement, span)
let checked_condition = .typecheck_expression_and_dereference_if_needed(new_condition, scope_id, safety_mode, type_hint: None, span)

if not checked_condition.type().equals(builtin(BuiltinType::Bool)) {
.error("Condition must be a boolean expression", new_condition.span())
}

let checked_block = .typecheck_block(new_then_block, parent_scope_id: scope_id, safety_mode)
let checked_block = .typecheck_block(new_then_block!, parent_scope_id: scope_id, safety_mode)
if checked_block.yielded_type.has_value() {
.error("An 'if' block is not allowed to yield values", new_then_block.find_yield_span()!)
.error("An 'if' block is not allowed to yield values", new_then_block!.find_yield_span()!)
}

mut checked_else: CheckedStatement? = None
Expand Down Expand Up @@ -4021,7 +4024,9 @@ struct Typechecker {
type_hint = Some(.get_function(.current_function_id!).return_type_id)
}

let checked_expr = .typecheck_expression(expr!, scope_id, safety_mode, type_hint)
let (new_condition, new_then_block, new_else_statement) = .expand_context_for_bindings(condition: expr!, acc: None, then_block: None, else_statement: None, span)
let checked_expr = .typecheck_expression_and_dereference_if_needed(new_condition, scope_id, safety_mode, type_hint, span)

return CheckedStatement::Return(val: checked_expr, span)
}

Expand Down Expand Up @@ -5687,7 +5692,9 @@ struct Typechecker {
}
is_value_match = true

let checked_expression = .typecheck_expression(expr, scope_id, safety_mode, type_hint: Some(subject_type_id))
let (new_condition, new_then_block, new_else_statement) = .expand_context_for_bindings(condition: expr, acc: None, then_block: None, else_statement: None, span)
let checked_expression = .typecheck_expression_and_dereference_if_needed(new_condition, scope_id, safety_mode, type_hint: Some(subject_type_id), span)

if not checked_expression.to_number_constant(program: .program).has_value() {
all_variants_constant = false
}
Expand Down
18 changes: 18 additions & 0 deletions tests/typechecker/is_variant_binding_match_case.jakt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// Expect:
/// - output: "1\n"

enum Foo {
Bar(i64)
Baz(m: String)
}

function main() {
let foo = Foo::Baz(m: "Hello")

let result = match true {
(foo is Baz(m: n) and n == "Hello") => 1
else => 2
}

println("{}", result)
}
15 changes: 15 additions & 0 deletions tests/typechecker/is_variant_binding_return_statement.jakt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/// Expect:
/// - output: "true\n"

enum Foo {
Bar(i64)
Baz(m: String)
}

function match_foo(foo: Foo) => foo is Baz(m: n) and n == "Hello"

function main() {
let foo = Foo::Baz(m: "Hello")

println("{}", match_foo(foo))
}

0 comments on commit abab871

Please sign in to comment.