-
Notifications
You must be signed in to change notification settings - Fork 15
/
repeated_rotations_search.py
126 lines (102 loc) · 4.17 KB
/
repeated_rotations_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import functools
import itertools
import typing
from typing import Sequence
import attrs
import structlog
from puya.teal import models
TealOpSequence = tuple[models.TealOpN, ...]
logger = structlog.get_logger(__name__)
class InvalidOpSequenceError(Exception):
pass
@attrs.define
class TealStack:
stack: list[int]
@classmethod
def from_stack_size(cls, stack_size: int) -> "TealStack":
return cls(stack=list(range(stack_size)))
def apply(self, ops: Sequence[models.TealOpN]) -> "TealStack":
stack = TealStack(self.stack.copy())
for op in ops:
n = op.n
if n:
index = len(stack.stack) - n - 1
if index < 0 or index >= len(stack.stack):
raise InvalidOpSequenceError
match op.op_code:
case "cover":
stack.stack.insert(index, stack.stack.pop())
case "uncover":
stack.stack.append(stack.stack.pop(index))
case _:
raise InvalidOpSequenceError
return stack
@functools.cache
def simplify_rotation_ops(original_ops: TealOpSequence) -> TealOpSequence | None:
num_rot_ops = len(original_ops)
max_rot_op_n = 0
for o in original_ops:
max_rot_op_n = max(max_rot_op_n, o.n)
original_stack = TealStack.from_stack_size(max_rot_op_n + 1)
expected = original_stack.apply(original_ops)
# entire sequence can be removed!
if expected == original_stack:
return ()
possible_rotation_ops = get_possible_rotation_ops(max_rot_op_n)
original_stack_result = original_stack.apply(original_ops)
# TODO: use a non-bruteforce approach and/or capture common simplifications as data
for num_rotation_ops in range(num_rot_ops):
for maybe_ops in itertools.permutations(possible_rotation_ops, num_rotation_ops):
try:
stack = original_stack.apply(maybe_ops)
except InvalidOpSequenceError:
continue
if expected == stack:
assert original_stack_result == original_stack.apply(maybe_ops)
return tuple(attrs.evolve(op, source_location=None) for op in maybe_ops)
return None
@functools.cache
def get_possible_rotation_ops(n: int) -> TealOpSequence:
possible_ops = list[models.TealOpN]()
for i in range(1, n + 1):
possible_ops.append(models.Cover(i, source_location=None))
possible_ops.append(models.Uncover(i, source_location=None))
return tuple(possible_ops)
ROTATION_SIMPLIFY_OPS = frozenset(
[
"cover",
"uncover",
]
)
def repeated_rotation_ops_search(teal_ops: list[models.TealOp]) -> list[models.TealOp]:
maybe_remove_rotations = list[models.TealOpN]()
result = list[models.TealOp]()
for teal_op in teal_ops:
if teal_op.op_code in ROTATION_SIMPLIFY_OPS:
maybe_remove_rotations.append(typing.cast(models.TealOpN, teal_op))
else:
maybe_simplified = _maybe_simplified(maybe_remove_rotations)
maybe_remove_rotations = []
result.extend(maybe_simplified)
result.append(teal_op)
result.extend(_maybe_simplified(maybe_remove_rotations))
return result
def _maybe_simplified(
maybe_remove_rotations: list[models.TealOpN], window_size: int = 5
) -> Sequence[models.TealOpN]:
if len(maybe_remove_rotations) < 2:
return maybe_remove_rotations
for start_idx in range(len(maybe_remove_rotations) - 1):
window = maybe_remove_rotations[start_idx : start_idx + window_size + 1]
simplified = simplify_rotation_ops(tuple(window))
if simplified is not None:
logger.debug(
f"Replaced '{'; '.join(map(str, maybe_remove_rotations))}'"
f" with '{'; '.join(map(str, simplified))}',"
f" reducing by {len(maybe_remove_rotations) - len(simplified)} ops by search"
)
result_ = maybe_remove_rotations.copy()
result_[start_idx : start_idx + window_size + 1] = simplified
assert result_ != maybe_remove_rotations
return result_
return maybe_remove_rotations