Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Authentication and Authorization logic. #127

Merged
merged 19 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions spectree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging

from .models import Tag
from .models import SecurityScheme, Tag
from .response import Response
from .spec import SpecTree

__all__ = ["SpecTree", "Response", "Tag"]
__all__ = ["SpecTree", "Response", "Tag", "SecurityScheme"]

# setup library logging
logging.getLogger(__name__).addHandler(logging.NullHandler())
6 changes: 6 additions & 0 deletions spectree/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging
from typing import List, Optional

from .models import SecurityScheme


class Config:
Expand All @@ -11,6 +14,7 @@ class Config:
:ivar TITLE: service name
:ivar VERSION: service version
:ivar DOMAIN: service host domain
:ivar SECURITY_SCHEMES: OpenAPI `securitySchemes` JSON with list of auth configs
"""

def __init__(self, **kwargs):
Expand All @@ -28,6 +32,8 @@ def __init__(self, **kwargs):
self.VERSION = "0.1"
self.DOMAIN = None

self.SECURITY_SCHEMES: Optional[List[SecurityScheme]] = None

self.logger = logging.getLogger(__name__)

self.update(**kwargs)
Expand Down
96 changes: 95 additions & 1 deletion spectree/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import re
from enum import Enum
from typing import Any, Dict, Sequence

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, root_validator, validator

# OpenAPI names validation regexp
OpenAPI_NAME_RE = re.compile(r"^[A-Za-z0-9-._]+")


class ExternalDocs(BaseModel):
Expand Down Expand Up @@ -44,3 +49,92 @@ class UnprocessableEntity(BaseModel):
"""Model of 422 Unprocessable Entity error."""

__root__: Sequence[UnprocessableEntityElement]


class SecureType(str, Enum):
HTTP = "http"
API_KEY = "apiKey"
OAUTH_TWO = "oauth2"
OPEN_ID_CONNECT = "openIdConnect"


class InType(str, Enum):
HEADER = "header"
QUERY = "query"
COOKIE = "cookie"


type_req_fields = {
SecureType.HTTP: ["scheme"],
SecureType.API_KEY: ["name", "field_in"],
SecureType.OAUTH_TWO: ["flows"],
SecureType.OPEN_ID_CONNECT: ["openIdConnectUrl"],
}


class SecuritySchemeData(BaseModel):
"""
Security scheme data
https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.3.md#securitySchemeObject
"""

type: SecureType = Field(..., description="Secure scheme type")
description: str = Field(
None,
description="A short description for security scheme.",
)
name: str = Field(
None,
description="The name of the header, query or cookie parameter to be used.",
)
field_in: InType = Field(
None, alias="in", description="The location of the API key."
)
scheme: str = Field(None, description="The name of the HTTP Authorization scheme.")
bearerFormat: dict = Field(
None,
description="A hint to the client to identify how the bearer token is formatted.",
)
flows: dict = Field(
None,
description="Containing configuration information for the flow types supported.",
)
openIdConnectUrl: str = Field(
None, description="OpenId Connect URL to discover OAuth2 configuration values."
)

@root_validator()
def check_type_required_fields(cls, values: dict):
exist_fields = {key for key in values.keys() if values[key]}
if not values.get("type"):
raise ValueError("Type field is required")

if not set(type_req_fields[values["type"]]).issubset(exist_fields):
raise ValueError(
f"For `{values['type']}` type `{', '.join(type_req_fields[values['type']])}` field(s) is required."
)
return values

class Config:
validate_assignment = True


class SecurityScheme(BaseModel):
"""
Named security scheme
"""

name: str = Field(
...,
description="Custom security scheme name. Can only contain - [A-Za-z0-9-._]",
)
AndreiDrang marked this conversation as resolved.
Show resolved Hide resolved
data: SecuritySchemeData = Field(..., description="Security scheme data")

@validator("name")
def check_name(cls, value: str):
if not OpenAPI_NAME_RE.fullmatch(value):
raise ValueError("Name not match OpenAPI rules")
return value

class Config:
validate_assignment = True
16 changes: 15 additions & 1 deletion spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,23 @@ def validate(
cookies=None,
resp=None,
tags=(),
security=None,
before=None,
after=None,
):
"""
- validate query, json, headers in request
- validate response body and status code
- add tags to this API route
- add security to this API route

:param query: `pydantic.BaseModel`, query in uri like `?name=value`
:param json: `pydantic.BaseModel`, JSON format request body
:param headers: `pydantic.BaseModel`, if you have specific headers
:param cookies: `pydantic.BaseModel`, if you have cookies for this route
:param resp: `spectree.Response`
:param tags: a tuple of strings or :class:`spectree.models.Tag`
:param security: dict with security config for current route and method
:param before: :meth:`spectree.utils.default_before_handler` for
specific endpoint
:param after: :meth:`spectree.utils.default_after_handler` for
Expand Down Expand Up @@ -187,6 +190,8 @@ async def async_validate(*args, **kwargs):
if tags:
validation.tags = tags

# if security not exist - set empty dict
validation.security = security or {}
# register decorator
validation._decorator = self
return validation
Expand Down Expand Up @@ -219,6 +224,7 @@ def _generate_spec(self):
"operationId": f"{method.lower()}_{path}",
"description": desc or "",
"tags": [str(x) for x in getattr(func, "tags", ())],
"security": [getattr(func, "security", {})],
"parameters": parse_params(func, parameters[:], self.models),
"responses": parse_resp(func),
}
Expand All @@ -236,7 +242,15 @@ def _generate_spec(self):
},
"tags": list(tags.values()),
"paths": {**routes},
"components": {"schemas": {**self.models, **self._get_model_definitions()}},
"components": {
"schemas": {**self.models, **self._get_model_definitions()},
"securitySchemes": {
scheme.name: scheme.data.dict(exclude_none=True, by_alias=True)
for scheme in self.config.SECURITY_SCHEMES
}
if self.config.SECURITY_SCHEMES
else None,
},
}
return spec

Expand Down
64 changes: 63 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, Field, root_validator

from spectree import Tag
from spectree import SecurityScheme, Tag

api_tag = Tag(name="API", description="🐱", externalDocs={"url": "https://pypi.org"})

Expand Down Expand Up @@ -67,3 +67,65 @@ def get_paths(spec):

paths.sort()
return paths


# data from example - https://swagger.io/docs/specification/authentication/
SECURITY_SCHEMAS = [
SecurityScheme(
name="auth_apiKey",
data={"type": "apiKey", "name": "Authorization", "in": "header"},
),
SecurityScheme(name="auth_BasicAuth", data={"type": "http", "scheme": "basic"}),
SecurityScheme(name="auth_BearerAuth", data={"type": "http", "scheme": "bearer"}),
SecurityScheme(
name="auth_openID",
data={
"type": "openIdConnect",
"openIdConnectUrl": "https://example.com/.well-known/openid-configuration",
},
),
SecurityScheme(
name="auth_oauth2",
data={
"type": "oauth2",
"flows": {
"authorizationCode": {
"authorizationUrl": "https://example.com/oauth/authorize",
"tokenUrl": "https://example.com/oauth/token",
"scopes": {
"read": "Grants read access",
"write": "Grants write access",
"admin": "Grants access to admin operations",
},
},
},
},
),
]
WRONG_SECURITY_SCHEMAS_DATA = [
{
"name": "auth_apiKey_name",
"data": {"type": "apiKey", "name": "Authorization"},
},
{
"name": "auth_apiKey_in",
"data": {"type": "apiKey", "in": "header"},
},
{
"name": "auth_BasicAuth_scheme",
"data": {"type": "http"},
},
{
"name": "auth_openID_openIdConnectUrl",
"data": {"type": "openIdConnect"},
},
{
"name": "auth_oauth2_flows",
"data": {"type": "oauth2"},
},
{
"name": "empty_Data",
"data": {},
},
{"name": "wrong_Data", "data": {"x": "y"}},
]
49 changes: 49 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import string

import pytest
from pydantic import ValidationError

from spectree import SecurityScheme
from spectree.config import Config

from .common import SECURITY_SCHEMAS, WRONG_SECURITY_SCHEMAS_DATA


@pytest.fixture
def config():
Expand All @@ -16,6 +22,7 @@ def test_update_config(config):
assert config.FILENAME == default.FILENAME
assert config.TITLE == "demo"
assert config.VERSION == "latest"
assert config.SECURITY_SCHEMES is None

config.update(unknown="missing")
with pytest.raises(AttributeError):
Expand All @@ -38,3 +45,45 @@ def test_update_mode(config):
with pytest.raises(AssertionError) as e:
config.update(mode="true")
assert "MODE" in str(e.value)


@pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS)
def test_update_security_scheme(config, secure_item: SecurityScheme):
# update and validate each schema type
config.update(security_schemes={secure_item.name: secure_item.data})
assert config.SECURITY_SCHEMES == {secure_item.name: secure_item.data}


def test_update_security_schemes(config):
# update and validate ALL schemas types
config.update(security_schemes=SECURITY_SCHEMAS)
assert config.SECURITY_SCHEMES == SECURITY_SCHEMAS


@pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS)
def test_update_security_scheme_wrong_type(config, secure_item: SecurityScheme):
# update and validate each schema type
with pytest.raises(ValidationError):
secure_item.data.type += "_wrong"


@pytest.mark.parametrize(
"symbol", [symb for symb in string.punctuation if symb not in "-._"]
)
@pytest.mark.parametrize(("secure_item"), SECURITY_SCHEMAS)
def test_update_security_scheme_wrong_name(
config, secure_item: SecurityScheme, symbol: str
):
# update and validate each schema name
with pytest.raises(ValidationError):
secure_item.name += symbol

with pytest.raises(ValidationError):
secure_item.name = symbol + secure_item.name


@pytest.mark.parametrize(("secure_item"), WRONG_SECURITY_SCHEMAS_DATA)
def test_update_security_scheme_wrong_data(config, secure_item: dict):
# update and validate each schema type
with pytest.raises(ValidationError):
SecurityScheme(**secure_item)
Loading