Skip to content

Commit

Permalink
Merge pull request #52 from vstinner/copy-slice-legalize
Browse files Browse the repository at this point in the history
Allow to copy and slice Bytecode
  • Loading branch information
MatthieuDartiailh committed Jan 29, 2020
2 parents 7be986f + 4c60dff commit 2d93b01
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 15 deletions.
2 changes: 1 addition & 1 deletion bytecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bytecode.flags import CompilerFlags
from bytecode.instr import (UNSET, Label, SetLineno, Instr, CellVar, FreeVar, # noqa
Compare)
from bytecode.bytecode import BaseBytecode, _InstrList, Bytecode # noqa
from bytecode.bytecode import BaseBytecode, _BaseBytecodeList, _InstrList, Bytecode # noqa
from bytecode.concrete import (ConcreteInstr, ConcreteBytecode, # noqa
# import needed to use it in bytecode.py
_ConvertBytecodeToConcrete)
Expand Down
60 changes: 59 additions & 1 deletion bytecode/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,59 @@ def update_flags(self, *, is_async=False):
self.flags = infer_flags(self, is_async)


class _BaseBytecodeList(BaseBytecode, list):
"""List subclass providing type stable slicing and copying.
"""
def __getitem__(self, index):
value = super().__getitem__(index)
if isinstance(index, slice):
value = type(self)(value)
value._copy_attr_from(self)

return value

def copy(self):
new = type(self)(super().copy())
new._copy_attr_from(self)
return new

def legalize(self):
"""Check that all the element of the list are valid and remove SetLineno.
"""
lineno_pos = []
set_lineno = None
current_lineno = self.first_lineno

for pos, instr in enumerate(self):
if isinstance(instr, SetLineno):
set_lineno = instr.lineno
lineno_pos.append(pos)
continue
# Filter out Labels
if not isinstance(instr, Instr):
continue
if set_lineno is not None:
instr.lineno = set_lineno
elif instr.lineno is None:
instr.lineno = current_lineno
else:
current_lineno = instr.lineno

for i in reversed(lineno_pos):
del self[i]

def __iter__(self):
instructions = super().__iter__()
for instr in instructions:
self._check_instr(instr)
yield instr

def _check_instr(self, instr):
raise NotImplementedError()


class _InstrList(list):

def _flat(self):
Expand Down Expand Up @@ -112,7 +165,7 @@ def __eq__(self, other):
return (self._flat() == other._flat())


class Bytecode(_InstrList, BaseBytecode):
class Bytecode(_InstrList, _BaseBytecodeList):

def __init__(self, instructions=()):
BaseBytecode.__init__(self)
Expand All @@ -135,6 +188,11 @@ def _check_instr(self, instr):
"but %s was found"
% type(instr).__name__)

def _copy_attr_from(self, bytecode):
super()._copy_attr_from(bytecode)
if isinstance(bytecode, Bytecode):
self.argnames = bytecode.argnames

@staticmethod
def from_code(code):
concrete = _bytecode.ConcreteBytecode.from_code(code)
Expand Down
46 changes: 46 additions & 0 deletions bytecode/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,44 @@ def __iter__(self):

yield instr

def __getitem__(self, index):
value = super().__getitem__(index)
if isinstance(index, slice):
value = type(self)(value)
value.next_block = self.next_block

return value

def copy(self):
new = type(self)(super().copy())
new.next_block = self.next_block
return new

def legalize(self, first_lineno):
"""Check that all the element of the list are valid and remove SetLineno.
"""
lineno_pos = []
set_lineno = None
current_lineno = first_lineno

for pos, instr in enumerate(self):
if isinstance(instr, SetLineno):
set_lineno = current_lineno = instr.lineno
lineno_pos.append(pos)
continue
if set_lineno is not None:
instr.lineno = set_lineno
elif instr.lineno is None:
instr.lineno = current_lineno
else:
current_lineno = instr.lineno

for i in reversed(lineno_pos):
del self[i]

return current_lineno

def get_jump(self):
if not self:
return None
Expand Down Expand Up @@ -98,6 +136,14 @@ def __init__(self):

self.add_block()

def legalize(self):
"""Legalize all blocks.
"""
current_lineno = self.first_lineno
for block in self._blocks:
current_lineno = block.legalize(current_lineno)

def get_block_index(self, block):
try:
return self._block_index[id(block)]
Expand Down
9 changes: 8 additions & 1 deletion bytecode/concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def disassemble(cls, lineno, code, offset):
return cls(name, arg, lineno=lineno)


class ConcreteBytecode(_bytecode.BaseBytecode, list):
class ConcreteBytecode(_bytecode._BaseBytecodeList):

def __init__(self, instructions=(), *, consts=(), names=(), varnames=()):
super().__init__()
Expand All @@ -156,6 +156,13 @@ def _check_instr(self, instr):
"but %s was found"
% type(instr).__name__)

def _copy_attr_from(self, bytecode):
super()._copy_attr_from(bytecode)
if isinstance(bytecode, ConcreteBytecode):
self.consts = bytecode.consts
self.names = bytecode.names
self.varnames = bytecode.varnames

def __repr__(self):
return '<ConcreteBytecode instr#=%s>' % len(self)

Expand Down
74 changes: 68 additions & 6 deletions bytecode/tests/test_bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,71 @@ def test_invalid_types(self):
code.append(123)
with self.assertRaises(ValueError):
list(code)
with self.assertRaises(ValueError):
code.legalize()
with self.assertRaises(ValueError):
Bytecode([123])

def test_legalize(self):
code = Bytecode()
code.first_lineno = 3
code.extend([Instr("LOAD_CONST", 7),
Instr("STORE_NAME", 'x'),
Instr("LOAD_CONST", 8, lineno=4),
Instr("STORE_NAME", 'y'),
Label(),
SetLineno(5),
Instr("LOAD_CONST", 9, lineno=6),
Instr("STORE_NAME", 'z')])

code.legalize()
self.assertListEqual(code, [Instr("LOAD_CONST", 7, lineno=3),
Instr("STORE_NAME", "x", lineno=3),
Instr("LOAD_CONST", 8, lineno=4),
Instr("STORE_NAME", "y", lineno=4),
Label(),
Instr("LOAD_CONST", 9, lineno=5),
Instr("STORE_NAME", "z", lineno=5)])

def test_slice(self):
code = Bytecode()
code.first_lineno = 3
code.extend([Instr("LOAD_CONST", 7),
Instr("STORE_NAME", 'x'),
SetLineno(4),
Instr("LOAD_CONST", 8),
Instr("STORE_NAME", 'y'),
SetLineno(5),
Instr("LOAD_CONST", 9),
Instr("STORE_NAME", 'z')])
sliced_code = code[:]
self.assertEqual(code, sliced_code)
for name in ("argcount", "posonlyargcount", "kwonlyargcount", "first_lineno",
"name", "filename", "docstring", "cellvars", "freevars",
"argnames"):
self.assertEqual(getattr(code, name, None),
getattr(sliced_code, name, None))

def test_copy(self):
code = Bytecode()
code.first_lineno = 3
code.extend([Instr("LOAD_CONST", 7),
Instr("STORE_NAME", 'x'),
SetLineno(4),
Instr("LOAD_CONST", 8),
Instr("STORE_NAME", 'y'),
SetLineno(5),
Instr("LOAD_CONST", 9),
Instr("STORE_NAME", 'z')])

copy_code = code.copy()
self.assertEqual(code, copy_code)
for name in ("argcount", "posonlyargcount", "kwonlyargcount", "first_lineno",
"name", "filename", "docstring", "cellvars", "freevars",
"argnames"):
self.assertEqual(getattr(code, name, None),
getattr(copy_code, name, None))

def test_from_code(self):
code = get_code("""
if test:
Expand Down Expand Up @@ -97,12 +159,12 @@ def test_setlineno(self):
concrete = code.to_concrete_bytecode()
self.assertEqual(concrete.consts, [7, 8, 9])
self.assertEqual(concrete.names, ['x', 'y', 'z'])
code.extend([ConcreteInstr("LOAD_CONST", 0, lineno=3),
ConcreteInstr("STORE_NAME", 0, lineno=3),
ConcreteInstr("LOAD_CONST", 1, lineno=4),
ConcreteInstr("STORE_NAME", 1, lineno=4),
ConcreteInstr("LOAD_CONST", 2, lineno=5),
ConcreteInstr("STORE_NAME", 2, lineno=5)])
self.assertListEqual(list(concrete), [ConcreteInstr("LOAD_CONST", 0, lineno=3),
ConcreteInstr("STORE_NAME", 0, lineno=3),
ConcreteInstr("LOAD_CONST", 1, lineno=4),
ConcreteInstr("STORE_NAME", 1, lineno=4),
ConcreteInstr("LOAD_CONST", 2, lineno=5),
ConcreteInstr("STORE_NAME", 2, lineno=5)])

def test_to_code(self):
code = Bytecode()
Expand Down
41 changes: 41 additions & 0 deletions bytecode/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_iter_invalid_types(self):
block.append(Label())
with self.assertRaises(ValueError):
list(block)
with self.assertRaises(ValueError):
block.legalize(1)

# Only one jump allowed and only at the end
block = BasicBlock()
Expand All @@ -42,13 +44,31 @@ def test_iter_invalid_types(self):
Instr('NOP')])
with self.assertRaises(ValueError):
list(block)
with self.assertRaises(ValueError):
block.legalize(1)

# jump target must be a BasicBlock
block = BasicBlock()
label = Label()
block.extend([Instr('JUMP_ABSOLUTE', label)])
with self.assertRaises(ValueError):
list(block)
with self.assertRaises(ValueError):
block.legalize(1)

def test_slice(self):
block = BasicBlock([Instr("NOP")])
next_block = BasicBlock()
block.next_block = next_block
self.assertEqual(block, block[:])
self.assertIs(next_block, block[:].next_block)

def test_copy(self):
block = BasicBlock([Instr("NOP")])
next_block = BasicBlock()
block.next_block = next_block
self.assertEqual(block, block.copy())
self.assertIs(next_block, block.copy().next_block)


class BytecodeBlocksTests(TestCase):
Expand Down Expand Up @@ -135,6 +155,27 @@ def test_setlineno(self):
Instr("LOAD_CONST", 9),
Instr("STORE_NAME", 'z')])

def test_legalize(self):
code = Bytecode()
code.first_lineno = 3
code.extend([Instr("LOAD_CONST", 7),
Instr("STORE_NAME", 'x'),
Instr("LOAD_CONST", 8, lineno=4),
Instr("STORE_NAME", 'y'),
SetLineno(5),
Instr("LOAD_CONST", 9, lineno=6),
Instr("STORE_NAME", 'z')])

blocks = ControlFlowGraph.from_bytecode(code)
blocks.legalize()
self.assertBlocksEqual(blocks,
[Instr("LOAD_CONST", 7, lineno=3),
Instr("STORE_NAME", 'x', lineno=3),
Instr("LOAD_CONST", 8, lineno=4),
Instr("STORE_NAME", 'y', lineno=4),
Instr("LOAD_CONST", 9, lineno=5),
Instr("STORE_NAME", 'z', lineno=5)])

def test_to_bytecode(self):
# if test:
# x = 2
Expand Down
53 changes: 53 additions & 0 deletions bytecode/tests/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def test_invalid_types(self):
code.append(Label())
with self.assertRaises(ValueError):
list(code)
with self.assertRaises(ValueError):
code.legalize()
with self.assertRaises(ValueError):
ConcreteBytecode([Label()])

Expand Down Expand Up @@ -379,6 +381,57 @@ def test_explicit_stacksize(self):
new_code_obj = concrete.to_code(stacksize=explicit_stacksize)
self.assertEqual(new_code_obj.co_stacksize, explicit_stacksize)

def test_legalize(self):
concrete = ConcreteBytecode()
concrete.first_lineno = 3
concrete.consts = [7, 8, 9]
concrete.names = ['x', 'y', 'z']
concrete.extend([ConcreteInstr("LOAD_CONST", 0),
ConcreteInstr("STORE_NAME", 0),
ConcreteInstr("LOAD_CONST", 1, lineno=4),
ConcreteInstr("STORE_NAME", 1),
SetLineno(5),
ConcreteInstr("LOAD_CONST", 2, lineno=6),
ConcreteInstr("STORE_NAME", 2)])

concrete.legalize()
self.assertListEqual(list(concrete), [ConcreteInstr("LOAD_CONST", 0, lineno=3),
ConcreteInstr("STORE_NAME", 0, lineno=3),
ConcreteInstr("LOAD_CONST", 1, lineno=4),
ConcreteInstr("STORE_NAME", 1, lineno=4),
ConcreteInstr("LOAD_CONST", 2, lineno=5),
ConcreteInstr("STORE_NAME", 2, lineno=5)])

def test_slice(self):
concrete = ConcreteBytecode()
concrete.first_lineno = 3
concrete.consts = [7, 8, 9]
concrete.names = ['x', 'y', 'z']
concrete.extend([ConcreteInstr("LOAD_CONST", 0),
ConcreteInstr("STORE_NAME", 0),
SetLineno(4),
ConcreteInstr("LOAD_CONST", 1),
ConcreteInstr("STORE_NAME", 1),
SetLineno(5),
ConcreteInstr("LOAD_CONST", 2),
ConcreteInstr("STORE_NAME", 2)])
self.assertEqual(concrete, concrete[:])

def test_copy(self):
concrete = ConcreteBytecode()
concrete.first_lineno = 3
concrete.consts = [7, 8, 9]
concrete.names = ['x', 'y', 'z']
concrete.extend([ConcreteInstr("LOAD_CONST", 0),
ConcreteInstr("STORE_NAME", 0),
SetLineno(4),
ConcreteInstr("LOAD_CONST", 1),
ConcreteInstr("STORE_NAME", 1),
SetLineno(5),
ConcreteInstr("LOAD_CONST", 2),
ConcreteInstr("STORE_NAME", 2)])
self.assertEqual(concrete, concrete.copy())


class ConcreteFromCodeTests(TestCase):

Expand Down

0 comments on commit 2d93b01

Please sign in to comment.