diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index e43cb820..680717d0 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -45,6 +45,7 @@ from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, NoReturn, Optional, Sequence, Text, TextIO, Tuple, Type, Union import unittest from unittest import mock # pylint: disable=unused-import Allow absltest.mock. +import unittest.case from urllib import parse from absl import app # pylint: disable=g-import-not-at-top @@ -53,18 +54,6 @@ from absl.testing import _pretty_print_reporter from absl.testing import xml_reporter -# Use an if-type-checking block to prevent leakage of type-checking only -# symbols. We don't want people relying on these at runtime. -if typing.TYPE_CHECKING: - # Unbounded TypeVar for general usage - _T = typing.TypeVar('_T') - - import unittest.case # pylint: disable=g-import-not-at-top,g-bad-import-order - - _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr - - -# pylint: enable=g-import-not-at-top # Re-export a bunch of unittest functions we support so that people don't # have to import unittest to get them @@ -80,6 +69,9 @@ FLAGS = flags.FLAGS +# Private typing symbols. +_T = typing.TypeVar('_T') # Unbounded TypeVar for general usage +_OutcomeType = unittest.case._Outcome # pytype: disable=module-attr _TEXT_OR_BINARY_TYPES = (str, bytes) # Suppress surplus entries in AssertionError stack traces. @@ -125,8 +117,7 @@ class TempFileCleanup(enum.Enum): # pylint: disable=invalid-name -def _get_default_test_random_seed(): - # type: () -> int +def _get_default_test_random_seed() -> int: random_seed = 301 value = os.environ.get('TEST_RANDOM_SEED', '') try: @@ -136,14 +127,12 @@ def _get_default_test_random_seed(): return random_seed -def get_default_test_srcdir(): - # type: () -> Text +def get_default_test_srcdir() -> str: """Returns default test source dir.""" return os.environ.get('TEST_SRCDIR', '') -def get_default_test_tmpdir(): - # type: () -> Text +def get_default_test_tmpdir() -> str: """Returns default test temp dir.""" tmpdir = os.environ.get('TEST_TMPDIR', '') if not tmpdir: @@ -152,8 +141,7 @@ def get_default_test_tmpdir(): return tmpdir -def _get_default_randomize_ordering_seed(): - # type: () -> int +def _get_default_randomize_ordering_seed() -> int: """Returns default seed to use for randomizing test order. This function first checks the --test_randomize_ordering_seed flag, and then @@ -235,12 +223,10 @@ def _get_default_randomize_ordering_seed(): # We might need to monkey-patch TestResult so that it stops considering an # unexpected pass as a as a "successful result". For details, see # http://bugs.python.org/issue20165 -def _monkey_patch_test_result_for_unexpected_passes(): - # type: () -> None +def _monkey_patch_test_result_for_unexpected_passes() -> None: """Workaround for .""" - def wasSuccessful(self): - # type: () -> bool + def wasSuccessful(self) -> bool: """Tells whether or not this result was a success. Any unexpected pass is to be counted as a non-success. @@ -266,8 +252,9 @@ def wasSuccessful(self): _monkey_patch_test_result_for_unexpected_passes() -def _open(filepath, mode, _open_func=open): - # type: (Text, Text, Callable[..., IO]) -> IO +def _open( + filepath: str, mode: str, _open_func: Callable[..., IO[AnyStr]] = open +) -> IO[AnyStr]: """Opens a file. Like open(), but ensure that we can open real files even if tests stub out @@ -294,14 +281,12 @@ class _TempDir(object): to e.g. `os.path.join()`. """ - def __init__(self, path): - # type: (Text) -> None + def __init__(self, path: str) -> None: """Module-private: do not instantiate outside module.""" self._path = path @property - def full_path(self): - # type: () -> Text + def full_path(self) -> str: """Returns the path, as a string, for the directory. TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply @@ -309,14 +294,18 @@ def full_path(self): """ return self._path - def __fspath__(self): - # type: () -> Text + def __fspath__(self) -> str: """See os.PathLike.""" return self.full_path - def create_file(self, file_path=None, content=None, mode='w', encoding='utf8', - errors='strict'): - # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile + def create_file( + self, + file_path: Optional[str] = None, + content: Optional[AnyStr] = None, + mode: str = 'w', + encoding: str = 'utf8', + errors: str = 'strict', + ) -> '_TempFile': """Create a file in the directory. NOTE: If the file already exists, it will be made writable and overwritten. @@ -343,8 +332,7 @@ def create_file(self, file_path=None, content=None, mode='w', encoding='utf8', errors) return tf - def mkdir(self, dir_path=None): - # type: (Optional[Text]) -> _TempDir + def mkdir(self, dir_path: Optional[str] = None) -> '_TempDir': """Create a directory in the directory. Args: @@ -375,16 +363,20 @@ class _TempFile(object): to e.g. `os.path.join()`. """ - def __init__(self, path): - # type: (Text) -> None + def __init__(self, path: str) -> None: """Private: use _create instead.""" self._path = path - # pylint: disable=line-too-long @classmethod - def _create(cls, base_path, file_path, content, mode, encoding, errors): - # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text] - # pylint: enable=line-too-long + def _create( + cls, + base_path: str, + file_path: Optional[str], + content: AnyStr, + mode: str, + encoding: str, + errors: str, + ) -> Tuple['_TempFile', str]: """Module-private: create a tempfile instance.""" if file_path: cleanup_path = os.path.join(base_path, _get_first_part(file_path)) @@ -415,8 +407,7 @@ def _create(cls, base_path, file_path, content, mode, encoding, errors): return tf, cleanup_path @property - def full_path(self): - # type: () -> Text + def full_path(self) -> str: """Returns the path, as a string, for the file. TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply @@ -424,25 +415,27 @@ def full_path(self): """ return self._path - def __fspath__(self): - # type: () -> Text + def __fspath__(self) -> str: """See os.PathLike.""" return self.full_path - def read_text(self, encoding='utf8', errors='strict'): - # type: (Text, Text) -> Text + def read_text(self, encoding: str = 'utf8', errors: str = 'strict') -> str: """Return the contents of the file as text.""" with self.open_text(encoding=encoding, errors=errors) as fp: return fp.read() - def read_bytes(self): - # type: () -> bytes + def read_bytes(self) -> bytes: """Return the content of the file as bytes.""" with self.open_bytes() as fp: return fp.read() - def write_text(self, text, mode='w', encoding='utf8', errors='strict'): - # type: (Text, Text, Text, Text) -> None + def write_text( + self, + text: str, + mode: str = 'w', + encoding: str = 'utf8', + errors: str = 'strict', + ) -> None: """Write text to the file. Args: @@ -456,8 +449,7 @@ def write_text(self, text, mode='w', encoding='utf8', errors='strict'): with self.open_text(mode, encoding=encoding, errors=errors) as fp: fp.write(text) - def write_bytes(self, data, mode='wb'): - # type: (bytes, Text) -> None + def write_bytes(self, data: bytes, mode: str = 'wb') -> None: """Write bytes to the file. Args: @@ -468,8 +460,9 @@ def write_bytes(self, data, mode='wb'): with self.open_bytes(mode) as fp: fp.write(data) - def open_text(self, mode='rt', encoding='utf8', errors='strict'): - # type: (Text, Text, Text) -> ContextManager[TextIO] + def open_text( + self, mode: str = 'rt', encoding: str = 'utf8', errors: str = 'strict' + ) -> ContextManager[TextIO]: """Return a context manager for opening the file in text mode. Args: @@ -492,8 +485,7 @@ def open_text(self, mode='rt', encoding='utf8', errors='strict'): cm = self._open(mode, encoding, errors) return cm - def open_bytes(self, mode='rb'): - # type: (Text) -> ContextManager[BinaryIO] + def open_bytes(self, mode: str = 'rb') -> ContextManager[BinaryIO]: """Return a context manager for opening the file in binary mode. Args: @@ -538,26 +530,24 @@ class _method(object): (e.g. Cls.method(self, ...)) but is still situationally useful. """ - def __init__(self, finstancemethod): - # type: (Callable[..., Any]) -> None + def __init__(self, finstancemethod: Callable[..., Any]) -> None: self._finstancemethod = finstancemethod self._fclassmethod = None - def classmethod(self, fclassmethod): - # type: (Callable[..., Any]) -> _method + def classmethod(self, fclassmethod: Callable[..., Any]) -> '_method': self._fclassmethod = classmethod(fclassmethod) return self - def __doc__(self): - # type: () -> str + def __doc__(self) -> str: if getattr(self._finstancemethod, '__doc__'): return self._finstancemethod.__doc__ elif getattr(self._fclassmethod, '__doc__'): return self._fclassmethod.__doc__ return '' - def __get__(self, obj, type_): - # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any] + def __get__( + self, obj: Optional[Any], type_: Optional[Type[Any]] + ) -> Callable[..., Any]: func = self._fclassmethod if obj is None else self._finstancemethod return func.__get__(obj, type_) # pytype: disable=attribute-error @@ -571,9 +561,7 @@ class TestCase(unittest.TestCase): # tempfile module. This can be overridden at the class level, instance level, # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See # `TempFileCleanup` for details on the different values. - # TODO(b/70517332): Remove the type comment and the disable once pytype has - # better support for enums. - tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch + tempfile_cleanup: TempFileCleanup = TempFileCleanup.ALWAYS maxDiff = 80 * 20 longMessage = True @@ -586,7 +574,7 @@ class TestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestCase, self).__init__(*args, **kwargs) # This is to work around missing type stubs in unittest.pyi - self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType] + self._outcome: Optional[_OutcomeType] = getattr(self, '_outcome') def setUp(self): super(TestCase, self).setUp() @@ -609,8 +597,11 @@ def setUpClass(cls): cls._cls_exit_stack = contextlib.ExitStack() cls.addClassCleanup(cls._cls_exit_stack.close) - def create_tempdir(self, name=None, cleanup=None): - # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir + def create_tempdir( + self, + name: Optional[str] = None, + cleanup: Optional[TempFileCleanup] = None, + ) -> _TempDir: """Create a temporary directory specific to the test. NOTE: The directory and its contents will be recursively cleared before @@ -663,11 +654,15 @@ def test_foo(self): return _TempDir(path) - # pylint: disable=line-too-long - def create_tempfile(self, file_path=None, content=None, mode='w', - encoding='utf8', errors='strict', cleanup=None): - # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile - # pylint: enable=line-too-long + def create_tempfile( + self, + file_path: Optional[str] = None, + content: Optional[AnyStr] = None, + mode: str = 'w', + encoding: str = 'utf8', + errors: str = 'strict', + cleanup: Optional[TempFileCleanup] = None, + ) -> _TempFile: """Create a temporary file specific to the test. This creates a named file on disk that is isolated to this test, and will @@ -720,8 +715,7 @@ def test_foo(self): return tf @_method - def enter_context(self, manager): - # type: (ContextManager[_T]) -> _T + def enter_context(self, manager: ContextManager[_T]) -> _T: """Returns the CM's value after registering it with the exit stack. Entering a context pushes it onto a stack of contexts. When `enter_context` @@ -754,8 +748,7 @@ def enter_context(self, manager): return self._exit_stack.enter_context(manager) @enter_context.classmethod - def enter_context(cls, manager): # pylint: disable=no-self-argument - # type: (ContextManager[_T]) -> _T + def enter_context(cls, manager: ContextManager[_T]) -> _T: # pylint: disable=no-self-argument if sys.version_info >= (3, 11): return cls.enterClassContext(manager) @@ -766,23 +759,23 @@ def enter_context(cls, manager): # pylint: disable=no-self-argument return cls._cls_exit_stack.enter_context(manager) @classmethod - def _get_tempdir_path_cls(cls): - # type: () -> Text + def _get_tempdir_path_cls(cls) -> str: return os.path.join(TEST_TMPDIR.value, cls.__qualname__.replace('__main__.', '')) - def _get_tempdir_path_test(self): - # type: () -> Text + def _get_tempdir_path_test(self) -> str: return os.path.join(self._get_tempdir_path_cls(), self._testMethodName) - def _get_tempfile_cleanup(self, override): - # type: (Optional[TempFileCleanup]) -> TempFileCleanup + def _get_tempfile_cleanup( + self, override: Optional[TempFileCleanup] + ) -> TempFileCleanup: if override is not None: return override return self.tempfile_cleanup - def _maybe_add_temp_path_cleanup(self, path, cleanup): - # type: (Text, Optional[TempFileCleanup]) -> None + def _maybe_add_temp_path_cleanup( + self, path: str, cleanup: Optional[TempFileCleanup] + ) -> None: cleanup = self._get_tempfile_cleanup(cleanup) if cleanup == TempFileCleanup.OFF: return @@ -835,8 +828,7 @@ def _internal_ran_and_passed_when_called_during_cleanup( self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error return result.wasSuccessful() - def shortDescription(self): - # type: () -> Text + def shortDescription(self) -> str: """Formats both the test method name and the first line of its docstring. If no docstring is given, only returns the method name. @@ -1871,8 +1863,9 @@ def assertJsonEqual(self, first, second, msg=None): self.assertSameStructure(first_structured, second_structured, aname='first', bname='second', msg=msg) - def _getAssertEqualityFunc(self, first, second): - # type: (Any, Any) -> Callable[..., None] + def _getAssertEqualityFunc( + self, first: Any, second: Any + ) -> Callable[..., None]: try: return super(TestCase, self)._getAssertEqualityFunc(first, second) except AttributeError: @@ -1890,8 +1883,9 @@ def fail(self, msg=None, user_msg=None) -> NoReturn: return super(TestCase, self).fail(self._formatMessage(user_msg, msg)) -def _sorted_list_difference(expected, actual): - # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]] +def _sorted_list_difference( + expected: List[_T], actual: List[_T] +) -> Tuple[List[_T], List[_T]]: """Finds elements in only one or the other of two, sorted input lists. Returns a two-element tuple of lists. The first list contains those @@ -1940,25 +1934,21 @@ def _sorted_list_difference(expected, actual): return missing, unexpected -def _are_both_of_integer_type(a, b): - # type: (object, object) -> bool +def _are_both_of_integer_type(a: object, b: object) -> bool: return isinstance(a, int) and isinstance(b, int) -def _are_both_of_sequence_type(a, b): - # type: (object, object) -> bool +def _are_both_of_sequence_type(a: object, b: object) -> bool: return isinstance(a, abc.Sequence) and isinstance( b, abc.Sequence) and not isinstance( a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES) -def _are_both_of_set_type(a, b): - # type: (object, object) -> bool +def _are_both_of_set_type(a: object, b: object) -> bool: return isinstance(a, abc.Set) and isinstance(b, abc.Set) -def _are_both_of_mapping_type(a, b): - # type: (object, object) -> bool +def _are_both_of_mapping_type(a: object, b: object) -> bool: return isinstance(a, abc.Mapping) and isinstance( b, abc.Mapping) @@ -2082,8 +2072,7 @@ def get_command_stderr(command, env=None, close_fds=True): return (exit_status, output) -def _quote_long_string(s): - # type: (Union[Text, bytes, bytearray]) -> Text +def _quote_long_string(s: Union[str, bytes, bytearray]) -> str: """Quotes a potentially multi-line string to make the start and end obvious. Args: @@ -2102,8 +2091,7 @@ def _quote_long_string(s): '----------->8\n') -def print_python_version(): - # type: () -> None +def print_python_version() -> None: # Having this in the test output logs by default helps debugging when all # you've got is the log and no other idea of which Python was used. sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: ' @@ -2112,8 +2100,7 @@ def print_python_version(): sys.executable if sys.executable else 'embedded.')) -def main(*args, **kwargs): - # type: (Text, Any) -> None +def main(*args: str, **kwargs: Any) -> None: """Executes a set of Python unit tests. Usually this function is called without arguments, so the @@ -2131,8 +2118,7 @@ def main(*args, **kwargs): _run_in_app(run_tests, args, kwargs) -def _is_in_app_main(): - # type: () -> bool +def _is_in_app_main() -> bool: """Returns True iff app.run is active.""" f = sys._getframe().f_back # pylint: disable=protected-access while f: @@ -2142,8 +2128,7 @@ def _is_in_app_main(): return False -def _register_sigterm_with_faulthandler(): - # type: () -> None +def _register_sigterm_with_faulthandler() -> None: """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts.""" if getattr(faulthandler, 'register', None): # faulthandler.register is not available on Windows. @@ -2155,8 +2140,11 @@ def _register_sigterm_with_faulthandler(): '%r; ignoring.\n' % e) -def _run_in_app(function, args, kwargs): - # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None +def _run_in_app( + function: Callable[..., None], + args: Sequence[str], + kwargs: Mapping[str, Any], +) -> None: """Executes a set of Python unit tests, ensuring app.run. This is a private function, users should call absltest.main(). @@ -2236,8 +2224,9 @@ def main_function(argv): app.run(main=main_function) -def _is_suspicious_attribute(testCaseClass, name): - # type: (Type, Text) -> bool +def _is_suspicious_attribute( + testCaseClass: Type[unittest.TestCase], name: str +) -> bool: """Returns True if an attribute is a method named like a test method.""" if name.startswith('Test') and len(name) > 4 and name[4].isupper(): attr = getattr(testCaseClass, name) @@ -2249,8 +2238,7 @@ def _is_suspicious_attribute(testCaseClass, name): return False -def skipThisClass(reason): - # type: (Text) -> Callable[[_T], _T] +def skipThisClass(reason: str) -> Callable[[_T], _T]: """Skip tests in the decorated TestCase, but not any of its subclasses. This decorator indicates that this class should skip all its tests, but not @@ -2366,8 +2354,7 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name return names -def get_default_xml_output_filename(): - # type: () -> Optional[Text] +def get_default_xml_output_filename() -> Optional[str]: if os.environ.get('XML_OUTPUT_FILE'): return os.environ['XML_OUTPUT_FILE'] elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'): @@ -2409,8 +2396,7 @@ def _setup_filtering(argv: MutableSequence[str]) -> bool: return True -def _setup_test_runner_fail_fast(argv): - # type: (MutableSequence[Text]) -> None +def _setup_test_runner_fail_fast(argv: MutableSequence[str]) -> None: """Implements the bazel test fail fast protocol. The following environment variable is used in this method: @@ -2696,8 +2682,7 @@ def run_tests( sys.exit(not result.wasSuccessful()) -def _rmtree_ignore_errors(path): - # type: (Text) -> None +def _rmtree_ignore_errors(path: str) -> None: if os.path.isfile(path): try: os.unlink(path) @@ -2707,7 +2692,6 @@ def _rmtree_ignore_errors(path): shutil.rmtree(path, ignore_errors=True) -def _get_first_part(path): - # type: (Text) -> Text +def _get_first_part(path: str) -> str: parts = path.split(os.sep, 1) return parts[0]