-
Notifications
You must be signed in to change notification settings - Fork 419
[ReplayBuffer] add ReplayBuffer with various StorageBackend #1490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import unittest | ||
| import asyncio | ||
| from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOBackend, StalenessBackend | ||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status | ||
|
|
||
| class MockState: | ||
| def __init__(self, id, staleness=0): | ||
| self.id = id | ||
| self.seq_staleness = staleness | ||
| self.status = Status.COMPLETED | ||
|
|
||
| class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): | ||
| async def test_fifo_backend(self): | ||
| backend = FIFOBackend() | ||
| buffer = ReplayBuffer(storage_backend=backend) | ||
| states = [MockState(i) for i in range(1, 4)] | ||
|
|
||
| await buffer.put(states, "task1") | ||
| res = await buffer.get(2, "task1", Status.COMPLETED) | ||
|
|
||
| self.assertEqual(len(res), 2) | ||
| self.assertEqual(res[0].id, 1) | ||
| self.assertEqual(res[1].id, 2) | ||
|
|
||
| async def test_staleness_priority(self): | ||
| backend = StalenessBackend(min_staleness=0, max_staleness=5) | ||
| buffer = ReplayBuffer(storage_backend=backend) | ||
|
|
||
| s1 = MockState(id="low", staleness=1) | ||
| s5 = MockState(id="high", staleness=5) | ||
|
|
||
| await buffer.put([s1], "task1") | ||
| await buffer.put([s5], "task1") | ||
|
|
||
| res = await buffer.get(2, "task1", Status.COMPLETED) | ||
| self.assertEqual(res[0].id, "high") | ||
| self.assertEqual(res[1].id, "low") | ||
|
|
||
| async def test_multi_task(self): | ||
| buffer = ReplayBuffer() | ||
| await buffer.put([MockState(100)], "task_a") | ||
| await buffer.put([MockState(200)], "task_b") | ||
|
|
||
| res_a = await buffer.get(10, "task_a", Status.COMPLETED) | ||
| res_b = await buffer.get(10, "task_b", Status.COMPLETED) | ||
| self.assertEqual(len(res_a), 1) | ||
| self.assertEqual(res_a[0].id, 100) | ||
| self.assertEqual(len(res_b), 1) | ||
| self.assertEqual(res_b[0].id, 200) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,225 @@ | ||
| import asyncio | ||
| from abc import ABC, abstractmethod | ||
| from collections import defaultdict, deque | ||
| from dataclasses import dataclass, field | ||
|
|
||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status | ||
|
|
||
|
|
||
| @dataclass | ||
| class StorageIndices: | ||
| # 为不同存储后段提供统一的索引接口 | ||
|
YanhuiDua marked this conversation as resolved.
|
||
| task_name: str | None = None | ||
| group_status: Status | None = None | ||
| tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8 | ||
|
|
||
|
YanhuiDua marked this conversation as resolved.
|
||
|
|
||
| class Storage(ABC): | ||
| @abstractmethod | ||
| async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... | ||
| @abstractmethod | ||
| async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... | ||
| @abstractmethod | ||
| def __len__(self): ... | ||
|
|
||
|
|
||
| class NaiveStorage(Storage): | ||
| def __init__(self): | ||
| self._storage = defaultdict(list) | ||
|
|
||
| def _hash_storage_indices(self, indices: StorageIndices) -> tuple: | ||
| base = (indices.task_name, indices.group_status) | ||
|
|
||
| if indices.tags: | ||
| sorted_tags = tuple(sorted(indices.tags.items())) | ||
| return base + sorted_tags | ||
| return base | ||
|
YanhuiDua marked this conversation as resolved.
|
||
|
|
||
| async def put(self, items: list[RolloutState], storage_indices: StorageIndices): | ||
| indices = self._hash_storage_indices(storage_indices) | ||
| self._storage[indices].extend(items) | ||
|
|
||
| async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: | ||
| indices = self._hash_storage_indices(storage_indices) | ||
| target_list = self._storage[indices] | ||
| target_count = min(count, len(target_list)) | ||
| result = target_list[:target_count] | ||
| self._storage[indices] = target_list[target_count:] | ||
| return result | ||
|
|
||
| def __len__(self): | ||
| return sum(len(v) for v in self._storage.values()) | ||
|
|
||
|
|
||
| class PandasStorage(Storage): | ||
| def __init__(self, limit: int = 0): | ||
| raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.") | ||
| import pandas as pd | ||
|
YanhuiDua marked this conversation as resolved.
|
||
|
|
||
| self._df = pd.DataFrame(columns=["task_name", "group_status", "data"]) | ||
|
|
||
| def __len__(self): ... | ||
| async def put(self, items: list[RolloutState], indices: StorageIndices): | ||
| import pandas as pd | ||
|
|
||
| new_rows = [] | ||
| base_info = {"task_name": indices.task_name, "group_status": indices.group_status, **indices.tags} | ||
|
|
||
| for item in items: | ||
| row = base_info.copy() | ||
| row["data"] = item | ||
| new_rows.append(row) | ||
|
|
||
| new_df = pd.DataFrame(new_rows) | ||
| self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False) | ||
|
|
||
| async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: | ||
| if self._df.empty: | ||
| return [] | ||
| mask = (self._df["task_name"] == indices.task_name) & (self._df["group_status"] == indices.group_status) | ||
| for key, value in indices.tags.items(): | ||
| if key in self._df.columns: | ||
| mask &= self._df[key] == value | ||
| else: | ||
| return [] | ||
| target_df = self._df[mask].head(count) | ||
| if target_df.empty: | ||
| return [] | ||
| result = target_df["data"].tolist() | ||
| self._df.drop(target_df.index, inplace=True) | ||
| return result | ||
|
|
||
|
|
||
| class SQLStorage(Storage): | ||
| def __init__(self, db_path: str = ":memory:"): | ||
| raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.") | ||
| self.db_path = db_path | ||
|
YanhuiDua marked this conversation as resolved.
|
||
| self._init_db() | ||
|
|
||
| def _init_db(self): ... | ||
| def _serialize_item(self, item: RolloutState) -> bytes: ... | ||
| def _deserialize_item(self, blob: bytes) -> RolloutState: ... | ||
| def __len__(self): ... | ||
|
|
||
| async def put(self, items: list[RolloutState], indices: StorageIndices): | ||
| import json | ||
| import sqlite3 | ||
|
|
||
| rows = [] | ||
| tags_json = json.dumps(indices.tags) | ||
|
|
||
| for item in items: | ||
| data_blob = self._serialize_item(item) | ||
| rows.append((indices.task_name, indices.group_status, tags_json, data_blob)) | ||
|
|
||
| with sqlite3.connect(self.db_path) as conn: | ||
| conn.executemany( | ||
| "INSERT INTO replay_buffer (task_name, group_status, tags, data) VALUES (?, ?, ?, ?)", rows | ||
| ) | ||
|
|
||
| async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: | ||
| import sqlite3 | ||
|
|
||
| # 构建动态查询 | ||
| query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?" | ||
| params = [indices.task_name, indices.group_status] | ||
|
|
||
| # SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤) | ||
| # 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。 | ||
| # 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列) | ||
| # 或者使用 JSON_EXTRACT (推荐) | ||
| for key, value in indices.tags.items(): | ||
| # 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。 | ||
| # $.key 取出对应的值 | ||
| query += f" AND json_extract(tags, '$.{key}') = ?" | ||
| params.append(value) | ||
|
|
||
| query += f" LIMIT {count}" | ||
|
|
||
|
YanhuiDua marked this conversation as resolved.
|
||
| results = [] | ||
| ids_to_delete = [] | ||
|
|
||
| with sqlite3.connect(self.db_path) as conn: | ||
| cursor = conn.execute(query, params) | ||
| rows = cursor.fetchall() | ||
|
|
||
| for row_id, data_blob in rows: | ||
| results.append(self._deserialize_item(data_blob)) | ||
| ids_to_delete.append(row_id) | ||
|
|
||
| if ids_to_delete: | ||
| placeholders = ",".join("?" for _ in ids_to_delete) | ||
| conn.execute(f"DELETE FROM replay_buffer WHERE id IN ({placeholders})", ids_to_delete) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| class FIFOBackend(NaiveStorage): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FIFO/Staleness的逻辑与SQL/Pandas/Naive的逻辑并列,不是继承关系,应该解耦开 |
||
| # 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了 | ||
| def __init__(self, limit: int = 0): | ||
| self.limit = limit | ||
| self._storage = defaultdict(lambda: deque(maxlen=limit) if limit > 0 else deque()) | ||
|
|
||
| async def put(self, items: list[RolloutState], storage_indices: StorageIndices): | ||
| await super().put(items, storage_indices) | ||
|
|
||
| async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: | ||
| indices = self._hash_storage_indices(storage_indices) | ||
| target_count = min(count, len(self._storage[indices])) | ||
| return [self._storage[indices].popleft() for _ in range(target_count)] | ||
|
|
||
|
|
||
| class StalenessBackend(NaiveStorage): | ||
| # xtuner v1的异步的replay buffer的实现,同样不持久保存 | ||
| # TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成 | ||
| def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): | ||
| self.limit = limit | ||
| self.max_staleness = max_staleness | ||
| self.min_staleness = min_staleness | ||
| self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) | ||
| self._bucket_counts = defaultdict(int) | ||
|
|
||
| async def put(self, items: list[RolloutState], storage_indices: StorageIndices): | ||
| indices = self._hash_storage_indices(storage_indices) | ||
| group_seq_staleness = max([item.seq_staleness for item in items]) | ||
| self._storage[indices][group_seq_staleness].extend(items) | ||
| self._bucket_counts[indices] += len(items) | ||
|
|
||
| async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: | ||
| indices = self._hash_storage_indices(storage_indices) | ||
| if self._bucket_counts[indices] == 0: | ||
| return [] | ||
|
|
||
| target_items = [] | ||
| needed = count | ||
|
|
||
| for s in range(self.max_staleness, self.min_staleness - 1, -1): | ||
| if needed <= 0: | ||
| break | ||
| cur_bucket = self._storage[indices][s] | ||
| take = min(len(cur_bucket), needed) | ||
| for _ in range(take): | ||
| target_items.append(cur_bucket.popleft()) | ||
| self._bucket_counts[indices] -= take | ||
| needed -= take | ||
| return target_items | ||
|
|
||
| def __len__(self): | ||
| return sum(count for count in self._bucket_counts.values()) | ||
|
|
||
|
|
||
| class ReplayBuffer: | ||
| def __init__(self, storage_backend: Storage = None): | ||
| self._storage = FIFOBackend() if storage_backend is None else storage_backend | ||
| self._lock = asyncio.Lock() | ||
|
|
||
| async def put(self, items: list[RolloutState], task_name: str, **kwargs) -> None: | ||
| group_status = update_group_status(items) | ||
| indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) | ||
| async with self._lock: | ||
| await self._storage.put(items, indices) | ||
|
|
||
| async def get(self, batch_size: int, task_name: str, group_status: Status, **kwargs) -> list[RolloutState]: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. group_state 我们可以自己基于 items 算出来吗?我感觉让用户自己决定可能比较难,如果这是一套可以固定下来的逻辑,是不是我们自己处理就行。 |
||
| indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) | ||
| async with self._lock: | ||
| return await self._storage.get(batch_size, indices) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/ray/test_replay_buffer.pyhas several unused imports (asyncio,StorageIndices,RolloutState). Cleaning these up avoids confusion about what the tests actually exercise.