diff --git a/service_client/plugins.py b/service_client/plugins.py index e5db946..7f736fd 100644 --- a/service_client/plugins.py +++ b/service_client/plugins.py @@ -2,6 +2,8 @@ from asyncio import coroutine from urllib.parse import quote_plus from aiohttp.multidict import CIMultiDict +from aiohttp.multipart import MultipartWriter + from service_client.utils import IncompleteFormatter from dirty_loader import LoaderNamespaceReversedCached @@ -103,3 +105,61 @@ def prepare_session(self, service_desc, session, request_params): session.request.set_request_params(request_params) except AttributeError: pass + + +class Multipart(BasePlugin): + + def __init__(self, default_multipart_content_type='form-data', + default_content_disposition='attachment'): + self.default_multipart_content_type = default_multipart_content_type + self.default_content_disposition = default_content_disposition + + @coroutine + def before_request(self, service_desc, session, request_params): + if request_params['method'].upper() in ['GET', 'DELETE'] or 'files' not in request_params: + return + + try: + multipart_content_type = service_desc['multipart']['content-type'] + except KeyError: # pragma: no cover + multipart_content_type = self.default_multipart_content_type + + try: + content_type = request_params['headers'].pop('content-type') + except KeyError: # pragma: no cover + content_type = '' + + mp = MultipartWriter(multipart_content_type) + + try: + data = request_params['data'] + mp.append(data, {'content-type': content_type}) + except KeyError: # pragma: no cover + pass + + files = request_params.pop('files') + + for f in files: + try: + file_headers = f['headers'] + except KeyError: + file_headers = None + part = mp.append(f['file'], file_headers) + + try: + content_disposition = f['content-disposition'] + except KeyError: + try: + content_disposition = service_desc['multipart']['content-disposition'] + except KeyError: # pragma: no cover + content_disposition = self.default_content_disposition + + params = {} + try: + params['filename'] = f['filename'] + except KeyError: + pass + + part.set_content_disposition(content_disposition, **params) + + request_params['data'] = mp diff --git a/tests/tests_plugins.py b/tests/tests_plugins.py index b8d9c6e..19255cb 100644 --- a/tests/tests_plugins.py +++ b/tests/tests_plugins.py @@ -5,10 +5,11 @@ ''' from asyncio import coroutine from aiohttp.client import ClientSession +from aiohttp.multipart import MultipartWriter from asynctest.case import TestCase from service_client import SessionWrapper -from service_client.plugins import Path, Timeout, Headers, QueryParams, Mock +from service_client.plugins import Path, Timeout, Headers, QueryParams, Mock, Multipart class PathTest(TestCase): @@ -248,7 +249,7 @@ def setUp(self): }} self.service_client = type('DynTestServiceClient', (), - {'rest_service_name': 'test_service_name', + {'service_name': 'test_service_name', 'loop': self.loop})() self.plugin.assign_service_client(self.service_client) @@ -259,3 +260,73 @@ def test_calling_mock(self): self.assertIsInstance(self.session.request, FakeMock) response = self.session.request('POST', 'default_url') self.assertEqual(200, response.status) + + +class TestMultipart(TestCase): + + def setUp(self): + + self.plugin = Multipart() + self.session = SessionWrapper(ClientSession()) + self.service_desc = {'multipart': { + 'content-type': 'alternative', + 'content-disposition': 'inline' + }} + + self.service_client = type('DynTestServiceClient', (), + {'service_name': 'test_service_name', + 'loop': self.loop})() + self.plugin.assign_service_client(self.service_client) + + @coroutine + def test_no_files(self): + + request_params = {'method': 'post', + 'data': 'aaaaaaa'} + yield from self.plugin.before_request(self.service_desc, self.session, request_params) + + self.assertEqual(request_params['data'], 'aaaaaaa') + + @coroutine + def test_multi_files_get(self): + request_params = {'method': 'get', + 'data': 'aaaaaaa', + 'files': [{'file': 'eeeee', + 'filename': 'foo.txt', + 'content-disposition': 'attachment'}, + {'file': 'barbar', + 'filename': 'bar.txt', + 'content-disposition': 'form-data'}, + {'file': 'other'}]} + yield from self.plugin.before_request(self.service_desc, self.session, request_params) + + self.assertEqual(request_params['data'], 'aaaaaaa') + + @coroutine + def test_multi_files_post(self): + request_params = {'method': 'post', + 'headers': { + 'content-type': 'application/json' + }, + 'data': 'aaaaaaa', + 'files': [{'file': 'eeeee', + 'filename': 'foo.txt', + 'headers': {'content-type': 'text/plain'}, + 'content-disposition': 'attachment'}, + {'file': 'barbar', + 'filename': 'bar.txt', + 'content-disposition': 'form-data'}, + {'file': 'other'}]} + yield from self.plugin.before_request(self.service_desc, self.session, request_params) + + self.assertNotIn('files', request_params) + self.assertIsInstance(request_params['data'], MultipartWriter) + self.assertEqual(len(request_params['data'].parts), 4) + self.assertIn('content-type', request_params['data'].headers) + self.assertIn('multipart/alternative', request_params['data'].headers['content-type']) + self.assertEquals('application/json', request_params['data'].parts[0].headers['content-type']) + self.assertIn('attachment', request_params['data'].parts[1].headers['content-disposition']) + self.assertIn('filename="foo.txt"', request_params['data'].parts[1].headers['content-disposition']) + self.assertEquals('text/plain', request_params['data'].parts[1].headers['content-type']) + self.assertIn('form-data', request_params['data'].parts[2].headers['content-disposition']) + self.assertIn('inline', request_params['data'].parts[3].headers['content-disposition'])