Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing of ssl parameter #7335

Merged
merged 5 commits into from Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7335.misc
@@ -0,0 +1 @@
Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer`
7 changes: 4 additions & 3 deletions aiohttp/client.py
Expand Up @@ -22,6 +22,7 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -344,7 +345,7 @@ async def _request(
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel,
ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -677,7 +678,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
Expand Down Expand Up @@ -723,7 +724,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
Expand Down
7 changes: 4 additions & 3 deletions aiohttp/client_reqrep.py
Expand Up @@ -17,6 +17,7 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -156,7 +157,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, bool, Fingerprint]
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -210,7 +211,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, bool, Fingerprint, None] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -270,7 +271,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]:
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
return self._ssl

@property
Expand Down
27 changes: 14 additions & 13 deletions aiohttp/connector.py
Expand Up @@ -22,6 +22,7 @@
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -464,7 +465,7 @@ def _available_connections(self, key: "ConnectionKey") -> int:
return available

async def connect(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Connection:
"""Get from pool or create new connection."""
key = req.connection_key
Expand Down Expand Up @@ -659,7 +660,7 @@ def _release(
)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
raise NotImplementedError()

Expand Down Expand Up @@ -734,7 +735,7 @@ def __init__(
use_dns_cache: bool = True,
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl: Union[None, bool, Fingerprint, SSLContext] = None,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, _SENTINEL] = sentinel,
Expand Down Expand Up @@ -870,7 +871,7 @@ async def _resolve_host(
return self._cached_hosts.next_addrs(key)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
"""Create connection.

Expand Down Expand Up @@ -906,7 +907,7 @@ def _make_ssl_context(verified: bool) -> SSLContext:
sslcontext.set_default_verify_paths()
return sslcontext

def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context

0. if req.ssl is false, return None
Expand Down Expand Up @@ -939,7 +940,7 @@ def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
else:
return None

def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
return ret
Expand All @@ -951,7 +952,7 @@ def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
**kwargs: Any,
Expand All @@ -973,7 +974,7 @@ async def _wrap_create_connection(
def _warn_about_tls_in_tls(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
) -> None:
"""Issue a warning if the requested URL has HTTPS scheme."""
if req.request_info.url.scheme != "https":
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def _warn_about_tls_in_tls(
async def _start_tls_connection(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
Expand Down Expand Up @@ -1070,7 +1071,7 @@ async def _start_tls_connection(

async def _create_direct_connection(
self,
req: "ClientRequest",
req: ClientRequest,
traces: List["Trace"],
timeout: "ClientTimeout",
*,
Expand Down Expand Up @@ -1146,7 +1147,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
raise last_exc

async def _create_proxy_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
headers: Dict[str, str] = {}
if req.proxy_headers is not None:
Expand Down Expand Up @@ -1285,7 +1286,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down Expand Up @@ -1345,7 +1346,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down