Skip to content

Commit

Permalink
Support backport Literal from typing_extensions (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
rfrowe committed May 28, 2020
1 parent 87a5f22 commit 6b8c22c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tests/test_typeguard_py36.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from typeguard import TypeChecker, typechecked
from typing_extensions import Literal


class TestTypeChecked:
Expand Down Expand Up @@ -106,3 +107,12 @@ def test_async_generator(self, checker, func):
func()

assert len(record) == 0


def test_literal():
@typechecked
def foo(a: Literal[1, 6, 8]):
pass

foo(6)
pytest.raises(TypeError, foo, 4).match(r'must be one of \(1, 6, 8\); got 4 instead$')
20 changes: 18 additions & 2 deletions typeguard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
except ImportError:
Literal = TypedDict = None

try:
from typing_extensions import Literal as BPLiteral
except ImportError:
BPLiteral = None

try:
from typing import AsyncGenerator
except ImportError:
Expand Down Expand Up @@ -465,9 +470,15 @@ def check_typevar(argname: str, value, typevar: TypeVar, memo: Optional[_CallMem


def check_literal(argname: str, value, expected_type, memo: Optional[_CallMemo]):
if value not in expected_type.__args__:
try:
args = expected_type.__args__
except AttributeError:
# Instance of Literal from typing_extensions
args = expected_type.__values__

if value not in args:
raise TypeError('the value of {} must be one of {}; got {} instead'.
format(argname, expected_type.__args__, value))
format(argname, args, value))


def check_number(argname: str, value, expected_type):
Expand Down Expand Up @@ -523,6 +534,8 @@ def check_protocol(argname: str, value, expected_type):
origin_type_checkers[Type] = check_class
if Literal is not None:
origin_type_checkers[Literal] = check_literal
if BPLiteral is not None:
origin_type_checkers[BPLiteral] = check_literal

generator_origin_types = (Generator, collections.abc.Generator,
Iterator, collections.abc.Iterator,
Expand Down Expand Up @@ -595,6 +608,9 @@ def check_type(argname: str, value, expected_type, memo: Optional[_CallMemo] = N
elif isinstance(expected_type, TypeVar):
# Only happens on < 3.6
check_typevar(argname, value, expected_type, memo)
elif BPLiteral is not None and isinstance(expected_type, BPLiteral.__class__):
# Only happens on < 3.7 when using Literal from typing_extensions
check_literal(argname, value, expected_type, memo)
elif (isfunction(expected_type) and
getattr(expected_type, "__module__", None) == "typing" and
getattr(expected_type, "__qualname__", None).startswith("NewType.") and
Expand Down

0 comments on commit 6b8c22c

Please sign in to comment.