Skip to content

Commit

Permalink
Fix openapi >=3 versions compatibility (#91)
Browse files Browse the repository at this point in the history
Co-authored-by: Keming <kemingy94@gmail.com>
Co-authored-by: cgaunet <cyrilg@theodo.fr>
  • Loading branch information
3 people committed Dec 11, 2020
1 parent ee2e8b0 commit f37afe9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
11 changes: 7 additions & 4 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,16 @@ async def async_validate(*args, **kwargs):
):
if model is not None:
assert issubclass(model, BaseModel)
self.models[model.__name__] = model.schema()
self.models[model.__name__] = model.schema(
ref_template="#/components/schemas/{model}"
)
setattr(validation, name, model.__name__)

if resp:
for model in resp.models:
self.models[model.__name__] = model.schema()
self.models[model.__name__] = model.schema(
ref_template="#/components/schemas/{model}"
)
validation.resp = resp

if tags:
Expand Down Expand Up @@ -214,8 +218,7 @@ def _generate_spec(self):
},
"tags": list(tags.values()),
"paths": {**routes},
"components": {"schemas": {**self.models}},
"definitions": self._get_model_definitions(),
"components": {"schemas": {**self.models, **self._get_model_definitions()}},
}
return spec

Expand Down
5 changes: 4 additions & 1 deletion tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

@pytest.mark.parametrize("api", [flask_api, falcon_api, starlette_api])
def test_plugin_spec(api):
models = {m.__name__: m.schema() for m in (Query, JSON, Resp, Cookies, Headers)}
models = {
m.__name__: m.schema(ref_template="#/components/schemas/{model}")
for m in (Query, JSON, Resp, Cookies, Headers)
}
for name, schema in models.items():
assert api.spec["components"]["schemas"][name] == schema

Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def test_parse_request():


def test_parse_params():
models = {"DemoModel": DemoModel.schema()}
models = {
"DemoModel": DemoModel.schema(ref_template="#/components/schemas/{model}")
}
assert parse_params(demo_func, [], models) == []
params = parse_params(demo_class.demo_method, [], models)
assert len(params) == 3
Expand Down

0 comments on commit f37afe9

Please sign in to comment.