Skip to content

Commit

Permalink
Fix overload positional args and infer from defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Jul 13, 2024
1 parent 661599d commit 37e1608
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 87 deletions.
8 changes: 4 additions & 4 deletions .mypy/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -10467,7 +10467,7 @@
"code": "redundant-expr",
"column": 18,
"message": "Condition is always true",
"offset": 105,
"offset": 71,
"src": "while True:",
"target": "mypy.ipc.IPCBase.read"
}
Expand Down Expand Up @@ -17097,7 +17097,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"type\" is not using @override but is overriding a method in class \"mypy.semanal_shared.SemanticAnalyzerCoreInterface\"",
"offset": 477,
"offset": 480,
"src": "def type(self) -> TypeInfo | None:",
"target": "mypy.semanal"
},
Expand Down Expand Up @@ -18297,7 +18297,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_any\" is not using @override but is overriding a method in class \"mypy.type_visitor.TypeTranslator\"",
"offset": 117,
"offset": 215,
"src": "def visit_any(self, t: AnyType) -> Type:",
"target": "mypy.semanal.MakeAnyNonExplicit.visit_any"
},
Expand Down Expand Up @@ -32283,7 +32283,7 @@
"code": "redundant-expr",
"column": 19,
"message": "Condition is always true",
"offset": 498,
"offset": 495,
"src": "if extra_attrs_set is None:",
"target": "mypy.typeops.make_simplified_union"
},
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
- `work_not_properly_function_names` made available to per module configuration (#699)
- Support `BASEDMYPY_TYPE_CHECKING` (#702)
- Enable stub mode within `TYPE_CHECKING` branches (#702)
- Infer from overloads - add default value in impl (#697)
### Fixes
- positional arguments on overloads break super (#697)
- positional arguments on overloads duplicate unions (#697)

## [2.5.0]
### Added
Expand Down
105 changes: 103 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@

from __future__ import annotations

from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import Any, Callable, Collection, Final, Iterable, Iterator, List, TypeVar, cast
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

import mypy.state
from mypy import errorcodes as codes, message_registry
from mypy.constant_fold import constant_fold_expr
from mypy.errorcodes import ErrorCode
Expand Down Expand Up @@ -221,6 +223,7 @@
set_callable_name as set_callable_name,
)
from mypy.semanal_typeddict import TypedDictAnalyzer
from mypy.subtypes import is_subtype
from mypy.tvar_scope import TypeVarLikeScope
from mypy.typeanal import (
SELF_TYPE_NAMES,
Expand All @@ -242,7 +245,6 @@
callable_type,
function_type,
get_type_vars,
infer_impl_from_parts,
try_getting_str_literals_from_type,
)
from mypy.types import (
Expand Down Expand Up @@ -283,6 +285,7 @@
TypeVarTupleType,
TypeVarType,
UnboundType,
UnionType,
UnpackType,
UntypedType,
get_proper_type,
Expand Down Expand Up @@ -1333,7 +1336,7 @@ def analyze_overload_sigs_and_impl(
else:
non_overload_indexes.append(i)
if self.options.infer_function_types and impl and not non_overload_indexes:
infer_impl_from_parts(
self.infer_impl_from_parts(
impl, types, self.named_type("builtins.function"), self.named_type
)
return types, impl, non_overload_indexes
Expand Down Expand Up @@ -7162,6 +7165,104 @@ def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[s
names.append(specifier.fullname)
return tuple(names)

def infer_impl_from_parts(
self,
impl: OverloadPart,
types: list[CallableType],
fallback: Instance,
named_type: Callable[[str, list[Type]], Type],
):
impl_func = impl if isinstance(impl, FuncDef) else impl.func
# infer the types of the impl from the overload types
arg_types: dict[str | int, dict[Type, None]] = defaultdict(dict)
ret_types: dict[Type, None] = {}
for tp in types:
for i, arg_type in enumerate(tp.arg_types):
arg_name = tp.arg_names[i]
if not arg_name: # if it's positional only
arg_types[i][arg_type] = None
else:
if arg_name in impl_func.arg_names:
if arg_type not in arg_types[arg_name]:
arg_types[arg_name][arg_type] = None
if arg_name and arg_name in impl_func.arg_names:
if arg_type not in arg_types[arg_name]:
arg_types[arg_name][arg_type] = None
t = get_proper_type(tp.ret_type)
if isinstance(t, Instance) and t.type.fullname == "typing.Coroutine":
ret_type = t.args[2]
else:
ret_type = tp.ret_type
ret_types[ret_type] = None

res_arg_types = [
(
UnionType.make_union(
tuple({**(arg_types[arg_name_] if arg_name_ else {}), **arg_types[i]})
)
if arg_kind not in (ARG_STAR, ARG_STAR2)
else UntypedType()
)
for i, (arg_name_, arg_kind) in enumerate(
zip(impl_func.arg_names, impl_func.arg_kinds)
)
]

if isinstance(impl, Decorator):
impl = impl.func
for i, arg in enumerate(impl.arguments):
init = arg.initializer
if not arg.initializer:
continue
typ: Type | None
if isinstance(init, NameExpr) and init.fullname == "builtins.None":
typ = NoneType()
else:
typ = self.analyze_simple_literal_type(arg.initializer, True, do_inner=True)
if not typ:
continue
with mypy.state.state.strict_optional_set(self.options.strict_optional):
if not is_subtype(typ, res_arg_types[i], options=self.options):
res_arg_types[i] = UnionType.make_union((res_arg_types[i], typ))

ret_type = UnionType.make_union(tuple(ret_types))

if impl_func.is_coroutine:
# if the impl is a coroutine, then assume the parts are also, if not need annotation
any_type = AnyType(TypeOfAny.special_form)
ret_type = named_type("typing.Coroutine", [any_type, any_type, ret_type])

# use unanalyzed_type because we would have already tried to infer from defaults
if impl_func.unanalyzed_type:
assert isinstance(impl_func.unanalyzed_type, CallableType)
assert isinstance(impl_func.type, CallableType)
impl_func.type = impl_func.type.copy_modified(
arg_types=[
i if not is_unannotated_any(u) else r
for i, u, r in zip(
impl_func.type.arg_types,
impl_func.unanalyzed_type.arg_types,
res_arg_types,
)
],
ret_type=(
ret_type
if isinstance(
get_proper_type(impl_func.unanalyzed_type.ret_type), (AnyType, NoneType)
)
else impl_func.type.ret_type
),
)
else:
impl_func.type = CallableType(
res_arg_types,
impl_func.arg_kinds,
impl_func.arg_names,
ret_type,
fallback,
definition=impl_func,
)


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
Expand Down
81 changes: 1 addition & 80 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from __future__ import annotations

import itertools
from collections import defaultdict
from typing import Any, Callable, Iterable, List, Sequence, TypeVar, cast
from typing import Any, Iterable, List, Sequence, TypeVar, cast

from mypy.copytype import copy_type
from mypy.expandtype import expand_type, expand_type_by_instance
Expand All @@ -25,7 +24,6 @@
FuncDef,
FuncItem,
OverloadedFuncDef,
OverloadPart,
StrExpr,
TypeInfo,
Var,
Expand Down Expand Up @@ -65,7 +63,6 @@
flatten_nested_unions,
get_proper_type,
get_proper_types,
is_unannotated_any,
)
from mypy.typevars import fill_typevars

Expand Down Expand Up @@ -1126,82 +1123,6 @@ def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequen
return literal_items, union_items


def infer_impl_from_parts(
impl: OverloadPart,
types: list[CallableType],
fallback: Instance,
named_type: Callable[[str, list[Type]], Type],
):
impl_func = impl if isinstance(impl, FuncDef) else impl.func
# infer the types of the impl from the overload types
arg_types: dict[str | int, list[Type]] = defaultdict(list)
ret_types = []
for tp in types:
for i, arg_type in enumerate(tp.arg_types):
arg_name = tp.arg_names[i]
if not arg_name: # if it's positional only
if arg_type not in arg_types[i]:
arg_types[i].append(arg_type)
else:
if arg_name in impl_func.arg_names:
if arg_type not in arg_types[arg_name]:
arg_types[arg_name].append(arg_type)
if arg_name and arg_name in impl_func.arg_names:
if arg_type not in arg_types[arg_name]:
arg_types[arg_name].append(arg_type)
t = get_proper_type(tp.ret_type)
if isinstance(t, Instance) and t.type.fullname == "typing.Coroutine":
ret_type = t.args[2]
else:
ret_type = tp.ret_type
if ret_type not in ret_types:
ret_types.append(ret_type)
res_arg_types = [
(
UnionType.make_union((arg_types[arg_name_] if arg_name_ else []) + arg_types[i])
if arg_kind not in (ARG_STAR, ARG_STAR2)
else UntypedType()
)
for i, (arg_name_, arg_kind) in enumerate(zip(impl_func.arg_names, impl_func.arg_kinds))
]

ret_type = UnionType.make_union(ret_types)

if impl_func.is_coroutine:
# if the impl is a coroutine, then assume the parts are also, if not need annotation
any_type = AnyType(TypeOfAny.special_form)
ret_type = named_type("typing.Coroutine", [any_type, any_type, ret_type])

# use unanalyzed_type because we would have already tried to infer from defaults
if impl_func.unanalyzed_type:
assert isinstance(impl_func.unanalyzed_type, CallableType)
assert isinstance(impl_func.type, CallableType)
impl_func.type = impl_func.type.copy_modified(
arg_types=[
i if not is_unannotated_any(u) else r
for i, u, r in zip(
impl_func.type.arg_types, impl_func.unanalyzed_type.arg_types, res_arg_types
)
],
ret_type=(
ret_type
if isinstance(
get_proper_type(impl_func.unanalyzed_type.ret_type), (AnyType, NoneType)
)
else impl_func.type.ret_type
),
)
else:
impl_func.type = CallableType(
res_arg_types,
impl_func.arg_kinds,
impl_func.arg_names,
ret_type,
fallback,
definition=impl_func,
)


def try_getting_instance_fallback(typ: Type) -> Instance | None:
"""Returns the Instance fallback for this type if one exists or None."""
typ = get_proper_type(typ)
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-based-infer-function-types.test
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def f(a, b):
reveal_type(a)
reveal_type(b)
[out]
main:9: note: Revealed type is "int | int"
main:9: note: Revealed type is "int"
main:10: note: Revealed type is "str"


Expand Down
25 changes: 25 additions & 0 deletions test-data/unit/check-based-overload.test
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,28 @@ o: object
assert isinstance(o, str)
f(lambda _: reveal_type(o)) # N: Revealed type is "object"
[builtins fixtures/tuple.pyi]


[case testPositional]
from typing import overload

@overload
def f(a: int, b: int): ...
@overload
def f(a: int, c: str, /): ...
def f(a, b):
reveal_type(a) # N: Revealed type is "int"
reveal_type(b) # N: Revealed type is "int | str"


[case testInferDefault]
from typing import overload

@overload
def f(a: int, b: int, c: 1): ...
@overload
def f(a: int, b: str, c: 2): ...
def f(a="who", b=None, c=3):
reveal_type(a) # N: Revealed type is "int | 'who'"
reveal_type(b) # N: Revealed type is "int | str | None"
reveal_type(c) # N: Revealed type is "1 | 2 | 3"

0 comments on commit 37e1608

Please sign in to comment.