diff --git a/CHANGELOG.md b/CHANGELOG.md
index e9b3614c..ceddc215 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,7 @@ New error codes:
* Introduce Y060, which flags redundant inheritance from `Generic[]`.
* Introduce Y061: Do not use `None` inside a `Literal[]` slice.
For example, use `Literal["foo"] | None` instead of `Literal["foo", None]`.
+* Introduce Y062: Protocol method parameters should not be positional-or-keyword.
Other changes:
* The undocumented `pyi.__version__` and `pyi.PyiTreeChecker.version`
diff --git a/ERRORCODES.md b/ERRORCODES.md
index 687541e2..5c8045aa 100644
--- a/ERRORCODES.md
+++ b/ERRORCODES.md
@@ -64,6 +64,7 @@ The following warnings are currently emitted by default:
| Y059 | `Generic[]` should always be the last base class, if it is present in a class's bases tuple. At runtime, if `Generic[]` is not the final class in a the bases tuple, this [can cause the class creation to fail](https://github.com/python/cpython/issues/106102). In a stub file, however, this rule is enforced purely for stylistic consistency.
| Y060 | Redundant inheritance from `Generic[]`. For example, `class Foo(Iterable[_T], Generic[_T]): ...` can be written more simply as `class Foo(Iterable[_T]): ...`.
To avoid false-positive errors, and to avoid complexity in the implementation, this check is deliberately conservative: it only looks at classes that have exactly two bases.
| Y061 | Do not use `None` inside a `Literal[]` slice. For example, use `Literal["foo"] \| None` instead of `Literal["foo", None]`. While both are legal according to [PEP 586](https://peps.python.org/pep-0586/), the former is preferred for stylistic consistency.
+| Y062 | Protocol methods should not have positional-or-keyword parameters. Usually, a positional-only parameter is better.
## Warnings disabled by default
diff --git a/pyi.py b/pyi.py
index 83d5620c..4710c889 100644
--- a/pyi.py
+++ b/pyi.py
@@ -3,6 +3,7 @@
import argparse
import ast
+import contextlib
import logging
import re
import sys
@@ -968,6 +969,7 @@ def __init__(self, filename: str) -> None:
self.string_literals_allowed = NestingCounter()
self.in_function = NestingCounter()
self.in_class = NestingCounter()
+ self.in_protocol = NestingCounter()
self.visiting_arg = NestingCounter()
def __repr__(self) -> str:
@@ -1592,23 +1594,32 @@ def _check_class_bases(self, bases: list[ast.expr]) -> None:
self.error(Generic_basenode, Y060)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
+ class_is_protocol = class_is_typeddict = False
+ for base in node.bases:
+ if _is_Protocol(base):
+ class_is_protocol = True
+ if _is_TypedDict(base):
+ class_is_typeddict = True
+
if node.name.startswith("_") and not self.in_class.active:
- for base in node.bases:
- if _is_Protocol(base):
- self.protocol_defs[node.name].append(node)
- break
- if _is_TypedDict(base):
- self.class_based_typeddicts[node.name].append(node)
- break
+ if class_is_protocol:
+ self.protocol_defs[node.name].append(node)
+ if class_is_typeddict:
+ self.class_based_typeddicts[node.name].append(node)
old_class_node = self.current_class_node
self.current_class_node = node
- with self.in_class.enabled():
+ with contextlib.ExitStack() as stack:
+ stack.enter_context(self.in_class.enabled())
+ if class_is_protocol:
+ stack.enter_context(self.in_protocol.enabled())
self.generic_visit(node)
self.current_class_node = old_class_node
self._check_class_bases(node.bases)
+ self.check_class_pass_and_ellipsis(node)
+ def check_class_pass_and_ellipsis(self, node: ast.ClassDef) -> None:
# empty class body should contain "..." not "pass"
if len(node.body) == 1:
statement = node.body[0]
@@ -1987,6 +1998,12 @@ def check_self_typevars(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> N
return_annotation=return_annotation,
)
+ def check_arg_kinds(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
+ for pos_or_kw in node.args.args[1:]: # exclude "self"
+ if pos_or_kw.arg.startswith("__"):
+ continue
+ self.error(pos_or_kw, Y062.format(arg=pos_or_kw.arg, method=node.name))
+
def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
with self.in_function.enabled():
self.generic_visit(node)
@@ -2010,6 +2027,8 @@ def _visit_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if self.in_class.active:
self.check_self_typevars(node)
+ if self.in_protocol.active:
+ self.check_arg_kinds(node)
def visit_arg(self, node: ast.arg) -> None:
if _is_NoReturn(node.annotation):
@@ -2233,6 +2252,10 @@ def parse_options(options: argparse.Namespace) -> None:
"class would be inferred as generic anyway"
)
Y061 = 'Y061 None inside "Literal[]" expression. Replace with "{suggestion}"'
+Y062 = (
+ 'Y062 Argument "{arg}" to protocol method "{method}" should probably not be positional-or-keyword. '
+ "Make it positional-only, since usually you don't want to mandate a specific argument name"
+)
Y090 = (
'Y090 "{original}" means '
'"a tuple of length 1, in which the sole element is of type {typ!r}". '
diff --git a/tests/protocol_arg.pyi b/tests/protocol_arg.pyi
new file mode 100644
index 00000000..9bb77b94
--- /dev/null
+++ b/tests/protocol_arg.pyi
@@ -0,0 +1,9 @@
+from typing import Protocol
+
+class P(Protocol):
+ def method1(self, arg: int) -> None: ... # Y062 Argument "arg" to protocol method "method1" should not be positional-or-keyword (suggestion: make it positional-only)
+ def method2(self, arg: str, /) -> None: ...
+ def method3(self, *, arg: str) -> None: ...
+ def method4(self, __arg: int) -> None: ...
+ def method5(self, __arg: int, *, foo: str) -> None: ...
+ def method6(self, arg, /, *, foo: str) -> None: ...