Skip to content

Commit

Permalink
Fixed mypy errors and updated type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed May 12, 2022
1 parent 85eec8c commit 303db25
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 52 deletions.
3 changes: 2 additions & 1 deletion docs/versionhistory.rst
Expand Up @@ -7,9 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- **BACKWARDS INCOMPATIBLE** Replaced AnyIO's own ``ExceptionGroup`` class with the PEP 654
``BaseExceptionGroup`` and ``ExceptionGroup``
- Bumped minimum version of trio to v0.19
- Bumped minimum version of trio to v0.20
- Changed the pytest plugin to run both the setup and teardown phases of asynchronous generator
fixtures within a single task to enable use cases where a context manager straddles the ``yield``
- Updated type annotations on ``open_process()`` to accept bytes and sequences of bytes

**3.5.0**

Expand Down
8 changes: 4 additions & 4 deletions src/anyio/_backends/_asyncio.py
Expand Up @@ -1916,7 +1916,7 @@ def create_blocking_portal(cls) -> abc.BlockingPortal:
@classmethod
async def open_process(
cls,
command: str | Sequence[str],
command: str | bytes | Sequence[str | bytes],
*,
shell: bool,
stdin: int | IO[Any] | None,
Expand All @@ -1929,9 +1929,9 @@ async def open_process(
await cls.checkpoint()
if shell:
process = await asyncio.create_subprocess_shell(
command,
cast("str | bytes", command),
stdin=stdin,
stdout=stdout, # type: ignore[arg-type]
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
Expand Down Expand Up @@ -2033,7 +2033,7 @@ async def create_udp_socket(
@classmethod
async def getaddrinfo(
cls,
host: str | bytes,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
Expand Down
42 changes: 17 additions & 25 deletions src/anyio/_backends/_trio.py
Expand Up @@ -32,8 +32,15 @@
)

import trio.from_thread
import trio.lowlevel
from outcome import Error, Outcome, Value
from trio.lowlevel import TrioToken
from trio.lowlevel import (
TrioToken,
current_root_task,
current_task,
wait_readable,
wait_writable,
)
from trio.socket import SocketType as TrioSocketType
from trio.to_thread import run_sync

Expand All @@ -53,21 +60,6 @@
from ..abc import IPSockAddrType, UDPPacketType
from ..abc._eventloop import AsyncBackend

try:
from trio import lowlevel as trio_lowlevel
except ImportError:
from trio import hazmat as trio_lowlevel # type: ignore[no-redef]
from trio.hazmat import wait_readable, wait_writable
else:
from trio.lowlevel import wait_readable, wait_writable

try:
from trio.lowlevel import (
open_process as trio_open_process, # type: ignore[attr-defined]
)
except ImportError:
from trio import open_process as trio_open_process

if TYPE_CHECKING:
from trio_typing import TaskStatus

Expand Down Expand Up @@ -853,7 +845,7 @@ def create_blocking_portal(cls) -> abc.BlockingPortal:
@classmethod
async def open_process(
cls,
command: str | Sequence[str],
command: str | bytes | Sequence[str | bytes],
*,
shell: bool,
stdin: int | IO[Any] | None,
Expand All @@ -863,7 +855,7 @@ async def open_process(
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
process = await trio_open_process(
process = await trio.lowlevel.open_process( # type: ignore[attr-defined]
command,
stdin=stdin,
stdout=stdout,
Expand Down Expand Up @@ -944,17 +936,17 @@ async def create_udp_socket(
@classmethod
async def getaddrinfo(
cls,
host: str | bytes,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
type: int | SocketKind = 0,
proto: int = 0,
flags: int = 0,
) -> GetAddrInfoReturnType:
# https: // github.com / python - trio / trio - typing / pull / 57
# https://github.com/python-trio/trio-typing/pull/57
return await trio.socket.getaddrinfo( # type: ignore[return-value]
host, port, family, type, proto, flags
host, port, family, type, proto, flags # type: ignore[arg-type]
)

@classmethod
Expand All @@ -967,7 +959,7 @@ async def getnameinfo(
@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
try:
await wait_readable(sock)
await wait_readable(sock) # type: ignore[arg-type]
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
Expand All @@ -976,7 +968,7 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None:
@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
try:
await wait_writable(sock)
await wait_writable(sock) # type: ignore[arg-type]
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
Expand All @@ -1000,7 +992,7 @@ def open_signal_receiver(cls, *signals: Signals) -> ContextManager:

@classmethod
def get_current_task(cls) -> TaskInfo:
task = trio_lowlevel.current_task()
task = current_task()

parent_id = None
if task.parent_nursery and task.parent_nursery.parent_task:
Expand All @@ -1010,7 +1002,7 @@ def get_current_task(cls) -> TaskInfo:

@classmethod
def get_running_tasks(cls) -> list[TaskInfo]:
root_task = trio_lowlevel.current_root_task()
root_task = current_root_task()
task_infos = [TaskInfo(id(root_task), None, root_task.name, root_task.coro)]
nurseries = root_task.child_nurseries
while nurseries:
Expand Down
6 changes: 3 additions & 3 deletions src/anyio/_core/_sockets.py
Expand Up @@ -289,7 +289,7 @@ async def create_tcp_listener(
gai_res = await getaddrinfo(
local_host,
local_port,
family=family, # type: ignore[arg-type]
family=family,
type=socket.SOCK_STREAM,
flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG,
)
Expand Down Expand Up @@ -462,7 +462,7 @@ async def create_connected_udp_socket(


async def getaddrinfo(
host: str,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
Expand Down Expand Up @@ -493,7 +493,7 @@ async def getaddrinfo(
# Handle unicode hostnames
if isinstance(host, str):
try:
encoded_host = host.encode("ascii")
encoded_host: bytes | None = host.encode("ascii")
except UnicodeEncodeError:
import idna

Expand Down
35 changes: 23 additions & 12 deletions src/anyio/_core/_subprocesses.py
Expand Up @@ -86,7 +86,7 @@ async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None:


async def open_process(
command: str | Sequence[str],
command: str | bytes | Sequence[str | bytes],
*,
stdin: int | IO[Any] | None = PIPE,
stdout: int | IO[Any] | None = PIPE,
Expand Down Expand Up @@ -116,14 +116,25 @@ async def open_process(
:return: an asynchronous process object
"""
shell = isinstance(command, str)
return await get_async_backend().open_process(
command,
shell=shell,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
if isinstance(command, (str, bytes)):
return await get_async_backend().open_process(
command,
shell=True,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
else:
return await get_async_backend().open_process(
command,
shell=False,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
49 changes: 46 additions & 3 deletions src/anyio/abc/_eventloop.py
Expand Up @@ -6,9 +6,20 @@
from os import PathLike
from signal import Signals
from socket import AddressFamily, SocketKind, socket
from typing import IO, TYPE_CHECKING, Any, Callable, ContextManager, Sequence, TypeVar
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
ContextManager,
Sequence,
TypeVar,
overload,
)

if TYPE_CHECKING:
from typing import Literal

from .._core._sockets import GetAddrInfoReturnType
from .._core._synchronization import CapacityLimiter, Event
from .._core._tasks import CancelScope
Expand Down Expand Up @@ -199,11 +210,43 @@ def run_sync_from_thread(
def create_blocking_portal(cls) -> BlockingPortal:
pass

@classmethod
@overload
async def open_process(
cls,
command: str | bytes,
*,
shell: Literal[True],
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike[str] | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
pass

@classmethod
@overload
async def open_process(
cls,
command: Sequence[str | bytes],
*,
shell: Literal[False],
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike[str] | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
pass

@classmethod
@abstractmethod
async def open_process(
cls,
command: str | Sequence[str],
command: str | bytes | Sequence[str | bytes],
*,
shell: bool,
stdin: int | IO[Any] | None,
Expand Down Expand Up @@ -257,7 +300,7 @@ async def create_udp_socket(
@abstractmethod
async def getaddrinfo(
cls,
host: str | bytes,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_from_thread.py
Expand Up @@ -454,8 +454,8 @@ def taskfunc(*, task_status: TaskStatus) -> None:

with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
future, start_value = portal.start_task(
taskfunc, name="testname"
) # type: ignore[arg-type]
taskfunc, name="testname" # type: ignore[arg-type]
)
assert start_value == "testname"

def test_contextvar_propagation_sync(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_taskgroups.py
Expand Up @@ -769,8 +769,8 @@ async def task_group_generator() -> AsyncGenerator[None, None]:

gen = task_group_generator()
anyio.run(
gen.__anext__,
backend=anyio_backend_name, # type: ignore[arg-type]
gen.__anext__, # type: ignore[arg-type]
backend=anyio_backend_name,
backend_options=anyio_backend_options,
)
pytest.raises(
Expand Down

0 comments on commit 303db25

Please sign in to comment.