Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,89 +1,166 @@
import typing
import pydantic
import os
import json
from typing import Union, Literal, List, Dict, Any, Type

from pydantic import (
BaseModel,
ValidationError,
constr,
conint,
validator,
root_validator,
)

from ansible_collections.nhsd.apigee.plugins.module_utils.models.apigee.rate_limiting_config import (
RateLimitingConfig,
)

def _literal_name(class_):
# This accesses the 'attribute_name' from
# class class_:
# name: typing.Literal['attribute_name']
return class_.__fields__['name'].type_.__args__[0]

MANUAL_APPROVAL_EXCEPTIONS = ["canary-api-prod"]


class ApigeeProductAttributeRateLimiting(BaseModel):
name: Literal["ratelimiting"]
value: Union[Dict[str, RateLimitingConfig], str]

@validator("value")
def validate_ratelimiting(
cls, ratelimiting: Union[Dict[str, RateLimitingConfig], str]
) -> str:
"""
Apigee API requires a string. We decode it as JSON in the
shared flow.

class ApigeeProductAttributeAccess(pydantic.BaseModel):
name: typing.Literal["access"]
value: typing.Literal["public", "private"]
So if pydantic has happily parsed this into a
Dict[str,RateLimitingConfig], then json dump it.

Otherwise, if we've gotten a string (e.g. by calling the
Apigee API) check the schema is valid using the pydantic
models.

class ApigeeProductAttributeRateLimit(pydantic.BaseModel):
name: typing.Literal["ratelimit"]
value: pydantic.constr(regex=r"^[0-9]+(ps|pm)$")
Running strings through a JSON parser will also 'normalize'
the JSON string, so whitespace and key order doesn't matter
for diffs.
"""
error_msg = f"Malformed 'ratelimiting' attribute: {ratelimiting}"

if isinstance(ratelimiting, str):
# If we have a string, run it through Pydantic by hand.
try:
ratelimiting_dict = json.loads(ratelimiting)
for key, value in ratelimiting_dict.items():
ratelimiting_dict[key] = RateLimitingConfig(**value)
ratelimiting = ratelimiting_dict
except (ValidationError, json.JSONDecodeError):
raise ValueError(error_msg)

# Apigee enforces these must be strings, so do a nicely sorted
# JSON dump.
ratelimiting_dict = {}
for proxy_name, config in ratelimiting.items():
ratelimiting_dict[proxy_name] = config.dict()
ratelimiting = json.dumps(ratelimiting_dict, sort_keys=True)
return ratelimiting


class ApigeeProductAttributeAccess(BaseModel):
name: Literal["access"]
value: Literal["public", "private"]


class ApigeeProductAttributeRateLimit(BaseModel):
name: Literal["ratelimit"]
value: constr(regex=r"^[0-9]+(ps|pm)$")


def _literal_name(class_):
# This accesses the 'attribute_name' from
# class class_:
# name: Literal['attribute_name']
return class_.__fields__["name"].type_.__args__[0]


# This ensures that a generic ApigeeProductAttribute can't be
# constructed from a more specific one that fails valiation.
# constructed from a more specific one that fails valiation. Sadly
# the pydantic error message is a mess, e.g. if you pass in
# 'ratelimiting' with invalid JSON, the error messages will tell you
# you failed validation for all our customized ApigeeProductAttribute
# types.
PRODUCT_ATTRIBUTE_REGEX = (
"^(?!("
+ "|".join(
_literal_name(c)
for c in [
ApigeeProductAttributeAccess,
ApigeeProductAttributeRateLimit,
ApigeeProductAttributeRateLimiting,
]
)
+ ")$)"
)


class ApigeeProductAttribute(pydantic.BaseModel):
name: pydantic.constr(regex=PRODUCT_ATTRIBUTE_REGEX)
class ApigeeProductAttribute(BaseModel):
name: constr(regex=PRODUCT_ATTRIBUTE_REGEX)
value: str


class ApigeeProduct(pydantic.BaseModel):
def _count_cls(items: List[Any], cls: Type):
return sum(isinstance(item, cls) for item in items)


class ApigeeProduct(BaseModel):
name: str
approvalType: typing.Literal["auto", "manual"]
attributes: typing.List[
typing.Union[
ApigeeProductAttributeAccess,
ApigeeProductAttributeRateLimit,
ApigeeProductAttribute,
],
]
approvalType: Literal["auto", "manual"]
attributes: List[
Union[
ApigeeProductAttributeAccess,
ApigeeProductAttributeRateLimit,
ApigeeProductAttributeRateLimiting,
ApigeeProductAttribute,
],
]
description: str
displayName: str
environments: typing.List[str]
proxies: typing.List[str]
quota: str
quotaInterval: str
quotaTimeUnit: typing.Literal["minute", "hour"]
scopes: typing.List[str]

@pydantic.root_validator
environments: List[str]
proxies: List[str]
quota: constr(regex=r"[1-9][0-9]*")
quotaInterval: constr(regex=r"[1-9][0-9]*")
quotaTimeUnit: Literal["minute", "hour"]
scopes: List[str]

@root_validator
def override_approval_type_for_prod(cls, values):
manual_approval_exceptions = ["canary-api-prod"]
if "prod" in values["environments"]:
if values["approvalType"] == "auto" and not values["name"] in manual_approval_exceptions:
values["approvalType"] = "manual"
name = values["name"]
environments = values["environments"]
if "prod" in environments and name not in MANUAL_APPROVAL_EXCEPTIONS:
values["approvalType"] = "manual"
return values

@pydantic.validator("environments", "scopes", "proxies")
@validator("environments", "scopes", "proxies")
def sorted(cls, v):
return sorted(v)

@pydantic.validator("attributes")
@validator("attributes")
def validate_attributes(cls, attributes, values):
attributes = sorted(attributes, key=lambda a: a.name)

for class_ in [
ApigeeProductAttributeAccess,
ApigeeProductAttributeRateLimit,
]:
attrs = [a for a in attributes if isinstance(a, class_)]
if len(attrs) != 1:
class_min_max = [
(ApigeeProductAttributeAccess, 1, 1),
(ApigeeProductAttributeRateLimit, 1, 1),
(ApigeeProductAttributeRateLimiting, 0, 1),
]

for _class, _min, _max in class_min_max:
count = _count_cls(attributes, _class)
if count < _min or count > _max:
if _min == _max:
count_msg = f"exactly {_min}"
else:
count_msg = f"between {_min} and {_max}"
raise AssertionError(
f"Product {values['name']} must contain exactly 1 "
+ f"attribute with name: '{_literal_name(class_)}', "
+ f"found {len(attrs)}"
f"Product {values['name']} must contain {count_msg} "
+ f"'{_literal_name(_class)}' attributes , "
+ f"your product has {count}."
)

return attributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Pydantic class for the rateliming config JSON, attached to products
and apps to control the ApplyRateLimiting shared flow.
"""
from typing import Literal

from pydantic import BaseModel, conint, constr, Extra


class ExcludeNoneModel(BaseModel):

"""
Providing default values for ratelimiting here would mean that
changing defaults required a redeploy for all proxies.

Therefore we set None as the default value on all
RateLimitingConfig attributes, and *do not* export them as JSON.

The platform defaults are used to fill in the missing values
inside the ApplyRateLimiting shared flow. This pattern us to
update the defaults for everyone by just by updating the shared
flow.
"""
def dict(self, **kwargs):
kwargs["exclude_none"] = True
return super().dict(**kwargs)

class Config:
extra=Extra.forbid


class QuotaConfig(ExcludeNoneModel):
enabled: bool = None
interval: conint(gt=0) = None
limit: conint(gt=0) = None
timeunit: Literal["minute", "hour"] = None


class SpikeArrestConfig(ExcludeNoneModel):
enabled: bool = None
ratelimit: constr(regex=r"^[1-9][0-9]*(ps|pm)$") = None


class RateLimitingConfig(ExcludeNoneModel):
quota: QuotaConfig = None
spikeArrest: SpikeArrestConfig = None
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ansible_collections.nhsd.apigee.plugins.module_utils.paas import api_registry

SCHEMA_VERSION = "1.1.1"
SCHEMA_VERSION = "1.1.2"

_REGISTRY_DATA = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def invalid_guid():
def mock_api_registry(monkeypatch):
def _mock_api_registry_get(name: str):
if name == CANARY_API["name"]:
print("HELLO")
return CANARY_API
else:
raise ValueError(f"No API named {name} found.")
Expand Down
Loading