Skip to content

Commit

Permalink
Merge pull request #37 from KenyonY/feat-cache
Browse files Browse the repository at this point in the history
feat: Efficient caching
  • Loading branch information
KenyonY committed May 19, 2024
2 parents 1c8c8f4 + 74a3666 commit 3482096
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 34 deletions.
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .core import LevelDBDict, LMDBDict, RemoteDBDict

__version__ = "0.2.8"
__version__ = "0.2.9-alpha"

__all__ = [
"FlaxKV",
Expand Down
94 changes: 72 additions & 22 deletions flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _init(self):
self._static_view = self._db_manager.new_static_view()

self.buffer_dict = {}
self._stat_buffer_num = 0
self.delete_buffer_set = set()

self._buffered_count = 0
Expand Down Expand Up @@ -289,6 +290,8 @@ def get(self, key: Any, default=None):
"""
with self._buffer_lock:
if key in self.delete_buffer_set:
self.delete_buffer_set.discard(key)
self.buffer_dict[key] = default
return default

if key in self.buffer_dict:
Expand All @@ -297,13 +300,16 @@ def get(self, key: Any, default=None):
if self._cache_all_db:
return self._cache_dict.get(key, default)

key = self._encode_key(key)
value = self._static_view.get(key)
_encode_key = self._encode_key(key)
value = self._static_view.get(_encode_key)

if value is None:
self.buffer_dict[key] = default
return default

return value if self._raw else decode(value)
v = value if self._raw else decode(value)
self.buffer_dict[key] = v
return v

def get_db_value(self, key: str):
"""
Expand Down Expand Up @@ -358,6 +364,7 @@ def _set(self, key, value):
with self._buffer_lock:
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)
self._stat_buffer_num = len(self.buffer_dict)

self._buffered_count += 1
self._last_set_time = time.time()
Expand Down Expand Up @@ -401,7 +408,8 @@ def update(self, d: dict):
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)

self._buffered_count += 1
self._stat_buffer_num = len(self.buffer_dict)
self._buffered_count += len(d)

self._last_set_time = time.time()
# Trigger immediate write if buffer size exceeds MAX_BUFFER_SIZE
Expand Down Expand Up @@ -480,6 +488,7 @@ def _write_buffer_to_db(
f"write {self._db_manager.db_type.upper()} buffer to db successfully! "
f"current_num={current_write_num} latest_num={self._latest_write_num}"
)
self._stat_buffer_num = len(self.buffer_dict)

def __iter__(self):
"""
Expand Down Expand Up @@ -527,6 +536,9 @@ def __delitem__(self, key):
self._last_set_time = time.time()
if key in self.buffer_dict:
del self.buffer_dict[key]
# If it is in the buffer (possibly obtained through get), then _stat_buffer_num -= 1,
# and _stat_buffer_num can be negative
self._stat_buffer_num -= 1
return
else:
if self._cache_all_db:
Expand All @@ -552,6 +564,7 @@ def pop(self, key, default=None):
self._last_set_time = time.time()
if key in self.buffer_dict:
value = self.buffer_dict.pop(key)
self._stat_buffer_num -= 1
if self._raw:
return decode(value)
else:
Expand Down Expand Up @@ -727,6 +740,13 @@ def keys(self, decode_raw=True):
yield d_key
self._db_manager.close_static_view(view)

def to_dict(self, decode_raw=True):
"""
Retrieves all the key-value pairs in the database and buffer.
Returns: dict
"""
return self.db_dict(decode_raw=decode_raw)

def db_dict(self, decode_raw=True):
"""
Retrieves all the key-value pairs in the database and buffer.
Expand Down Expand Up @@ -898,18 +918,26 @@ def set_mapsize(self, map_size):
def stat(self):
if self._cache_all_db:
db_count = len(self._cache_dict)
count = db_count + self._stat_buffer_num
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'lmdb',
}
else:
env = self._db_manager.get_env()
stats = env.stat()
db_count = stats['entries']
buffer_count = len(self.buffer_dict.keys())
count = db_count + buffer_count
return {
'count': count,
'buffer': buffer_count,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
}
count = db_count + self._stat_buffer_num - len(self.delete_buffer_set)
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'lmdb',
}


class LevelDBDict(BaseDBDict):
Expand Down Expand Up @@ -959,10 +987,18 @@ def _iter_db_view(self, view, include_key=True, include_value=True):
yield key_or_value

def stat(self):
buffer_keys = set(self.buffer_dict.keys())

if self._cache_all_db:
db_keys = set(self._cache_dict.keys())
db_count = len(db_keys)
count = db_count + self._stat_buffer_num
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'leveldb',
}
else:
with self._buffer_lock:
view = self._db_manager.new_static_view()
Expand All @@ -971,12 +1007,21 @@ def stat(self):
view.close()

db_count = len(db_keys)
db_valid_keys = db_keys - self.delete_buffer_set
intersection_count = len(buffer_keys.intersection(db_valid_keys))
buffer_count = len(buffer_keys)
count = len(db_valid_keys) + buffer_count - intersection_count

return {'count': count, 'buffer': buffer_count, "db": db_count}
# db_valid_keys = db_keys - self.delete_buffer_set
# buffer_keys = set(self.buffer_dict.keys())
# intersection_count = len(buffer_keys.intersection(db_valid_keys))
# count = len(db_valid_keys) + self._stat_buffer_num - intersection_count
count = db_count + self._stat_buffer_num - len(self.delete_buffer_set)

# db_valid_keys = db_keys.union(buffer_keys) - self.delete_buffer_set
# count = len(db_valid_keys)
return {
'count': count,
'buffer': self._stat_buffer_num,
"db": db_count,
'marked_delete': len(self.delete_buffer_set),
'type': 'leveldb',
}


class RemoteDBDict(BaseDBDict):
Expand Down Expand Up @@ -1154,17 +1199,22 @@ def db_dict(self, decode_raw=True):
def stat(self):
if self._cache_all_db:
db_count = len(self._cache_dict)
buffer_num = self._stat_buffer_num
count = db_count + buffer_num
else:
# fixme:
env = self._db_manager.get_env()
stats = env.stat()
db_count = stats['count']
buffer_count = len(self.buffer_dict.keys())
count = db_count + buffer_count
buffer_num = self._stat_buffer_num
count = db_count + buffer_num - len(self.delete_buffer_set)

return {
'count': count,
'buffer': buffer_count,
'buffer': buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
'type': 'remote',
}

def __repr__(self):
Expand Down
41 changes: 40 additions & 1 deletion flaxkv/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
# limitations under the License.


from __future__ import annotations

import logging
import time
from functools import wraps
from typing import TYPE_CHECKING

from rich import print
from rich.text import Text

from .pack import encode
if TYPE_CHECKING:
from flaxkv import FlaxKV

ENABLED_MEASURE_TIME_DECORATOR = True

Expand Down Expand Up @@ -55,6 +59,8 @@ def wrapper(self, *args, **kwargs):


def msg_encoder(func):
from .pack import encode

@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
Expand Down Expand Up @@ -99,3 +105,36 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def cache(db: FlaxKV = None):
"""Keep a cache of previous function calls."""

if db is None:
db = {}

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if key in db:
return db[key]
result = func(*args, **kwargs)
db[key] = result
return result

return wrapper

return decorator


def singleton(cls):
instances = {}

@wraps(cls)
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]

return get_instance
8 changes: 5 additions & 3 deletions flaxkv/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def check_pandas_type(obj):
return type(obj).__name__ == "DataFrame"


# def check_ext_type(obj):
# return isinstance(obj, (tuple, set))


def encode_hook(obj):
if isinstance(obj, np.ndarray):
return msgspec.msgpack.Ext(
Expand All @@ -54,8 +58,7 @@ def encode_hook(obj):
NPArray(dtype=obj.dtype.str, shape=obj.shape, data=obj.data)
),
)
elif check_pandas_type(obj):
# return msgspec.msgpack.Ext(2, pyarrow.serialize_pandas(obj).to_pybytes())
else:
return msgspec.msgpack.Ext(2, pickle.dumps(obj))
return obj

Expand All @@ -67,7 +70,6 @@ def ext_hook(type, data: memoryview):
serialized_array_rep.data, dtype=serialized_array_rep.dtype
).reshape(serialized_array_rep.shape)
elif type == 2:
# return pyarrow.deserialize_pandas(pyarrow.py_buffer(data.tobytes()))
return pickle.loads(data.tobytes())
return data

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test = [
"litestar>=2.5.0",
"pytest",
"pytest-aiohttp",
"sparrow-python",
"uvicorn",
"httpx[http2]",
"pandas",
Expand Down
Loading

0 comments on commit 3482096

Please sign in to comment.