Skip to content

Commit

Permalink
Fix typing of ssl parameter (#7335)
Browse files Browse the repository at this point in the history
`True` is not an allowed value.

(cherry picked from commit cff007e)
  • Loading branch information
Dreamsorcerer committed Jul 4, 2023
1 parent 156bf5d commit fa6ab8f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES/7335.misc
@@ -0,0 +1 @@
Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer`
7 changes: 4 additions & 3 deletions aiohttp/client.py
Expand Up @@ -22,6 +22,7 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -396,7 +397,7 @@ async def _request(
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -724,7 +725,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down Expand Up @@ -776,7 +777,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down
7 changes: 4 additions & 3 deletions aiohttp/client_reqrep.py
Expand Up @@ -16,6 +16,7 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -210,7 +211,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, bool, Fingerprint]
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -272,7 +273,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, bool, Fingerprint, None] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -330,7 +331,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]:
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
return self._ssl

@property
Expand Down
27 changes: 14 additions & 13 deletions aiohttp/connector.py
Expand Up @@ -20,6 +20,7 @@
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -487,7 +488,7 @@ def _available_connections(self, key: "ConnectionKey") -> int:
return available

async def connect(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Connection:
"""Get from pool or create new connection."""
key = req.connection_key
Expand Down Expand Up @@ -679,7 +680,7 @@ def _release(
)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
raise NotImplementedError()

Expand Down Expand Up @@ -757,7 +758,7 @@ def __init__(
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl_context: Optional[SSLContext] = None,
ssl: Union[None, bool, Fingerprint, SSLContext] = None,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, object] = sentinel,
Expand Down Expand Up @@ -894,7 +895,7 @@ async def _resolve_host(
return self._cached_hosts.next_addrs(key)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
"""Create connection.
Expand Down Expand Up @@ -930,7 +931,7 @@ def _make_ssl_context(verified: bool) -> SSLContext:
sslcontext.set_default_verify_paths()
return sslcontext

def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
Expand Down Expand Up @@ -963,7 +964,7 @@ def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
else:
return None

def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
return ret
Expand All @@ -975,7 +976,7 @@ def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
**kwargs: Any,
Expand Down Expand Up @@ -1040,7 +1041,7 @@ def _loop_supports_start_tls(self) -> bool:
def _warn_about_tls_in_tls(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
) -> None:
"""Issue a warning if the requested URL has HTTPS scheme."""
if req.request_info.url.scheme != "https":
Expand Down Expand Up @@ -1077,7 +1078,7 @@ def _warn_about_tls_in_tls(
async def _start_tls_connection(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
Expand Down Expand Up @@ -1137,7 +1138,7 @@ async def _start_tls_connection(

async def _create_direct_connection(
self,
req: "ClientRequest",
req: ClientRequest,
traces: List["Trace"],
timeout: "ClientTimeout",
*,
Expand Down Expand Up @@ -1214,7 +1215,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
raise last_exc

async def _create_proxy_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
self._fail_on_no_start_tls(req)
runtime_has_start_tls = self._loop_supports_start_tls()
Expand Down Expand Up @@ -1382,7 +1383,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down Expand Up @@ -1444,7 +1445,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down

0 comments on commit fa6ab8f

Please sign in to comment.