From c50099d26236475c4c2fc0916b86e7eddc8a16d9 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Wed, 11 Feb 2026 20:50:06 +0800 Subject: [PATCH 1/3] [ReplayBuffer] add ReplayBuffer with various StorageBackend: FIFO, Staleness, or Database(implement in the future) --- tests/ray/test_replay_buffer.py | 48 +++++++++++++++ xtuner/v1/rl/base/replay_buffer.py | 94 ++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 tests/ray/test_replay_buffer.py create mode 100644 xtuner/v1/rl/base/replay_buffer.py diff --git a/tests/ray/test_replay_buffer.py b/tests/ray/test_replay_buffer.py new file mode 100644 index 000000000..5ed58467b --- /dev/null +++ b/tests/ray/test_replay_buffer.py @@ -0,0 +1,48 @@ +import unittest +import asyncio +from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend +from xtuner.v1.data_proto.rl_data import RolloutState, Status + +class MockState: + def __init__(self, id, staleness=0): + self.id = id + self.seq_staleness = staleness + +class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): + async def test_fifo_backend(self): + backend = FIFOStorageBackend() + buffer = ReplayBuffer(storage_backend=backend) + states = [MockState(i) for i in range(1, 4)] + + await buffer.put(states, "task1", Status.COMPLETED) + res = await buffer.get(2, "task1", Status.COMPLETED) + + self.assertEqual(len(res), 2) + self.assertEqual(res[0].id, 1) + self.assertEqual(res[1].id, 2) + + async def test_staleness_priority(self): + backend = StalenessStorageBackend(min_staleness=0, max_staleness=5) + buffer = ReplayBuffer(storage_backend=backend) + + s1 = MockState(id="low", staleness=1) + s5 = MockState(id="high", staleness=5) + + await buffer.put([s1], "task1", Status.COMPLETED) + await buffer.put([s5], "task1", Status.COMPLETED) + + res = await buffer.get(2, "task1", Status.COMPLETED) + self.assertEqual(res[0].id, "high") + self.assertEqual(res[1].id, "low") + + async def test_multi_task(self): + buffer = ReplayBuffer() + await buffer.put([MockState(100)], "task_a", Status.COMPLETED) + await buffer.put([MockState(200)], "task_b", Status.COMPLETED) + + res_a = await buffer.get(10, "task_a", Status.COMPLETED) + res_b = await buffer.get(10, "task_b", Status.COMPLETED) + self.assertEqual(len(res_a), 1) + self.assertEqual(res_a[0].id, 100) + self.assertEqual(len(res_b), 1) + self.assertEqual(res_b[0].id, 200) \ No newline at end of file diff --git a/xtuner/v1/rl/base/replay_buffer.py b/xtuner/v1/rl/base/replay_buffer.py new file mode 100644 index 000000000..72f654733 --- /dev/null +++ b/xtuner/v1/rl/base/replay_buffer.py @@ -0,0 +1,94 @@ +import asyncio +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass + +from xtuner.v1.data_proto.rl_data import RolloutState, Status + + +@dataclass(frozen=True) +class StorageIndices: + # 为不同存储后段提供统一的接口 + task_name: str | None = None + group_status: Status | None = None + + def get_key(self): + # 给用户留出重新定义索引的接口 + return (self.task_name, self.group_status) + + +class StorageBackend(ABC): + @abstractmethod + def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... + @abstractmethod + def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... + + +class FIFOStorageBackend(StorageBackend): + # 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了 + def __init__(self, limit: int = 0): + self.limit = limit + self._storage = defaultdict(list) + + def put(self, items: list[RolloutState], storage_indices: StorageIndices): + indices = storage_indices.get_key() + target_list = self._storage[indices] + target_list.extend(items) + if self.limit > 0 and len(target_list) > self.limit: + self._storage[indices] = target_list[-self.limit :] + + def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = storage_indices.get_key() + target_count = min(count, len(self._storage[indices])) + target_items = self._storage[indices][:target_count] + self._storage[indices] = self._storage[indices][target_count:] + return target_items + + +class StalenessStorageBackend(StorageBackend): + # xtuner v1的异步的replay buffer的实现,同样不持久保存 + # TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成 + def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): + self.limit = limit + self.max_staleness = max_staleness + self.min_staleness = min_staleness + self._storage = defaultdict(lambda: {i: [] for i in range(min_staleness, max_staleness + 1)}) + self._bucket_counts = defaultdict(int) + + def put(self, items: list[RolloutState], storage_indices: StorageIndices): + indices = storage_indices.get_key() + group_seq_staleness = max([item.seq_staleness for item in items]) + self._storage[indices][group_seq_staleness].extend(items) + self._bucket_counts[indices] += len(items) + + def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = storage_indices.get_key() + if self._bucket_counts[indices] == 0: + return [] + + target_items = [] + for s in range(self.max_staleness, self.min_staleness - 1, -1): + cur_bucket = self._storage[indices][s] + needed = count - len(target_items) + take = min(len(cur_bucket), needed) + target_items.extend(cur_bucket[:take]) + self._storage[indices][s] = self._storage[indices][s][take:] + self._bucket_counts[indices] -= take + + if len(target_items) >= count: + break + return target_items + + +class ReplayBuffer: + def __init__(self, storage_backend: StorageBackend = None): + self._storage = FIFOStorageBackend() if storage_backend is None else storage_backend + self._lock = asyncio.Lock() + + async def put(self, items: list[RolloutState], task_name: str, group_status: Status): + async with self._lock: + self._storage.put(items, StorageIndices(task_name=task_name, group_status=group_status)) + + async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[RolloutState]: + async with self._lock: + return self._storage.get(batch_size, StorageIndices(task_name=task_name, group_status=group_status)) From 0aba58064f4a791d52334fa8983263ed5d6a3af9 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 12 Feb 2026 13:46:13 +0800 Subject: [PATCH 2/3] [ReplayBuffer] optimize implementation of ReplayBuffer --- xtuner/v1/rl/base/replay_buffer.py | 200 +++++++++++++++++++++++------ 1 file changed, 164 insertions(+), 36 deletions(-) diff --git a/xtuner/v1/rl/base/replay_buffer.py b/xtuner/v1/rl/base/replay_buffer.py index 72f654733..1f146ff10 100644 --- a/xtuner/v1/rl/base/replay_buffer.py +++ b/xtuner/v1/rl/base/replay_buffer.py @@ -1,49 +1,60 @@ import asyncio from abc import ABC, abstractmethod -from collections import defaultdict -from dataclasses import dataclass +from collections import defaultdict, deque +from dataclasses import dataclass, field from xtuner.v1.data_proto.rl_data import RolloutState, Status -@dataclass(frozen=True) +@dataclass class StorageIndices: - # 为不同存储后段提供统一的接口 + # 为不同存储后段提供统一的索引接口 task_name: str | None = None group_status: Status | None = None - - def get_key(self): - # 给用户留出重新定义索引的接口 - return (self.task_name, self.group_status) + tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8 class StorageBackend(ABC): @abstractmethod - def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... + @abstractmethod + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... @abstractmethod - def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ... + def __len__(self): ... class FIFOStorageBackend(StorageBackend): # 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了 def __init__(self, limit: int = 0): self.limit = limit - self._storage = defaultdict(list) + if limit > 0: + self._storage = defaultdict(lambda: deque(maxlen=limit)) + else: + self._storage = defaultdict(deque) - def put(self, items: list[RolloutState], storage_indices: StorageIndices): - indices = storage_indices.get_key() - target_list = self._storage[indices] - target_list.extend(items) - if self.limit > 0 and len(target_list) > self.limit: - self._storage[indices] = target_list[-self.limit :] + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): + indices = self._hash_storage_indices(storage_indices) + self._storage[indices].extend(items) - def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: - indices = storage_indices.get_key() + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = self._hash_storage_indices(storage_indices) target_count = min(count, len(self._storage[indices])) - target_items = self._storage[indices][:target_count] - self._storage[indices] = self._storage[indices][target_count:] + target_items = [] + for _ in range(target_count): + target_items.append(self._storage[indices].popleft()) return target_items + def _hash_storage_indices(self, indices: StorageIndices) -> tuple: + base = (indices.task_name, indices.group_status) + + if indices.tags: + sorted_tags = tuple(sorted(indices.tags.items())) + return base + sorted_tags + return base + + def __len__(self): + return sum(len(q) for q in self._storage.values()) + class StalenessStorageBackend(StorageBackend): # xtuner v1的异步的replay buffer的实现,同样不持久保存 @@ -52,43 +63,160 @@ def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = self.limit = limit self.max_staleness = max_staleness self.min_staleness = min_staleness - self._storage = defaultdict(lambda: {i: [] for i in range(min_staleness, max_staleness + 1)}) + self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) self._bucket_counts = defaultdict(int) - def put(self, items: list[RolloutState], storage_indices: StorageIndices): - indices = storage_indices.get_key() + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): + indices = self._hash_storage_indices(storage_indices) group_seq_staleness = max([item.seq_staleness for item in items]) self._storage[indices][group_seq_staleness].extend(items) self._bucket_counts[indices] += len(items) - def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: - indices = storage_indices.get_key() + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = self._hash_storage_indices(storage_indices) if self._bucket_counts[indices] == 0: return [] target_items = [] + needed = count + for s in range(self.max_staleness, self.min_staleness - 1, -1): + if needed <= 0: + break cur_bucket = self._storage[indices][s] - needed = count - len(target_items) take = min(len(cur_bucket), needed) - target_items.extend(cur_bucket[:take]) - self._storage[indices][s] = self._storage[indices][s][take:] + for _ in range(take): + target_items.append(cur_bucket.popleft()) self._bucket_counts[indices] -= take - - if len(target_items) >= count: - break + needed -= take return target_items + def _hash_storage_indices(self, indices: StorageIndices) -> tuple: + base = (indices.task_name, indices.group_status) + + if indices.tags: + sorted_tags = tuple(sorted(indices.tags.items())) + return base + sorted_tags + return base + + def __len__(self): + return sum(count for count in self._bucket_counts.values()) + + +class PandasStorageBackend(StorageBackend): + def __init__(self, limit: int = 0): + raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.") + import pandas as pd + + self._df = pd.DataFrame(columns=["task_name", "group_status", "data"]) + + def __len__(self): ... + async def put(self, items: list[RolloutState], indices: StorageIndices): + import pandas as pd + + new_rows = [] + base_info = {"task_name": indices.task_name, "group_status": indices.group_status, **indices.tags} + + for item in items: + row = base_info.copy() + row["data"] = item + new_rows.append(row) + + new_df = pd.DataFrame(new_rows) + self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False) + + def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: + if self._df.empty: + return [] + mask = (self._df["task_name"] == indices.task_name) & (self._df["group_status"] == indices.group_status) + for key, value in indices.tags.items(): + if key in self._df.columns: + mask &= self._df[key] == value + else: + return [] + target_df = self._df[mask].head(count) + if target_df.empty: + return [] + result = target_df["data"].tolist() + self._df.drop(target_df.index, inplace=True) + return result + + +class SQLStorageBackend(StorageBackend): + def __init__(self, db_path: str = ":memory:"): + raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.") + self.db_path = db_path + self._init_db() + + def _init_db(self): ... + def _serialize_item(self, item: RolloutState) -> bytes: ... + def _deserialize_item(self, blob: bytes) -> RolloutState: ... + def __len__(self): ... + + async def put(self, items: list[RolloutState], indices: StorageIndices): + import json + import sqlite3 + + rows = [] + tags_json = json.dumps(indices.tags) + + for item in items: + data_blob = self._serialize_item(item) + rows.append((indices.task_name, indices.group_status, tags_json, data_blob)) + + with sqlite3.connect(self.db_path) as conn: + conn.executemany( + "INSERT INTO replay_buffer (task_name, group_status, tags, data) VALUES (?, ?, ?, ?)", rows + ) + + async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: + import sqlite3 + + # 构建动态查询 + query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?" + params = [indices.task_name, indices.group_status] + + # SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤) + # 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。 + # 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列) + # 或者使用 JSON_EXTRACT (推荐) + for key, value in indices.tags.items(): + # 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。 + # $.key 取出对应的值 + query += f" AND json_extract(tags, '$.{key}') = ?" + params.append(value) + + query += f" LIMIT {count}" + + results = [] + ids_to_delete = [] + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(query, params) + rows = cursor.fetchall() + + for row_id, data_blob in rows: + results.append(self._deserialize_item(data_blob)) + ids_to_delete.append(row_id) + + if ids_to_delete: + placeholders = ",".join("?" for _ in ids_to_delete) + conn.execute(f"DELETE FROM replay_buffer WHERE id IN ({placeholders})", ids_to_delete) + + return results + class ReplayBuffer: def __init__(self, storage_backend: StorageBackend = None): self._storage = FIFOStorageBackend() if storage_backend is None else storage_backend self._lock = asyncio.Lock() - async def put(self, items: list[RolloutState], task_name: str, group_status: Status): + async def put(self, items: list[RolloutState], task_name: str, group_status: Status, **kwargs) -> None: + indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) async with self._lock: - self._storage.put(items, StorageIndices(task_name=task_name, group_status=group_status)) + await self._storage.put(items, indices) - async def get(self, batch_size: int, task_name: str, group_status: Status) -> list[RolloutState]: + async def get(self, batch_size: int, task_name: str, group_status: Status, **kwargs) -> list[RolloutState]: + indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) async with self._lock: - return self._storage.get(batch_size, StorageIndices(task_name=task_name, group_status=group_status)) + return await self._storage.get(batch_size, indices) From c6ce6d1bc9d874dbad01a0a416639938cc1befb2 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Wed, 25 Feb 2026 14:06:32 +0800 Subject: [PATCH 3/3] fix comments: add NaiveStorage and take fifo/staleness as policy for getting item --- tests/ray/test_replay_buffer.py | 17 ++-- xtuner/v1/data_proto/rl_data.py | 45 +++++++++ xtuner/v1/rl/base/replay_buffer.py | 143 +++++++++++++++-------------- 3 files changed, 127 insertions(+), 78 deletions(-) diff --git a/tests/ray/test_replay_buffer.py b/tests/ray/test_replay_buffer.py index 5ed58467b..78b367b54 100644 --- a/tests/ray/test_replay_buffer.py +++ b/tests/ray/test_replay_buffer.py @@ -1,20 +1,21 @@ import unittest import asyncio -from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend +from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOBackend, StalenessBackend from xtuner.v1.data_proto.rl_data import RolloutState, Status class MockState: def __init__(self, id, staleness=0): self.id = id self.seq_staleness = staleness + self.status = Status.COMPLETED class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): async def test_fifo_backend(self): - backend = FIFOStorageBackend() + backend = FIFOBackend() buffer = ReplayBuffer(storage_backend=backend) states = [MockState(i) for i in range(1, 4)] - await buffer.put(states, "task1", Status.COMPLETED) + await buffer.put(states, "task1") res = await buffer.get(2, "task1", Status.COMPLETED) self.assertEqual(len(res), 2) @@ -22,14 +23,14 @@ async def test_fifo_backend(self): self.assertEqual(res[1].id, 2) async def test_staleness_priority(self): - backend = StalenessStorageBackend(min_staleness=0, max_staleness=5) + backend = StalenessBackend(min_staleness=0, max_staleness=5) buffer = ReplayBuffer(storage_backend=backend) s1 = MockState(id="low", staleness=1) s5 = MockState(id="high", staleness=5) - await buffer.put([s1], "task1", Status.COMPLETED) - await buffer.put([s5], "task1", Status.COMPLETED) + await buffer.put([s1], "task1") + await buffer.put([s5], "task1") res = await buffer.get(2, "task1", Status.COMPLETED) self.assertEqual(res[0].id, "high") @@ -37,8 +38,8 @@ async def test_staleness_priority(self): async def test_multi_task(self): buffer = ReplayBuffer() - await buffer.put([MockState(100)], "task_a", Status.COMPLETED) - await buffer.put([MockState(200)], "task_b", Status.COMPLETED) + await buffer.put([MockState(100)], "task_a") + await buffer.put([MockState(200)], "task_b") res_a = await buffer.get(10, "task_a", Status.COMPLETED) res_b = await buffer.get(10, "task_b", Status.COMPLETED) diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index c3df4b984..6b4783ec8 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -95,6 +95,7 @@ class RolloutState(BaseModel): reward: float | list[float] | list[dict] | None = None # --- 状态 --- + task_name: str | None = None status: Status = Status.INIT error_msg: str | None = None seq_staleness: int = 0 # 整条序列的staleness,一般为最大的token_staleness @@ -139,3 +140,47 @@ def update_status_from_finish_reason(finish_reason: str | None) -> Status: else: logger.error(f"finish_reason '{finish_reason}' is unknown, setting status to FAILED.") return Status.FAILED + + +def update_group_status(rollout_states: list[RolloutState]) -> Status: + """Updates the group status based on the individual rollout states. + + Group Status Logic: + ------------------------------------------------------------- + | Individual Rollout States | Group Status (Output) | + | :----------------------------- | :----------------------- | + | All `Status.COMPLETED` | `Status.COMPLETED` | + | Any `Status.FAILED` | `Status.FAILED` | + | Any `Status.ABORTED` | `Status.ABORTED` | + | Any `Status.EXPIRED` | `Status.EXPIRED` | + | Any `Status.FILTERED` | `Status.FILTERED` | + | *Others* | *Determined by priority*| + ------------------------------------------------------------- + + Priority Order (from highest to lowest): + 1. FAILED + 2. ABORTED + 3. EXPIRED + 4. FILTERED + 5. COMPLETED + + Args: + rollout_states (list[RolloutState]): A list of individual rollout states. + + Returns: + Status: The aggregated group status based on the individual states. + """ + if all(state.status == Status.COMPLETED for state in rollout_states): + return Status.COMPLETED + elif any(state.status == Status.FAILED for state in rollout_states): + return Status.FAILED + elif any(state.status == Status.ABORTED for state in rollout_states): + return Status.ABORTED + elif any(state.status == Status.EXPIRED for state in rollout_states): + return Status.EXPIRED + elif any(state.status == Status.FILTERED for state in rollout_states): + return Status.FILTERED + else: + # If there are other statuses, we can determine the group status based on a defined priority order. + # For now, we will default to COMPLETED if none of the above conditions are met. + return Status.COMPLETED diff --git a/xtuner/v1/rl/base/replay_buffer.py b/xtuner/v1/rl/base/replay_buffer.py index 1f146ff10..33bd5c7a6 100644 --- a/xtuner/v1/rl/base/replay_buffer.py +++ b/xtuner/v1/rl/base/replay_buffer.py @@ -3,7 +3,7 @@ from collections import defaultdict, deque from dataclasses import dataclass, field -from xtuner.v1.data_proto.rl_data import RolloutState, Status +from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status @dataclass @@ -14,7 +14,7 @@ class StorageIndices: tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8 -class StorageBackend(ABC): +class Storage(ABC): @abstractmethod async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ... @abstractmethod @@ -23,26 +23,9 @@ async def get(self, count: int, storage_indices: StorageIndices) -> list[Rollout def __len__(self): ... -class FIFOStorageBackend(StorageBackend): - # 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了 - def __init__(self, limit: int = 0): - self.limit = limit - if limit > 0: - self._storage = defaultdict(lambda: deque(maxlen=limit)) - else: - self._storage = defaultdict(deque) - - async def put(self, items: list[RolloutState], storage_indices: StorageIndices): - indices = self._hash_storage_indices(storage_indices) - self._storage[indices].extend(items) - - async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: - indices = self._hash_storage_indices(storage_indices) - target_count = min(count, len(self._storage[indices])) - target_items = [] - for _ in range(target_count): - target_items.append(self._storage[indices].popleft()) - return target_items +class NaiveStorage(Storage): + def __init__(self): + self._storage = defaultdict(list) def _hash_storage_indices(self, indices: StorageIndices) -> tuple: base = (indices.task_name, indices.group_status) @@ -52,58 +35,23 @@ def _hash_storage_indices(self, indices: StorageIndices) -> tuple: return base + sorted_tags return base - def __len__(self): - return sum(len(q) for q in self._storage.values()) - - -class StalenessStorageBackend(StorageBackend): - # xtuner v1的异步的replay buffer的实现,同样不持久保存 - # TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成 - def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): - self.limit = limit - self.max_staleness = max_staleness - self.min_staleness = min_staleness - self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) - self._bucket_counts = defaultdict(int) - async def put(self, items: list[RolloutState], storage_indices: StorageIndices): indices = self._hash_storage_indices(storage_indices) - group_seq_staleness = max([item.seq_staleness for item in items]) - self._storage[indices][group_seq_staleness].extend(items) - self._bucket_counts[indices] += len(items) + self._storage[indices].extend(items) async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: indices = self._hash_storage_indices(storage_indices) - if self._bucket_counts[indices] == 0: - return [] - - target_items = [] - needed = count - - for s in range(self.max_staleness, self.min_staleness - 1, -1): - if needed <= 0: - break - cur_bucket = self._storage[indices][s] - take = min(len(cur_bucket), needed) - for _ in range(take): - target_items.append(cur_bucket.popleft()) - self._bucket_counts[indices] -= take - needed -= take - return target_items - - def _hash_storage_indices(self, indices: StorageIndices) -> tuple: - base = (indices.task_name, indices.group_status) - - if indices.tags: - sorted_tags = tuple(sorted(indices.tags.items())) - return base + sorted_tags - return base + target_list = self._storage[indices] + target_count = min(count, len(target_list)) + result = target_list[:target_count] + self._storage[indices] = target_list[target_count:] + return result def __len__(self): - return sum(count for count in self._bucket_counts.values()) + return sum(len(v) for v in self._storage.values()) -class PandasStorageBackend(StorageBackend): +class PandasStorage(Storage): def __init__(self, limit: int = 0): raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.") import pandas as pd @@ -125,7 +73,7 @@ async def put(self, items: list[RolloutState], indices: StorageIndices): new_df = pd.DataFrame(new_rows) self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False) - def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: + async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: if self._df.empty: return [] mask = (self._df["task_name"] == indices.task_name) & (self._df["group_status"] == indices.group_status) @@ -142,7 +90,7 @@ def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: return result -class SQLStorageBackend(StorageBackend): +class SQLStorage(Storage): def __init__(self, db_path: str = ":memory:"): raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.") self.db_path = db_path @@ -206,12 +154,67 @@ async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: return results +class FIFOBackend(NaiveStorage): + # 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了 + def __init__(self, limit: int = 0): + self.limit = limit + self._storage = defaultdict(lambda: deque(maxlen=limit) if limit > 0 else deque()) + + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): + await super().put(items, storage_indices) + + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = self._hash_storage_indices(storage_indices) + target_count = min(count, len(self._storage[indices])) + return [self._storage[indices].popleft() for _ in range(target_count)] + + +class StalenessBackend(NaiveStorage): + # xtuner v1的异步的replay buffer的实现,同样不持久保存 + # TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成 + def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): + self.limit = limit + self.max_staleness = max_staleness + self.min_staleness = min_staleness + self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) + self._bucket_counts = defaultdict(int) + + async def put(self, items: list[RolloutState], storage_indices: StorageIndices): + indices = self._hash_storage_indices(storage_indices) + group_seq_staleness = max([item.seq_staleness for item in items]) + self._storage[indices][group_seq_staleness].extend(items) + self._bucket_counts[indices] += len(items) + + async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: + indices = self._hash_storage_indices(storage_indices) + if self._bucket_counts[indices] == 0: + return [] + + target_items = [] + needed = count + + for s in range(self.max_staleness, self.min_staleness - 1, -1): + if needed <= 0: + break + cur_bucket = self._storage[indices][s] + take = min(len(cur_bucket), needed) + for _ in range(take): + target_items.append(cur_bucket.popleft()) + self._bucket_counts[indices] -= take + needed -= take + return target_items + + def __len__(self): + return sum(count for count in self._bucket_counts.values()) + + class ReplayBuffer: - def __init__(self, storage_backend: StorageBackend = None): - self._storage = FIFOStorageBackend() if storage_backend is None else storage_backend + def __init__(self, storage_backend: Storage = None): + self._storage = FIFOBackend() if storage_backend is None else storage_backend self._lock = asyncio.Lock() - async def put(self, items: list[RolloutState], task_name: str, group_status: Status, **kwargs) -> None: + async def put(self, items: list[RolloutState], task_name: str, **kwargs) -> None: + group_status = update_group_status(items) indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs) async with self._lock: await self._storage.put(items, indices)