From 5cbf27ecbb4318a9fd177307a52146aa993f86c8 Mon Sep 17 00:00:00 2001 From: Grey Li Date: Mon, 29 Mar 2021 19:53:31 +0800 Subject: [PATCH] Add type annotations and refactor some APIs - Add role and optional argument for auth_required. - Rename openapi._OpenAPIMixin to openapi.OpenAPI - Rename module scaffold to utils - Change scaffold.Scaffold to utils.method_route - Change _AuthErrorMixin to function handle_auth_error - Return None for current_user if not found - Merge openapi module back to app - Add mypy check in tox.ini - Change flake-8 max-line-length to 100 --- apiflask/__init__.py | 20 +- apiflask/app.py | 727 ++++++++++++++++++++-- apiflask/blueprint.py | 34 +- apiflask/decorators.py | 110 ++-- apiflask/errors.py | 26 +- apiflask/fields.py | 40 +- apiflask/openapi.py | 588 ----------------- apiflask/py.typed | 0 apiflask/scaffold.py | 43 -- apiflask/schemas.py | 8 +- apiflask/security.py | 48 +- apiflask/settings.py | 81 +-- apiflask/types.py | 18 + apiflask/utils.py | 50 ++ setup.cfg | 14 +- tests/{test_scaffold.py => test_utils.py} | 0 tox.ini | 9 +- 17 files changed, 993 insertions(+), 823 deletions(-) delete mode 100644 apiflask/openapi.py create mode 100644 apiflask/py.typed delete mode 100644 apiflask/scaffold.py create mode 100644 apiflask/types.py create mode 100644 apiflask/utils.py rename tests/{test_scaffold.py => test_utils.py} (100%) diff --git a/apiflask/__init__.py b/apiflask/__init__.py index 999e8dd9..19625fc3 100644 --- a/apiflask/__init__.py +++ b/apiflask/__init__.py @@ -1,9 +1,15 @@ -from .app import APIFlask # noqa: F401 -from .blueprint import APIBlueprint # noqa: F401 -from .decorators import auth_required, input, output, doc # noqa: F401 -from .errors import HTTPError, api_abort # noqa: F401 -from .schemas import Schema # noqa: F401 -from .fields import fields # noqa: F401 -from .security import HTTPBasicAuth, HTTPTokenAuth # noqa: F401 +# flake8: noqa +from .app import APIFlask +from .blueprint import APIBlueprint +from .decorators import input +from .decorators import output +from .decorators import doc +from .decorators import auth_required +from .errors import api_abort +from .errors import HTTPError +from .schemas import Schema +from . import fields +from .security import HTTPBasicAuth +from .security import HTTPTokenAuth __version__ = '0.3.0dev' diff --git a/apiflask/app.py b/apiflask/app.py index d4d00089..7ca90e31 100644 --- a/apiflask/app.py +++ b/apiflask/app.py @@ -1,53 +1,170 @@ +from typing import Iterable, Union, List, Optional, Type, Tuple, Any, Dict +import re +import sys + from flask import Flask +from flask import Blueprint +from flask import render_template +from flask.config import ConfigAttribute from flask.globals import _request_ctx_stack from werkzeug.exceptions import HTTPException as WerkzeugHTTPException +from apispec import APISpec +from apispec.ext.marshmallow import MarshmallowPlugin +from marshmallow import Schema as MarshmallowSchema +from flask_marshmallow import fields +try: + from flask_marshmallow import sqla +except ImportError: + sqla = None -from .openapi import _OpenAPIMixin -from .errors import HTTPError, default_error_handler -from .scaffold import Scaffold +from .errors import HTTPError +from .errors import default_error_handler +from .utils import route_shortcuts +from .security import HTTPBasicAuth +from .security import HTTPTokenAuth +from .errors import get_error_message +from .types import ResponseType +from .types import ErrorCallbackType +from .types import SpecCallbackType -class APIFlask(Flask, Scaffold, _OpenAPIMixin): +@route_shortcuts +class APIFlask(Flask): """ The Flask object with some Web API support. - :param import_name: the name of the application package. - :param title: The title of the API, defaults to "APIFlask". - You can change it to the name of your API (e.g. "Pet API"). - :param version: The version of the API, defaults to "1.0.0". - :param tags: The tags of the OpenAPI spec documentation, accepts a list. - See :attr:`tags` for more details. - :param spec_path: The path to OpenAPI Spec documentation. It - defaults to ``/openapi.json```, if the path end with ``.yaml`` - or ``.yml``, the YAML format of the OAS will be returned. - :param swagger_path: The path to Swagger UI documentation. - :param redoc_path: The path to Redoc documentation. - :param json_errors: If True, APIFlask will return a JSON response - for HTTP errors. - :param enable_openapi: If False, will disable OpenAPI spec and docs views. + Arguments: + import_name: the name of the application package. + title: The title of the API, defaults to "APIFlask". + You can change it to the name of your API (e.g. "Pet API"). + version: The version of the API, defaults to "1.0.0". + tags: The tags of the OpenAPI spec documentation, accepts a list. + See :attr:`tags` for more details. + spec_path: The path to OpenAPI Spec documentation. It + defaults to `/openapi.json`, if the path end with `.yaml` + or `.yml`, the YAML format of the OAS will be returned. + swagger_path: The path to Swagger UI documentation. + redoc_path: The path to Redoc documentation. + json_errors: If True, APIFlask will return a JSON response + for HTTP errors. + enable_openapi: If False, will disable OpenAPI spec and docs views. """ + #: The title of the API (openapi.info.title), defaults to "APIFlask". + #: You can change it to the name of your API (e.g. "Pet API"). + title: Optional[str] = None + + #: The version of the API (openapi.info.version), defaults to "1.0.0". + version: Optional[str] = None + + #: The description of the API (openapi.info.description). + #: + #: This attribute can also be configured from the config with the + #: ``DESCRIPTION`` configuration key. Defaults to ``None``. + description: Optional[str] = ConfigAttribute('DESCRIPTION') # type: ignore + + #: The tags of the OpenAPI spec documentation (openapi.tags), accepts a + #: list of dicts. + #: You can also pass a simple list contains the tag name:: + #: + #: app.tags = ['foo', 'bar', 'baz'] + #: + #: A standard OpenAPI tags list will look like this:: + #: + #: app.tags = [ + #: {'name': 'foo', 'description': 'The description of foo'}, + #: {'name': 'bar', 'description': 'The description of bar'}, + #: {'name': 'baz', 'description': 'The description of baz'} + #: ] + #: + #: If not set, the blueprint names will be used as tags. + #: + #: This attribute can also be configured from the config with the + #: ``TAGS`` configuration key. Defaults to ``None``. + tags: Optional[Union[List[str], List[Dict[str, str]]] + ] = ConfigAttribute('TAGS') # type: ignore + + #: The contact information of the API (openapi.info.contact). + #: Example value: + #: + #: app.contact = { + #: 'name': 'API Support', + #: 'url': 'http://www.example.com/support', + #: 'email': 'support@example.com' + #: } + #: + #: This attribute can also be configured from the config with the + #: ``CONTACT`` configuration key. Defaults to ``None``. + contact: Optional[Dict[str, str]] = ConfigAttribute('CONTACT') # type: ignore + + #: The license of the API (openapi.info.license). + #: Example value: + #: + #: app.license = { + #: 'name': 'Apache 2.0', + #: 'url': 'http://www.apache.org/licenses/LICENSE-2.0.html' + #: } + #: + #: This attribute can also be configured from the config with the + #: ``LICENSE`` configuration key. Defaults to ``None``. + license: Optional[Dict[str, str]] = ConfigAttribute('LICENSE') # type: ignore + + #: The servers information of the API (openapi.servers), accepts multiple + #: server dicts. + #: Example value: + #: + #: app.servers = [ + #: { + #: 'name': 'Production Server', + #: 'url': 'http://api.example.com' + #: } + #: ] + #: + #: This attribute can also be configured from the config with the + #: ``SERVERS`` configuration key. Defaults to ``None``. + servers: Optional[List[Dict[str, str]]] = ConfigAttribute('SERVERS') # type: ignore + + #: The external documentation information of the API (openapi.externalDocs). + #: Example value: + #: + #: app.external_docs = { + #: 'description': 'Find more info here', + #: 'url': 'http://docs.example.com' + #: } + #: + #: This attribute can also be configured from the config with the + #: ``EXTERNAL_DOCS`` configuration key. Defaults to ``None``. + external_docs: Optional[Dict[str, str]] = ConfigAttribute('EXTERNAL_DOCS') # type: ignore + + #: The terms of service URL of the API (openapi.info.termsOfService). + #: Example value: + #: + #: app.terms_of_service = "http://example.com/terms/" + #: + #: This attribute can also be configured from the config with the + #: ``TERMS_OF_SERVICE`` configuration key. Defaults to ``None``. + terms_of_service: Optional[str] = ConfigAttribute('TERMS_OF_SERVICE') # type: ignore def __init__( self, - import_name, - title='APIFlask', - version='0.1.0', - spec_path='/openapi.json', - docs_path='/docs', - docs_oauth2_redirect_path='/docs/oauth2-redirect', - redoc_path='/redoc', - json_errors=True, - enable_openapi=True, - static_url_path=None, - static_folder='static', - static_host=None, - host_matching=False, - subdomain_matching=False, - template_folder='templates', - instance_path=None, - instance_relative_config=False, - root_path=None - ): + import_name: str, + title: str = 'APIFlask', + version: str = '0.1.0', + spec_path: str = '/openapi.json', + docs_path: str = '/docs', + docs_oauth2_redirect_path: str = '/docs/oauth2-redirect', + redoc_path: str = '/redoc', + json_errors: bool = True, + enable_openapi: bool = True, + static_url_path: Optional[str] = None, + static_folder: str = 'static', + static_host: Optional[str] = None, + host_matching: bool = False, + subdomain_matching: bool = False, + template_folder: str = 'templates', + instance_path: Optional[str] = None, + instance_relative_config: bool = False, + root_path: Optional[str] = None + ) -> None: super(APIFlask, self).__init__( import_name, static_url_path=static_url_path, @@ -60,45 +177,48 @@ def __init__( instance_relative_config=instance_relative_config, root_path=root_path ) - _OpenAPIMixin.__init__( - self, - title=title, - version=version, - spec_path=spec_path, - docs_path=docs_path, - docs_oauth2_redirect_path=docs_oauth2_redirect_path, - redoc_path=redoc_path, - enable_openapi=enable_openapi - ) # Set default config self.config.from_object('apiflask.settings') + + self.title = title + self.version = version + self.spec_path = spec_path + self.docs_path = docs_path + self.redoc_path = redoc_path + self.docs_oauth2_redirect_path = docs_oauth2_redirect_path + self.enable_openapi = enable_openapi + self.json_errors = json_errors - self.spec_callback = None - self.error_callback = default_error_handler - self._spec = None + self.spec_callback: Optional[SpecCallbackType] = None # type: ignore + self.error_callback: ErrorCallbackType = default_error_handler # type: ignore + self._spec: Optional[Union[dict, str]] = None self._register_openapi_blueprint() @self.errorhandler(HTTPError) - def handle_http_error(error): + def handle_http_error( + error: HTTPError + ) -> ResponseType: return self.error_callback( error.status_code, error.message, error.detail, - error.headers + error.headers # type: ignore ) if self.json_errors: @self.errorhandler(WerkzeugHTTPException) - def handle_werkzeug_errrors(error): + def handle_werkzeug_errrors( + error: WerkzeugHTTPException + ) -> ResponseType: return self.error_callback( - error.code, + error.code, # type: ignore error.name, - detail=None, - headers=None + detail=None, # type: ignore + headers=None # type: ignore ) - def dispatch_request(self): + def dispatch_request(self) -> ResponseType: """Overwrite the default dispatch method to pass view arguments as positional arguments. With this overwrite, the view function can accept the parameters in a intuitive way (from top to bottom, from left to right):: @@ -127,7 +247,10 @@ def get_pet(name, pet_id, age, query, pet): # otherwise dispatch to the handler for that endpoint return self.view_functions[rule.endpoint](*req.view_args.values()) - def error_processor(self, f): + def error_processor( + self, + f: ErrorCallbackType + ) -> ErrorCallbackType: """Registers a error handler callback function. The callback function will be called when validation error hanppend when @@ -173,3 +296,491 @@ def my_error_handler(status_code, message, detail, headers): """ self.error_callback = f return f + + def _register_openapi_blueprint(self) -> None: + bp = Blueprint( + 'openapi', + __name__, + template_folder='templates', + static_folder='static', + static_url_path='/apiflask' + ) + + if self.spec_path: + @bp.route(self.spec_path) + def spec() -> Union[ + Union[Dict[Any, Any], str], + Tuple[Union[Dict[Any, Any], str], int, Dict[str, str]] + ]: + if self.spec_path.endswith('.yaml') or \ + self.spec_path.endswith('.yml'): + # YAML spec + return self.get_spec('yaml'), 200, \ + {'Content-Type': 'text/vnd.yaml'} + else: + # JSON spec + return self.get_spec('json') + + if self.docs_path: + @bp.route(self.docs_path) + def swagger_ui() -> str: + return render_template('apiflask/swagger_ui.html', + title=self.title, version=self.version) + + if self.docs_oauth2_redirect_path: + @bp.route(self.docs_oauth2_redirect_path) + def swagger_ui_oauth_redirect() -> str: + return render_template('apiflask/swagger_ui_oauth2_redirect.html', + title=self.title, version=self.version) + + if self.redoc_path: + @bp.route(self.redoc_path) + def redoc() -> str: + return render_template('apiflask/redoc.html', + title=self.title, version=self.version) + + if self.enable_openapi and ( + self.spec_path or self.docs_path or self.redoc_path + ): + self.register_blueprint(bp) + + def get_spec(self, spec_format: Optional[str] = None) -> Union[dict, str]: + if spec_format is None: + spec_format = self.config['SPEC_FORMAT'].lower() + if self._spec is None: + if spec_format == 'json': + self._spec = self._generate_spec().to_dict() + else: + self._spec = self._generate_spec().to_yaml() + if self.spec_callback: + self._spec = self.spec_callback(self._spec) + return self._spec + + def spec_processor(self, f: SpecCallbackType) -> SpecCallbackType: + self.spec_callback = f + return f + + @property + def spec(self) -> Union[dict, str]: + return self.get_spec() + + def _generate_spec(self) -> APISpec: + def resolver(schema: MarshmallowSchema) -> str: + name = schema.__class__.__name__ + if name.endswith('Schema'): + name = name[:-6] or name + if schema.partial: + name += 'Update' + return name + + # info object + info: dict = {} + if self.contact: + info['contact'] = self.contact + if self.license: + info['license'] = self.license + if self.terms_of_service: + info['termsOfService'] = self.terms_of_service + if self.description: + info['description'] = self.description + else: + # auto-generate info.description from module doc + if self.config['AUTO_DESCRIPTION']: + module_name = self.import_name + while module_name: + module = sys.modules[module_name] + if module.__doc__: + info['description'] = module.__doc__.strip() + break + if '.' not in module_name: + module_name = '.' + module_name + module_name = module_name.rsplit('.', 1)[0] + + # tags + tags: Optional[Union[List[str], List[Dict[str, Any]]]] = self.tags + if tags is not None: + # Convert simple tags list into standard OpenAPI tags + if isinstance(tags[0], str): + for index, tag in enumerate(tags): + tags[index] = {'name': tag} # type: ignore + else: + tags: List[str] = [] # type: ignore + if self.config['AUTO_TAGS']: + # auto-generate tags from blueprints + for name, blueprint in self.blueprints.items(): + if name == 'openapi' or name in self.config['DOCS_HIDE_BLUEPRINTS']: + continue + if hasattr(blueprint, 'tag') and blueprint.tag is not None: + if isinstance(blueprint.tag, dict): + tag = blueprint.tag + else: + tag = {'name': blueprint.tag} + else: + tag = {'name': name.title()} + module = sys.modules[blueprint.import_name] + if module.__doc__: + tag['description'] = module.__doc__.strip() + tags.append(tag) # type: ignore + + # additional fields + kwargs: dict = {} + if self.servers: + kwargs['servers'] = self.servers + if self.external_docs: + kwargs['externalDocs'] = self.external_docs + + ma_plugin: Type[MarshmallowPlugin] = MarshmallowPlugin(schema_name_resolver=resolver) + spec: Type[APISpec] = APISpec( + title=self.title, + version=self.version, + openapi_version='3.0.3', + plugins=[ma_plugin], + info=info, + tags=tags, + **kwargs + ) + + # configure flask-marshmallow URL types + ma_plugin.converter.field_mapping[fields.URLFor] = ('string', 'url') + ma_plugin.converter.field_mapping[fields.AbsoluteURLFor] = \ + ('string', 'url') + if sqla is not None: # pragma: no cover + ma_plugin.converter.field_mapping[sqla.HyperlinkRelated] = \ + ('string', 'url') + + # security schemes + auth_schemes: List[Union[Type[HTTPBasicAuth], Type[HTTPTokenAuth]]] = [] + auth_names: List[str] = [] + auth_blueprints: Dict[str, Dict[str, Any]] = {} + + def update_auth_schemas_names( + auth: Union[Type[HTTPBasicAuth], Type[HTTPTokenAuth]] + ) -> None: + auth_schemes.append(auth) + if isinstance(auth, HTTPBasicAuth): + name = 'BasicAuth' + elif isinstance(auth, HTTPTokenAuth): + if auth.scheme == 'Bearer' and auth.header is None: + name = 'BearerAuth' + else: + name = 'ApiKeyAuth' + else: + raise RuntimeError('Unknown authentication scheme') + if name in auth_names: + v = 2 + new_name = f'{name}_{v}' + while new_name in auth_names: + v += 1 + new_name = f'{name}_{v}' + name = new_name + auth_names.append(name) + + # detect auth_required on before_request functions + for blueprint_name, funcs in self.before_request_funcs.items(): + for f in funcs: + if hasattr(f, '_spec'): # pragma: no cover + auth = f._spec.get('auth') # type: ignore + if auth is not None and auth not in auth_schemes: + auth_blueprints[blueprint_name] = { # type: ignore + 'auth': auth, + 'roles': f._spec.get('roles') # type: ignore + } + update_auth_schemas_names(auth) + + for rule in self.url_map.iter_rules(): + view_func = self.view_functions[rule.endpoint] + if hasattr(view_func, '_spec'): + auth = view_func._spec.get('auth') + if auth is not None and auth not in auth_schemes: + update_auth_schemas_names(auth) + + security: Dict[Union[Type[HTTPBasicAuth], Type[HTTPTokenAuth]], str] = {} + security_schemes: Dict[str, Dict[str, str]] = {} + for name, auth in zip(auth_names, auth_schemes): + security[auth] = name + if isinstance(auth, HTTPTokenAuth): + if auth.scheme == 'Bearer' and auth.header is None: + security_schemes[name] = { + 'type': 'http', + 'scheme': 'Bearer', + } + else: + security_schemes[name] = { + 'type': 'apiKey', + 'name': auth.header, + 'in': 'header', + } + else: + security_schemes[name] = { + 'type': 'http', + 'scheme': 'Basic', + } + + if hasattr(auth, 'description') and auth.description is not None: + security_schemes[name]['description'] = auth.description + + for name, scheme in security_schemes.items(): + spec.components.security_scheme(name, scheme) + + # paths + paths: Dict[str, Dict[str, Any]] = {} + # rules: List[Any] = list(self.url_map.iter_rules()) + rules: List[Any] = sorted( + list(self.url_map.iter_rules()), key=lambda rule: len(rule.rule) + ) + for rule in rules: + operations: Dict[str, Any] = {} + view_func = self.view_functions[rule.endpoint] + # skip endpoints from openapi blueprint and the built-in static endpoint + if rule.endpoint.startswith('openapi') or \ + rule.endpoint.startswith('static'): + continue + # skip endpoints from blueprints in config DOCS_HIDE_BLUEPRINTS list + blueprint_name: Optional[str] = None # type: ignore + if '.' in rule.endpoint: + blueprint_name = rule.endpoint.split('.', 1)[0] + if blueprint_name in self.config['DOCS_HIDE_BLUEPRINTS']: + continue + # add a default 200 response for bare views + default_response = {'schema': {}, 'status_code': 200, 'description': None} + if not hasattr(view_func, '_spec'): + if self.config['AUTO_200_RESPONSE']: + view_func._spec = {'response': default_response} + else: + continue # pragma: no cover + # skip views flagged with @doc(hide=True) + if view_func._spec.get('hide'): + continue + + # tag + operation_tags: Optional[Union[str, List[str]]] = None + if view_func._spec.get('tags'): + operation_tags = view_func._spec.get('tags') + else: + # if tag not set, try to use blueprint name as tag + if self.tags is None and self.config['AUTO_TAGS'] and blueprint_name is not None: + blueprint = self.blueprints[blueprint_name] + if hasattr(blueprint, 'tag') and blueprint.tag is not None: + if isinstance(blueprint.tag, dict): + operation_tags = blueprint.tag['name'] + else: + operation_tags = blueprint.tag + else: + operation_tags = blueprint_name.title() + + for method in ['GET', 'POST', 'PUT', 'PATCH', 'DELETE']: + if method not in rule.methods: + continue + operation: Dict[str, Any] = { + 'parameters': [ + {'in': location, 'schema': schema} + for schema, location in view_func._spec.get('args', []) + ], + 'responses': {}, + } + if operation_tags: + if isinstance(operation_tags, list): + operation['tags'] = operation_tags + else: + operation['tags'] = [operation_tags] + + # summary + if view_func._spec.get('summary'): + operation['summary'] = view_func._spec.get('summary') + else: + # auto-generate summary from dotstring or view function name + if self.config['AUTO_PATH_SUMMARY']: + docs = (view_func.__doc__ or '').strip().split('\n') + if docs[0]: + # Use the first line of docstring as summary + operation['summary'] = docs[0] + else: + # Use the function name as summary + operation['summary'] = ' '.join( + view_func.__name__.split('_')).title() + + # description + if view_func._spec.get('description'): + operation['description'] = view_func._spec.get('description') + else: + # auto-generate description from dotstring + if self.config['AUTO_PATH_DESCRIPTION']: + docs = (view_func.__doc__ or '').strip().split('\n') + if len(docs) > 1: + # Use the remain lines of docstring as description + operation['description'] = '\n'.join(docs[1:]).strip() + + # deprecated + if view_func._spec.get('deprecated'): + operation['deprecated'] = view_func._spec.get('deprecated') + + # responses + descriptions: Dict[str, str] = { + '200': self.config['DEFAULT_200_DESCRIPTION'], + '201': self.config['DEFAULT_201_DESCRIPTION'], + '204': self.config['DEFAULT_204_DESCRIPTION'], + } + + def add_response( + status_code: str, + schema: Union[MarshmallowSchema, dict], + description: str + ) -> None: + operation['responses'][status_code] = { + 'content': { + 'application/json': { + 'schema': schema + } + } + } + operation['responses'][status_code]['description'] = description + + if view_func._spec.get('response'): + status_code: str = str(view_func._spec.get('response')['status_code']) + schema = view_func._spec.get('response')['schema'] + description: str = view_func._spec.get('response')['description'] or \ + descriptions.get(status_code, self.config['DEFAULT_2XX_DESCRIPTION']) + add_response(status_code, schema, description) + else: + # add a default 200 response for views without using @output + # or @doc(responses={...}) + if not view_func._spec.get('responses') and self.config['AUTO_200_RESPONSE']: + add_response('200', {}, descriptions['200']) + + def add_response_and_schema( + status_code: str, + schema: Union[MarshmallowSchema, dict], + schema_name: str, + description: str + ) -> None: + if isinstance(schema, type): + schema = schema() + add_response(status_code, schema, description) + elif isinstance(schema, dict): + if schema_name not in spec.components.schemas: + spec.components.schema(schema_name, schema) + schema_ref = {'$ref': f'#/components/schemas/{schema_name}'} + add_response(status_code, schema_ref, description) + else: + raise RuntimeError( + 'The schema must be a Marshamallow schema \ + class or an OpenAPI schema dict.' + ) + + # add validation error response + if self.config['AUTO_VALIDATION_ERROR_RESPONSE']: + if view_func._spec.get('body') or view_func._spec.get('args'): + status_code: str = str( # type: ignore + self.config['VALIDATION_ERROR_STATUS_CODE'] + ) + description: str = self.config[ # type: ignore + 'VALIDATION_ERROR_DESCRIPTION' + ] + schema: Union[ # type: ignore + MarshmallowSchema, dict + ] = self.config['VALIDATION_ERROR_SCHEMA'] + add_response_and_schema( + status_code, schema, 'ValidationError', description + ) + + # add authorization error response + if self.config['AUTO_AUTH_ERROR_RESPONSE']: + if view_func._spec.get('auth') or ( + blueprint_name is not None and blueprint_name in auth_blueprints + ): + status_code: str = str( # type: ignore + self.config['AUTH_ERROR_STATUS_CODE'] + ) + description: str = self.config['AUTH_ERROR_DESCRIPTION'] # type: ignore + schema: Union[ # type: ignore + MarshmallowSchema, dict + ] = self.config['AUTH_ERROR_SCHEMA'] + add_response_and_schema( + status_code, schema, 'AuthorizationError', description + ) + + if view_func._spec.get('responses'): + responses: Union[List[int], Dict[int, str]] \ + = view_func._spec.get('responses') + if isinstance(responses, list): + responses: Dict[int, str] = {} # type: ignore + for status_code in view_func._spec.get('responses'): + responses[ # type: ignore + status_code + ] = get_error_message(int(status_code)) + for status_code, description in responses.items(): # type: ignore + status_code: str = str(status_code) # type: ignore + if status_code in operation['responses']: + continue + if self.config['AUTO_HTTP_ERROR_RESPONSE'] and ( + status_code.startswith('4') or + status_code.startswith('5') # type: ignore + ): + schema: Union[ # type: ignore + MarshmallowSchema, dict + ] = self.config['HTTP_ERROR_SCHEMA'] + add_response_and_schema( + status_code, schema, 'HTTPError', description + ) + else: + add_response(status_code, {}, description) + + # requestBody + if view_func._spec.get('body'): + operation['requestBody'] = { + 'content': { + 'application/json': { + 'schema': view_func._spec['body'], + } + } + } + + # security + if blueprint_name is not None and blueprint_name in auth_blueprints: + operation['security'] = [{ + security[auth_blueprints[blueprint_name]['auth']]: + auth_blueprints[blueprint_name]['roles'] + }] + + if view_func._spec.get('auth'): + operation['security'] = [{ + security[view_func._spec['auth']]: view_func._spec['roles'] + }] + + operations[method.lower()] = operation + + # parameters + path_arguments: Iterable = re.findall(r'<(([^<:]+:)?([^>]+))>', rule.rule) + if path_arguments: + arguments: List[Dict[str, str]] = [] + for _, argument_type, argument_name in path_arguments: + argument = { + 'in': 'path', + 'name': argument_name, + } + if argument_type == 'int:': + argument['schema'] = {'type': 'integer'} + elif argument_type == 'float:': + argument['schema'] = {'type': 'number'} + else: + argument['schema'] = {'type': 'string'} + arguments.append(argument) + + for method, operation in operations.items(): + operation['parameters'] = arguments + operation['parameters'] + + path: str = re.sub(r'<([^<:]+:)?', '{', rule.rule).replace('>', '}') + if path not in paths: + paths[path] = operations + else: + paths[path].update(operations) + + for path, operations in paths.items(): + # sort by method before adding them to the spec + sorted_operations: Dict[str, Any] = {} + for method in ['get', 'post', 'put', 'patch', 'delete']: + if method in operations: + sorted_operations[method] = operations[method] + spec.path(path=path, operations=sorted_operations) + + return spec diff --git a/apiflask/blueprint.py b/apiflask/blueprint.py index 7dd87b48..d30683d9 100644 --- a/apiflask/blueprint.py +++ b/apiflask/blueprint.py @@ -1,10 +1,12 @@ -from flask import Blueprint as BaseBlueprint +from typing import Optional, Union +from flask import Blueprint -from .scaffold import Scaffold -from .scaffold import _sentinel +from .utils import route_shortcuts +from .utils import _sentinel -class APIBlueprint(BaseBlueprint, Scaffold): +@route_shortcuts +class APIBlueprint(Blueprint): """Flask's Blueprint with some API support. .. versionadded:: 0.2.0 @@ -12,18 +14,18 @@ class APIBlueprint(BaseBlueprint, Scaffold): def __init__( self, - name, - import_name, - tag=None, - static_folder=None, - static_url_path=None, - template_folder=None, - url_prefix=None, - subdomain=None, - url_defaults=None, - root_path=None, - cli_group=_sentinel, - ): + name: str, + import_name: str, + tag: Optional[Union[str, dict]] = None, + static_folder: Optional[str] = None, + static_url_path: Optional[str] = None, + template_folder: Optional[str] = None, + url_prefix: Optional[str] = None, + subdomain: Optional[str] = None, + url_defaults: Optional[dict] = None, + root_path: Optional[str] = None, + cli_group: Union[Optional[str]] = _sentinel # type: ignore + ) -> None: super(APIBlueprint, self).__init__( name, import_name, diff --git a/apiflask/decorators.py b/apiflask/decorators.py index 33ccf3d4..6eddae4f 100644 --- a/apiflask/decorators.py +++ b/apiflask/decorators.py @@ -1,45 +1,73 @@ +from typing import Callable, Union, List, Optional, Dict, Any, Type, Mapping from functools import wraps -from flask import Response, jsonify, current_app +from flask import Response +from flask import jsonify +from flask import current_app from webargs.flaskparser import FlaskParser as BaseFlaskParser +from marshmallow import ValidationError as MarshmallowValidationError +from marshmallow import Schema as MarshmallowSchema from .errors import ValidationError -from .scaffold import _sentinel +from .utils import _sentinel from .schemas import EmptySchema +from .security import HTTPBasicAuth +from .security import HTTPTokenAuth +from .types import DecoratedType +from .types import ResponseType +from .types import RequestType class FlaskParser(BaseFlaskParser): - def handle_error(self, error, req, schema, *, error_status_code, - error_headers): + def handle_error( # type: ignore + self, + error: MarshmallowValidationError, + req: RequestType, + schema: MarshmallowSchema, + *, + error_status_code: int, + error_headers: Mapping[str, str] + ) -> None: raise ValidationError( error_status_code or current_app.config['VALIDATION_ERROR_STATUS_CODE'], current_app.config['VALIDATION_ERROR_DESCRIPTION'], - error.messages) + error.messages, + error_headers + ) -parser = FlaskParser() -use_args = parser.use_args +parser: FlaskParser = FlaskParser() +use_args: Callable = parser.use_args -def _annotate(f, **kwargs): +def _annotate(f: Any, **kwargs: Any) -> None: if not hasattr(f, '_spec'): f._spec = {} for key, value in kwargs.items(): f._spec[key] = value -def auth_required(auth, **kwargs): +def auth_required( + auth: Union[Type[HTTPBasicAuth], Type[HTTPTokenAuth]], + role: Optional[Union[list, str]] = None, + optional: Optional[str] = None +) -> Callable[[DecoratedType], DecoratedType]: + roles = role + if not isinstance(role, list): # pragma: no cover + roles = [role] if role is not None else [] + def decorator(f): - roles = kwargs.get('role') - if not isinstance(roles, list): # pragma: no cover - roles = [roles] if roles is not None else [] _annotate(f, auth=auth, roles=roles) - return auth.login_required(**kwargs)(f) + return auth.login_required(role=role, optional=optional)(f) return decorator -def input(schema, location='json', **kwargs): +def input( + schema: MarshmallowSchema, + location: str = 'json', + **kwargs: Any +) -> Callable[[DecoratedType], DecoratedType]: if isinstance(schema, type): # pragma: no cover schema = schema() @@ -64,7 +92,11 @@ def decorator(f): return decorator -def output(schema, status_code=200, description=None): +def output( + schema: MarshmallowSchema, + status_code: int = 200, + description: Optional[str] = None +) -> Callable[[DecoratedType], DecoratedType]: if isinstance(schema, type): # pragma: no cover schema = schema() @@ -86,7 +118,7 @@ def _jsonify(obj, many=_sentinel, *args, **kwargs): # pragma: no cover return jsonify(data, *args, **kwargs) @wraps(f) - def _response(*args, **kwargs): + def _response(*args: Any, **kwargs: Any) -> ResponseType: rv = f(*args, **kwargs) if isinstance(rv, Response): # pragma: no cover raise RuntimeError( @@ -110,30 +142,34 @@ def _response(*args, **kwargs): def doc( - summary=None, - description=None, - tags=None, - responses=None, - deprecated=None, - hide=False -): + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Union[List[str], List[Dict[str, Any]]]] = None, + responses: Optional[Union[List[int], Dict[int, str]]] = None, + deprecated: Optional[bool] = False, + hide: Optional[bool] = False +) -> Callable[[DecoratedType], DecoratedType]: """ Set up OpenAPI documentation for view function. - :param summary: The summary of this view function. If not set, the name of - the view function will be used. If your view function named with ``get_pet``, - then the summary will be "Get Pet". If the view function has docstring, then - the first line of the docstring will be used. The precedence will be: - @doc(summary='blah') > the frist line of docstring > the view functino name - :param description: The description of this view function. If not set, the lines - after the empty line of the docstring will be used. - :param tags: The tag list of this view function, map the tags you passed in the :attr:`tags` - attribute. You can pass a list of tag names or just a single tag string. If ``app.tags`` - not set, the blueprint name will be used as tag name. - :param responses: The other responses for this view function, accept a dict in a format - of ``{400: 'Bad Request'}``. - :param deprecated: Flag this endpoint as deprecated in API docs. Defaults to ``None``. - :param hide: Hide this endpoint in API docs. Defaults to ``False``. + Arguments: + summary: The summary of this view function. If not set, the name of + the view function will be used. If your view function named with `get_pet`, + then the summary will be "Get Pet". If the view function has docstring, then + the first line of the docstring will be used. The precedence will be: + @doc(summary='blah') > the frist line of docstring > the view functino name + description: The description of this view function. If not set, the lines + after the empty line of the docstring will be used. + tags: The tag list of this view function, map the tags you passed in the :attr:`tags` + attribute. You can pass a list of tag names or just a single tag string. If `app.tags` + not set, the blueprint name will be used as tag name. + responses: The other responses for this view function, accept a dict in a format + of `{400: 'Bad Request'}`. + deprecated: Flag this endpoint as deprecated in API docs. Defaults to `False`. + hide: Hide this endpoint in API docs. Defaults to `False`. + + .. versionchanged:: 0.3.0 + Change the default value of deprecated from `None` to `False`. .. versionadded:: 0.2.0 """ diff --git a/apiflask/errors.py b/apiflask/errors.py index adc9bfb9..e8390777 100644 --- a/apiflask/errors.py +++ b/apiflask/errors.py @@ -1,9 +1,17 @@ +from typing import Any, Optional, Mapping, Union, Tuple + from werkzeug.http import HTTP_STATUS_CODES class HTTPError(Exception): - def __init__(self, status_code, message=None, detail=None, headers=None): + def __init__( + self, + status_code: int, + message: Optional[str] = None, + detail: Optional[Any] = None, + headers: Optional[Mapping[str, str]] = None + ) -> None: super(HTTPError, self).__init__() self.status_code = status_code self.detail = detail @@ -19,15 +27,25 @@ class ValidationError(HTTPError): pass -def api_abort(status_code, message=None, detail=None, headers=None): +def api_abort( + status_code: int, + message: Optional[str] = None, + detail: Optional[Any] = None, + headers: Optional[Mapping[str, str]] = None +) -> None: raise HTTPError(status_code, message, detail, headers) -def get_error_message(status_code): +def get_error_message(status_code: int) -> str: return HTTP_STATUS_CODES.get(status_code, 'Unknown error') -def default_error_handler(status_code, message=None, detail=None, headers=None): +def default_error_handler( + status_code: int, + message: Optional[str] = None, + detail: Optional[Any] = None, + headers: Optional[Mapping[str, str]] = None +) -> Union[Tuple[dict, int], Tuple[dict, int, Mapping[str, str]]]: if message is None: message = get_error_message(status_code) if detail is None: diff --git a/apiflask/fields.py b/apiflask/fields.py index de3708e7..a1f1abba 100644 --- a/apiflask/fields.py +++ b/apiflask/fields.py @@ -1,7 +1,33 @@ -from marshmallow import fields # noqa: F401 -from marshmallow.fields import (Field, Raw, Nested, Mapping, Dict, List, Tuple, # noqa: F401 - String, UUID, Number, Integer, Decimal, Boolean, # noqa: F401 - Float, DateTime, NaiveDateTime, AwareDateTime, # noqa: F401 - Time, Date, TimeDelta, URL, Email, IP, IPv4, # noqa: F401 - IPv6, Method, Function, Constant, Pluck) # noqa: F401 -from flask_marshmallow.fields import URLFor, AbsoluteURLFor, Hyperlinks # noqa: F401 +# flake8: noqa +from marshmallow.fields import Field +from marshmallow.fields import Raw +from marshmallow.fields import Nested +from marshmallow.fields import Mapping +from marshmallow.fields import Dict +from marshmallow.fields import List +from marshmallow.fields import Tuple +from marshmallow.fields import String +from marshmallow.fields import UUID +from marshmallow.fields import Number +from marshmallow.fields import Integer +from marshmallow.fields import Decimal +from marshmallow.fields import Boolean +from marshmallow.fields import Float +from marshmallow.fields import DateTime +from marshmallow.fields import NaiveDateTime +from marshmallow.fields import AwareDateTime +from marshmallow.fields import Time +from marshmallow.fields import Date +from marshmallow.fields import TimeDelta +from marshmallow.fields import URL +from marshmallow.fields import Email +from marshmallow.fields import IP +from marshmallow.fields import IPv4 +from marshmallow.fields import IPv6 +from marshmallow.fields import Method +from marshmallow.fields import Function +from marshmallow.fields import Constant +from marshmallow.fields import Pluck +from flask_marshmallow.fields import URLFor +from flask_marshmallow.fields import Hyperlinks +from flask_marshmallow.fields import AbsoluteURLFor diff --git a/apiflask/openapi.py b/apiflask/openapi.py deleted file mode 100644 index 1cea61f3..00000000 --- a/apiflask/openapi.py +++ /dev/null @@ -1,588 +0,0 @@ -import re -import sys - -from flask import Blueprint -from flask import render_template -from flask.config import ConfigAttribute -from apispec import APISpec -from apispec.ext.marshmallow import MarshmallowPlugin -from flask_marshmallow import fields -try: - from flask_marshmallow import sqla -except ImportError: - sqla = None - -from .security import HTTPBasicAuth, HTTPTokenAuth -from .errors import get_error_message - - -class _OpenAPIMixin: - #: The title of the API (openapi.info.title), defaults to "APIFlask". - #: You can change it to the name of your API (e.g. "Pet API"). - title = None - - #: The version of the API (openapi.info.version), defaults to "1.0.0". - version = None - - #: The description of the API (openapi.info.description). - #: - #: This attribute can also be configured from the config with the - #: ``DESCRIPTION`` configuration key. Defaults to ``None``. - description = ConfigAttribute('DESCRIPTION') - - #: The tags of the OpenAPI spec documentation (openapi.tags), accepts a - #: list of dicts. - #: You can also pass a simple list contains the tag name:: - #: - #: app.tags = ['foo', 'bar', 'baz'] - #: - #: A standard OpenAPI tags list will look like this:: - #: - #: app.tags = [ - #: {'name': 'foo', 'description': 'The description of foo'}, - #: {'name': 'bar', 'description': 'The description of bar'}, - #: {'name': 'baz', 'description': 'The description of baz'} - #: ] - #: - #: If not set, the blueprint names will be used as tags. - #: - #: This attribute can also be configured from the config with the - #: ``TAGS`` configuration key. Defaults to ``None``. - tags = ConfigAttribute('TAGS') - - #: The contact information of the API (openapi.info.contact). - #: Example value: - #: - #: app.contact = { - #: 'name': 'API Support', - #: 'url': 'http://www.example.com/support', - #: 'email': 'support@example.com' - #: } - #: - #: This attribute can also be configured from the config with the - #: ``CONTACT`` configuration key. Defaults to ``None``. - contact = ConfigAttribute('CONTACT') - - #: The license of the API (openapi.info.license). - #: Example value: - #: - #: app.license = { - #: 'name': 'Apache 2.0', - #: 'url': 'http://www.apache.org/licenses/LICENSE-2.0.html' - #: } - #: - #: This attribute can also be configured from the config with the - #: ``LICENSE`` configuration key. Defaults to ``None``. - license = ConfigAttribute('LICENSE') - - #: The servers information of the API (openapi.servers), accepts multiple - #: server dicts. - #: Example value: - #: - #: app.servers = [ - #: { - #: 'name': 'Production Server', - #: 'url': 'http://api.example.com' - #: } - #: ] - #: - #: This attribute can also be configured from the config with the - #: ``SERVERS`` configuration key. Defaults to ``None``. - servers = ConfigAttribute('SERVERS') - - #: The external documentation information of the API (openapi.externalDocs). - #: Example value: - #: - #: app.external_docs = { - #: 'description': 'Find more info here', - #: 'url': 'http://docs.example.com' - #: } - #: - #: This attribute can also be configured from the config with the - #: ``EXTERNAL_DOCS`` configuration key. Defaults to ``None``. - external_docs = ConfigAttribute('EXTERNAL_DOCS') - - #: The terms of service URL of the API (openapi.info.termsOfService). - #: Example value: - #: - #: app.terms_of_service = "http://example.com/terms/" - #: - #: This attribute can also be configured from the config with the - #: ``TERMS_OF_SERVICE`` configuration key. Defaults to ``None``. - terms_of_service = ConfigAttribute('TERMS_OF_SERVICE') - - def __init__( - self, - title, - version, - spec_path, - docs_path, - redoc_path, - docs_oauth2_redirect_path, - enable_openapi - ): - self.title = title - self.version = version - self.spec_path = spec_path - self.docs_path = docs_path - self.redoc_path = redoc_path - self.docs_oauth2_redirect_path = docs_oauth2_redirect_path - self.enable_openapi = enable_openapi - - def _register_openapi_blueprint(self): - bp = Blueprint( - 'openapi', - __name__, - template_folder='templates', - static_folder='static', - static_url_path='/apiflask' - ) - - if self.spec_path: - @bp.route(self.spec_path) - def spec(): - if self.spec_path.endswith('.yaml') or \ - self.spec_path.endswith('.yml'): - # YAML spec - return self.get_spec('yaml'), 200, \ - {'Content-Type': 'text/vnd.yaml'} - else: - # JSON spec - return self.get_spec('json') - - if self.docs_path: - @bp.route(self.docs_path) - def swagger_ui(): - return render_template('apiflask/swagger_ui.html', - title=self.title, version=self.version) - - if self.docs_oauth2_redirect_path: - @bp.route(self.docs_oauth2_redirect_path) - def swagger_ui_oauth_redirect(): - return render_template('apiflask/swagger_ui_oauth2_redirect.html', - title=self.title, version=self.version) - - if self.redoc_path: - @bp.route(self.redoc_path) - def redoc(): - return render_template('apiflask/redoc.html', - title=self.title, version=self.version) - - if self.enable_openapi and ( - self.spec_path or self.docs_path or self.redoc_path - ): - self.register_blueprint(bp) - - def get_spec(self, spec_format=None): - if spec_format is None: - spec_format = self.config['SPEC_FORMAT'].lower() - if self._spec is None: - if spec_format == 'json': - self._spec = self._generate_spec().to_dict() - else: - self._spec = self._generate_spec().to_yaml() - if self.spec_callback: - self._spec = self.spec_callback(self._spec) - return self._spec - - def spec_processor(self, f): - self.spec_callback = f - return f - - @property - def spec(self): - return self.get_spec() - - def _generate_spec(self): - def resolver(schema): - name = schema.__class__.__name__ - if name.endswith('Schema'): - name = name[:-6] or name - if schema.partial: - name += 'Update' - return name - - # info object - info = {} - if self.contact: - info['contact'] = self.contact - if self.license: - info['license'] = self.license - if self.terms_of_service: - info['termsOfService'] = self.terms_of_service - if self.description: - info['description'] = self.description - else: - # auto-generate info.description from module doc - if self.config['AUTO_DESCRIPTION']: - module_name = self.import_name - while module_name: - module = sys.modules[module_name] - if module.__doc__: - info['description'] = module.__doc__.strip() - break - if '.' not in module_name: - module_name = '.' + module_name - module_name = module_name.rsplit('.', 1)[0] - - # tags - tags = self.tags - if tags is not None: - # Convert simple tags list into standard OpenAPI tags - if isinstance(tags[0], str): - for index, tag in enumerate(tags): - tags[index] = {'name': tag} - else: - tags = [] - if self.config['AUTO_TAGS']: - # auto-generate tags from blueprints - for name, blueprint in self.blueprints.items(): - if name == 'openapi' or name in self.config['DOCS_HIDE_BLUEPRINTS']: - continue - if hasattr(blueprint, 'tag') and blueprint.tag is not None: - if isinstance(blueprint.tag, dict): - tag = blueprint.tag - else: - tag = {'name': blueprint.tag} - else: - tag = {'name': name.title()} - module = sys.modules[blueprint.import_name] - if module.__doc__: - tag['description'] = module.__doc__.strip() - tags.append(tag) - - # additional fields - kwargs = {} - if self.servers: - kwargs['servers'] = self.servers - if self.external_docs: - kwargs['externalDocs'] = self.external_docs - - ma_plugin = MarshmallowPlugin(schema_name_resolver=resolver) - spec = APISpec( - title=self.title, - version=self.version, - openapi_version='3.0.3', - plugins=[ma_plugin], - info=info, - tags=tags, - **kwargs - ) - - # configure flask-marshmallow URL types - ma_plugin.converter.field_mapping[fields.URLFor] = ('string', 'url') - ma_plugin.converter.field_mapping[fields.AbsoluteURLFor] = \ - ('string', 'url') - if sqla is not None: # pragma: no cover - ma_plugin.converter.field_mapping[sqla.HyperlinkRelated] = \ - ('string', 'url') - - # security schemes - auth_schemes = [] - auth_names = [] - auth_blueprints = {} - - def update_auth_schemas_names(auth): - auth_schemes.append(auth) - if isinstance(auth, HTTPBasicAuth): - name = 'BasicAuth' - elif isinstance(auth, HTTPTokenAuth): - if auth.scheme == 'Bearer' and auth.header is None: - name = 'BearerAuth' - else: - name = 'ApiKeyAuth' - else: - raise RuntimeError('Unknown authentication scheme') - if name in auth_names: - v = 2 - new_name = f'{name}_{v}' - while new_name in auth_names: - v += 1 - new_name = f'{name}_{v}' - name = new_name - auth_names.append(name) - - # detect auth_required on before_request functions - for blueprint_name, funcs in self.before_request_funcs.items(): - for f in funcs: - if hasattr(f, '_spec'): # pragma: no cover - auth = f._spec.get('auth') - if auth is not None and auth not in auth_schemes: - auth_blueprints[blueprint_name] = { - 'auth': auth, - 'roles': f._spec.get('roles') - } - update_auth_schemas_names(auth) - - for rule in self.url_map.iter_rules(): - view_func = self.view_functions[rule.endpoint] - if hasattr(view_func, '_spec'): - auth = view_func._spec.get('auth') - if auth is not None and auth not in auth_schemes: - update_auth_schemas_names(auth) - - security = {} - security_schemes = {} - for name, auth in zip(auth_names, auth_schemes): - security[auth] = name - if isinstance(auth, HTTPTokenAuth): - if auth.scheme == 'Bearer' and auth.header is None: - security_schemes[name] = { - 'type': 'http', - 'scheme': 'Bearer', - } - else: - security_schemes[name] = { - 'type': 'apiKey', - 'name': auth.header, - 'in': 'header', - } - else: - security_schemes[name] = { - 'type': 'http', - 'scheme': 'Basic', - } - - if hasattr(auth, 'description') and auth.description is not None: - security_schemes[name]['description'] = auth.description - - for name, scheme in security_schemes.items(): - spec.components.security_scheme(name, scheme) - - # paths - paths = {} - rules = list(self.url_map.iter_rules()) - rules = sorted(rules, key=lambda rule: len(rule.rule)) - for rule in rules: - operations = {} - view_func = self.view_functions[rule.endpoint] - # skip endpoints from openapi blueprint and the built-in static endpoint - if rule.endpoint.startswith('openapi') or \ - rule.endpoint.startswith('static'): - continue - # skip endpoints from blueprints in config DOCS_HIDE_BLUEPRINTS list - if '.' in rule.endpoint: - blueprint_name = rule.endpoint.split('.', 1)[0] - if blueprint_name in self.config['DOCS_HIDE_BLUEPRINTS']: - continue - else: - blueprint_name = None - # add a default 200 response for bare views - default_response = {'schema': {}, 'status_code': 200, 'description': None} - if not hasattr(view_func, '_spec'): - if self.config['AUTO_200_RESPONSE']: - view_func._spec = {'response': default_response} - else: - continue # pragma: no cover - # skip views flagged with @doc(hide=True) - if view_func._spec.get('hide'): - continue - - # tag - tags = None - if view_func._spec.get('tags'): - tags = view_func._spec.get('tags') - else: - # if tag not set, try to use blueprint name as tag - if self.tags is None and self.config['AUTO_TAGS']: - if blueprint_name is not None: - blueprint = self.blueprints[blueprint_name] - if hasattr(blueprint, 'tag') and blueprint.tag is not None: - if isinstance(blueprint.tag, dict): - tags = blueprint.tag['name'] - else: - tags = blueprint.tag - else: - tags = blueprint_name.title() - - for method in ['GET', 'POST', 'PUT', 'PATCH', 'DELETE']: - if method not in rule.methods: - continue - operation = { - 'parameters': [ - {'in': location, 'schema': schema} - for schema, location in view_func._spec.get('args', []) - ], - 'responses': {}, - } - if tags: - if isinstance(tags, list): - operation['tags'] = tags - else: - operation['tags'] = [tags] - - # summary - if view_func._spec.get('summary'): - operation['summary'] = view_func._spec.get('summary') - else: - # auto-generate summary from dotstring or view function name - if self.config['AUTO_PATH_SUMMARY']: - docs = (view_func.__doc__ or '').strip().split('\n') - if docs[0]: - # Use the first line of docstring as summary - operation['summary'] = docs[0] - else: - # Use the function name as summary - operation['summary'] = ' '.join( - view_func.__name__.split('_')).title() - - # description - if view_func._spec.get('description'): - operation['description'] = view_func._spec.get('description') - else: - # auto-generate description from dotstring - if self.config['AUTO_PATH_DESCRIPTION']: - docs = (view_func.__doc__ or '').strip().split('\n') - if len(docs) > 1: - # Use the remain lines of docstring as description - operation['description'] = '\n'.join(docs[1:]).strip() - - # deprecated - if view_func._spec.get('deprecated'): - operation['deprecated'] = view_func._spec.get('deprecated') - - # responses - descriptions = { - '200': self.config['DEFAULT_200_DESCRIPTION'], - '201': self.config['DEFAULT_201_DESCRIPTION'], - '204': self.config['DEFAULT_204_DESCRIPTION'], - } - - def add_response(status_code, schema, description): - operation['responses'][status_code] = { - 'content': { - 'application/json': { - 'schema': schema - } - } - } - operation['responses'][status_code]['description'] = description - - if view_func._spec.get('response'): - status_code = str(view_func._spec.get('response')['status_code']) - schema = view_func._spec.get('response')['schema'] - description = view_func._spec.get('response')['description'] or \ - descriptions.get(status_code, self.config['DEFAULT_2XX_DESCRIPTION']) - add_response(status_code, schema, description) - else: - # add a default 200 response for views without using @output - # or @doc(responses={...}) - if not view_func._spec.get('responses') and self.config['AUTO_200_RESPONSE']: - add_response('200', {}, descriptions['200']) - - def add_response_and_schema(status_code, schema, schema_name, description): - if isinstance(schema, type): - schema = schema() - add_response(status_code, schema, description) - elif isinstance(schema, dict): - if schema_name not in spec.components.schemas: - spec.components.schema(schema_name, schema) - schema_ref = {'$ref': f'#/components/schemas/{schema_name}'} - add_response(status_code, schema_ref, description) - else: - raise RuntimeError( - 'The schema must be a Marshamallow schema \ - class or an OpenAPI schema dict.' - ) - - # add validation error response - if self.config['AUTO_VALIDATION_ERROR_RESPONSE']: - if view_func._spec.get('body') or view_func._spec.get('args'): - status_code = str(self.config['VALIDATION_ERROR_STATUS_CODE']) - description = self.config['VALIDATION_ERROR_DESCRIPTION'] - schema = self.config['VALIDATION_ERROR_SCHEMA'] - add_response_and_schema( - status_code, schema, 'ValidationError', description - ) - - # add authorization error response - if self.config['AUTO_AUTH_ERROR_RESPONSE']: - if view_func._spec.get('auth') or ( - blueprint_name is not None and blueprint_name in auth_blueprints - ): - status_code = str(self.config['AUTH_ERROR_STATUS_CODE']) - description = self.config['AUTH_ERROR_DESCRIPTION'] - schema = self.config['AUTH_ERROR_SCHEMA'] - add_response_and_schema( - status_code, schema, 'AuthorizationError', description - ) - - if view_func._spec.get('responses'): - responses = view_func._spec.get('responses') - if isinstance(responses, list): - responses = {} - for status_code in view_func._spec.get('responses'): - responses[status_code] = get_error_message(status_code) - for status_code, description in responses.items(): - status_code = str(status_code) - if status_code in operation['responses']: - continue - if self.config['AUTO_HTTP_ERROR_RESPONSE'] and ( - status_code.startswith('4') or status_code.startswith('5') - ): - schema = self.config['HTTP_ERROR_SCHEMA'] - add_response_and_schema( - status_code, schema, 'HTTPError', description - ) - else: - add_response(status_code, {}, description) - - # requestBody - if view_func._spec.get('body'): - operation['requestBody'] = { - 'content': { - 'application/json': { - 'schema': view_func._spec['body'], - } - } - } - - # security - if blueprint_name is not None and blueprint_name in auth_blueprints: - operation['security'] = [{ - security[auth_blueprints[blueprint_name]['auth']]: - auth_blueprints[blueprint_name]['roles'] - }] - - if view_func._spec.get('auth'): - operation['security'] = [{ - security[view_func._spec['auth']]: view_func._spec['roles'] - }] - - operations[method.lower()] = operation - - # parameters - path_arguments = re.findall(r'<(([^<:]+:)?([^>]+))>', rule.rule) - if path_arguments: - arguments = [] - for _, argument_type, argument_name in path_arguments: - argument = { - 'in': 'path', - 'name': argument_name, - } - if argument_type == 'int:': - argument['schema'] = {'type': 'integer'} - elif argument_type == 'float:': - argument['schema'] = {'type': 'number'} - else: - argument['schema'] = {'type': 'string'} - arguments.append(argument) - - for method, operation in operations.items(): - operation['parameters'] = arguments + operation['parameters'] - - path = re.sub(r'<([^<:]+:)?', '{', rule.rule).replace('>', '}') - if path not in paths: - paths[path] = operations - else: - paths[path].update(operations) - - for path, operations in paths.items(): - # sort by method before adding them to the spec - sorted_operations = {} - for method in ['get', 'post', 'put', 'patch', 'delete']: - if method in operations: - sorted_operations[method] = operations[method] - spec.path(path=path, operations=sorted_operations) - - return spec diff --git a/apiflask/py.typed b/apiflask/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/apiflask/scaffold.py b/apiflask/scaffold.py deleted file mode 100644 index d685a9ed..00000000 --- a/apiflask/scaffold.py +++ /dev/null @@ -1,43 +0,0 @@ -_sentinel = object() - - -class Scaffold: - """Base object for APIFlask and Blueprint. - - .. versionadded:: 0.2.0 - """ - # TODO Remove these shortcuts when pin Flask>=2.0 - def get(self, rule, **options): - """Shortcut for ``app.route()``. - - .. versionadded:: 0.2.0 - """ - return self.route(rule, methods=['GET'], **options) - - def post(self, rule, **options): - """Shortcut for ``app.route(methods=['POST'])``. - - .. versionadded:: 0.2.0 - """ - return self.route(rule, methods=['POST'], **options) - - def put(self, rule, **options): - """Shortcut for ``app.route(methods=['PUT'])``. - - .. versionadded:: 0.2.0 - """ - return self.route(rule, methods=['PUT'], **options) - - def patch(self, rule, **options): - """Shortcut for ``app.route(methods=['PATCH'])``. - - .. versionadded:: 0.2.0 - """ - return self.route(rule, methods=['PATCH'], **options) - - def delete(self, rule, **options): - """Shortcut for ``app.route(methods=['DELETE'])``. - - .. versionadded:: 0.2.0 - """ - return self.route(rule, methods=['DELETE'], **options) diff --git a/apiflask/schemas.py b/apiflask/schemas.py index c34f02ba..3e2af742 100644 --- a/apiflask/schemas.py +++ b/apiflask/schemas.py @@ -1,7 +1,9 @@ +from typing import Dict, Any + from flask_marshmallow import Schema -validation_error_detail_schema = { +validation_error_detail_schema: Dict[str, Any] = { "type": "object", "properties": { "": { @@ -19,7 +21,7 @@ } -validation_error_schema = { +validation_error_schema: Dict[str, Any] = { "properties": { "detail": validation_error_detail_schema, "message": { @@ -33,7 +35,7 @@ } -http_error_schema = { +http_error_schema: Dict[str, Any] = { "properties": { "detail": { "type": "object" diff --git a/apiflask/security.py b/apiflask/security.py index 79123b8f..e62370bf 100644 --- a/apiflask/security.py +++ b/apiflask/security.py @@ -1,3 +1,5 @@ +from typing import Optional, Union, Tuple, Any, Mapping + from flask import g, current_app from flask_httpauth import HTTPBasicAuth as BaseHTTPBasicAuth from flask_httpauth import HTTPTokenAuth as BaseHTTPTokenAuth @@ -7,37 +9,45 @@ class _AuthBase: - def __init__(self, description=None): + def __init__(self, description: Optional[str] = None) -> None: self.description = description @property - def current_user(self): - if hasattr(g, 'flask_httpauth_user'): # pragma: no cover - return g.flask_httpauth_user - + def current_user(self) -> Union[None, Any]: + return g.get('flask_httpauth_user', None) -class _AuthErrorMixin: - def __init__(self): - @self.error_handler - def handle_auth_error(status_code): - if current_app.json_errors: - return default_error_handler(status_code) - else: - return 'Unauthorized Access', status_code +def handle_auth_error( + status_code: int +) -> Union[Tuple[str, int], Tuple[dict, int], Tuple[dict, int, Mapping[str, str]]]: + if current_app.json_errors: + return default_error_handler(status_code) + else: + return 'Unauthorized Access', status_code -class HTTPBasicAuth(_AuthBase, BaseHTTPBasicAuth, _AuthErrorMixin): +class HTTPBasicAuth(_AuthBase, BaseHTTPBasicAuth): - def __init__(self, scheme=None, realm=None, description=None): + def __init__( + self, + scheme=None, + realm=None, + description=None + ) -> None: super(HTTPBasicAuth, self).__init__(description=description) BaseHTTPBasicAuth.__init__(self, scheme=scheme, realm=realm) - _AuthErrorMixin.__init__(self) + self.error_handler(handle_auth_error) -class HTTPTokenAuth(_AuthBase, BaseHTTPTokenAuth, _AuthErrorMixin): +class HTTPTokenAuth(_AuthBase, BaseHTTPTokenAuth): - def __init__(self, scheme='Bearer', realm=None, header=None, description=None): + def __init__( + self, + scheme='Bearer', + realm=None, + header=None, + description=None + ) -> None: super(HTTPTokenAuth, self).__init__(description=description) BaseHTTPTokenAuth.__init__(self, scheme=scheme, realm=realm, header=header) - _AuthErrorMixin.__init__(self) + self.error_handler(handle_auth_error) diff --git a/apiflask/settings.py b/apiflask/settings.py index 41b47ef1..d44c1765 100644 --- a/apiflask/settings.py +++ b/apiflask/settings.py @@ -1,47 +1,52 @@ -from .schemas import http_error_schema, validation_error_schema +from typing import Union, List, Optional, Type, Dict + +from marshmallow import Schema + +from .schemas import http_error_schema +from .schemas import validation_error_schema # OpenAPI fields -DESCRIPTION = None -TAGS = None -CONTACT = None -LICENSE = None -SERVERS = None -EXTERNAL_DOCS = None -TERMS_OF_SERVICE = None -SPEC_FORMAT = 'json' +DESCRIPTION: Optional[str] = None +TAGS: Optional[Union[List[str], List[Dict[str, str]]]] = None +CONTACT: Optional[Dict[str, str]] = None +LICENSE: Optional[Dict[str, str]] = None +SERVERS: Optional[List[Dict[str, str]]] = None +EXTERNAL_DOCS: Optional[Dict[str, str]] = None +TERMS_OF_SERVICE: Optional[str] = None +SPEC_FORMAT: str = 'json' # Automation behaviour control -AUTO_TAGS = True -AUTO_DESCRIPTION = True -AUTO_PATH_SUMMARY = True -AUTO_PATH_DESCRIPTION = True -AUTO_200_RESPONSE = True +AUTO_TAGS: bool = True +AUTO_DESCRIPTION: bool = True +AUTO_PATH_SUMMARY: bool = True +AUTO_PATH_DESCRIPTION: bool = True +AUTO_200_RESPONSE: bool = True # Response customization -DEFAULT_2XX_DESCRIPTION = 'Successful response' -DEFAULT_200_DESCRIPTION = 'Successful response' -DEFAULT_201_DESCRIPTION = 'Resource created' -DEFAULT_204_DESCRIPTION = 'Empty response' -AUTO_VALIDATION_ERROR_RESPONSE = True -VALIDATION_ERROR_STATUS_CODE = 400 -VALIDATION_ERROR_DESCRIPTION = 'Validation error' -VALIDATION_ERROR_SCHEMA = validation_error_schema -AUTO_AUTH_ERROR_RESPONSE = True -AUTH_ERROR_STATUS_CODE = 401 -AUTH_ERROR_DESCRIPTION = 'Authorization error' -AUTH_ERROR_SCHEMA = http_error_schema -AUTO_HTTP_ERROR_RESPONSE = True -HTTP_ERROR_SCHEMA = http_error_schema +DEFAULT_2XX_DESCRIPTION: str = 'Successful response' +DEFAULT_200_DESCRIPTION: str = 'Successful response' +DEFAULT_201_DESCRIPTION: str = 'Resource created' +DEFAULT_204_DESCRIPTION: str = 'Empty response' +AUTO_VALIDATION_ERROR_RESPONSE: bool = True +VALIDATION_ERROR_STATUS_CODE: int = 400 +VALIDATION_ERROR_DESCRIPTION: str = 'Validation error' +VALIDATION_ERROR_SCHEMA: Union[Type[Schema], dict] = validation_error_schema +AUTO_AUTH_ERROR_RESPONSE: bool = True +AUTH_ERROR_STATUS_CODE: int = 401 +AUTH_ERROR_DESCRIPTION: str = 'Authorization error' +AUTH_ERROR_SCHEMA: Union[Type[Schema], dict] = http_error_schema +AUTO_HTTP_ERROR_RESPONSE: bool = True +HTTP_ERROR_SCHEMA: Union[Type[Schema], dict] = http_error_schema # Swagger UI and Redoc -DOCS_HIDE_BLUEPRINTS = [] -DOCS_FAVICON = None -REDOC_USE_GOOGLE_FONT = True -REDOC_STANDALONE_JS = 'https://cdn.jsdelivr.net/npm/redoc@next/bundles/\ +DOCS_HIDE_BLUEPRINTS: List[str] = [] +DOCS_FAVICON: Optional[str] = None +REDOC_USE_GOOGLE_FONT: bool = True +REDOC_STANDALONE_JS: str = 'https://cdn.jsdelivr.net/npm/redoc@next/bundles/\ redoc.standalone.js' -SWAGGER_UI_CSS = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css' -SWAGGER_UI_BUNDLE_JS = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/\ +SWAGGER_UI_CSS: str = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css' +SWAGGER_UI_BUNDLE_JS: str = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/\ swagger-ui-bundle.js' -SWAGGER_UI_STANDALONE_PRESET_JS = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/\ +SWAGGER_UI_STANDALONE_PRESET_JS: str = 'https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/\ swagger-ui-standalone-preset.js' -SWAGGER_UI_LAYOUT = 'BaseLayout' -SWAGGER_UI_CONFIG = None -SWAGGER_UI_OAUTH_CONFIG = None +SWAGGER_UI_LAYOUT: str = 'BaseLayout' +SWAGGER_UI_CONFIG: Optional[dict] = None +SWAGGER_UI_OAUTH_CONFIG: Optional[dict] = None diff --git a/apiflask/types.py b/apiflask/types.py new file mode 100644 index 00000000..0828343b --- /dev/null +++ b/apiflask/types.py @@ -0,0 +1,18 @@ +from typing import Any, Callable, TypeVar, Union, Dict, List, Tuple, Mapping + +from flask.wrappers import Response + +DecoratedType = TypeVar('DecoratedType', bound=Callable[..., Any]) +RequestType = TypeVar('RequestType') + +_Body = Union[str, bytes, Dict[str, Any], Response] +_Status = Union[str, int] +_Headers = Union[Dict[Any, Any], List[Tuple[Any, Any]]] +ResponseType = Union[ + _Body, + Tuple[_Body, _Status, _Headers], + Tuple[_Body, _Status], + Tuple[_Body, _Headers] +] +SpecCallbackType = Callable[[Union[dict, str]], Union[dict, str]] +ErrorCallbackType = Callable[[int, str, Any, Mapping[str, str]], ResponseType] diff --git a/apiflask/utils.py b/apiflask/utils.py new file mode 100644 index 00000000..6cdf405c --- /dev/null +++ b/apiflask/utils.py @@ -0,0 +1,50 @@ +from typing import Any + +_sentinel = object() + + +def route_shortcuts(cls): + cls_route = cls.route + + # TODO Remove these shortcuts when pin Flask>=2.0 + def get(self, rule: str, **options: Any): + """Shortcut for ``app.route()``. + + .. versionadded:: 0.2.0 + """ + return cls_route(self, rule, methods=['GET'], **options) + + def post(self, rule: str, **options: Any): + """Shortcut for ``app.route(methods=['POST'])``. + + .. versionadded:: 0.2.0 + """ + return cls_route(self, rule, methods=['POST'], **options) + + def put(self, rule: str, **options: Any): + """Shortcut for ``app.route(methods=['PUT'])``. + + .. versionadded:: 0.2.0 + """ + return cls_route(self, rule, methods=['PUT'], **options) + + def patch(self, rule: str, **options: Any): + """Shortcut for ``app.route(methods=['PATCH'])``. + + .. versionadded:: 0.2.0 + """ + return cls_route(self, rule, methods=['PATCH'], **options) + + def delete(self, rule: str, **options: Any): + """Shortcut for ``app.route(methods=['DELETE'])``. + + .. versionadded:: 0.2.0 + """ + return cls_route(self, rule, methods=['DELETE'], **options) + + cls.get = get + cls.post = post + cls.put = put + cls.patch = patch + cls.delete = delete + return cls diff --git a/setup.cfg b/setup.cfg index 61d90815..2bc681e5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,14 @@ [flake8] -max-line-length = 99 +max-line-length = 100 + +[mypy] +allow_redefinition = True + +[mypy-flask_marshmallow.*] +ignore_missing_imports = True + +[mypy-apispec.*] +ignore_missing_imports = True + +[mypy-flask_httpauth.*] +ignore_missing_imports = True diff --git a/tests/test_scaffold.py b/tests/test_utils.py similarity index 100% rename from tests/test_scaffold.py rename to tests/test_utils.py diff --git a/tox.ini b/tox.ini index d9c9216f..3c6cfb07 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] -envlist=flake8,py37,py38,py39,pypy37,docs +envlist=flake8,py37,py38,py39,pypy37,docs,mypy skip_missing_interpreters=True [gh-actions] python = 3.7: py37 - 3.8: py38 + 3.8: py38, mypy 3.9: py39 pypy3.7: pypy37 @@ -31,3 +31,8 @@ whitelist_externals= mkdocs commands= mkdocs build + +[testenv:mypy] +deps = mypy +commands = + mypy apiflask/