Skip to content

Commit

Permalink
Merge pull request #33 from KenyonY/remote-db-sync
Browse files Browse the repository at this point in the history
feat: Implement cache synchronization among multiple clients
  • Loading branch information
KenyonY committed Jan 21, 2024
2 parents 256dbed + 2a046a5 commit 408c473
Show file tree
Hide file tree
Showing 16 changed files with 709 additions and 369 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ Very suitable for scenarios that require high-frequency insertion/update of data
- **Machine Learning:**
`flaxkv` is very suitable for saving various large datasets of embeddings, images, texts, and other key-value structures in machine learning.

### Limitations
* In the current version, due to the delayed writing feature, in a multi-process environment,
one process cannot read the data written by another process in real-time (usually delayed by a few seconds).
If immediate writing is desired, the .write_immediately() method must be called.
This limitation does not exist in a single-process environment.
* By default, the value does not support the `Tuple`, `Set` types. If these types are forcibly set, they will be deserialized into a `List`.

## Citation
If `FlaxKV` has been helpful to your research, please cite:
```bibtex
Expand Down
3 changes: 3 additions & 0 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ pytest -s -v run.py
- **机器学习**
适用于保存机器学习中的各种嵌入向量、图像、文本和其它键-值结构的大型数据集。

### 限制
* 当前版本下,由于延迟写入的特性,在多进程中使用中 一个进程无法实时读取到另一进程写入的数据 (一般会延后几秒钟),若希望即刻写入需调用`.write_immediately()`方法。这个限制在单进程下不存在。
* 值(value) 默认不支持`Tuple`,`Set`类型,若强行set,这些类型会被反序列化为`List`.

## 引用
如果`FlaxKV`对你的研究有帮助,欢迎引用:
Expand Down
12 changes: 6 additions & 6 deletions benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import random
import shutil
import subprocess
import time

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -76,7 +75,7 @@ def process_result():
"Sqlite3",
# "flaxkv-LMDB",
"flaxkv-LevelDB",
# "flaxkv-REMOTE",
"flaxkv-REMOTE",
]
)
def temp_db(request):
Expand Down Expand Up @@ -122,12 +121,13 @@ def benchmark(db, db_name, n=200):
db.write_immediately(block=True)

mt.start()
keys = []
for key in db.keys():
...
keys.append(key)
mt.show_interval(f"{db_name} read (keys only)")
idx = 0
for key, value in db.items():
idx += 1

for key in keys:
value = db[key]
read_cost = float(mt.show_interval(f"{db_name} read (traverse elements) "))
print("--------------------------")
return write_cost, read_cost
Expand Down
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .core import LevelDBDict, LMDBDict, RemoteDBDict

__version__ = "0.2.6"
__version__ = "0.2.7-alpha"

__all__ = [
"FlaxKV",
Expand Down
51 changes: 37 additions & 14 deletions flaxkv/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import os
import platform

import fire
import uvicorn

try:
import uvloop

uvloop.install()
except:
...


class Cli:
@staticmethod
def run(port=8000, workers=1, **kwargs):
def run(port=8000, **kwargs):
"""
Runs the application using the Uvicorn server.
Args:
port (int): The port number on which to run the server. Default is 8000.
workers (int): The number of worker processes to run. Default is 1.
Returns:
None
Expand All @@ -36,17 +42,34 @@ def run(port=8000, workers=1, **kwargs):
if platform.system() == "Windows":
os.environ["TZ"] = ""

uvicorn.run(
app="flaxkv.serve.app:app",
host=kwargs.get("host", "0.0.0.0"),
port=port,
workers=workers,
app_dir="..",
ssl_keyfile=kwargs.get("ssl_keyfile", None),
ssl_certfile=kwargs.get("ssl_certfile", None),
use_colors=False,
log_level="info",
)
log_level = kwargs.get("log", "info")
os.environ['FLAXKV_LOG_LEVEL'] = log_level.upper()

http2 = kwargs.get("http2", False)
if http2:
print("use http2")
from hypercorn.asyncio import serve
from hypercorn.config import Config

from flaxkv.serve.app import app

config = Config()
config.bind = [f"0.0.0.0:{port}"]
asyncio.run(serve(app, config))
else:
import uvicorn

uvicorn.run(
app="flaxkv.serve.app:app",
host=kwargs.get("host", "0.0.0.0"),
port=port,
workers=1,
app_dir="..",
ssl_keyfile=kwargs.get("ssl_keyfile", None),
ssl_certfile=kwargs.get("ssl_certfile", None),
use_colors=True,
log_level=log_level.lower(),
)


def main():
Expand Down
124 changes: 91 additions & 33 deletions flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .decorators import class_measure_time
from .helper import SimpleQueue
from .log import setting_log
from .manager import DBManager
from .manager import DBManager, RemoteTransaction
from .pack import check_pandas_type, decode, decode_key, encode

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,12 +74,17 @@ def __init__(
raw (bool): Only used by the server.
"""
log_level = kwargs.pop('log', None)
if log_level:
stdout = kwargs.pop("stdout", True)
if log_level and stdout:
log_configs = setting_log(
level="DEBUG" if log_level is True else log_level,
stdout=kwargs.pop("stdout", False),
stdout=stdout,
save_file=kwargs.pop('save_log', False),
)
try:
logger.remove(0)
except Exception:
pass
log_ids = [logger.add(**log_conf) for log_conf in log_configs]
self._logger = logger.bind(flaxkv=True)

Expand Down Expand Up @@ -961,6 +966,8 @@ def __init__(
backend='leveldb',
**kwargs,
):
self._start_event = threading.Event()

super().__init__(
"remote",
root_path_or_url=root_path_or_url,
Expand All @@ -969,19 +976,68 @@ def __init__(
rebuild=rebuild,
**kwargs,
)
self._start_event.wait()

def _iter_db_view(self, view, include_key=True, include_value=True):
def _start(self):
"""
Iterates over the items in the database view.
Args:
view: The database view to iterate over.
Starts the background worker thread.
"""

if include_key and include_value:
...
else:
...
self._thread_sync_notify = threading.Thread(target=self._attach_db)
self._thread_sync_notify.daemon = True
self._thread_sync_notify.start()

self._thread_running = True
self._thread.start()
self._thread_write_monitor.start()

def _attach_db(self):
def set_cache(data):
if data.type == "buffer_dict":
buffer_dict = data.data
for raw_key, raw_value in buffer_dict.items():
self._cache_dict[decode_key(raw_key)] = decode(raw_value)
elif data.type == "delete_keys":
for raw_key in data.data.keys():
self._cache_dict.pop(decode_key(raw_key))
else:
raise ValueError(f"Unknown data type: {data['type']}")

view: RemoteTransaction = self._db_manager.new_static_view()

for data in view.attach_db(self._start_event):
if data is None:
break
set_cache(data)

def _pull_db_data_to_cache(self, decode_raw=True):

self._start_event.wait()

(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
view,
) = self._get_status_info(return_view=True, decode_raw=decode_raw)

view: RemoteTransaction
view.check_db_exist()
with view.client.stream("GET", f"/dict_stream?db_name={self._db_name}") as r:
buffer = bytearray()
for data in r.iter_bytes():
buffer.extend(data)

remote_db_dict = decode(bytes(buffer))
for dk, dv in remote_db_dict.items():
if dk not in delete_buffer_set:
self._cache_dict[dk] = dv

def _iter_db_view(self, view, include_key=True, include_value=True):
"""
Just a placeholder, now we don't use it.
"""

def keys(self, fetch_all=True, decode_raw=True):
(
Expand Down Expand Up @@ -1021,8 +1077,6 @@ def keys(self, fetch_all=True, decode_raw=True):
else:
raise NotImplementedError

# self._db_manager.close_static_view(view)

def items(self, fetch_all=True, decode_raw=True):
if fetch_all:
return self.db_dict(decode_raw=decode_raw).items()
Expand Down Expand Up @@ -1063,27 +1117,8 @@ def db_dict(self, decode_raw=True):
else:
_db_dict = buffer_dict

# self._db_manager.close_static_view(view)
return _db_dict

def _pull_db_data_to_cache(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
view,
) = self._get_status_info(return_view=True, decode_raw=decode_raw)
with view.client.stream("GET", f"/dict_stream?db_name={self._db_name}") as r:
buffer = bytearray()
for data in r.iter_bytes():
buffer.extend(data)

remote_db_dict = decode(bytes(buffer))
for dk, dv in remote_db_dict.items():
if dk not in delete_buffer_set:
self._cache_dict[dk] = dv

def stat(self):
if self._cache_all_db:
db_count = len(self._cache_dict)
Expand All @@ -1101,4 +1136,27 @@ def stat(self):
}

def __repr__(self):
if self._cache_all_db:
return str(self._cache_dict)
return str({"keys": self.stat()['count']})

def clear(self, wait=True):
raise NotImplementedError

def destroy(self):
raise NotImplementedError

def close(self, write=True, wait=False):
"""
Closes the database and stops the background worker.
Args:
write (bool, optional): Whether to write the buffer to the database before closing. Defaults to True.
wait (bool, optional): Whether to wait for the background worker to finish. Defaults to False.
"""
self._close_background_worker(write=write, block=wait)

self._db_manager.close_static_view(self._static_view)
self._db_manager.close()
self._db_manager.env.close_connection()
self._logger.info(f"Closed ({self._db_manager.db_type.upper()}) successfully")
38 changes: 38 additions & 0 deletions flaxkv/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,41 @@ async def wrapper(*args, **kwargs):
return encode(result)

return wrapper


def retry(max_retries=3, delay=1, backoff=2, exceptions=(Exception,)):
"""
A decorator for automatically retrying a function upon encountering specified exceptions.
Args:
max_retries (int): The maximum number of times to retry the function.
delay (float): The initial delay between retries in seconds.
backoff (float): The multiplier by which the delay should increase after each retry.
exceptions (tuple): A tuple of exception classes upon which to retry.
Returns:
The return value of the wrapped function, if it succeeds.
Raises the last encountered exception if the function never succeeds.
"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
retries = 0
current_delay = delay
while retries <= max_retries:
try:
return func(*args, **kwargs)
except exceptions as e:
retries += 1
if retries == max_retries:
raise
print(
f"Retrying `{func.__name__}` after {current_delay} seconds, retry : {retries}\n"
)
time.sleep(current_delay)
current_delay *= backoff

return wrapper

return decorator
Loading

0 comments on commit 408c473

Please sign in to comment.