Skip to content

Commit

Permalink
Fix serialization of Content-Disposition
Browse files Browse the repository at this point in the history
  • Loading branch information
JWCook committed Apr 9, 2021
1 parent 575ee00 commit ce714de
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 25 deletions.
63 changes: 38 additions & 25 deletions aiohttp_client_cache/response.py
Expand Up @@ -5,8 +5,8 @@
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union

import attr
from aiohttp import ClientResponse, ClientResponseError
from aiohttp.client_reqrep import ContentDisposition, RequestInfo
from aiohttp import ClientResponse, ClientResponseError, hdrs, multipart
from aiohttp.client_reqrep import ContentDisposition, MappingProxyType, RequestInfo
from aiohttp.typedefs import RawHeaders, StrOrURL
from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy
from yarl import URL
Expand Down Expand Up @@ -41,11 +41,10 @@ class CachedResponse:
method: str = attr.ib()
reason: str = attr.ib()
status: int = attr.ib()
url: StrOrURL = attr.ib()
url: URL = attr.ib(converter=URL)
version: str = attr.ib()
_body: Any = attr.ib(default=None)
_links: LinkItems = attr.ib(factory=list)
content_disposition: ContentDisposition = attr.ib(default=None)
cookies: SimpleCookie = attr.ib(default=None)
created_at: datetime = attr.ib(factory=datetime.utcnow)
encoding: str = attr.ib(default=None)
Expand Down Expand Up @@ -78,24 +77,22 @@ async def from_client_response(cls, client_response: ClientResponse, expires: da
except RuntimeError:
pass

response.url = str(client_response.url)
if client_response.history:
response.history = (
*[await cls.from_client_response(r) for r in client_response.history],
)
return response

@property
def ok(self) -> bool:
"""Returns ``True`` if ``status`` is less than ``400``, ``False`` if not"""
try:
self.raise_for_status()
return True
except ClientResponseError:
return False

def get_encoding(self):
return self.encoding
def content_disposition(self) -> Optional[ContentDisposition]:
"""Get Content-Disposition headers, if any"""
raw = self.headers.get(hdrs.CONTENT_DISPOSITION)
if raw is None:
return None
disposition_type, params_dct = multipart.parse_content_disposition(raw)
params = MappingProxyType(params_dct)
filename = multipart.content_disposition_filename(params)
return ContentDisposition(disposition_type, params, filename)

@property
def headers(self) -> CIMultiDictProxy[str]:
Expand All @@ -109,6 +106,10 @@ def decode_header(header):

return CIMultiDictProxy(CIMultiDict([decode_header(h) for h in self.raw_headers]))

@property
def host(self) -> str:
return self.url.host or ''

@property
def is_expired(self) -> bool:
"""Determine if this cached response is expired"""
Expand All @@ -120,14 +121,35 @@ def links(self) -> LinkMultiDict:
items = [(k, _to_url_multidict(v)) for k, v in self._links]
return MultiDictProxy(MultiDict([(k, MultiDictProxy(v)) for k, v in items]))

@property
def ok(self) -> bool:
"""Returns ``True`` if ``status`` is less than ``400``, ``False`` if not"""
try:
self.raise_for_status()
return True
except ClientResponseError:
return False

@property
def request_info(self) -> RequestInfo:
return RequestInfo(
url=URL(self.url),
method=self.method,
headers=self.headers,
real_url=URL(self.real_url),
)

def get_encoding(self):
return self.encoding

async def json(self, encoding: Optional[str] = None, **kwargs) -> Optional[Dict[str, Any]]:
"""Read and decode JSON response"""
stripped = self._body.strip()
if not stripped:
return None
return json.loads(stripped.decode(encoding or self.encoding))

def raise_for_status(self) -> None:
def raise_for_status(self):
if self.status >= 400:
raise ClientResponseError(
self.request_info, # type: ignore # These types are interchangeable
Expand All @@ -144,15 +166,6 @@ async def read(self) -> bytes:
def release(self):
"""No-op function for compatibility with ClientResponse"""

@property
def request_info(self) -> RequestInfo:
return RequestInfo(
url=URL(self.url),
method=self.method,
headers=self.headers,
real_url=URL(self.real_url),
)

async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str:
"""Read response payload and decode"""
return self._body.decode(encoding or self.encoding, errors=errors)
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_response.py
Expand Up @@ -37,6 +37,7 @@ async def test_basic_attrs(aiohttp_client):
assert response.method == 'GET'
assert response.reason == 'Not Found'
assert response.status == 404
assert isinstance(response.url, URL)
assert response.encoding == 'utf-8'
assert response.headers['Content-Type'] == 'text/plain; charset=utf-8'
assert await response.text() == '404: Not Found'
Expand Down

0 comments on commit ce714de

Please sign in to comment.