From b14e53e492c5d77309d9a1ccb0080e79785b3398 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Mon, 27 Feb 2023 15:59:11 +0000 Subject: [PATCH] fix: memoized decorator memory leak (#23139) (cherry picked from commit 79274eb5bca7c123842b08e075572d14f34cb5a3) --- superset/connectors/sqla/utils.py | 7 +- superset/constants.py | 2 + superset/jinja_context.py | 8 +- ...-24_620241d1153f_update_time_grain_sqla.py | 2 - superset/models/core.py | 7 +- superset/models/datasource_access_request.py | 2 - superset/models/slice.py | 11 ++- superset/utils/date_parser.py | 6 +- superset/utils/memoized.py | 81 ---------------- tests/integration_tests/utils_tests.py | 1 + tests/unit_tests/memoized_tests.py | 96 ------------------- tests/unit_tests/models/core_test.py | 4 +- 12 files changed, 24 insertions(+), 203 deletions(-) delete mode 100644 superset/utils/memoized.py delete mode 100644 tests/unit_tests/memoized_tests.py diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index e3745dac2a2b..4cf20e5511ad 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -17,6 +17,7 @@ from __future__ import annotations import logging +from functools import lru_cache from typing import ( Any, Callable, @@ -40,6 +41,7 @@ from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.sql.type_api import TypeEngine +from superset.constants import LRU_CACHE_MAX_SIZE from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( SupersetGenericDBErrorException, @@ -49,7 +51,6 @@ from superset.result_set import SupersetResultSet from superset.sql_parse import has_table_query, insert_rls, ParsedQuery from superset.superset_typing import ResultSetColumnType -from superset.utils.memoized import memoized if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -200,12 +201,12 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) -@memoized +@lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def get_dialect_name(drivername: str) -> str: return SqlaURL.create(drivername).get_dialect().name -@memoized +@lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote diff --git a/superset/constants.py b/superset/constants.py index 5c1f0e36fe26..cdbce050d360 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -37,6 +37,8 @@ QUERY_CANCEL_KEY = "cancel_query" QUERY_EARLY_CANCEL_KEY = "early_cancel_query" +LRU_CACHE_MAX_SIZE = 256 + class RouteMethod: # pylint: disable=too-few-public-methods """ diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 823c67451beb..d9409e297b5f 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -17,7 +17,7 @@ """Defines the templating context for SQL Lab""" import json import re -from functools import partial +from functools import lru_cache, partial from typing import ( Any, Callable, @@ -38,6 +38,7 @@ from sqlalchemy.types import String from typing_extensions import TypedDict +from superset.constants import LRU_CACHE_MAX_SIZE from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.exceptions import SupersetTemplateException from superset.extensions import feature_flag_manager @@ -46,7 +47,6 @@ get_user_id, merge_extra_filters, ) -from superset.utils.memoized import memoized if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -70,7 +70,7 @@ COLLECTION_TYPES = ("list", "dict", "tuple", "set") -@memoized +@lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def context_addons() -> Dict[str, Any]: return current_app.config.get("JINJA_CONTEXT_ADDONS", {}) @@ -602,7 +602,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str: } -@memoized +@lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def get_template_processors() -> Dict[str, Any]: processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {}) for engine, processor in DEFAULT_PROCESSORS.items(): diff --git a/superset/migrations/versions/2020-04-29_09-24_620241d1153f_update_time_grain_sqla.py b/superset/migrations/versions/2020-04-29_09-24_620241d1153f_update_time_grain_sqla.py index 97bea8f9d142..29735facbe87 100644 --- a/superset/migrations/versions/2020-04-29_09-24_620241d1153f_update_time_grain_sqla.py +++ b/superset/migrations/versions/2020-04-29_09-24_620241d1153f_update_time_grain_sqla.py @@ -34,7 +34,6 @@ from superset import db, db_engine_specs from superset.databases.utils import make_url_safe -from superset.utils.memoized import memoized Base = declarative_base() @@ -70,7 +69,6 @@ class Slice(Base): datasource_id = Column(Integer) -@memoized def duration_by_name(database: Database): return {grain.name: grain.duration for grain in database.grains()} diff --git a/superset/models/core.py b/superset/models/core.py index 4186a9a086dd..ac7cc517ef0d 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -24,6 +24,7 @@ from contextlib import closing, contextmanager, nullcontext from copy import deepcopy from datetime import datetime +from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import numpy @@ -54,7 +55,7 @@ from sqlalchemy.sql import expression, Select from superset import app, db_engine_specs -from superset.constants import PASSWORD_MASK +from superset.constants import LRU_CACHE_MAX_SIZE, PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain from superset.extensions import ( @@ -67,7 +68,6 @@ from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils from superset.utils.core import get_username -from superset.utils.memoized import memoized config = app.config custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] @@ -723,7 +723,7 @@ def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: return self.get_db_engine_spec(url) @classmethod - @memoized + @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: backend = url.get_backend_name() try: @@ -897,7 +897,6 @@ def has_view(self, view_name: str, schema: Optional[str] = None) -> bool: def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool: return self.has_view(view_name=view_name, schema=schema) - @memoized def get_dialect(self) -> Dialect: sqla_url = make_url_safe(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() diff --git a/superset/models/datasource_access_request.py b/superset/models/datasource_access_request.py index 60bfe0823828..1f286f96d8b4 100644 --- a/superset/models/datasource_access_request.py +++ b/superset/models/datasource_access_request.py @@ -22,7 +22,6 @@ from superset import app, db, security_manager from superset.models.helpers import AuditMixinNullable -from superset.utils.memoized import memoized if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -57,7 +56,6 @@ def datasource(self) -> "BaseDatasource": return self.get_datasource @datasource.getter # type: ignore - @memoized def get_datasource(self) -> "BaseDatasource": ds = db.session.query(self.cls_model).filter_by(id=self.datasource_id).first() return ds diff --git a/superset/models/slice.py b/superset/models/slice.py index 332d51d1af93..54429133d3c2 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -46,7 +46,6 @@ from superset.tasks.utils import get_current_user from superset.thumbnails.digest import get_chart_digest from superset.utils import core as utils -from superset.utils.memoized import memoized from superset.viz import BaseViz, viz_types if TYPE_CHECKING: @@ -151,9 +150,12 @@ def clone(self) -> "Slice": # pylint: disable=using-constant-test @datasource.getter # type: ignore - @memoized def get_datasource(self) -> Optional["BaseDatasource"]: - return db.session.query(self.cls_model).filter_by(id=self.datasource_id).first() + return ( + db.session.query(self.cls_model) + .filter_by(id=self.datasource_id) + .one_or_none() + ) @renders("datasource_name") def datasource_link(self) -> Optional[Markup]: @@ -189,8 +191,7 @@ def datasource_edit_url(self) -> Optional[str]: # pylint: enable=using-constant-test - @property # type: ignore - @memoized + @property def viz(self) -> Optional[BaseViz]: form_data = json.loads(self.params) viz_class = viz_types.get(self.viz_type) diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 7e79c72f1eb7..65938f8c0584 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -18,6 +18,7 @@ import logging import re from datetime import datetime, timedelta +from functools import lru_cache from time import struct_time from typing import Dict, List, Optional, Tuple @@ -45,8 +46,7 @@ TimeRangeAmbiguousError, TimeRangeParseFailError, ) -from superset.constants import NO_TIME_RANGE -from superset.utils.memoized import memoized +from superset.constants import LRU_CACHE_MAX_SIZE, NO_TIME_RANGE ParserElement.enablePackrat() @@ -394,7 +394,7 @@ def eval(self) -> datetime: ) -@memoized +@lru_cache(maxsize=LRU_CACHE_MAX_SIZE) def datetime_parser() -> ParseResults: # pylint: disable=too-many-locals ( # pylint: disable=invalid-name DATETIME, diff --git a/superset/utils/memoized.py b/superset/utils/memoized.py deleted file mode 100644 index 153542fbb7b1..000000000000 --- a/superset/utils/memoized.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import functools -from typing import Any, Callable, Dict, Optional, Tuple, Type - - -class _memoized: - """Decorator that caches a function's return value each time it is called - - If called later with the same arguments, the cached value is returned, and - not re-evaluated. - - Define ``watch`` as a tuple of attribute names if this Decorator - should account for instance variable changes. - """ - - def __init__( - self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None - ) -> None: - self.func = func - self.cache: Dict[Any, Any] = {} - self.is_method = False - self.watch = watch or () - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - key = [args, frozenset(kwargs.items())] - if self.is_method: - key.append(tuple(getattr(args[0], v, None) for v in self.watch)) - key = tuple(key) # type: ignore - try: - if key in self.cache: - return self.cache[key] - except TypeError as ex: - # Uncachable -- for instance, passing a list as an argument. - raise TypeError("Function cannot be memoized") from ex - value = self.func(*args, **kwargs) - try: - self.cache[key] = value - except TypeError as ex: - raise TypeError("Function cannot be memoized") from ex - return value - - def __repr__(self) -> str: - """Return the function's docstring.""" - return self.func.__doc__ or "" - - def __get__( - self, obj: Any, objtype: Type[Any] - ) -> functools.partial: # type: ignore - if not self.is_method: - self.is_method = True - # Support instance methods. - func = functools.partial(self.__call__, obj) - func.__func__ = self.func # type: ignore - return func - - -def memoized( - func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None -) -> Callable[..., Any]: - if func: - return _memoized(func) - - def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: - return _memoized(f, watch) - - return wrapper diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 967a4e9388cf..e27ad6ec3c5e 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -991,6 +991,7 @@ def test_log_this(self) -> None: slc = self.get_slice("Girls", db.session) dashboard_id = 1 + assert slc.viz is not None resp = self.get_json_resp( f"/superset/explore_json/{slc.datasource_type}/{slc.datasource_id}/" + f'?form_data={{"slice_id": {slc.id}}}&dashboard_id={dashboard_id}', diff --git a/tests/unit_tests/memoized_tests.py b/tests/unit_tests/memoized_tests.py deleted file mode 100644 index 3b3f436606f5..000000000000 --- a/tests/unit_tests/memoized_tests.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from pytest import mark - -from superset.utils.memoized import memoized - - -@mark.unittest -class TestMemoized: - def test_memoized_on_functions(self): - watcher = {"val": 0} - - @memoized - def test_function(a, b, c): - watcher["val"] += 1 - return a * b * c - - result1 = test_function(1, 2, 3) - result2 = test_function(1, 2, 3) - assert result1 == result2 - assert watcher["val"] == 1 - - def test_memoized_on_methods(self): - class test_class: - def __init__(self, num): - self.num = num - self.watcher = 0 - - @memoized - def test_method(self, a, b, c): - self.watcher += 1 - return a * b * c * self.num - - instance = test_class(5) - result1 = instance.test_method(1, 2, 3) - result2 = instance.test_method(1, 2, 3) - assert result1 == result2 - assert instance.watcher == 1 - instance.num = 10 - assert result2 == instance.test_method(1, 2, 3) - - def test_memoized_on_methods_with_watches(self): - class test_class: - def __init__(self, x, y): - self.x = x - self.y = y - self.watcher = 0 - - @memoized(watch=("x", "y")) - def test_method(self, a, b, c): - self.watcher += 1 - return a * b * c * self.x * self.y - - instance = test_class(3, 12) - result1 = instance.test_method(1, 2, 3) - result2 = instance.test_method(1, 2, 3) - assert result1 == result2 - assert instance.watcher == 1 - result3 = instance.test_method(2, 3, 4) - assert instance.watcher == 2 - result4 = instance.test_method(2, 3, 4) - assert instance.watcher == 2 - assert result3 == result4 - assert result3 != result1 - instance.x = 1 - result5 = instance.test_method(2, 3, 4) - assert instance.watcher == 3 - assert result5 != result4 - result6 = instance.test_method(2, 3, 4) - assert instance.watcher == 3 - assert result6 == result5 - instance.x = 10 - instance.y = 10 - result7 = instance.test_method(2, 3, 4) - assert instance.watcher == 4 - assert result7 != result6 - instance.x = 3 - instance.y = 12 - result8 = instance.test_method(1, 2, 3) - assert instance.watcher == 4 - assert result1 == result8 diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 5eb60dc6f93e..f8534391d837 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -59,9 +59,7 @@ def get_metrics( }, ] - database.get_db_engine_spec = mocker.MagicMock( # type: ignore - return_value=CustomSqliteEngineSpec - ) + database.get_db_engine_spec = mocker.MagicMock(return_value=CustomSqliteEngineSpec) assert database.get_metrics("table") == [ { "expression": "COUNT(DISTINCT user_id)",