Skip to content

Commit

Permalink
Support passing APISpec object to spec processor
Browse files Browse the repository at this point in the history
Add a `SPEC_PROCESSOR_PASS_OBJECT` to control the argument type of the spec processor.
  • Loading branch information
greyli committed Mar 12, 2023
1 parent 0faff1f commit 2702b68
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 16 deletions.
10 changes: 10 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@

## Version 1.3.0

- Add config `SPEC_PROCESSOR_PASS_OBJECT` to control the argument type of
spec processor. The `spec` argument will be an `apispec.APISpec` object
when this config is `True` ([issue #213][issue_213]).

[issue_213]: https://github.com/apiflask/apiflask/issues/213


## Version 1.2.4

Released: -
Expand Down
18 changes: 18 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,24 @@ app.config['YAML_SPEC_MIMETYPE'] = 'text/x-yaml'
This configuration variable was added in the [version 0.4.0](/changelog/#version-040).


### SPEC_PROCESSOR_PASS_OBJECT

If `True`, the `spec` argument passed to the spec processor will be an
[`apispec.APISpec`](https://apispec.readthedocs.io/en/latest/api_core.html#apispec.APISpec) object.

- Type: `bool`
- Default value: `False`
- Examples:

```python
app.config['SPEC_PROCESSOR_PASS_OBJECT'] = True
```

!!! warning "Version >= 1.3.0"

This configuration variable was added in the [version 1.3.0](/changelog/#version-130).


## Automation behavior control

The following configuration variables are used to control the automation behavior
Expand Down
22 changes: 18 additions & 4 deletions docs/openapi.md
Original file line number Diff line number Diff line change
Expand Up @@ -920,11 +920,25 @@ def update_spec(spec):
return spec
```

Notice the format of the spec depends on the value of the configuration
variable `SPEC_FORMAT` (defaults to `'json'`):
By default, the `spec` argument is a dict. When the `SPEC_PROCESSOR_PASS_OBJECT` config is
`True`, the `spec` argument will be an
[`apispec.APISpec`](https://apispec.readthedocs.io/en/latest/api_core.html#apispec.APISpec) object.

- `'json'` -> dict
- `'yaml'` -> string
```python
from apiflask import APIFlask

app = APIFlask(__name__)
app.config['SPEC_PROCESSOR_PASS_OBJECT'] = True

class FooSchema(Schema):
id = Integer()

@app.spec_processor
def update_spec(spec):
spec.title = 'Updated Title'
spec.components.schema('Foo', schema=FooSchema) # add a schema manually
return spec
```

Check out [the example application](https://github.com/apiflask/apiflask/tree/main/examples/openapi/app.py)
for OpenAPI support, see [the examples page](/examples) for running the example application.
24 changes: 19 additions & 5 deletions src/apiflask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from apispec import APISpec
from apispec import BasePlugin
from apispec.ext.marshmallow import MarshmallowPlugin
from apispec.yaml_utils import dict_to_yaml
from flask import Blueprint
from flask import Flask
from flask import has_request_context
Expand Down Expand Up @@ -680,16 +681,29 @@ def _get_spec(
- Rename the method name to `_get_spec`.
- Add the `force_update` parameter.
*Version changed: 1.3.0*
- Add the `SPEC_PROCESSOR_PASS_OBJECT` config to control the argument type
when calling the spec processor.
"""
if spec_format is None:
spec_format = self.config['SPEC_FORMAT']
if self._spec is None or force_update:
if spec_format == 'json':
self._spec = self._generate_spec().to_dict()
else:
self._spec = self._generate_spec().to_yaml()
spec_object: APISpec = self._generate_spec()
if self.spec_callback:
self._spec = self.spec_callback(self._spec) # type: ignore
if self.config['SPEC_PROCESSOR_PASS_OBJECT']:
self._spec = self.spec_callback(
spec_object # type: ignore
).to_dict()
else:
self._spec = self.spec_callback(
spec_object.to_dict()
)
else:
self._spec = spec_object.to_dict()
if spec_format in ['yml', 'yaml']:
self._spec = dict_to_yaml(self._spec) # type: ignore
# sync local spec
if self.config['SYNC_LOCAL_SPEC']:
spec_path = self.config['LOCAL_SPEC_PATH']
Expand Down
4 changes: 4 additions & 0 deletions src/apiflask/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LOCAL_SPEC_PATH: t.Optional[str] = None
LOCAL_SPEC_JSON_INDENT: int = 2
SYNC_LOCAL_SPEC: t.Optional[bool] = None
SPEC_PROCESSOR_PASS_OBJECT: bool = False
# Automation behavior control
AUTO_TAGS: bool = True
AUTO_SERVERS: bool = True
Expand Down Expand Up @@ -71,3 +72,6 @@

# Version changed: 1.2.0
# Change VALIDATION_ERROR_STATUS_CODE from 400 to 422.

# Version added: 1.3.0
# SPEC_PROCESSOR_PASS_OBJECT
34 changes: 27 additions & 7 deletions tests/test_openapi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .schemas import Baz
from .schemas import Foo
from apiflask import APIBlueprint
from apiflask import Schema as BaseSchema
from apiflask import Schema
from apiflask.commands import spec_command
from apiflask.fields import Integer

Expand All @@ -35,6 +35,26 @@ def edit_spec(spec):
assert rv.json['info']['title'] == 'Foo'


def test_spec_processor_pass_object(app, client):
app.config['SPEC_PROCESSOR_PASS_OBJECT'] = True

class NotUsedSchema(Schema):
id = Integer()

@app.spec_processor
def process_spec(spec):
spec.title = 'Foo'
spec.components.schema('NotUsed', schema=NotUsedSchema)
return spec

rv = client.get('/openapi.json')
assert rv.status_code == 200
validate_spec(rv.json)
assert rv.json['info']['title'] == 'Foo'
assert 'NotUsed' in rv.json['components']['schemas']
assert 'id' in rv.json['components']['schemas']['NotUsed']['properties']


@pytest.mark.parametrize('spec_format', ['json', 'yaml', 'yml'])
def test_get_spec(app, spec_format):
spec = app._get_spec(spec_format)
Expand Down Expand Up @@ -114,20 +134,20 @@ def bar():
def baz():
pass

class Spam(BaseSchema):
class Spam(Schema):
id = Integer()

@app.route('/spam')
@app.output(Spam)
def spam():
pass

class Schema(BaseSchema):
class Ham(Schema):
id = Integer()

@app.route('/schema')
@app.output(Schema)
def schema():
@app.route('/ham')
@app.output(Ham)
def ham():
pass

spec = app.spec
Expand All @@ -136,7 +156,7 @@ def schema():
assert 'Bar' in spec['components']['schemas']
assert 'Baz' in spec['components']['schemas']
assert 'Spam' in spec['components']['schemas']
assert 'Schema' in spec['components']['schemas']
assert 'Ham' in spec['components']['schemas']


def test_servers_and_externaldocs(app):
Expand Down

0 comments on commit 2702b68

Please sign in to comment.