From f2dbdf95e95066e162f1f0587d02cba46b437eb4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 21 Feb 2017 00:00:09 -0800 Subject: [PATCH] refactor WebRequest.post() method --- aiohttp/abc.py | 4 +- aiohttp/multipart.py | 28 +++++++--- aiohttp/web_request.py | 97 ++++++++++++++------------------- aiohttp/web_response.py | 6 +- setup.cfg | 5 ++ tests/test_client_functional.py | 9 ++- tests/test_web_request.py | 11 ---- 7 files changed, 74 insertions(+), 86 deletions(-) diff --git a/aiohttp/abc.py b/aiohttp/abc.py index bc38f9bf7c8..d1e4c460ec6 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -137,12 +137,12 @@ class AbstractPayloadWriter(ABC): def write(self, chunk): """Write chunk into stream""" - @asyncio.coroutine # pragma: no branch + @asyncio.coroutine @abstractmethod def write_eof(self, chunk=b''): """Write last chunk""" - @asyncio.coroutine # pragma: no branch @asyncio.coroutine + @abstractmethod def drain(self): """Flush the write buffer.""" diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 7952905c681..ccc3b1a11c0 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -13,7 +13,7 @@ from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) -from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype +from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype, reify from .http import HttpParser from .payload import (BytesPayload, LookupError, Payload, StringPayload, get_payload) @@ -113,18 +113,19 @@ def unescape(text, *, chars=''.join(map(re.escape, CHAR))): return disptype.lower(), params -def content_disposition_filename(params): +def content_disposition_filename(params, name='filename'): + name_suf = '%s*' % name if not params: return None - elif 'filename*' in params: - return params['filename*'] - elif 'filename' in params: - return params['filename'] + elif name_suf in params: + return params[name_suf] + elif name in params: + return params[name] else: parts = [] fnparams = sorted((key, value) for key, value in params.items() - if key.startswith('filename*')) + if key.startswith(name_suf)) for num, (key, value) in enumerate(fnparams): _, tail = key.split('*', 1) if tail.endswith('*'): @@ -203,6 +204,7 @@ def __init__(self, boundary, headers, content): self._unread = deque() self._prev_chunk = None self._content_eof = 0 + self._cache = {} if PY_35: def __aiter__(self): @@ -466,13 +468,21 @@ def get_charset(self, default=None): *_, params = parse_mimetype(ctype) return params.get('charset', default) - @property + @reify + def name(self): + """Returns filename specified in Content-Disposition header or ``None`` + if missed or header is malformed.""" + _, params = parse_content_disposition( + self.headers.get(CONTENT_DISPOSITION)) + return content_disposition_filename(params, 'name') + + @reify def filename(self): """Returns filename specified in Content-Disposition header or ``None`` if missed or header is malformed.""" _, params = parse_content_disposition( self.headers.get(CONTENT_DISPOSITION)) - return content_disposition_filename(params) + return content_disposition_filename(params, 'filename') class MultipartReader(object): diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 424605321b8..4428efe4544 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -1,14 +1,13 @@ import asyncio -import binascii -import cgi import collections import datetime -import io import json import re +import tempfile import warnings from email.utils import parsedate from types import MappingProxyType +from urllib.parse import parse_qsl from multidict import CIMultiDict, MultiDict, MultiDictProxy from yarl import URL @@ -19,7 +18,8 @@ __all__ = ('BaseRequest', 'FileField', 'Request') -FileField = collections.namedtuple('Field', 'name filename file content_type') +FileField = collections.namedtuple( + 'Field', 'name filename file content_type headers') ############################################################ @@ -40,7 +40,6 @@ def __init__(self, message, payload, protocol, time_service, task, *, self._protocol = protocol self._transport = protocol.transport self._post = None - self._post_files_cache = None self._payload = payload self._headers = message.headers @@ -233,18 +232,6 @@ def GET(self): DeprecationWarning) return self._rel_url.query - @reify - def POST(self): - """A multidict with all the variables in the POST parameters. - - post() methods has to be called before using this attribute. - """ - warnings.warn("POST property is deprecated, use .post() instead", - DeprecationWarning) - if self._post is None: - raise RuntimeError("POST is not available before post()") - return self._post - @property def headers(self): """A case-insensitive multidict proxy with all headers.""" @@ -400,47 +387,43 @@ def post(self): warnings.warn('To process multipart requests use .multipart' ' coroutine instead.', DeprecationWarning) - body = yield from self.read() - content_charset = self.charset or 'utf-8' - - environ = {'REQUEST_METHOD': self._method, - 'CONTENT_LENGTH': str(len(body)), - 'QUERY_STRING': '', - 'CONTENT_TYPE': self._headers.get(hdrs.CONTENT_TYPE)} - - fs = cgi.FieldStorage(fp=io.BytesIO(body), - environ=environ, - keep_blank_values=True, - encoding=content_charset) - - supported_transfer_encoding = { - 'base64': binascii.a2b_base64, - 'quoted-printable': binascii.a2b_qp - } - out = MultiDict() - _count = 1 - for field in fs.list or (): - transfer_encoding = field.headers.get( - hdrs.CONTENT_TRANSFER_ENCODING, None) - if field.filename: - ff = FileField(field.name, - field.filename, - field.file, # N.B. file closed error - field.type) - if self._post_files_cache is None: - self._post_files_cache = {} - self._post_files_cache[field.name+str(_count)] = field - _count += 1 - out.add(field.name, ff) - else: - value = field.value - if transfer_encoding in supported_transfer_encoding: - # binascii accepts bytes - value = value.encode('utf-8') - value = supported_transfer_encoding[ - transfer_encoding](value) - out.add(field.name, value) + + if content_type == 'multipart/form-data': + multipart = yield from self.multipart() + + field = yield from multipart.next() + while field is not None: + content_type = field.headers.get(hdrs.CONTENT_TYPE) + + if field.filename: + # store file in temp file + tmp = tempfile.TemporaryFile() + chunk = yield from field.read_chunk(size=2**16) + while chunk: + tmp.write(field.decode(chunk)) + chunk = yield from field.read_chunk(size=2**16) + tmp.seek(0) + + ff = FileField(field.name, field.filename, + tmp, content_type, field.headers) + out.add(field.name, ff) + else: + value = yield from field.read(decode=True) + if content_type.startswith('text/'): + charset = field.get_charset(default='utf-8') + value = value.decode(charset) + out.add(field.name, value) + + field = yield from multipart.next() + else: + data = yield from self.read() + if data: + charset = self.charset or 'utf-8' + out.extend( + parse_qsl( + data.rstrip().decode(charset), + encoding=charset)) self._post = MultiDictProxy(out) return self._post diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index eff07f3df37..2050eebad5f 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -45,6 +45,7 @@ def __init__(self, *, status=200, reason=None, headers=None): self._req = None self._payload_writer = None self._eof_sent = False + self._body_length = 0 if headers is not None: self._headers = CIMultiDict(headers) @@ -98,11 +99,11 @@ def force_close(self): @property def body_length(self): - return self._payload_writer.output_length + return self._body_length @property def output_length(self): - return self._payload_writer.output_length + return self._payload_writer.buffer_size def enable_chunked_encoding(self, chunk_size=None): """Enables automatic chunked transfer encoding.""" @@ -418,6 +419,7 @@ def write_eof(self, data=b''): yield from self._payload_writer.write_eof(data) self._eof_sent = True self._req = None + self._body_length = self._payload_writer.output_length self._payload_writer = None def __repr__(self): diff --git a/setup.cfg b/setup.cfg index 090855dab3c..2884c49c168 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,3 +14,8 @@ timeout = 10 [isort] known_third_party=jinja2 known_first_party=aiohttp,aiohttp_jinja2,aiopg + +[report] +exclude_lines = + @abc.abstractmethod + @abstractmethod diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 42767ff4dee..15f37399f29 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1037,7 +1037,6 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_charset_post(loop, test_client): @asyncio.coroutine @@ -1064,8 +1063,9 @@ def test_POST_DATA_with_context_transfer_encoding(loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() - assert data['name'] == b'text' # should it be str? - return web.Response(body=data['name']) + print(data) + assert data['name'] == 'text' + return web.Response(text=data['name']) app = web.Application(loop=loop) app.router.add_post('/', handler) @@ -1081,14 +1081,13 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_content_type_context_transfer_encoding( loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() - assert data['name'] == 'text' # should it be str? + assert data['name'] == 'text' return web.Response(body=data['name']) app = web.Application(loop=loop) diff --git a/tests/test_web_request.py b/tests/test_web_request.py index d7f634550d9..2776f928251 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -58,17 +58,6 @@ def test_doubleslashes(make_request): assert '/bar//foo/' == req.path -def test_POST(make_request): - req = make_request('POST', '/') - with pytest.raises(RuntimeError): - req.POST - - marker = object() - req._post = marker - assert req.POST is marker - assert req.POST is marker - - def test_content_type_not_specified(make_request): req = make_request('Get', '/') assert 'application/octet-stream' == req.content_type