diff --git a/altair_saver/_utils.py b/altair_saver/_utils.py index 68f5c16..be5bb37 100644 --- a/altair_saver/_utils.py +++ b/altair_saver/_utils.py @@ -6,7 +6,7 @@ import subprocess import sys import tempfile -from typing import IO, Iterator, List, Optional, Union +from typing import Callable, IO, Iterator, List, Optional, Union import altair as alt @@ -168,31 +168,53 @@ def extract_format(fp: Union[IO, str]) -> str: def check_output_with_stderr( - cmd: Union[str, List[str]], shell: bool = False, input: Optional[bytes] = None + cmd: Union[str, List[str]], + shell: bool = False, + input: Optional[bytes] = None, + stderr_filter: Callable[[str], bool] = None, ) -> bytes: """Run a command in a subprocess, printing stderr to sys.stderr. - Arguments are passed directly to subprocess.run(). + This function exists because normally, stderr from subprocess in the notebook + is printed to the terminal rather than to the notebook itself. - This is important because subprocess stderr in notebooks is printed to the - terminal rather than the notebook. + Parameters + ---------- + cmd, shell, input : + Arguments are passed directly to `subprocess.run()`. + stderr_filter : function(str)->bool (optional) + If provided, this function is used to filter stderr lines from display. + + Returns + ------- + result : bytes + The stdout from the command + + Raises + ------ + subprocess.CalledProcessError : if the called process returns a non-zero exit code. """ try: ps = subprocess.run( cmd, shell=shell, + input=input, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - input=input, ) except subprocess.CalledProcessError as err: - if err.stderr: - sys.stderr.write(err.stderr.decode()) - sys.stderr.flush() + stderr = err.stderr raise else: - if ps.stderr: - sys.stderr.write(ps.stderr.decode()) - sys.stderr.flush() + stderr = ps.stderr return ps.stdout + finally: + s = stderr.decode() + if stderr_filter: + s = "\n".join(filter(stderr_filter, s.splitlines())) + if s: + if not s.endswith("\n"): + s += "\n" + sys.stderr.write(s) + sys.stderr.flush() diff --git a/altair_saver/tests/test_utils.py b/altair_saver/tests/test_utils.py index 497e62c..ae5840b 100644 --- a/altair_saver/tests/test_utils.py +++ b/altair_saver/tests/test_utils.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from _pytest.capture import SysCaptureBinary +from _pytest.capture import SysCapture from altair_saver.types import JSONDict from altair_saver._utils import ( @@ -133,20 +133,27 @@ def test_infer_mode_from_spec(mode: str, spec: JSONDict) -> None: assert infer_mode_from_spec(spec) == mode -def test_check_output_with_stderr(capsysbinary: SysCaptureBinary) -> None: - output = check_output_with_stderr( - r'>&2 echo "the error" && echo "the output"', shell=True - ) - assert output == b"the output\n" - captured = capsysbinary.readouterr() - assert captured.out == b"" - assert captured.err == b"the error\n" +@pytest.mark.parametrize("cmd_error", [True, False]) +@pytest.mark.parametrize("use_filter", [True, False]) +def test_check_output_with_stderr( + capsys: SysCapture, use_filter: bool, cmd_error: bool +) -> None: + cmd = r'>&2 echo "first error\nsecond error" && echo "the output"' + stderr_filter = None if not use_filter else lambda line: line.startswith("second") + if cmd_error: + cmd += r" && exit 1" + with pytest.raises(subprocess.CalledProcessError) as err: + check_output_with_stderr(cmd, shell=True, stderr_filter=stderr_filter) + assert err.value.stderr == b"first error\nsecond error\n" + else: + output = check_output_with_stderr(cmd, shell=True, stderr_filter=stderr_filter) + assert output == b"the output\n" + + captured = capsys.readouterr() + assert captured.out == "" -def test_check_output_with_stderr_exit_1(capsysbinary: SysCaptureBinary) -> None: - with pytest.raises(subprocess.CalledProcessError) as err: - check_output_with_stderr(r'>&2 echo "the error" && exit 1', shell=True) - assert err.value.stderr == b"the error\n" - captured = capsysbinary.readouterr() - assert captured.out == b"" - assert captured.err == b"the error\n" + if use_filter: + assert captured.err == "second error\n" + else: + assert captured.err == "first error\nsecond error\n"