Skip to content

Commit

Permalink
flags: fix the flag inference mechanism (#56)
Browse files Browse the repository at this point in the history
* flags: fix the flag inference mechanism that could easily generate invalid flags

* flags: fix the default of update_flags and test using update_flags

* docs: update the changelog and api docs

* travis: avoid building branches and improve tests

* coverage: ignore coverage of the test files
  • Loading branch information
MatthieuDartiailh committed Mar 2, 2020
1 parent 54d9ad7 commit 7c8c768
Show file tree
Hide file tree
Showing 14 changed files with 363 additions and 59 deletions.
14 changes: 14 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[run]
omit =
setup.py
bytecode/tests/*

[report]
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
pragma: no cover

# Don't complain if tests don't hit defensive assertion code:
raise NotImplementedError
pass
4 changes: 4 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ language: python
dist: xenial
cache: pip

branches:
only:
- master

matrix:
include:
- python: 3.5
Expand Down
6 changes: 3 additions & 3 deletions bytecode/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def __eq__(self, other):
return False
if self.kwonlyargcount != other.kwonlyargcount:
return False
if self.compute_stacksize() != other.compute_stacksize():
return False
if self.flags != other.flags:
return False
if self.first_lineno != other.first_lineno:
Expand All @@ -62,6 +60,8 @@ def __eq__(self, other):
return False
if self.freevars != other.freevars:
return False
if self.compute_stacksize() != other.compute_stacksize():
return False

return True

Expand All @@ -75,7 +75,7 @@ def flags(self, value):
value = _bytecode.CompilerFlags(value)
self._flags = value

def update_flags(self, *, is_async=False):
def update_flags(self, *, is_async=None):
self.flags = infer_flags(self, is_async)


Expand Down
112 changes: 86 additions & 26 deletions bytecode/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,35 @@ class CompilerFlags(IntFlag):
GENERATOR = 0x00020 # noqa
NOFREE = 0x00040 # noqa
# New in Python 3.5
# Used for coroutines defined using async def ie native coroutine
COROUTINE = 0x00080 # noqa
# Used for coroutines defined as a generator and then decorated using
# types.coroutine
ITERABLE_COROUTINE = 0x00100 # noqa
# New in Python 3.6
# Generator defined in an async def function
ASYNC_GENERATOR = 0x00200 # noqa

# __future__ flags
FUTURE_GENERATOR_STOP = 0x80000 # noqa


def infer_flags(bytecode, is_async=False):
def infer_flags(bytecode, is_async=None):
"""Infer the proper flags for a bytecode based on the instructions.
Because the bytecode does not have enough context to guess if a function
is asynchronous the algorithm tries to be conservative and will never turn
a previously async code into a sync one.
Parameters
----------
bytecode : Bytecode | ConcreteBytecode | ControlFlowGraph
Bytecode for which to infer the proper flags
is_async : bool | None, optional
Force the code to be marked as asynchronous if True, prevent it from
being marked as asynchronous if False and simply infer the best
solution based on the opcode and the existing flag if None.
"""
flags = CompilerFlags(0)
if not isinstance(bytecode, (_bytecode.Bytecode,
Expand All @@ -49,41 +66,84 @@ def infer_flags(bytecode, is_async=False):
if not isinstance(i, (_bytecode.SetLineno,
_bytecode.Label))}

# Identify optimized code
if not (instr_names & {'STORE_NAME', 'LOAD_NAME', 'DELETE_NAME'}):
flags |= CompilerFlags.OPTIMIZED

# Check for free variables
if not (instr_names & {'LOAD_CLOSURE', 'LOAD_DEREF', 'STORE_DEREF',
'DELETE_DEREF', 'LOAD_CLASSDEREF'}):
flags |= CompilerFlags.NOFREE

# Copy flags for which we cannot infer the right value
flags |= bytecode.flags & (CompilerFlags.NEWLOCALS
| CompilerFlags.VARARGS
| CompilerFlags.VARKEYWORDS
| CompilerFlags.NESTED)

if instr_names & {'YIELD_VALUE', 'YIELD_FROM'}:
if not is_async and not bytecode.flags & CompilerFlags.ASYNC_GENERATOR:
flags |= CompilerFlags.GENERATOR
sure_generator = instr_names & {'YIELD_VALUE'}
maybe_generator = instr_names & {'YIELD_VALUE', 'YIELD_FROM'}

sure_async = instr_names & {'GET_AWAITABLE', 'GET_AITER', 'GET_ANEXT',
'BEFORE_ASYNC_WITH', 'SETUP_ASYNC_WITH',
'END_ASYNC_FOR'}

# If performing inference or forcing an async behavior, first inspect
# the flags since this is the only way to identify iterable coroutines
if is_async in (None, True):

if bytecode.flags & CompilerFlags.COROUTINE:
if sure_generator:
flags |= CompilerFlags.ASYNC_GENERATOR
else:
flags |= CompilerFlags.COROUTINE
elif bytecode.flags & CompilerFlags.ITERABLE_COROUTINE:
if sure_async:
msg = ("The ITERABLE_COROUTINE flag is set but bytecode that"
"can only be used in async functions have been "
"detected. Please unset that flag before performing "
"inference.")
raise ValueError(msg)
flags |= CompilerFlags.ITERABLE_COROUTINE
elif bytecode.flags & CompilerFlags.ASYNC_GENERATOR:
if not sure_generator:
flags |= CompilerFlags.COROUTINE
else:
flags |= CompilerFlags.ASYNC_GENERATOR

# If the code was not asynchronous before determine if it should now be
# asynchronous based on the opcode and the is_async argument.
else:
flags |= CompilerFlags.ASYNC_GENERATOR

if not (instr_names & {'LOAD_CLOSURE', 'LOAD_DEREF', 'STORE_DEREF',
'DELETE_DEREF', 'LOAD_CLASSDEREF'}):
flags |= CompilerFlags.NOFREE

if (not (bytecode.flags & CompilerFlags.ITERABLE_COROUTINE
or flags & CompilerFlags.ASYNC_GENERATOR)
and (instr_names & {'GET_AWAITABLE', 'GET_AITER', 'GET_ANEXT',
'BEFORE_ASYNC_WITH', 'SETUP_ASYNC_WITH'}
or bytecode.flags & CompilerFlags.COROUTINE)):
flags |= CompilerFlags.COROUTINE

flags |= bytecode.flags & CompilerFlags.ITERABLE_COROUTINE
if sure_async:
# YIELD_FROM is not allowed in async generator
if sure_generator:
flags |= CompilerFlags.ASYNC_GENERATOR
else:
flags |= CompilerFlags.COROUTINE

elif maybe_generator:
if is_async:
if sure_generator:
flags |= CompilerFlags.ASYNC_GENERATOR
else:
flags |= CompilerFlags.COROUTINE
else:
flags |= CompilerFlags.GENERATOR

elif is_async:
flags |= CompilerFlags.COROUTINE

# If the code should not be asynchronous, check first it is possible and
# next set the GENERATOR flag if relevant
else:
if sure_async:
raise ValueError("The is_async argument is False but bytecodes "
"that can only be used in async functions have "
"been detected.")

if maybe_generator:
flags |= CompilerFlags.GENERATOR

flags |= bytecode.flags & CompilerFlags.FUTURE_GENERATOR_STOP

if ([bool(flags & getattr(CompilerFlags, k))
for k in ('COROUTINE', 'ITERABLE_COROUTINE', 'GENERATOR',
'ASYNC_GENERATOR')].count(True) > 1):
raise ValueError("Code should not have more than one of the "
"following flag set : generator, coroutine, "
"iterable coroutine and async generator, got:"
"%s" % flags)

return flags
2 changes: 1 addition & 1 deletion bytecode/tests/test_bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ def test_to_code(self):


if __name__ == "__main__":
unittest.main()
unittest.main() # pragma: no cover
35 changes: 34 additions & 1 deletion bytecode/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def test_legalize(self):
Instr("LOAD_CONST", 9, lineno=5),
Instr("STORE_NAME", 'z', lineno=5)])

def test_repr(self):
r = repr(ControlFlowGraph())
self.assertIn("ControlFlowGraph", r)
self.assertIn("1", r)

def test_to_bytecode(self):
# if test:
# x = 2
Expand Down Expand Up @@ -379,13 +384,32 @@ def test_eq(self):
code2 = disassemble(source)
self.assertEqual(code1, code2)

# Type mismatch
self.assertFalse(code1 == 1)

# argnames mismatch
cfg = ControlFlowGraph()
cfg.argnames = 10
self.assertFalse(code1 == cfg)

# instr mismatch
cfg = ControlFlowGraph()
cfg.argnames = code1.argnames
self.assertFalse(code1 == cfg)

def check_getitem(self, code):
# check internal Code block indexes (index by index, index by label)
for block_index, block in enumerate(code):
self.assertIs(code[block_index], block)
self.assertIs(code[block], block)
self.assertEqual(code.get_block_index(block), block_index)

def test_delitem(self):
cfg = ControlFlowGraph()
b = cfg.add_block()
del cfg[b]
self.assertEqual(len(cfg.get_instructions()), 0)

def sample_code(self):
code = disassemble('x = 1', remove_last_return_none=True)
self.assertBlocksEqual(code,
Expand Down Expand Up @@ -413,6 +437,13 @@ def test_split_block(self):
[Instr('NOP', lineno=1)])
self.check_getitem(code)

with self.assertRaises(TypeError):
code.split_block(1, 1)

with self.assertRaises(ValueError) as e:
code.split_block(code[0], -2)
self.assertIn("positive", e.exception.args[0])

def test_split_block_end(self):
code = self.sample_code()

Expand Down Expand Up @@ -540,7 +571,9 @@ def check_stack_size(self, func):
self.assertEqual(code.co_stacksize, cfg.compute_stacksize())

def test_empty_code(self):
self.assertEqual(ControlFlowGraph().compute_stacksize(), 0)
cfg = ControlFlowGraph()
del cfg[0]
self.assertEqual(cfg.compute_stacksize(), 0)

def test_handling_of_set_lineno(self):
code = Bytecode()
Expand Down
2 changes: 1 addition & 1 deletion bytecode/tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def func(*, arg, arg2):


if __name__ == "__main__":
unittest.main()
unittest.main() # pragma: no cover
35 changes: 34 additions & 1 deletion bytecode/tests/test_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest
import textwrap
from bytecode import (UNSET, Label, Instr, SetLineno, Bytecode,
CellVar, FreeVar,
CellVar, FreeVar, CompilerFlags,
ConcreteInstr, ConcreteBytecode)
from bytecode.tests import get_code, TestCase, WORDCODE

Expand Down Expand Up @@ -162,6 +162,39 @@ def test_get_jump_target(self):

class ConcreteBytecodeTests(TestCase):

def test_repr(self):
r = repr(ConcreteBytecode())
self.assertIn("ConcreteBytecode", r)
self.assertIn("0", r)

def test_eq(self):
code = ConcreteBytecode()
self.assertFalse(code == 1)

for name, val in (("names", ["a"]), ("varnames", ["a"]),
("consts", [1]),
("argcount", 1), ("kwonlyargcount", 2),
("flags", CompilerFlags(CompilerFlags.GENERATOR)),
("first_lineno", 10), ("filename", "xxxx.py"),
("name", "__x"), ("docstring", "x-x-x"),
("cellvars", [CellVar("x")]),
("freevars", [FreeVar("x")])):
c = ConcreteBytecode()
setattr(c, name, val)
# For obscure reasons using assertNotEqual here fail
self.assertFalse(code == c)

if sys.version_info > (3, 8):
c = ConcreteBytecode()
c.posonlyargcount = 10
self.assertFalse(code == c)

c = ConcreteBytecode()
c.consts = [1]
code.consts = [1]
c.append(ConcreteInstr("LOAD_CONST", 0))
self.assertFalse(code == c)

def test_attr(self):
code_obj = get_code("x = 5")
code = ConcreteBytecode.from_code(code_obj)
Expand Down

0 comments on commit 7c8c768

Please sign in to comment.