Skip to content

Commit

Permalink
fix: for loop predicate
Browse files Browse the repository at this point in the history
This PR fixes for loops executing once when the predicate already should not be met for decrementing loops.
I have also re-implemented the codegen logic for for-loops, resulting in fewer predecessors and hopefully more
readable IR.

Resolves #1207
  • Loading branch information
mhasel committed Jun 17, 2024
1 parent 44df0db commit 745fd10
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 595 deletions.
128 changes: 111 additions & 17 deletions src/codegen/generators/statement_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
/// - `end` the value indicating the end of the for loop
/// - `by_step` the step of the loop
/// - `body` the statements inside the for-loop
fn generate_for_statement(
fn _generate_for_statement(
&self,
counter: &AstNode,
start: &AstNode,
Expand All @@ -401,12 +401,11 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
//Check loop condition
builder.position_at_end(condition_check);
let exp_gen = self.create_expr_generator();
let counter_statement = exp_gen.generate_expression(counter)?;

let counter_statement = dbg!(exp_gen.generate_expression(counter))?;
//. / and_2 \
//. / and 1 \
//. (counter_end_le && counter_start_ge) || (counter_end_ge && counter_start_le)
let or_eval = self.generate_compare_expression(counter, end, start, &exp_gen)?;
let or_eval = dbg!(self.generate_compare_expression(counter, end, start, by_step, &exp_gen))?;

builder.build_conditional_branch(to_i1(or_eval.into_int_value(), builder), for_body, continue_block);

Expand Down Expand Up @@ -457,42 +456,137 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> {
Ok(())
}

fn generate_for_statement(
&self,
counter: &AstNode,
start: &AstNode,
end: &AstNode,
by_step: &Option<Box<AstNode>>,
body: &[AstNode],
) -> Result<(), Diagnostic> {
let (builder, current_function, context) = self.get_llvm_deps();
let exp_gen = self.create_expr_generator();
self.generate_assignment_statement(counter, start)?;
let predicate_incrementing = context.append_basic_block(current_function, "predicate_inc");
let predicate_decrementing = context.append_basic_block(current_function, "predicate_dec");
let loop_body = context.append_basic_block(current_function, "loop");
let afterloop = context.append_basic_block(current_function, "continue");

let counter = exp_gen.generate_lvalue(counter)?;
let end = exp_gen.generate_expression(end)?;
let counter_value = builder.build_load(counter, "");

let by_step = by_step.as_ref().map_or_else(
|| {
self.llvm.create_const_numeric(
&counter_value.get_type(),
"1",
SourceLocation::undefined(),
)
},
|step| {
self.register_debug_location(step);
exp_gen.generate_expression(step)
},
)?;
let is_incrementing = builder.build_int_compare(inkwell::IntPredicate::SGT, counter_value.into_int_value(), self.llvm.i32_type().const_zero(), "is_incrementing");

// --check loop predicate--
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// --incrementing loops--
builder.position_at_end(predicate_incrementing);
let value = builder.build_load(counter, "");
let inc_cmp = builder.build_int_compare(inkwell::IntPredicate::SLE, value.into_int_value(), end.into_int_value(), "condition");
builder.build_conditional_branch(inc_cmp, loop_body, afterloop);
// --decrementing loops--
builder.position_at_end(predicate_decrementing);
let value = builder.build_load(counter, "");
let dec_cmp = builder.build_int_compare(inkwell::IntPredicate::SGE, value.into_int_value(), end.into_int_value(), "condition");
builder.build_conditional_branch(dec_cmp, loop_body, afterloop);

// --body--
let body_generator = StatementCodeGenerator {
current_loop_exit: Some(afterloop),
current_loop_continue: Some(predicate_incrementing),
load_prefix: self.load_prefix.clone(),
load_suffix: self.load_suffix.clone(),
..*self
};
builder.position_at_end(loop_body);
body_generator.generate_body(body)?;
// --increment--
let value = builder.build_load(counter, "");
let inc = builder.build_int_add(value.into_int_value(), by_step.into_int_value(), "increment");
builder.build_store(counter, inc);
//--check condition again--
// builder.build_phi(self.llvm.i32_type(), "phi");
builder.build_conditional_branch(is_incrementing, predicate_incrementing, predicate_decrementing);
// --continue--
builder.position_at_end(afterloop);
Ok(())
}

fn generate_compare_expression(
&'a self,
counter: &AstNode,
end: &AstNode,
start: &AstNode,
start: &AstNode, // step
step: &Option<Box<AstNode>>,
exp_gen: &'a ExpressionCodeGenerator,
) -> Result<BasicValueEnum<'a>, Diagnostic> {
let bool_id = self.annotations.get_bool_id();

// correct(hopefully): (i <= end && step > 0) || (i >= end && step < 0)
// step == 0 => infinite loop (if step is const literal, validate, otherwise runtime sanitizer?/UB)

let counter_end_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
end.clone(),
bool_id,
);
let counter_start_ge = AstFactory::create_binary_expression(
counter.clone(),
Operator::GreaterOrEqual,
start.clone(),
let step = if let Some(step) = step.as_deref() {
dbg!(step.clone())
} else {
AstFactory::create_literal(
plc_ast::literals::AstLiteral::Integer(1),
SourceLocation::internal(),
bool_id,
)
};

let zero = AstFactory::create_literal(
plc_ast::literals::AstLiteral::Integer(0),
SourceLocation::internal(),
bool_id,
);

let step_negative =
AstFactory::create_binary_expression(step.clone(), Operator::Greater, zero.clone(), bool_id);
let step_positive = AstFactory::create_binary_expression(step.clone(), Operator::Less, zero, bool_id);
// let counter_start_ge = AstFactory::create_binary_expression(
// counter.clone(),
// Operator::GreaterOrEqual,
// start.clone(),
// bool_id,
// );
let counter_end_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
end.clone(),
bool_id,
);
let counter_start_le = AstFactory::create_binary_expression(
counter.clone(),
Operator::LessOrEqual,
start.clone(),
bool_id,
);
// let counter_start_le = AstFactory::create_binary_expression(
// counter.clone(),
// Operator::LessOrEqual,
// start.clone(),
// bool_id,
// );
let and_1 =
AstFactory::create_binary_expression(counter_end_le, Operator::And, counter_start_ge, bool_id);
AstFactory::create_binary_expression(counter_end_le, Operator::And, step_positive, bool_id);

let and_2 =
AstFactory::create_binary_expression(counter_end_ge, Operator::And, counter_start_le, bool_id);
AstFactory::create_binary_expression(counter_end_ge, Operator::And, step_negative, bool_id);
let or = AstFactory::create_binary_expression(and_1, Operator::Or, and_2, bool_id);

self.register_debug_location(&or);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,67 +11,30 @@ entry:
call void @llvm.dbg.declare(metadata i32* %myFunc, metadata !9, metadata !DIExpression()), !dbg !11
store i32 0, i32* %myFunc, align 4, !dbg !8
store i32 1, i32* %myFunc, align 4, !dbg !12
br label %condition_check, !dbg !12

condition_check: ; preds = %increment, %entry
%load_myFunc = load i32, i32* %myFunc, align 4, !dbg !12
%load_myFunc1 = load i32, i32* %myFunc, align 4, !dbg !12
%tmpVar = icmp sle i32 %load_myFunc1, 20, !dbg !12
%0 = zext i1 %tmpVar to i8, !dbg !12
%1 = icmp ne i8 %0, 0, !dbg !12
br i1 %1, label %2, label %5, !dbg !12

for_body: ; preds = %12
store i32 1, i32* %myFunc, align 4, !dbg !13
br label %increment, !dbg !13

increment: ; preds = %for_body
%tmpVar8 = add i32 %load_myFunc, 2, !dbg !14
store i32 %tmpVar8, i32* %myFunc, align 4, !dbg !14
br label %condition_check, !dbg !14

continue: ; preds = %12
%0 = load i32, i32* %myFunc, align 4, !dbg !12
%is_incrementing = icmp sgt i32 %0, 0, !dbg !13
br i1 %is_incrementing, label %predicate_inc, label %predicate_dec, !dbg !13

predicate_inc: ; preds = %loop, %entry
%1 = load i32, i32* %myFunc, align 4, !dbg !13
%condition = icmp sle i32 %1, 20, !dbg !13
br i1 %condition, label %loop, label %continue, !dbg !13

predicate_dec: ; preds = %loop, %entry
%2 = load i32, i32* %myFunc, align 4, !dbg !13
%condition1 = icmp sge i32 %2, 20, !dbg !13
br i1 %condition1, label %loop, label %continue, !dbg !13

loop: ; preds = %predicate_dec, %predicate_inc
store i32 1, i32* %myFunc, align 4, !dbg !14
%3 = load i32, i32* %myFunc, align 4, !dbg !14
%increment = add i32 %3, 2, !dbg !14
store i32 %increment, i32* %myFunc, align 4, !dbg !14
br i1 %is_incrementing, label %predicate_inc, label %predicate_dec, !dbg !14

continue: ; preds = %predicate_dec, %predicate_inc
%myFunc_ret = load i32, i32* %myFunc, align 4, !dbg !14
ret i32 %myFunc_ret, !dbg !14

2: ; preds = %condition_check
%load_myFunc2 = load i32, i32* %myFunc, align 4, !dbg !12
%tmpVar3 = icmp sge i32 %load_myFunc2, 1, !dbg !12
%3 = zext i1 %tmpVar3 to i8, !dbg !12
%4 = icmp ne i8 %3, 0, !dbg !12
br label %5, !dbg !12

5: ; preds = %2, %condition_check
%6 = phi i1 [ %1, %condition_check ], [ %4, %2 ], !dbg !12
%7 = zext i1 %6 to i8, !dbg !12
%8 = icmp ne i8 %7, 0, !dbg !12
br i1 %8, label %12, label %9, !dbg !12

9: ; preds = %5
%load_myFunc4 = load i32, i32* %myFunc, align 4, !dbg !12
%tmpVar5 = icmp sge i32 %load_myFunc4, 20, !dbg !12
%10 = zext i1 %tmpVar5 to i8, !dbg !12
%11 = icmp ne i8 %10, 0, !dbg !12
br i1 %11, label %16, label %19, !dbg !12

12: ; preds = %19, %5
%13 = phi i1 [ %8, %5 ], [ %22, %19 ], !dbg !12
%14 = zext i1 %13 to i8, !dbg !12
%15 = icmp ne i8 %14, 0, !dbg !12
br i1 %15, label %for_body, label %continue, !dbg !12

16: ; preds = %9
%load_myFunc6 = load i32, i32* %myFunc, align 4, !dbg !12
%tmpVar7 = icmp sle i32 %load_myFunc6, 1, !dbg !12
%17 = zext i1 %tmpVar7 to i8, !dbg !12
%18 = icmp ne i8 %17, 0, !dbg !12
br label %19, !dbg !12

19: ; preds = %16, %9
%20 = phi i1 [ %11, %9 ], [ %18, %16 ], !dbg !12
%21 = zext i1 %20 to i8, !dbg !12
%22 = icmp ne i8 %21, 0, !dbg !12
br label %12, !dbg !12
}

; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
Expand All @@ -95,5 +58,5 @@ attributes #0 = { nofree nosync nounwind readnone speculatable willreturn }
!10 = !DIBasicType(name: "DINT", size: 32, encoding: DW_ATE_signed, flags: DIFlagPublic)
!11 = !DILocation(line: 2, column: 17, scope: !4)
!12 = !DILocation(line: 3, column: 16, scope: !4)
!13 = !DILocation(line: 4, column: 16, scope: !4)
!14 = !DILocation(line: 3, column: 37, scope: !4)
!13 = !DILocation(line: 3, column: 37, scope: !4)
!14 = !DILocation(line: 4, column: 16, scope: !4)
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,27 @@ define void @prg(%prg* %0) section "fn-$RUSTY$prg:v" {
entry:
%x = getelementptr inbounds %prg, %prg* %0, i32 0, i32 0
store i32 3, i32* %x, align 4
br label %condition_check

condition_check: ; preds = %increment, %entry
%1 = load i32, i32* %x, align 4
%is_incrementing = icmp sgt i32 %1, 0
br i1 %is_incrementing, label %predicate_inc, label %predicate_dec

predicate_inc: ; preds = %loop, %entry
%2 = load i32, i32* %x, align 4
%condition = icmp sle i32 %2, 10
br i1 %condition, label %loop, label %continue

predicate_dec: ; preds = %loop, %entry
%3 = load i32, i32* %x, align 4
%condition1 = icmp sge i32 %3, 10
br i1 %condition1, label %loop, label %continue

loop: ; preds = %predicate_dec, %predicate_inc
%4 = load i32, i32* %x, align 4
%increment = add i32 %4, 1
store i32 %increment, i32* %x, align 4
br i1 %is_incrementing, label %predicate_inc, label %predicate_dec

continue: ; preds = %predicate_dec, %predicate_inc
%load_x = load i32, i32* %x, align 4
%load_x1 = load i32, i32* %x, align 4
%tmpVar = icmp sle i32 %load_x1, 10
%1 = zext i1 %tmpVar to i8
%2 = icmp ne i8 %1, 0
br i1 %2, label %3, label %6

for_body: ; preds = %13
br label %increment

increment: ; preds = %for_body
%tmpVar8 = add i32 %load_x, 1
store i32 %tmpVar8, i32* %x, align 4
br label %condition_check

continue: ; preds = %13
%load_x9 = load i32, i32* %x, align 4
ret void

3: ; preds = %condition_check
%load_x2 = load i32, i32* %x, align 4
%tmpVar3 = icmp sge i32 %load_x2, 3
%4 = zext i1 %tmpVar3 to i8
%5 = icmp ne i8 %4, 0
br label %6

6: ; preds = %3, %condition_check
%7 = phi i1 [ %2, %condition_check ], [ %5, %3 ]
%8 = zext i1 %7 to i8
%9 = icmp ne i8 %8, 0
br i1 %9, label %13, label %10

10: ; preds = %6
%load_x4 = load i32, i32* %x, align 4
%tmpVar5 = icmp sge i32 %load_x4, 10
%11 = zext i1 %tmpVar5 to i8
%12 = icmp ne i8 %11, 0
br i1 %12, label %17, label %20

13: ; preds = %20, %6
%14 = phi i1 [ %9, %6 ], [ %23, %20 ]
%15 = zext i1 %14 to i8
%16 = icmp ne i8 %15, 0
br i1 %16, label %for_body, label %continue

17: ; preds = %10
%load_x6 = load i32, i32* %x, align 4
%tmpVar7 = icmp sle i32 %load_x6, 3
%18 = zext i1 %tmpVar7 to i8
%19 = icmp ne i8 %18, 0
br label %20

20: ; preds = %17, %10
%21 = phi i1 [ %12, %10 ], [ %19, %17 ]
%22 = zext i1 %21 to i8
%23 = icmp ne i8 %22, 0
br label %13
}
Loading

0 comments on commit 745fd10

Please sign in to comment.