Skip to content

Commit

Permalink
Allow overloaded __exit__ and __aexit__ definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Apr 24, 2024
1 parent 1c8849f commit 1f7d294
Show file tree
Hide file tree
Showing 6 changed files with 481 additions and 50 deletions.
96 changes: 95 additions & 1 deletion crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from collections.abc import Awaitable
from types import TracebackType
from typing import Any, Type
from typing import Any, Type, overload

import _typeshed
import typing_extensions
Expand Down Expand Up @@ -73,3 +73,97 @@ async def __aexit__(self, /, typ: type[BaseException] | None, *args: Any) -> Awa
class BadSix:
def __exit__(self, typ, exc, tb, weird_extra_arg, extra_arg2 = None) -> None: ... # PYI036: Extra arg must have default
async def __aexit__(self, typ, exc, tb, *, weird_extra_arg) -> None: ... # PYI036: kwargs must have default

class AllPositionalOnlyArgs:
def __exit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...

class BadAllPositionalOnlyArgs:
def __exit__(self, typ: type[Exception] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType, /) -> None: ...

# Definitions not in a class scope can do whatever, we don't care
def __exit__(self, *args: bool) -> None: ...
async def __aexit__(self, *, go_crazy: bytes) -> list[str]: ...

# Here come the overloads...

class AcceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
@overload
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

# Using `object` or `Unused` in an overload definition is kinda strange,
# but let's allow it to be on the safe side
class AcceptableOverload2:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
@overload
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ...

class AcceptableOverload3:
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
# We don't have the ability (yet) to do arbitrary checking
# of whether one function definition is a subtype of another...
@overload
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
def __exit__(self, *args: object) -> None: ...
@overload
async def __aexit__(self, exc_typ: bool, /, exc: bool, exc_tb: bool, *, keyword_only: str) -> None: ...
@overload
async def __aexit__(self, *args: object) -> None: ...
async def __aexit__(self, *args: object) -> None: ...

class AcceptableOverload4:
# Same as above
@overload
def __exit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
def __exit__(self, *args: object) -> None: ...
@overload
async def __aexit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType, *, extra: str = "foo") -> None: ...
@overload
async def __aexit__(self, exc_typ: None, exc: None, tb: None) -> None: ...
async def __aexit__(self, *args: object) -> None: ...

class StrangeNumberOfOverloads:
# Only one overload? Type checkers will emit an error, but we should just ignore it
@overload
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
# More than two overloads? Anything could be going on; again, just ignore all the overloads
@overload
async def __aexit__(self, arg: bool) -> None: ...
@overload
async def __aexit__(self, arg: None, arg2: None, arg3: None) -> None: ...
@overload
async def __aexit__(self, arg: bool, arg2: bool, arg3: bool) -> None: ...
async def __aexit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...

# TODO: maybe we should emit an error on this one as well?
class BizarreAsyncSyncOverloadMismatch:
@overload
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
@overload
async def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
def __exit__(self, *args: object) -> None: ...

class UnacceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
@overload
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...

class UnacceptableOverload2:
@overload
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
@overload
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036
def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...
85 changes: 84 additions & 1 deletion crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import types
import typing
from collections.abc import Awaitable
from types import TracebackType
from typing import Any, Type
from typing import Any, Type, overload

import _typeshed
import typing_extensions
Expand Down Expand Up @@ -80,3 +80,86 @@ def isolated_scope():

class ShouldNotError:
def __exit__(self, typ: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ...

class AllPositionalOnlyArgs:
def __exit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...

class BadAllPositionalOnlyArgs:
def __exit__(self, typ: type[Exception] | None, exc: BaseException | None, tb: TracebackType | None, /) -> None: ...
async def __aexit__(self, typ: type[BaseException] | None, exc: BaseException | None, tb: TracebackType, /) -> None: ...

# Definitions not in a class scope can do whatever, we don't care
def __exit__(self, *args: bool) -> None: ...
async def __aexit__(self, *, go_crazy: bytes) -> list[str]: ...

# Here come the overloads...

class AcceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ...
@overload
def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ...

# Using `object` or `Unused` in an overload definition is kinda strange,
# but let's allow it to be on the safe side
class AcceptableOverload2:
@overload
def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ...
@overload
def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ...

class AcceptableOverload3:
# Just ignore any overloads that don't have exactly 3 annotated non-self parameters.
# We don't have the ability (yet) to do arbitrary checking
# of whether one function definition is a subtype of another...
@overload
def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
@overload
async def __aexit__(self, exc_typ: bool, /, exc: bool, exc_tb: bool, *, keyword_only: str) -> None: ...
@overload
async def __aexit__(self, *args: object) -> None: ...

class AcceptableOverload4:
# Same as above
@overload
def __exit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType) -> None: ...
@overload
def __exit__(self, *args: object) -> None: ...
@overload
async def __aexit__(self, exc_typ: type[Exception], exc: type[Exception], exc_tb: types.TracebackType, *, extra: str = "foo") -> None: ...
@overload
async def __aexit__(self, exc_typ: None, exc: None, tb: None) -> None: ...

class StrangeNumberOfOverloads:
# Only one overload? Type checkers will emit an error, but we should just ignore it
@overload
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
# More than two overloads? Anything could be going on; again, just ignore all the overloads
@overload
async def __aexit__(self, arg: bool) -> None: ...
@overload
async def __aexit__(self, arg: None, arg2: None, arg3: None) -> None: ...
@overload
async def __aexit__(self, arg: bool, arg2: bool, arg3: bool) -> None: ...

# TODO: maybe we should emit an error on this one as well?
class BizarreAsyncSyncOverloadMismatch:
@overload
def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...
@overload
async def __exit__(self, exc_typ: bool, exc: bool, tb: bool) -> None: ...

class UnacceptableOverload1:
@overload
def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay
@overload
def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036

class UnacceptableOverload2:
@overload
def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036
@overload
def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036
2 changes: 1 addition & 1 deletion crates/ruff_linter/src/checkers/ast/analyze/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) {
}
}
if checker.enabled(Rule::BadExitAnnotation) {
flake8_pyi::rules::bad_exit_annotation(checker, *is_async, name, parameters);
flake8_pyi::rules::bad_exit_annotation(checker, function_def);
}
if checker.enabled(Rule::RedundantNumericUnion) {
flake8_pyi::rules::redundant_numeric_union(checker, parameters);
Expand Down
Loading

0 comments on commit 1f7d294

Please sign in to comment.