Skip to content

Commit

Permalink
Pycturator should allow disabling specifc endpoints
Browse files Browse the repository at this point in the history
In order to disable pyctuator endpoints such as `/pyctuator/httptraces`, pyctuator now
can be initialized with flags selecting endpoints that shouldn't be registered in SBA and
should fail if queried (GET) directly.

Fixes #92
  • Loading branch information
michael.yak committed Jan 16, 2024
1 parent bb04e43 commit c05e36d
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 169 deletions.
13 changes: 13 additions & 0 deletions pyctuator/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Flag, auto


class Endpoints(Flag):
NONE = 0
ENV = auto()
INFO = auto()
HEALTH = auto()
METRICS = auto()
LOGGERS = auto()
THREAD_DUMP = auto()
LOGFILE = auto()
HTTP_TRACE = auto()
129 changes: 70 additions & 59 deletions pyctuator/impl/fastapi_pyctuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.requests import Request
from starlette.responses import Response

from pyctuator.endpoints import Endpoints
from pyctuator.environment.environment_provider import EnvironmentData
from pyctuator.httptrace import TraceRecord, TraceRequest, TraceResponse
from pyctuator.httptrace.http_tracer import Traces
Expand All @@ -34,7 +35,8 @@ def __init__(
app: FastAPI,
pyctuator_impl: PyctuatorImpl,
include_in_openapi_schema: bool = False,
customizer: Optional[Callable[[APIRouter], None]] = None
customizer: Optional[Callable[[APIRouter], None]] = None,
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> None:
super().__init__(app, pyctuator_impl)
router = APIRouter()
Expand Down Expand Up @@ -64,70 +66,79 @@ def options() -> None:
documentation.
"""

@router.get("/env", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_environment() -> EnvironmentData:
return pyctuator_impl.get_environment()
if Endpoints.ENV not in disabled_endpoints:
@router.get("/env", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_environment() -> EnvironmentData:
return pyctuator_impl.get_environment()

@router.get("/info", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_info() -> Dict:
return pyctuator_impl.get_app_info()
if Endpoints.INFO not in disabled_endpoints:
@router.get("/info", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_info() -> Dict:
return pyctuator_impl.get_app_info()

@router.get("/health", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_health(response: Response) -> object:
health = pyctuator_impl.get_health()
response.status_code = health.http_status()
return health
if Endpoints.HEALTH not in disabled_endpoints:
@router.get("/health", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_health(response: Response) -> object:
health = pyctuator_impl.get_health()
response.status_code = health.http_status()
return health

@router.get("/metrics", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_metric_names() -> MetricNames:
return pyctuator_impl.get_metric_names()
if Endpoints.METRICS not in disabled_endpoints:
@router.get("/metrics", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_metric_names() -> MetricNames:
return pyctuator_impl.get_metric_names()

@router.get("/metrics/{metric_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_metric_measurement(metric_name: str) -> Metric:
return pyctuator_impl.get_metric_measurement(metric_name)
@router.get("/metrics/{metric_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_metric_measurement(metric_name: str) -> Metric:
return pyctuator_impl.get_metric_measurement(metric_name)

# Retrieving All Loggers
@router.get("/loggers", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_loggers() -> LoggersData:
return pyctuator_impl.logging.get_loggers()

@router.post("/loggers/{logger_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def set_logger_level(item: FastApiLoggerItem, logger_name: str) -> Dict:
pyctuator_impl.logging.set_logger_level(logger_name, item.configuredLevel)
return {}

@router.get("/loggers/{logger_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_logger(logger_name: str) -> LoggerLevels:
return pyctuator_impl.logging.get_logger(logger_name)

@router.get("/dump", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
@router.get("/threaddump", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_thread_dump() -> ThreadDump:
return pyctuator_impl.get_thread_dump()

@router.get("/logfile", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_logfile(range_header: str = Header(default=None,
alias="range")) -> Response: # pylint: disable=redefined-builtin
if not range_header:
return Response(content=pyctuator_impl.logfile.log_messages.get_range())

str_res, start, end = pyctuator_impl.logfile.get_logfile(range_header)

my_res = Response(
status_code=HTTPStatus.PARTIAL_CONTENT.value,
content=str_res,
headers={
"Content-Type": "text/html; charset=UTF-8",
"Accept-Ranges": "bytes",
"Content-Range": f"bytes {start}-{end}/{end}",
})

return my_res

@router.get("/trace", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
@router.get("/httptrace", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_httptrace() -> Traces:
return pyctuator_impl.http_tracer.get_httptrace()
if Endpoints.LOGGERS not in disabled_endpoints:
@router.get("/loggers", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_loggers() -> LoggersData:
return pyctuator_impl.logging.get_loggers()

if Endpoints.LOGGERS not in disabled_endpoints:
@router.post("/loggers/{logger_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def set_logger_level(item: FastApiLoggerItem, logger_name: str) -> Dict:
pyctuator_impl.logging.set_logger_level(logger_name, item.configuredLevel)
return {}

@router.get("/loggers/{logger_name}", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_logger(logger_name: str) -> LoggerLevels:
return pyctuator_impl.logging.get_logger(logger_name)

if Endpoints.THREAD_DUMP not in disabled_endpoints:
@router.get("/dump", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
@router.get("/threaddump", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_thread_dump() -> ThreadDump:
return pyctuator_impl.get_thread_dump()

if Endpoints.LOGFILE not in disabled_endpoints:
@router.get("/logfile", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_logfile(range_header: str = Header(default=None,
alias="range")) -> Response: # pylint: disable=redefined-builtin
if not range_header:
return Response(content=pyctuator_impl.logfile.log_messages.get_range())

str_res, start, end = pyctuator_impl.logfile.get_logfile(range_header)

my_res = Response(
status_code=HTTPStatus.PARTIAL_CONTENT.value,
content=str_res,
headers={
"Content-Type": "text/html; charset=UTF-8",
"Accept-Ranges": "bytes",
"Content-Range": f"bytes {start}-{end}/{end}",
})

return my_res

if Endpoints.HTTP_TRACE not in disabled_endpoints:
@router.get("/trace", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
@router.get("/httptrace", include_in_schema=include_in_openapi_schema, tags=["pyctuator"])
def get_httptrace() -> Traces:
return pyctuator_impl.http_tracer.get_httptrace()

@app.middleware("http")
async def intercept_requests_and_responses(
Expand Down
3 changes: 3 additions & 0 deletions pyctuator/impl/pyctuator_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Dict, Mapping, Optional, Callable
from urllib.parse import urlparse

from pyctuator.endpoints import Endpoints
from pyctuator.environment.environment_provider import EnvironmentData, EnvironmentProvider
from pyctuator.environment.scrubber import SecretScrubber
from pyctuator.health.health_provider import HealthStatus, HealthSummary, Status, HealthProvider
Expand Down Expand Up @@ -57,10 +58,12 @@ def __init__(
logfile_max_size: int,
logfile_formatter: str,
additional_app_info: Optional[dict],
disabled_endpoints: Endpoints
):
self.app_info = app_info
self.pyctuator_endpoint_url = pyctuator_endpoint_url
self.additional_app_info = additional_app_info
self.disabled_endpoints = disabled_endpoints

self.metrics_providers: List[MetricsProvider] = []
self.health_providers: List[HealthProvider] = []
Expand Down
53 changes: 22 additions & 31 deletions pyctuator/impl/pyctuator_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from dataclasses import dataclass
from typing import Any
from typing import Any, Optional, Mapping

from pyctuator.endpoints import Endpoints
from pyctuator.impl.pyctuator_impl import PyctuatorImpl


Expand All @@ -11,25 +12,9 @@ class LinkHref:
templated: bool


# mypy: ignore_errors
# pylint: disable=too-many-instance-attributes
@dataclass
class EndpointsLinks:
self: LinkHref
env: LinkHref
info: LinkHref
health: LinkHref
metrics: LinkHref
loggers: LinkHref
dump: LinkHref
threaddump: LinkHref
logfile: LinkHref
httptrace: LinkHref


@dataclass
class EndpointsData:
_links: EndpointsLinks
_links: Mapping[str, LinkHref]


class PyctuatorRouter(ABC):
Expand All @@ -45,16 +30,22 @@ def __init__(
def get_endpoints_data(self) -> EndpointsData:
return EndpointsData(self.get_endpoints_links())

def get_endpoints_links(self):
return EndpointsLinks(
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url, False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/env", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/info", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/health", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/metrics", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/loggers", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/dump", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/threaddump", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/logfile", False),
LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + "/httptrace", False),
)
def get_endpoints_links(self) -> Mapping[str, LinkHref]:
def link_href(endpoint: Endpoints, path: str) -> Optional[LinkHref]:
return None if endpoint in self.pyctuator_impl.disabled_endpoints \
else LinkHref(self.pyctuator_impl.pyctuator_endpoint_url + path, False)

endpoints = {
"self": LinkHref(self.pyctuator_impl.pyctuator_endpoint_url, False),
"env": link_href(Endpoints.ENV, "/env"),
"info": link_href(Endpoints.INFO, "/info"),
"health": link_href(Endpoints.HEALTH, "/health"),
"metrics": link_href(Endpoints.METRICS, "/metrics"),
"loggers": link_href(Endpoints.LOGGERS, "/loggers"),
"dump": link_href(Endpoints.THREAD_DUMP, "/dump"),
"threaddump": link_href(Endpoints.THREAD_DUMP, "/threaddump"),
"logfile": link_href(Endpoints.LOGFILE, "/logfile"),
"httptrace": link_href(Endpoints.HTTP_TRACE, "/httptrace"),
}

return {endpoint: link_href for (endpoint, link_href) in endpoints.items() if link_href is not None}
37 changes: 30 additions & 7 deletions pyctuator/pyctuator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# To do that, all imports are in conditional branches after detecting which frameworks are installed.
# DO NOT add any web-framework-dependent imports to the global scope.
from pyctuator.auth import Auth
from pyctuator.endpoints import Endpoints
from pyctuator.environment.custom_environment_provider import CustomEnvironmentProvider
from pyctuator.environment.os_env_variables_impl import OsEnvironmentVariableProvider
from pyctuator.health.diskspace_health_impl import DiskSpaceHealthProvider
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
additional_app_info: Optional[dict] = None,
ssl_context: Optional[ssl.SSLContext] = None,
customizer: Optional[Callable] = None,
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> None:
"""The entry point for integrating pyctuator with a web-frameworks such as FastAPI and Flask.
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
:param customizer: a function that can customize the integration with the web-framework which is therefore web-
framework specific. For FastAPI, the function receives pyctuator's APIRouter allowing to add "dependencies" and
anything else that's provided by the router. See fastapi_with_authentication_example_app.py
:param disabled_endpoints: optional set of endpoints (such as /pyctuator/health) that should be disabled
"""

self.auto_deregister = auto_deregister
Expand All @@ -96,6 +99,7 @@ def __init__(
logfile_max_size,
logfile_formatter,
additional_app_info,
disabled_endpoints,
)

# Register default health/metrics/environment providers
Expand All @@ -118,7 +122,7 @@ def __init__(
root_logger.addHandler(self.pyctuator_impl.logfile.log_messages)

# Find and initialize an integration layer between the web-framework adn pyctuator
framework_integrations: Dict[str, Callable[[Any, PyctuatorImpl, Optional[Callable]], bool]] = {
framework_integrations: Dict[str, Callable[[Any, PyctuatorImpl, Optional[Callable], Endpoints], bool]] = {
"flask": self._integrate_flask,
"fastapi": self._integrate_fastapi,
"aiohttp": self._integrate_aiohttp,
Expand All @@ -127,7 +131,7 @@ def __init__(
for framework_name, framework_integration_function in framework_integrations.items():
if self._is_framework_installed(framework_name):
logging.debug("Framework %s is installed, trying to integrate with it", framework_name)
success = framework_integration_function(app, self.pyctuator_impl, customizer)
success = framework_integration_function(app, self.pyctuator_impl, customizer, disabled_endpoints)
if success:
logging.debug("Integrated with framework %s", framework_name)
if registration_url is not None:
Expand Down Expand Up @@ -189,7 +193,8 @@ def _integrate_fastapi(
self,
app: Any,
pyctuator_impl: PyctuatorImpl,
customizer: Optional[Callable]
customizer: Optional[Callable],
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> bool:
"""
This method should only be called if we detected that FastAPI is installed.
Expand All @@ -199,12 +204,18 @@ def _integrate_fastapi(
from fastapi import FastAPI
if isinstance(app, FastAPI):
from pyctuator.impl.fastapi_pyctuator import FastApiPyctuator
FastApiPyctuator(app, pyctuator_impl, False, customizer)
FastApiPyctuator(app, pyctuator_impl, False, customizer, disabled_endpoints)
return True
return False

# pylint: disable=unused-argument
def _integrate_flask(self, app: Any, pyctuator_impl: PyctuatorImpl, customizer: Optional[Callable]) -> bool:
def _integrate_flask(
self,
app: Any,
pyctuator_impl: PyctuatorImpl,
customizer: Optional[Callable],
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> bool:
"""
This method should only be called if we detected that Flask is installed.
It will then check whether the given app is a Flask app, and if so - it will add the Pyctuator
Expand All @@ -218,7 +229,13 @@ def _integrate_flask(self, app: Any, pyctuator_impl: PyctuatorImpl, customizer:
return False

# pylint: disable=unused-argument
def _integrate_aiohttp(self, app: Any, pyctuator_impl: PyctuatorImpl, customizer: Optional[Callable]) -> bool:
def _integrate_aiohttp(
self,
app: Any,
pyctuator_impl: PyctuatorImpl,
customizer: Optional[Callable],
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> bool:
"""
This method should only be called if we detected that aiohttp is installed.
It will then check whether the given app is a aiohttp app, and if so - it will add the Pyctuator
Expand All @@ -232,7 +249,13 @@ def _integrate_aiohttp(self, app: Any, pyctuator_impl: PyctuatorImpl, customizer
return False

# pylint: disable=unused-argument
def _integrate_tornado(self, app: Any, pyctuator_impl: PyctuatorImpl, customizer: Optional[Callable]) -> bool:
def _integrate_tornado(
self,
app: Any,
pyctuator_impl: PyctuatorImpl,
customizer: Optional[Callable],
disabled_endpoints: Endpoints = Endpoints.NONE,
) -> bool:
"""
This method should only be called if we detected that tornado is installed.
It will then check whether the given app is a tornado app, and if so - it will add the Pyctuator
Expand Down
4 changes: 3 additions & 1 deletion tests/aiohttp_test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aiohttp import web

from pyctuator.endpoints import Endpoints
from pyctuator.pyctuator import Pyctuator
from tests.conftest import PyctuatorServer

Expand All @@ -15,7 +16,7 @@
# pylint: disable=unused-variable
class AiohttpPyctuatorServer(PyctuatorServer):

def __init__(self) -> None:
def __init__(self, disabled_endpoints: Endpoints = Endpoints.NONE) -> None:
global bind_port
self.port = bind_port
bind_port += 1
Expand All @@ -32,6 +33,7 @@ def __init__(self) -> None:
registration_interval_sec=1,
metadata=self.metadata,
additional_app_info=self.additional_app_info,
disabled_endpoints=disabled_endpoints,
)

@self.routes.get("/logfile_test_repeater")
Expand Down

0 comments on commit c05e36d

Please sign in to comment.