Skip to content

Commit

Permalink
refactor WebRequest.post() method
Browse files Browse the repository at this point in the history
  • Loading branch information
fafhrd91 committed Feb 21, 2017
1 parent caa6bdb commit f2dbdf9
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 86 deletions.
4 changes: 2 additions & 2 deletions aiohttp/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ class AbstractPayloadWriter(ABC):
def write(self, chunk):
"""Write chunk into stream"""

@asyncio.coroutine # pragma: no branch
@asyncio.coroutine
@abstractmethod
def write_eof(self, chunk=b''):
"""Write last chunk"""

@asyncio.coroutine # pragma: no branch
@asyncio.coroutine
@abstractmethod
def drain(self):
"""Flush the write buffer."""
28 changes: 19 additions & 9 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH,
CONTENT_TRANSFER_ENCODING, CONTENT_TYPE)
from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype
from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype, reify
from .http import HttpParser
from .payload import (BytesPayload, LookupError, Payload, StringPayload,
get_payload)
Expand Down Expand Up @@ -113,18 +113,19 @@ def unescape(text, *, chars=''.join(map(re.escape, CHAR))):
return disptype.lower(), params


def content_disposition_filename(params):
def content_disposition_filename(params, name='filename'):
name_suf = '%s*' % name
if not params:
return None
elif 'filename*' in params:
return params['filename*']
elif 'filename' in params:
return params['filename']
elif name_suf in params:
return params[name_suf]
elif name in params:
return params[name]
else:
parts = []
fnparams = sorted((key, value)
for key, value in params.items()
if key.startswith('filename*'))
if key.startswith(name_suf))
for num, (key, value) in enumerate(fnparams):
_, tail = key.split('*', 1)
if tail.endswith('*'):
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(self, boundary, headers, content):
self._unread = deque()
self._prev_chunk = None
self._content_eof = 0
self._cache = {}

if PY_35:
def __aiter__(self):
Expand Down Expand Up @@ -466,13 +468,21 @@ def get_charset(self, default=None):
*_, params = parse_mimetype(ctype)
return params.get('charset', default)

@property
@reify
def name(self):
"""Returns filename specified in Content-Disposition header or ``None``
if missed or header is malformed."""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params, 'name')

@reify
def filename(self):
"""Returns filename specified in Content-Disposition header or ``None``
if missed or header is malformed."""
_, params = parse_content_disposition(
self.headers.get(CONTENT_DISPOSITION))
return content_disposition_filename(params)
return content_disposition_filename(params, 'filename')


class MultipartReader(object):
Expand Down
97 changes: 40 additions & 57 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
import binascii
import cgi
import collections
import datetime
import io
import json
import re
import tempfile
import warnings
from email.utils import parsedate
from types import MappingProxyType
from urllib.parse import parse_qsl

from multidict import CIMultiDict, MultiDict, MultiDictProxy
from yarl import URL
Expand All @@ -19,7 +18,8 @@

__all__ = ('BaseRequest', 'FileField', 'Request')

FileField = collections.namedtuple('Field', 'name filename file content_type')
FileField = collections.namedtuple(
'Field', 'name filename file content_type headers')


############################################################
Expand All @@ -40,7 +40,6 @@ def __init__(self, message, payload, protocol, time_service, task, *,
self._protocol = protocol
self._transport = protocol.transport
self._post = None
self._post_files_cache = None

self._payload = payload
self._headers = message.headers
Expand Down Expand Up @@ -233,18 +232,6 @@ def GET(self):
DeprecationWarning)
return self._rel_url.query

@reify
def POST(self):
"""A multidict with all the variables in the POST parameters.
post() methods has to be called before using this attribute.
"""
warnings.warn("POST property is deprecated, use .post() instead",
DeprecationWarning)
if self._post is None:
raise RuntimeError("POST is not available before post()")
return self._post

@property
def headers(self):
"""A case-insensitive multidict proxy with all headers."""
Expand Down Expand Up @@ -400,47 +387,43 @@ def post(self):
warnings.warn('To process multipart requests use .multipart'
' coroutine instead.', DeprecationWarning)

body = yield from self.read()
content_charset = self.charset or 'utf-8'

environ = {'REQUEST_METHOD': self._method,
'CONTENT_LENGTH': str(len(body)),
'QUERY_STRING': '',
'CONTENT_TYPE': self._headers.get(hdrs.CONTENT_TYPE)}

fs = cgi.FieldStorage(fp=io.BytesIO(body),
environ=environ,
keep_blank_values=True,
encoding=content_charset)

supported_transfer_encoding = {
'base64': binascii.a2b_base64,
'quoted-printable': binascii.a2b_qp
}

out = MultiDict()
_count = 1
for field in fs.list or ():
transfer_encoding = field.headers.get(
hdrs.CONTENT_TRANSFER_ENCODING, None)
if field.filename:
ff = FileField(field.name,
field.filename,
field.file, # N.B. file closed error
field.type)
if self._post_files_cache is None:
self._post_files_cache = {}
self._post_files_cache[field.name+str(_count)] = field
_count += 1
out.add(field.name, ff)
else:
value = field.value
if transfer_encoding in supported_transfer_encoding:
# binascii accepts bytes
value = value.encode('utf-8')
value = supported_transfer_encoding[
transfer_encoding](value)
out.add(field.name, value)

if content_type == 'multipart/form-data':
multipart = yield from self.multipart()

field = yield from multipart.next()
while field is not None:
content_type = field.headers.get(hdrs.CONTENT_TYPE)

if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = yield from field.read_chunk(size=2**16)
while chunk:
tmp.write(field.decode(chunk))
chunk = yield from field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
tmp, content_type, field.headers)
out.add(field.name, ff)
else:
value = yield from field.read(decode=True)
if content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)

field = yield from multipart.next()
else:
data = yield from self.read()
if data:
charset = self.charset or 'utf-8'
out.extend(
parse_qsl(
data.rstrip().decode(charset),
encoding=charset))

self._post = MultiDictProxy(out)
return self._post
Expand Down
6 changes: 4 additions & 2 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, *, status=200, reason=None, headers=None):
self._req = None
self._payload_writer = None
self._eof_sent = False
self._body_length = 0

if headers is not None:
self._headers = CIMultiDict(headers)
Expand Down Expand Up @@ -98,11 +99,11 @@ def force_close(self):

@property
def body_length(self):
return self._payload_writer.output_length
return self._body_length

@property
def output_length(self):
return self._payload_writer.output_length
return self._payload_writer.buffer_size

def enable_chunked_encoding(self, chunk_size=None):
"""Enables automatic chunked transfer encoding."""
Expand Down Expand Up @@ -418,6 +419,7 @@ def write_eof(self, data=b''):
yield from self._payload_writer.write_eof(data)
self._eof_sent = True
self._req = None
self._body_length = self._payload_writer.output_length
self._payload_writer = None

def __repr__(self):
Expand Down
5 changes: 5 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ timeout = 10
[isort]
known_third_party=jinja2
known_first_party=aiohttp,aiohttp_jinja2,aiopg

[report]
exclude_lines =
@abc.abstractmethod
@abstractmethod
9 changes: 4 additions & 5 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,6 @@ def handler(request):
resp.close()


@pytest.mark.xfail
@asyncio.coroutine
def test_POST_DATA_with_charset_post(loop, test_client):
@asyncio.coroutine
Expand All @@ -1064,8 +1063,9 @@ def test_POST_DATA_with_context_transfer_encoding(loop, test_client):
@asyncio.coroutine
def handler(request):
data = yield from request.post()
assert data['name'] == b'text' # should it be str?
return web.Response(body=data['name'])
print(data)
assert data['name'] == 'text'
return web.Response(text=data['name'])

app = web.Application(loop=loop)
app.router.add_post('/', handler)
Expand All @@ -1081,14 +1081,13 @@ def handler(request):
resp.close()


@pytest.mark.xfail
@asyncio.coroutine
def test_POST_DATA_with_content_type_context_transfer_encoding(
loop, test_client):
@asyncio.coroutine
def handler(request):
data = yield from request.post()
assert data['name'] == 'text' # should it be str?
assert data['name'] == 'text'
return web.Response(body=data['name'])

app = web.Application(loop=loop)
Expand Down
11 changes: 0 additions & 11 deletions tests/test_web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,6 @@ def test_doubleslashes(make_request):
assert '/bar//foo/' == req.path


def test_POST(make_request):
req = make_request('POST', '/')
with pytest.raises(RuntimeError):
req.POST

marker = object()
req._post = marker
assert req.POST is marker
assert req.POST is marker


def test_content_type_not_specified(make_request):
req = make_request('Get', '/')
assert 'application/octet-stream' == req.content_type
Expand Down

0 comments on commit f2dbdf9

Please sign in to comment.