-
Notifications
You must be signed in to change notification settings - Fork 8
/
collapse_blocks.py
151 lines (126 loc) · 5.69 KB
/
collapse_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import contextlib
import itertools
from typing import Iterable
import attrs
import structlog
from puya.context import CompileContext
from puya.ir import models
from puya.ir.visitor_mutator import IRMutator
from puya.utils import unique
logger: structlog.typing.FilteringBoundLogger = structlog.get_logger(__name__)
@attrs.define
class BlockReferenceReplacer(IRMutator):
find: models.BasicBlock
replacement: models.BasicBlock
@classmethod
def apply(
cls,
find: models.BasicBlock,
replacement: models.BasicBlock,
blocks: Iterable[models.BasicBlock],
) -> None:
replacer = cls(find=find, replacement=replacement)
for block in blocks:
replacer.visit_block(block)
def visit_block(self, block: models.BasicBlock) -> None:
super().visit_block(block)
if self.find in block.predecessors:
block.predecessors = [
self.replacement if b is self.find else b for b in block.predecessors
]
logger.debug(f"Replaced predecessor {self.find} with {self.replacement} in {block}")
def visit_phi_argument(self, arg: models.PhiArgument) -> models.PhiArgument:
if arg.through == self.find:
arg.through = self.replacement
return arg
def visit_conditional_branch(self, branch: models.ConditionalBranch) -> models.ControlOp:
if branch.zero == self.find:
branch.zero = self.replacement
if branch.non_zero == self.find:
branch.non_zero = self.replacement
return _replace_single_target_with_goto(branch)
def visit_goto(self, goto: models.Goto) -> models.Goto:
if goto.target == self.find:
goto.target = self.replacement
return goto
def visit_goto_nth(self, goto_nth: models.GotoNth) -> models.ControlOp:
for index, block in enumerate(goto_nth.blocks):
if block == self.find:
goto_nth.blocks[index] = self.replacement
goto_nth.default = goto_nth.default.accept(self)
return _replace_single_target_with_goto(goto_nth)
def visit_switch(self, switch: models.Switch) -> models.ControlOp:
for case, target in switch.cases.items():
if target == self.find:
switch.cases[case] = self.replacement
switch.default = switch.default.accept(self)
return _replace_single_target_with_goto(switch)
def _replace_single_target_with_goto(terminator: models.ControlOp) -> models.ControlOp:
"""
If a ControlOp has a single target, replace it with a Goto, otherwise return the original op.
"""
match terminator:
case models.ControlOp(unique_targets=[single_target], can_exit=False):
replacement = models.Goto(
source_location=terminator.source_location,
target=single_target,
)
logger.debug(f"replaced {terminator} with {replacement}")
return replacement
case _:
return terminator
def remove_linear_jump(_context: CompileContext, subroutine: models.Subroutine) -> bool:
changes = False
for block in subroutine.body[1:]:
match block.predecessors:
case [models.BasicBlock(terminator=models.Goto(target=successor)) as predecessor]:
assert successor is block
# can merge blocks when there is an unconditional jump between them
predecessor.phis.extend(block.phis)
predecessor.ops.extend(block.ops)
# this will update the predecessors of all block.successors to
# now point back to predecessor e.g.
# predecessor <-> block <-> [ss1, ...]
# predecessor <-> [ss1, ...]
BlockReferenceReplacer.apply(
find=block, replacement=predecessor, blocks=block.successors
)
predecessor.terminator = block.terminator
# update block to reflect modifications
subroutine.body.remove(block)
changes = True
logger.debug(f"Merged linear {block} into {predecessor}")
return changes
def remove_empty_blocks(_context: CompileContext, subroutine: models.Subroutine) -> bool:
changes = False
for block in subroutine.body.copy():
if not block.phis and not block.ops and isinstance(block.terminator, models.Goto):
empty_block = block
target = block.terminator.target
if target.phis:
logger.debug(
f"Not removing empty block {empty_block} because it's used by phi nodes"
)
continue
# this will replace any ops that pointed to block
BlockReferenceReplacer.apply(
find=empty_block,
replacement=target,
blocks=empty_block.predecessors,
)
# remove the empty block from the targets predecessors, and add and of the empty block
# predecessors that aren't already present
target.predecessors = unique(
itertools.chain(empty_block.predecessors, target.predecessors)
)
# might have already been replaced by BlockReferenceReplacer
with contextlib.suppress(ValueError):
target.predecessors.remove(empty_block)
if empty_block is subroutine.body[0]:
# place target at start of body so it's now the new entry block
subroutine.body.remove(target)
subroutine.body.insert(0, target)
subroutine.body.remove(empty_block)
changes = True
logger.debug(f"Removed empty block: {empty_block}")
return changes