Skip to content

Commit

Permalink
Added support for protocols
Browse files Browse the repository at this point in the history
Fixes #83.
  • Loading branch information
agronholm committed Nov 26, 2019
1 parent f5517a8 commit a00f03e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ Type Notes
``List`` Contents are typechecked
``Literal``
``NamedTuple`` Field values are typechecked
``Protocol`` Value type checked with ``issubclass()`` against the
protocol
``Set`` Contents are typechecked
``Sequence`` Contents are typechecked
``Tuple`` Contents are typechecked
Expand Down
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v

**UNRELEASED**

- Added support for ``typing.Protocol`` subclasses
- Fixed the handling of ``total=False`` in ``TypedDict``
- Removed support of default values in ``TypedDict``, as they are not supported in the spec

Expand Down
17 changes: 16 additions & 1 deletion tests/test_typeguard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from io import StringIO, BytesIO
from typing import (
Any, Callable, Dict, List, Set, Tuple, Union, TypeVar, Sequence, NamedTuple, Iterable,
Container, Generic, BinaryIO, TextIO, Generator, Iterator)
Container, Generic, BinaryIO, TextIO, Generator, Iterator, SupportsInt)

import pytest

Expand Down Expand Up @@ -1068,6 +1068,21 @@ def func(x: annotation) -> None:
func(Child())
pytest.raises(TypeError, func, 'foo')

@pytest.mark.parametrize('value, error_re', [
(1, None),
('foo',
r'type of argument "arg" \(str\) is not compatible with the SupportsInt protocol')
], ids=['int', 'str'])
def test_protocol(self, value, error_re):
@typechecked
def foo(arg: SupportsInt):
pass

if error_re:
pytest.raises(TypeError, foo, value).match(error_re)
else:
foo(value)


class TestTypeChecker:
@pytest.fixture
Expand Down
15 changes: 14 additions & 1 deletion typeguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import OrderedDict
from enum import Enum
from functools import wraps, partial
from inspect import Parameter, isclass, isfunction, isgeneratorfunction
from inspect import Parameter, isclass, isfunction, isgeneratorfunction, getattr_static
from io import TextIOBase, RawIOBase, IOBase, BufferedIOBase
from traceback import extract_stack, print_stack
from types import CodeType, FunctionType
Expand All @@ -25,6 +25,11 @@
except ImportError:
Literal = TypedDict = None

try:
from typing import Protocol
except ImportError:
from typing import _Protocol as Protocol

try:
from typing import AsyncGenerator
except ImportError:
Expand Down Expand Up @@ -481,6 +486,12 @@ def check_io(argname: str, value, expected_type):
format(argname, qualified_name(value.__class__)))


def check_protocol(argname: str, value, expected_type):
if not issubclass(type(value), expected_type):
raise TypeError('type of {} ({}) is not compatible with the {} protocol'.
format(argname, type(value).__qualname__, expected_type.__qualname__))


# Equality checks are applied to these
origin_type_checkers = {
Callable: check_callable,
Expand Down Expand Up @@ -558,6 +569,8 @@ def check_type(argname: str, value, expected_type, memo: Optional[_CallMemo] = N
check_io(argname, value, expected_type)
elif issubclass(expected_type, dict) and hasattr(expected_type, '__annotations__'):
check_typed_dict(argname, value, expected_type, memo)
elif getattr_static(expected_type, '_is_protocol', False):
check_protocol(argname, value, expected_type)
else:
expected_type = (getattr(expected_type, '__extra__', None) or origin_type or
expected_type)
Expand Down

0 comments on commit a00f03e

Please sign in to comment.