diff --git a/fastapi_router_controller/__init__.py b/fastapi_router_controller/__init__.py index e2ec494..c8c901a 100644 --- a/fastapi_router_controller/__init__.py +++ b/fastapi_router_controller/__init__.py @@ -1,7 +1,5 @@ """FastAPI Router Contoller, FastAPI utility to allow Controller Class usage""" -__version__ = "0.1.0" - from fastapi_router_controller.lib.controller import Controller as Controller from fastapi_router_controller.lib.controller import OPEN_API_TAGS as ControllersTags from fastapi_router_controller.lib.controller_loader import ControllerLoader as ControllerLoader diff --git a/fastapi_router_controller/lib/controller.py b/fastapi_router_controller/lib/controller.py index e0ca476..dc05a8d 100644 --- a/fastapi_router_controller/lib/controller.py +++ b/fastapi_router_controller/lib/controller.py @@ -1,6 +1,7 @@ import inspect -import copy +from copy import deepcopy from fastapi import APIRouter, Depends +from fastapi_router_controller.lib.exceptions import MultipleResourceException, MultipleRouterException OPEN_API_TAGS = [] __app_controllers__ = [] @@ -35,22 +36,29 @@ class Controller: It expose some utilities and decorator functions to define a router controller class """ - - RC_KEY = "__router__" - SIGNATURE_KEY = "__signature__" + RC_KEY = '__router__' + SIGNATURE_KEY = '__signature__' + HAS_CONTROLLER_KEY = '__has_controller__' + RESOURCE_CLASS_KEY = '__resource_cls__' def __init__(self, router: APIRouter, openapi_tag: dict = None) -> None: """ :param router: The FastApi router to link to the Class :param openapi_tag: An openapi object that will describe your routes in the openapi tamplate """ - self.router = copy.deepcopy(router) + # Each Controller must be linked to one fastapi router + if hasattr(router, Controller.HAS_CONTROLLER_KEY): + raise MultipleRouterException() + + self.router = deepcopy(router) self.openapi_tag = openapi_tag self.cls = None if openapi_tag: OPEN_API_TAGS.append(openapi_tag) + setattr(router, Controller.HAS_CONTROLLER_KEY, True) + def __get_parent_routes(self, router: APIRouter): """ Private utility to get routes from an extended class @@ -65,14 +73,21 @@ def __get_parent_routes(self, router: APIRouter): self.router.add_api_route(route.path, route.endpoint, **options) def add_resource(self, cls): - if self.cls and cls != self.cls: - raise Exception("Every controller needs its own router!") - self.cls = cls - # check if cls was extended from another Controller + ''' + Mark a class as Controller Resource + ''' + # check if the same controller was already used for another cls (Resource) + if hasattr(self, Controller.RESOURCE_CLASS_KEY) and getattr(self, Controller.RESOURCE_CLASS_KEY) != cls: + raise MultipleResourceException() + + # check if cls (Resource) was exteded from another if hasattr(cls, Controller.RC_KEY): self.__get_parent_routes(cls.__router__) - cls.__router__ = self.router + + setattr(cls, Controller.RC_KEY, self.router) + setattr(self, Controller.RESOURCE_CLASS_KEY, cls) cls.router = lambda: Controller.__parse_controller_router(cls) + return cls def resource(self): @@ -101,28 +116,31 @@ def __parse_controller_router(cls): dependencies = None if hasattr(cls, "dependencies"): - dependencies = copy.deepcopy(cls.dependencies) + dependencies = deepcopy(cls.dependencies) delattr(cls, "dependencies") for route in router.routes: - # get the signature of the endpoint function - signature = inspect.signature(route.endpoint) - # get the parameters of the endpoint function - signature_parameters = list(signature.parameters.values()) # add class dependencies if dependencies: for depends in dependencies[::-1]: route.dependencies.insert(0, depends) + + # get the signature of the endpoint function + signature = inspect.signature(route.endpoint) + # get the parameters of the endpoint function + signature_parameters = list(signature.parameters.values()) # replace the class instance with the itself FastApi Dependecy signature_parameters[0] = signature_parameters[0].replace( default=Depends(cls) ) + # set self and after it the keyword args new_parameters = [signature_parameters[0]] + [ parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in signature_parameters[1:] ] + new_signature = signature.replace(parameters=new_parameters) setattr(route.endpoint, Controller.SIGNATURE_KEY, new_signature) diff --git a/fastapi_router_controller/lib/exceptions.py b/fastapi_router_controller/lib/exceptions.py new file mode 100644 index 0000000..112d572 --- /dev/null +++ b/fastapi_router_controller/lib/exceptions.py @@ -0,0 +1,7 @@ +class MultipleRouterException(Exception): + def __init__(self): + super().__init__('Router already used by another Controller') + +class MultipleResourceException(Exception): + def __init__(self): + super().__init__('Controller already used by another Resource') diff --git a/tests/test_controller.py b/tests/test_controller.py index c82f432..4785448 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -87,22 +87,19 @@ def root( id += self.x.create() return SampleObject(id=id) - def hello(self, f: Filter, y=Depends(get_y)): + def hello(self, + f: Filter, y=Depends(get_y) + ): _id = f.foo _id += y _id += self.x.create() return SampleObject(id=_id) - app = FastAPI( - title="A sample application using fastapi_router_controller", - version="0.1.0", - openapi_tags=ControllersTags, - ) - router = APIRouter() controller = Controller(router, openapi_tag={"name": "sample_controller"}) controller.add_resource(SampleController) + controller.route.add_api_route( "/", SampleController.root, @@ -111,9 +108,20 @@ def hello(self, f: Filter, y=Depends(get_y)): response_model=SampleObject, methods=["GET"], ) + controller.route.add_api_route( - "/hello", SampleController.hello, response_model=SampleObject, methods=["POST"] + "/hello", + SampleController.hello, + response_model=SampleObject, + methods=["POST"] + ) + + app = FastAPI( + title="A sample application using fastapi_router_controller", + version="0.1.0", + openapi_tags=ControllersTags, ) + app.include_router(SampleController.router()) return app diff --git a/tests/test_inherit.py b/tests/test_inherit.py index e76ebda..e1cd282 100644 --- a/tests/test_inherit.py +++ b/tests/test_inherit.py @@ -92,4 +92,4 @@ def test_invalid(self): class Controller2(Base): ... - self.assertEqual(str(ex.exception), "Every controller needs its own router!") + self.assertEqual(str(ex.exception), "Controller already used by another Resource")