Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
15r10nk committed Jan 3, 2023
1 parent 8f641b1 commit 57d8fb0
Showing 1 changed file with 41 additions and 30 deletions.
71 changes: 41 additions & 30 deletions executing/_index_node_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dis
import math
from types import FrameType
from typing import Set
from typing import Counter, Set
from .executing import EnhancedAST, NotOneValueFound, Source, function_node_types
from ._exceptions import KnownIssue, MultipleMatches

Expand All @@ -25,13 +25,14 @@


def walk_code(code):
for inst in dis.get_instructions(code):
instructions=list(dis.get_instructions(code))
for inst in instructions:
if isinstance(inst.argval, CodeType):
yield from walk_code(inst.argval)
yield code
yield code,instructions


Match = namedtuple("Match", "indexes code")
Match = namedtuple("Match", "indexes code index_code")


def all_equal(seq):
Expand Down Expand Up @@ -126,16 +127,14 @@ def inst_repr(inst):


class OpCodes:
def __getattr__(self, name) -> int:
def __getattr__(self, _:str) -> int:
# getattr makes mypy happy

raise AttributeError()


for key, value in dis.opmap.items():
setattr(OpCodes, key, value)


opcodes = OpCodes()

hasjcond = [
Expand All @@ -156,12 +155,14 @@ def __getattr__(self, name) -> int:

class CodeMap:
def __init__(self, tree: EnhancedAST, *, rewrite=lambda tree: None):
self.tree = tree


sys.setrecursionlimit(5000)

index_tree = copy.deepcopy(tree)
original_tree = copy.deepcopy(tree)

self.tree = tree

lineno_map = {}

Expand All @@ -179,25 +180,28 @@ def __init__(self, tree: EnhancedAST, *, rewrite=lambda tree: None):

# rewrite the ast in the same way some other tools do (pytest for example)
rewrite(index_tree)
rewrite(original_tree)

# compile the code
# the inst.starts_line contains now the index
original_bc = compile(tree, "<tree>", "exec")
original_bc = compile(original_tree, "<original_tree>", "exec")
index_bc = compile(index_tree, "<index_tree>", "exec")

# create a code map where every code-block can be found
# key is not unique but a good heuristic to speed up the search
self.code_map = defaultdict(list)

self.duplicated_code=set()

self.code_key_cache = {}
index_code_key_cache = {}

for original_code, index_code in zip(
for (original_code,_) , (index_code,instructions) in zip(
walk_code(original_bc), walk_code(index_bc)
):
indexes = []
last_index = None
for inst in dis.get_instructions(index_code):
for inst in instructions:
if inst.starts_line != None:
last_index = inst.starts_line

Expand All @@ -206,6 +210,12 @@ def __init__(self, tree: EnhancedAST, *, rewrite=lambda tree: None):

indexes.append(last_index)

double_codes=Counter(inst.argval for inst in instructions if isinstance(inst.argval,CodeType) and inst.argval.co_name in ("<genexpr>","<listexpr>"))
for code,num in double_codes.items():
if num > 1:
self.duplicated_code.add(code)


original_key = self.code_key_cached(original_code, cache=self.code_key_cache)

# walk_child iterates over the childs first
Expand All @@ -216,15 +226,15 @@ def __init__(self, tree: EnhancedAST, *, rewrite=lambda tree: None):
lineno_map=lineno_map.get,
)


key_error=original_key!=index_key

if key_error:
index_key=AnyCode()

index_code_key_cache[index_code]=index_key

if not key_error:
self.code_map[original_key].append(Match(indexes, index_code))
self.code_map[original_key].append(Match(indexes, original_code,index_code))

self.code_map=dict(self.code_map)

Expand Down Expand Up @@ -358,14 +368,6 @@ def optimize_jump(inst):

val = normalize_equal(val)

# some bug with this code:
# (a and b or c)
# if opname == "JUMP_IF_FALSE_OR_POP":
# opname = "POP_JUMP_IF_FALSE"

# (a or a) and c
# if opname == "JUMP_IF_TRUE_OR_POP":
# opname = "POP_JUMP_IF_TRUE"

if isinstance(val, CodeType):
val = child_code(val)
Expand All @@ -386,8 +388,13 @@ def optimize_jump(inst):
return key

def __getitem__(self, code) -> list[Match]:

if code in self.duplicated_code:
raise KnownIssue("code object is referenced twice in the same parent code object")

key = self.code_key_cached(code, cache=self.code_key_cache)
return self.code_map.get(
self.code_key_cached(code, cache=self.code_key_cache), []
key, []
)


Expand Down Expand Up @@ -420,35 +427,39 @@ def __init__(
print(matches)

# maybe pytest has rewritten the assertions
if not matches and False:
if not matches:
code_map_pytest = CodeMap(
tree, rewrite=lambda tree: pytest_rewrite_assert(tree, source.text)
)
matches = code_map_pytest[frame.f_code]

if not matches:
# search for closest match
pytest.skip()

matches = [m for matches in code_map.code_map.values() for m in matches]
matches += [
matches = [m for matches in code_map.code_map.values() for m in matches]+[
m for matches in code_map_pytest.code_map.values() for m in matches
]

differ = difflib.SequenceMatcher()
seq1 = [inst_repr(inst) for inst in instructions]
differ.set_seq1(seq1)

seqs2 = [[line.inst for line in match] for match in matches]
seqs2 = [(match.code,[inst_repr(inst) for inst in dis.get_instructions(match.code)]) for match in matches]

def ratio(seq):
differ.set_seq2(seq)
return differ.ratio()

seqs2.sort(key=ratio)
seqs2.sort(key=lambda a: ratio(a[1]))

print("simalar key:")
print(*difflib.unified_diff(seq1, seqs2[-1]), sep="\n")
print(*difflib.unified_diff(seq1, seqs2[-1][1]), sep="\n")
print("code:")
dis.dis(frame.f_code)
print("best match:")
dis.dis(seqs2[-1][0])



raise NotOneValueFound("no match found")

Expand All @@ -461,7 +472,7 @@ def ratio(seq):
# while (t for t in s):
# pass

if all_equal(m.code for m in matches):
if all_equal(m.index_code for m in matches):
match = matches[0]
else:
# pytest.skip()
Expand Down

0 comments on commit 57d8fb0

Please sign in to comment.