Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 308 lines (259 sloc) 9.9 KB
#!/usr/bin/python3
import argparse
import multiprocessing
import contextlib
from collections import defaultdict
import os
import io
import sys
import warnings
import shutil
import logging
import traceback
import tempfile
import ast
from termcolor import cprint
import importlib
import json
import docutils
import docutils.core
import docutils.nodes
log = logging.getLogger()
class Fail(Exception):
pass
class RunEnv:
"""
Test environment for python tests
"""
TEST_BUFR = "extra/bufr/synop-sunshine.bufr"
def __init__(self, entry):
self.entry = entry
self.globals = dict(globals())
self.locals = dict(locals())
self.assigned = set()
self.wanted = set()
self.stdout = io.StringIO()
self.stderr = io.StringIO()
# Import modules expected by the snippets
for modname in ("dballe", "wreport", "os", "datetime"):
self.globals[modname] = importlib.__import__(modname)
def check_reqs(self):
"""
Check if the snippet expects some well-known variables to be defined
"""
tree = ast.parse(self.entry["code"], self.entry["src"], "exec")
for node in ast.walk(tree):
if isinstance(node, ast.Assign):
if len(node.targets) != 1:
continue
if not isinstance(node.targets[0], ast.Name):
continue
self.assigned.add(node.targets[0].id)
elif isinstance(node, ast.Attribute):
if not isinstance(node.value, ast.Name):
continue
if node.value.id in self.assigned:
continue
self.wanted.add(node.value.id)
elif isinstance(node, ast.With):
if len(node.items) != 1:
continue
self.assigned.add(node.items[0].optional_vars.id)
# Provide some assertion methods the snippets can use
def assertEqual(self, a, b):
if a != b:
raise AssertionError(f"{a!r} != {b!r}")
@contextlib.contextmanager
def setup(self, capture_output=True):
import dballe
# Build a test environment in a temporary directory
with tempfile.TemporaryDirectory() as root:
shutil.copy(self.TEST_BUFR, os.path.join(root, "test.bufr"))
db_url = "sqlite:" + os.path.join(root, "db.sqlite")
db = dballe.DB.connect(db_url + "?wipe=yes")
importer = dballe.Importer("BUFR")
with importer.from_file(os.path.join(root, "test.bufr")) as f:
db.import_messages(f)
# Create well-known variables if needed
if "db" in self.wanted or "tr" in self.wanted or "explorer" in self.wanted:
self.locals["db"] = db
if "explorer" in self.wanted:
self.locals["explorer"] = dballe.Explorer()
with self.locals["explorer"].rebuild() as update:
with self.locals["db"].transaction() as tr:
update.add_db(tr)
if "tr" in self.wanted:
self.locals["tr"] = self.locals["db"].transaction()
if "msg" in self.wanted:
importer = dballe.Importer("BUFR")
with importer.from_file("extra/bufr/synop-sunshine.bufr") as f:
for msg in f:
self.locals["msg"] = msg[0]
# Set env variables the snippets may use
os.environ["DBA_DB"] = db_url
# Run the code, in the temp directory, discarding stdout
curdir = os.getcwd()
if capture_output:
orig_stdout = sys.stdout
orig_stderr = sys.stderr
sys.stdout = self.stdout
sys.stderr = self.stderr
os.chdir(root)
try:
yield
finally:
if capture_output:
sys.stderr = orig_stderr
sys.stdout = orig_stdout
os.chdir(curdir)
def store_exc(self):
type, value, tb = sys.exc_info()
for frame, lineno in traceback.walk_tb(tb):
if frame.f_code.co_filename == self.entry["src"]:
self.entry["exception"] = (lineno, repr(value))
if "exception" not in self.entry:
self.entry["exception"] = None, traceback.format_exc()
def run(self):
"""
Run a compiled test snippet
"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.entry["warnings"] = w
try:
code = compile(self.entry["code"], self.entry["src"], "exec")
except Exception:
self.entry["result"] = "fail-compile"
self.store_exc()
return
self.check_reqs()
with self.setup():
try:
exec(code, self.globals, self.locals)
except Exception:
self.entry["result"] = "fail-run"
self.store_exc()
return
finally:
self.entry["stdout"] = self.stdout.getvalue()
self.entry["stderr"] = self.stderr.getvalue()
self.entry["result"] = "ok"
def run_python_test(entry):
"""
Run a collected python code snippet, annotating its information with run
results
"""
env = RunEnv(entry)
env.run()
return entry
def format_python_result(entry, verbose=False):
"""
Colorful formatting of a code snippet annotated with run results
"""
if "exception" in entry:
color = "red"
ok = False
elif entry["warnings"]:
color = "yellow"
ok = False
else:
color = "green"
ok = True
cprint(f"{entry['src']}: {entry['source']}:{entry['line']}: {entry['result']}", color, attrs=["bold"])
# Turn warnings and exceptions into line annotations
line_annotations = defaultdict(list)
if "exception" in entry:
lineno, msg = entry["exception"]
line_annotations[lineno].append((msg, "red"))
if entry["warnings"]:
for w in entry["warnings"]:
line_annotations[w.lineno].append((f"{w.category}, {w.message}", "yellow"))
show_code = line_annotations or entry.get("stderr")
if verbose and entry.get("stdout"):
show_code = True
if show_code:
for idx, line in enumerate(entry["code"].splitlines(), start=1):
annotations = line_annotations.pop(idx, None)
cprint(f" {idx:2d} {line}", "red" if annotations else "grey", attrs=["bold"])
if annotations is not None:
for msg, color in annotations:
cprint(f" ↪ {msg}", color)
for lineno, annotations in line_annotations.items():
for msg, color in annotations:
for line in msg.splitlines():
cprint(f" ↪ {line}", color)
if verbose and entry.get("stdout"):
for line in entry["stdout"].splitlines():
cprint(f" O {line}", "blue")
if entry.get("stderr"):
for line in entry["stderr"].splitlines():
cprint(f" E {line}", "blue")
if not ok:
print()
def run_tests(entries, verbose=False):
"""
Run a list of collected test snippets
"""
by_lang = defaultdict(list)
for e in entries:
lang = e["lang"]
if lang == "default":
lang = "python"
by_lang[lang].append(e)
entries = by_lang["python"]
if entries:
with multiprocessing.Pool(1) as p:
entries = p.map(run_python_test, entries)
for entry in entries:
format_python_result(entry, verbose)
# Only python supported so far
def get_code_from_rst(fname):
"""
Return the code in the first code block in the given rst
"""
with open(fname, "rt") as fd:
doctree = docutils.core.publish_doctree(
fd, source_class=docutils.io.FileInput, settings_overrides={"input_encoding": "unicode"})
for node in doctree.traverse(docutils.nodes.literal_block):
# if "dballe.DB.connect" in str(node):
for subnode in node.traverse(docutils.nodes.Text):
return {
"src": fname,
"lang": "python",
"code": subnode,
"source": node.source,
"line": node.line,
}
def main():
parser = argparse.ArgumentParser(description="Run code found in documentation code blocks")
parser.add_argument("--verbose", "-v", action="store_true", help="verbose output")
parser.add_argument("--debug", action="store_true", help="debug output")
parser.add_argument("-r", "--run", action="store_true", help="run the first code snippet found in a .rst file")
parser.add_argument("code", action="store", nargs="?", default="doc/test_code.json",
help="JSON file with the collected test code")
args = parser.parse_args()
# Setup logging
FORMAT = "%(asctime)-15s %(levelname)s %(message)s"
if args.debug:
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr, format=FORMAT)
elif args.verbose:
logging.basicConfig(level=logging.INFO, stream=sys.stderr, format=FORMAT)
else:
logging.basicConfig(level=logging.WARN, stream=sys.stderr, format=FORMAT)
if args.run:
entry = get_code_from_rst(args.code)
if entry is None:
raise Fail(f"No code found in {args.code}")
env = RunEnv(entry)
env.run()
format_python_result(entry, verbose=args.debug or args.verbose)
else:
with open(args.code, "rt") as fd:
entries = json.load(fd)
run_tests(entries, verbose=args.debug or args.verbose)
if __name__ == "__main__":
try:
sys.exit(main())
except Fail as e:
print(e, file=sys.stderr)
sys.exit(1)
You can’t perform that action at this time.