Skip to content

Commit

Permalink
Fixed checks for assignments to varargs or varkwargs variables
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Mar 16, 2023
1 parent c333a35 commit 835fb65
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
**UNRELEASED**

- Fixed ``warn_on_error()`` not showing where the type violation actually occurred
- Fixed local assignment to ``*args`` or ``**kwargs`` being type checked incorrectly

**3.0.1** (2023-03-16)

Expand Down
34 changes: 30 additions & 4 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,16 +426,42 @@ def visit_FunctionDef(
if sys.version_info >= (3, 8):
all_args.extend(node.args.posonlyargs)

for arg in node.args.vararg, node.args.kwarg:
if arg is not None:
all_args.append(arg)

arg_annotations = {
arg.arg: arg.annotation
for arg in all_args
if arg.annotation is not None
and not self._memo.name_matches(arg.annotation, *anytype_names)
}
if node.args.vararg:
if sys.version_info >= (3, 9):
container = Name("tuple", ctx=Load())
else:
container = self._get_import("typing", "Tuple")

annotation = Subscript(
container,
Tuple(
[node.args.vararg.annotation, Constant(Ellipsis)],
ctx=Load(),
),
)
arg_annotations[node.args.vararg.arg] = annotation

if node.args.kwarg:
if sys.version_info >= (3, 9):
container = Name("dict", ctx=Load())
else:
container = self._get_import("typing", "Dict")

annotation = Subscript(
container,
Tuple(
[Name("str", ctx=Load()), node.args.kwarg.annotation],
ctx=Load(),
),
)
arg_annotations[node.args.kwarg.arg] = annotation

if arg_annotations:
self._memo.variable_annotations.update(arg_annotations)

Expand Down
70 changes: 70 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,76 @@ def foo() -> None:
).strip()
)

def test_varargs_assign(self) -> None:
node = parse(
dedent(
"""
def foo(*args: int) -> None:
args = (5,)
"""
)
)
TypeguardTransformer().visit(node)

if sys.version_info < (3, 9):
extra_import = "from typing import Tuple\n"
tuple_type = "Tuple"
else:
extra_import = ""
tuple_type = "tuple"

assert (
unparse(node)
== dedent(
f"""
from typeguard import CallMemo
from typeguard._functions import check_argument_types, \
check_variable_assignment
{extra_import}
def foo(*args: int) -> None:
call_memo = CallMemo(foo, locals())
check_argument_types(call_memo)
args = check_variable_assignment((5,), \
{{'args': {tuple_type}[int, ...]}}, call_memo)
"""
).strip()
)

def test_kwargs_assign(self) -> None:
node = parse(
dedent(
"""
def foo(**kwargs: int) -> None:
kwargs = {'a': 5}
"""
)
)
TypeguardTransformer().visit(node)

if sys.version_info < (3, 9):
extra_import = "from typing import Dict\n"
dict_type = "Dict"
else:
extra_import = ""
dict_type = "dict"

assert (
unparse(node)
== dedent(
f"""
from typeguard import CallMemo
from typeguard._functions import check_argument_types, \
check_variable_assignment
{extra_import}
def foo(**kwargs: int) -> None:
call_memo = CallMemo(foo, locals())
check_argument_types(call_memo)
kwargs = check_variable_assignment({{'a': 5}}, \
{{'kwargs': {dict_type}[str, int]}}, call_memo)
"""
).strip()
)

@pytest.mark.skipif(sys.version_info >= (3, 10), reason="Requires Python < 3.10")
def test_pep604_assign(self) -> None:
node = parse(
Expand Down

0 comments on commit 835fb65

Please sign in to comment.