Skip to content

Commit

Permalink
Merge pull request #130 from O1dLiu/fix_model_name_unique
Browse files Browse the repository at this point in the history
Fix form model overwritten
  • Loading branch information
kemingy committed May 13, 2021
2 parents cbee2a6 + 887fd24 commit 36c4d18
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 13 deletions.
3 changes: 2 additions & 1 deletion spectree/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,12 @@ def generate_spec(self):
responses[parse_code(code)] = {"description": DEFAULT_CODE_DESC[code]}

for code, model in self.code_models.items():
model_name = f"{model.__module__}.{model.__name__}"
responses[parse_code(code)] = {
"description": DEFAULT_CODE_DESC[code],
"content": {
"application/json": {
"schema": {"$ref": f"#/components/schemas/{model.__name__}"}
"schema": {"$ref": f"#/components/schemas/{model_name}"}
}
},
}
Expand Down
8 changes: 5 additions & 3 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,16 @@ async def async_validate(*args, **kwargs):
):
if model is not None:
assert issubclass(model, BaseModel)
self.models[model.__name__] = model.schema(
model_key = f"{model.__module__}.{model.__name__}"
self.models[model_key] = model.schema(
ref_template="#/components/schemas/{model}"
)
setattr(validation, name, model.__name__)
setattr(validation, name, model_key)

if resp:
for model in resp.models:
self.models[model.__name__] = model.schema(
model_key = f"{model.__module__}.{model.__name__}"
self.models[model_key] = model.schema(
ref_template="#/components/schemas/{model}"
)
validation.resp = resp
Expand Down
6 changes: 4 additions & 2 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
)
def test_plugin_spec(api):
models = {
m.__name__: m.schema(ref_template="#/components/schemas/{model}")
f"{m.__module__}.{m.__name__}": m.schema(
ref_template="#/components/schemas/{model}"
)
for m in (Query, JSON, Resp, Cookies, Headers)
}
for name, schema in models.items():
Expand Down Expand Up @@ -49,7 +51,7 @@ def test_plugin_spec(api):
assert user["tags"] == ["API", "test"]
assert (
user["requestBody"]["content"]["application/json"]["schema"]["$ref"]
== "#/components/schemas/JSON"
== "#/components/schemas/tests.common.JSON"
)
assert len(user["responses"]) == 3

Expand Down
4 changes: 2 additions & 2 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def test_response_spec():
assert spec["422"]["description"] == DEFAULT_CODE_DESC["HTTP_422"]
assert (
spec["201"]["content"]["application/json"]["schema"]["$ref"].split("/")[-1]
== "DemoModel"
== "tests.common.DemoModel"
)
assert (
spec["422"]["content"]["application/json"]["schema"]["$ref"].split("/")[-1]
== "UnprocessableEntity"
== "spectree.models.UnprocessableEntity"
)

assert spec.get(200) is None
Expand Down
14 changes: 9 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,27 @@ def test_parse_resp():
assert resp_spec["422"]["description"] == "Unprocessable Entity"
assert (
resp_spec["422"]["content"]["application/json"]["schema"]["$ref"]
== "#/components/schemas/UnprocessableEntity"
== "#/components/schemas/spectree.models.UnprocessableEntity"
)
assert (
resp_spec["200"]["content"]["application/json"]["schema"]["$ref"]
== "#/components/schemas/DemoModel"
== "#/components/schemas/tests.common.DemoModel"
)


def test_parse_request():
assert (
parse_request(demo_func)["content"]["application/json"]["schema"]["$ref"]
== "#/components/schemas/DemoModel"
== "#/components/schemas/tests.common.DemoModel"
)
assert parse_request(demo_class.demo_method) == {}


def test_parse_params():
models = {
"DemoModel": DemoModel.schema(ref_template="#/components/schemas/{model}")
"tests.common.DemoModel": DemoModel.schema(
ref_template="#/components/schemas/{model}"
)
}
assert parse_params(demo_func, [], models) == []
params = parse_params(demo_class.demo_method, [], models)
Expand All @@ -120,7 +122,9 @@ def test_parse_params():

def test_parse_params_with_route_param_keywords():
models = {
"DemoQuery": DemoQuery.schema(ref_template="#/components/schemas/{model}")
"tests.common.DemoQuery": DemoQuery.schema(
ref_template="#/components/schemas/{model}"
)
}
params = parse_params(demo_func_with_query, [], models)
assert params == [
Expand Down

0 comments on commit 36c4d18

Please sign in to comment.