From b211131dde0fd8bc9cc6ab7b6fdf4c883c0d148f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 6 Jul 2023 20:37:16 +0200 Subject: [PATCH] fix: refactor Flask init code --- aw_server/rest.py | 16 ---------- aw_server/server.py | 76 ++++++++++++++++++++++++++------------------- tests/conftest.py | 4 +-- 3 files changed, 46 insertions(+), 50 deletions(-) diff --git a/aw_server/rest.py b/aw_server/rest.py index 80d5287f..10dd9652 100644 --- a/aw_server/rest.py +++ b/aw_server/rest.py @@ -1,11 +1,9 @@ import json import traceback -from datetime import datetime, timedelta from functools import wraps from threading import Lock from typing import Dict -import flask.json.provider import iso8601 from aw_core import schema from aw_core.models import Event @@ -53,20 +51,6 @@ def decorator(*args, **kwargs): api = Api(blueprint, doc="/", decorators=[host_header_check]) -# TODO: Clean up JSONEncoder code? -# Move to server.py -class CustomJSONProvider(flask.json.provider.DefaultJSONProvider): - def default(self, obj, *args, **kwargs): - try: - if isinstance(obj, datetime): - return obj.isoformat() - if isinstance(obj, timedelta): - return obj.total_seconds() - except TypeError: - pass - return super().default(obj) - - # Loads event and bucket schema from JSONSchema in aw_core event = api.schema_model("Event", schema.get_json_schema("event")) bucket = api.schema_model("Bucket", schema.get_json_schema("bucket")) diff --git a/aw_server/server.py b/aw_server/server.py index 83766c06..50863b6a 100644 --- a/aw_server/server.py +++ b/aw_server/server.py @@ -1,8 +1,10 @@ import logging import os +from datetime import datetime, timedelta from typing import Dict, List import aw_datastore +import flask.json.provider from aw_datastore import Datastore from flask import ( Blueprint, @@ -26,40 +28,50 @@ class AWFlask(Flask): - def __init__(self, name, testing: bool, *args, **kwargs): - self.json_provider_class = rest.CustomJSONProvider - - # Only pretty-print JSON if in testing mode (because of performance) + def __init__( + self, + host: str, + testing: bool, + storage_method=None, + cors_origins=[], + custom_static=dict(), + *args, + **kwargs + ): + name = "aw-server" + self.json_provider_class = CustomJSONProvider + # only prettyprint JSON if testing (due to perf) self.json_provider_class.compact = not testing # Initialize Flask Flask.__init__(self, name, *args, **kwargs) - - # Is set on later initialization - self.api: ServerAPI = None # type: ignore - - -def create_app( - host: str, testing=True, storage_method=None, cors_origins=[], custom_static=dict() -) -> AWFlask: - app = AWFlask("aw-server", testing, static_folder=static_folder, static_url_path="") - - with app.app_context(): - _config_cors(cors_origins, testing) - - app.register_blueprint(root) - app.register_blueprint(rest.blueprint) - app.register_blueprint(get_custom_static_blueprint(custom_static)) - - if storage_method is None: - storage_method = aw_datastore.get_storage_methods()["memory"] - db = Datastore(storage_method, testing=testing) - app.api = ServerAPI(db=db, testing=testing) - - # needed for host-header check - app.config["HOST"] = host - - return app + self.config["HOST"] = host # needed for host-header check + with self.app_context(): + _config_cors(cors_origins, testing) + + # Initialize datastore and API + if storage_method is None: + storage_method = aw_datastore.get_storage_methods()["memory"] + db = Datastore(storage_method, testing=testing) + self.api = ServerAPI(db=db, testing=testing) + + self.register_blueprint(root) + self.register_blueprint(rest.blueprint) + self.register_blueprint(get_custom_static_blueprint(custom_static)) + + +class CustomJSONProvider(flask.json.provider.DefaultJSONProvider): + # encoding/decoding of datetime as iso8601 strings + # encoding of timedelta as second floats + def default(self, obj, *args, **kwargs): + try: + if isinstance(obj, datetime): + return obj.isoformat() + if isinstance(obj, timedelta): + return obj.total_seconds() + except TypeError: + pass + return super().default(obj) @root.route("/") @@ -105,10 +117,10 @@ def _start( cors_origins: List[str] = [], custom_static: Dict[str, str] = dict(), ): - app = create_app( + app = AWFlask( host, - storage_method=storage_method, testing=testing, + storage_method=storage_method, cors_origins=cors_origins, custom_static=custom_static, ) diff --git a/tests/conftest.py b/tests/conftest.py index 8849fbda..141bab1d 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,14 @@ import logging import pytest -from aw_server.server import create_app +from aw_server.server import AWFlask logging.basicConfig(level=logging.WARN) @pytest.fixture(scope="session") def app(): - return create_app("127.0.0.1", testing=True) + return AWFlask("127.0.0.1", testing=True) @pytest.fixture(scope="session")