Skip to content

Commit

Permalink
Tightening the runtime type check for ssl (#7698) (#8042)
Browse files Browse the repository at this point in the history
Currently, the valid types of ssl parameter are SSLContext,
Literal[False], Fingerprint or None.

If user sets ssl = False, we disable ssl certificate validation which
makes total sense. But if user set ssl = True by mistake, instead of
enabling ssl certificate validation or raising errors, we silently
disable the validation too which is a little subtle but weird.

In this PR, we added a check that if user sets ssl=True, we enable
certificate validation by treating it as using Default SSL Context.

---------

Co-authored-by: pre-commit-ci[bot]
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sviatoslav Sydorenko <wk.cvs.github@sydorenko.org.ua>
Co-authored-by: Sam Bull <aa6bs0@sambull.org>
Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Sam Bull <git@sambull.org>
(cherry picked from commit 9e14ea1)
(cherry picked from commit 4b91b53)
  • Loading branch information
Dreamsorcerer authored and patchback[bot] committed Jan 20, 2024
1 parent f000174 commit 564d83a
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES/7698.feature
@@ -0,0 +1 @@
Added support for passing `True` to `ssl` while deprecating `None`. -- by :user:`xiangyan99`
17 changes: 12 additions & 5 deletions aiohttp/client.py
Expand Up @@ -22,7 +22,6 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -408,7 +407,7 @@ async def _request(
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
server_hostname: Optional[str] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
Expand Down Expand Up @@ -562,7 +561,7 @@ async def _request(
proxy_auth=proxy_auth,
timer=timer,
session=self,
ssl=ssl,
ssl=ssl if ssl is not None else True,
server_hostname=server_hostname,
proxy_headers=proxy_headers,
traces=traces,
Expand Down Expand Up @@ -738,7 +737,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down Expand Up @@ -790,7 +789,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
verify_ssl: Optional[bool] = None,
fingerprint: Optional[bytes] = None,
ssl_context: Optional[SSLContext] = None,
Expand Down Expand Up @@ -824,6 +823,14 @@ async def _ws_connect(
extstr = ws_ext_gen(compress=compress)
real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr

# For the sake of backward compatibility, if user passes in None, convert it to True
if ssl is None:
warnings.warn(
"ssl=None is deprecated, please use ssl=True",
DeprecationWarning,
stacklevel=2,
)
ssl = True
ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint)

# send request
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/client_exceptions.py
Expand Up @@ -180,12 +180,12 @@ def port(self) -> Optional[int]:
return self._conn_key.port

@property
def ssl(self) -> Union[SSLContext, None, bool, "Fingerprint"]:
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl

def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)

# OSError.__reduce__ does too much black magick
Expand Down Expand Up @@ -219,7 +219,7 @@ def path(self) -> str:

def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, self.ssl if self.ssl is not None else "default", self.strerror
self, "default" if self.ssl is True else self.ssl, self.strerror
)


Expand Down
21 changes: 10 additions & 11 deletions aiohttp/client_reqrep.py
Expand Up @@ -17,7 +17,6 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -151,22 +150,22 @@ def check(self, transport: asyncio.Transport) -> None:
if ssl is not None:
SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None))
else: # pragma: no cover
SSL_ALLOWED_TYPES = type(None)
SSL_ALLOWED_TYPES = (bool, type(None))


def _merge_ssl_params(
ssl: Union["SSLContext", Literal[False], Fingerprint, None],
ssl: Union["SSLContext", bool, Fingerprint],
verify_ssl: Optional[bool],
ssl_context: Optional["SSLContext"],
fingerprint: Optional[bytes],
) -> Union["SSLContext", Literal[False], Fingerprint, None]:
) -> Union["SSLContext", bool, Fingerprint]:
if verify_ssl is not None and not verify_ssl:
warnings.warn(
"verify_ssl is deprecated, use ssl=False instead",
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -179,7 +178,7 @@ def _merge_ssl_params(
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -192,7 +191,7 @@ def _merge_ssl_params(
DeprecationWarning,
stacklevel=3,
)
if ssl is not None:
if ssl is not True:
raise ValueError(
"verify_ssl, ssl_context, fingerprint and ssl "
"parameters are mutually exclusive"
Expand All @@ -214,7 +213,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
ssl: Union[SSLContext, bool, Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -276,7 +275,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
ssl: Union[SSLContext, bool, Fingerprint] = True,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -315,7 +314,7 @@ def __init__(
real_response_class = response_class
self.response_class: Type[ClientResponse] = real_response_class
self._timer = timer if timer is not None else TimerNoop()
self._ssl = ssl
self._ssl = ssl if ssl is not None else True
self.server_hostname = server_hostname

if loop.get_debug():
Expand Down Expand Up @@ -357,7 +356,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
def ssl(self) -> Union["SSLContext", bool, Fingerprint]:
return self._ssl

@property
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/connector.py
Expand Up @@ -762,7 +762,7 @@ def __init__(
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl_context: Optional[SSLContext] = None,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
ssl: Union[bool, Fingerprint, SSLContext] = True,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, object] = sentinel,
Expand Down Expand Up @@ -955,13 +955,13 @@ def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
sslcontext = req.ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
sslcontext = self._ssl
if isinstance(sslcontext, ssl.SSLContext):
return sslcontext
if sslcontext is not None:
if sslcontext is not True:
# not verified or fingerprinted
return self._make_ssl_context(False)
return self._make_ssl_context(True)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_client_exceptions.py
Expand Up @@ -119,7 +119,7 @@ class TestClientConnectorError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand All @@ -136,7 +136,7 @@ def test_ctor(self) -> None:
assert err.os_error.strerror == "No such file"
assert err.host == "example.com"
assert err.port == 8080
assert err.ssl is None
assert err.ssl is True

def test_pickle(self) -> None:
err = client.ClientConnectorError(
Expand All @@ -153,7 +153,7 @@ def test_pickle(self) -> None:
assert err2.os_error.strerror == "No such file"
assert err2.host == "example.com"
assert err2.port == 8080
assert err2.ssl is None
assert err2.ssl is True
assert err2.foo == "bar"

def test_repr(self) -> None:
Expand All @@ -171,7 +171,7 @@ def test_str(self) -> None:
os_error=OSError(errno.ENOENT, "No such file"),
)
assert str(err) == (
"Cannot connect to host example.com:8080 ssl:" "default [No such file]"
"Cannot connect to host example.com:8080 ssl:default [No such file]"
)


Expand All @@ -180,7 +180,7 @@ class TestClientConnectorCertificateError:
host="example.com",
port=8080,
is_ssl=False,
ssl=None,
ssl=True,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_client_fingerprint.py
Expand Up @@ -37,7 +37,7 @@ def test_fingerprint_check_no_ssl() -> None:

def test__merge_ssl_params_verify_ssl() -> None:
with pytest.warns(DeprecationWarning):
assert _merge_ssl_params(None, False, None, None) is False
assert _merge_ssl_params(True, False, None, None) is False


def test__merge_ssl_params_verify_ssl_conflict() -> None:
Expand All @@ -50,7 +50,7 @@ def test__merge_ssl_params_verify_ssl_conflict() -> None:
def test__merge_ssl_params_ssl_context() -> None:
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
with pytest.warns(DeprecationWarning):
assert _merge_ssl_params(None, None, ctx, None) is ctx
assert _merge_ssl_params(True, None, ctx, None) is ctx


def test__merge_ssl_params_ssl_context_conflict() -> None:
Expand All @@ -64,7 +64,7 @@ def test__merge_ssl_params_ssl_context_conflict() -> None:
def test__merge_ssl_params_fingerprint() -> None:
digest = hashlib.sha256(b"123").digest()
with pytest.warns(DeprecationWarning):
ret = _merge_ssl_params(None, None, None, digest)
ret = _merge_ssl_params(True, None, None, digest)
assert ret.fingerprint == digest


Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_request.py
Expand Up @@ -166,7 +166,7 @@ def test_host_port_default_http(make_request) -> None:
req = make_request("get", "http://python.org/")
assert req.host == "python.org"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_host_port_default_https(make_request) -> None:
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_ipv6_default_http_port(make_request) -> None:
req = make_request("get", "http://[2001:db8::1]/")
assert req.host == "2001:db8::1"
assert req.port == 80
assert not req.ssl
assert not req.is_ssl()


def test_ipv6_default_https_port(make_request) -> None:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_connector.py
Expand Up @@ -28,19 +28,19 @@
@pytest.fixture()
def key():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def key2():
# Connection key
return ConnectionKey("localhost", 80, False, None, None, None, None)
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def ssl_key():
# Connection key
return ConnectionKey("localhost", 80, True, None, None, None, None)
return ConnectionKey("localhost", 80, True, True, None, None, None)


@pytest.fixture
Expand Down Expand Up @@ -1219,9 +1219,9 @@ async def test_cleanup_closed_disabled(loop, mocker) -> None:
assert not conn._cleanup_closed_transports


async def test_tcp_connector_ctor(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
assert conn._ssl is None
async def test_tcp_connector_ctor() -> None:
conn = aiohttp.TCPConnector()
assert conn._ssl is True

assert conn.use_dns_cache
assert conn.family == 0
Expand Down Expand Up @@ -1307,7 +1307,7 @@ async def test___get_ssl_context3(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop, ssl=ctx)
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is ctx


Expand All @@ -1333,7 +1333,7 @@ async def test___get_ssl_context6(loop) -> None:
conn = aiohttp.TCPConnector(loop=loop)
req = mock.Mock()
req.is_ssl.return_value = True
req.ssl = None
req.ssl = True
assert conn._get_ssl_context(req) is conn._make_ssl_context(True)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_proxy.py
Expand Up @@ -75,7 +75,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down Expand Up @@ -117,7 +117,7 @@ async def make_conn():
auth=None,
headers={"Host": "www.python.org", "Foo": "Bar"},
loop=self.loop,
ssl=None,
ssl=True,
)

conn.close()
Expand Down

0 comments on commit 564d83a

Please sign in to comment.