Skip to content

Commit

Permalink
Revert "Expire session storage cache on an async timer (streamlit#8083)…
Browse files Browse the repository at this point in the history
…" (streamlit#8281)

## Describe your changes

This reverts commit 44227ad.

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
kmcgrady committed Mar 12, 2024
1 parent e1960fb commit b90d822
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 69 deletions.
4 changes: 2 additions & 2 deletions lib/streamlit/runtime/caching/cache_resource_api.py
Expand Up @@ -22,6 +22,7 @@
from datetime import timedelta
from typing import Any, Callable, Final, TypeVar, cast, overload

from cachetools import TTLCache
from typing_extensions import TypeAlias

import streamlit as st
Expand All @@ -47,7 +48,6 @@
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
from streamlit.util import TimedCleanupCache

_LOGGER: Final = get_logger(__name__)

Expand Down Expand Up @@ -472,7 +472,7 @@ def __init__(
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TimedCleanupCache[str, MultiCacheResults] = TimedCleanupCache(
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
Expand Down
Expand Up @@ -16,6 +16,8 @@
import math
import threading

from cachetools import TTLCache

from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.storage.cache_storage_protocol import (
Expand All @@ -24,7 +26,6 @@
CacheStorageKeyNotFoundError,
)
from streamlit.runtime.stats import CacheStat
from streamlit.util import TimedCleanupCache

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
self.function_display_name = context.function_display_name
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
self._mem_cache: TimedCleanupCache[str, bytes] = TimedCleanupCache(
self._mem_cache: TTLCache[str, bytes] = TTLCache(
maxsize=self.max_entries,
ttl=self.ttl_seconds,
timer=cache_utils.TTLCACHE_TIMER,
Expand Down
5 changes: 3 additions & 2 deletions lib/streamlit/runtime/memory_session_storage.py
Expand Up @@ -16,8 +16,9 @@

from typing import MutableMapping

from cachetools import TTLCache

from streamlit.runtime.session_manager import SessionInfo, SessionStorage
from streamlit.util import TimedCleanupCache


class MemorySessionStorage(SessionStorage):
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
inaccessible and will be removed eventually.
"""

self._cache: MutableMapping[str, SessionInfo] = TimedCleanupCache(
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
maxsize=maxsize, ttl=ttl_seconds
)

Expand Down
39 changes: 1 addition & 38 deletions lib/streamlit/util.py
Expand Up @@ -16,16 +16,13 @@

from __future__ import annotations

import asyncio
import dataclasses
import functools
import hashlib
import os
import subprocess
import sys
from typing import Any, Callable, Final, Generic, Iterable, Mapping, TypeVar

from cachetools import TTLCache
from typing import Any, Callable, Final, Iterable, Mapping, TypeVar

from streamlit import env_util

Expand Down Expand Up @@ -202,37 +199,3 @@ def extract_key_query_params(
]
for item in sublist
}


K = TypeVar("K")
V = TypeVar("V")


class TimedCleanupCache(TTLCache, Generic[K, V]):
"""A TTLCache that asynchronously expires its entries."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._task: asyncio.Task[Any] | None = None

def __setitem__(self, key: K, value: V) -> None:
# Set an expiration task to run periodically
# Can't be created in init because that only runs once and
# the event loop might not exist yet.
if self._task is None:
try:
self._task = asyncio.create_task(expire_cache(self))
except RuntimeError:
# Just continue if the event loop isn't started yet.
pass
super().__setitem__(key, value)

def __del__(self):
if self._task is not None:
self._task.cancel()


async def expire_cache(cache: TTLCache) -> None:
while True:
await asyncio.sleep(30)
cache.expire()
25 changes: 0 additions & 25 deletions lib/tests/streamlit/util_test.py
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import gc
import random
import unittest
from typing import Dict, List, Set
Expand Down Expand Up @@ -187,26 +185,3 @@ def test_calc_md5_can_handle_bytes_and_strings(self):
util.calc_md5("eventually bytes"),
util.calc_md5("eventually bytes".encode("utf-8")),
)

def test_timed_cleanup_cache_gc(self):
"""Test that the TimedCleanupCache does not leave behind tasks when
the cache is not externally reachable"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def create_cache():
cache = util.TimedCleanupCache(maxsize=2, ttl=10)
cache["foo"] = "bar"

# expire_cache and create_cache
assert len(asyncio.all_tasks()) > 1

asyncio.run(create_cache())

gc.collect()

async def check():
# Only has this function running
assert len(asyncio.all_tasks()) == 1

asyncio.run(check())

0 comments on commit b90d822

Please sign in to comment.