Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/ray/test_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
import asyncio
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOBackend, StalenessBackend
from xtuner.v1.data_proto.rl_data import RolloutState, Status
Comment on lines +2 to +4
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/ray/test_replay_buffer.py has several unused imports (asyncio, StorageIndices, RolloutState). Cleaning these up avoids confusion about what the tests actually exercise.

Suggested change
import asyncio
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, FIFOStorageBackend, StalenessStorageBackend
from xtuner.v1.data_proto.rl_data import Status

Copilot uses AI. Check for mistakes.

class MockState:
def __init__(self, id, staleness=0):
self.id = id
self.seq_staleness = staleness
self.status = Status.COMPLETED

class TestReplayBuffer(unittest.IsolatedAsyncioTestCase):
async def test_fifo_backend(self):
backend = FIFOBackend()
buffer = ReplayBuffer(storage_backend=backend)
states = [MockState(i) for i in range(1, 4)]

await buffer.put(states, "task1")
res = await buffer.get(2, "task1", Status.COMPLETED)

self.assertEqual(len(res), 2)
self.assertEqual(res[0].id, 1)
self.assertEqual(res[1].id, 2)

async def test_staleness_priority(self):
backend = StalenessBackend(min_staleness=0, max_staleness=5)
buffer = ReplayBuffer(storage_backend=backend)

s1 = MockState(id="low", staleness=1)
s5 = MockState(id="high", staleness=5)

await buffer.put([s1], "task1")
await buffer.put([s5], "task1")

res = await buffer.get(2, "task1", Status.COMPLETED)
self.assertEqual(res[0].id, "high")
self.assertEqual(res[1].id, "low")

async def test_multi_task(self):
buffer = ReplayBuffer()
await buffer.put([MockState(100)], "task_a")
await buffer.put([MockState(200)], "task_b")

res_a = await buffer.get(10, "task_a", Status.COMPLETED)
res_b = await buffer.get(10, "task_b", Status.COMPLETED)
self.assertEqual(len(res_a), 1)
self.assertEqual(res_a[0].id, 100)
self.assertEqual(len(res_b), 1)
self.assertEqual(res_b[0].id, 200)
45 changes: 45 additions & 0 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class RolloutState(BaseModel):
reward: float | list[float] | list[dict] | None = None

# --- 状态 ---
task_name: str | None = None
status: Status = Status.INIT
error_msg: str | None = None
seq_staleness: int = 0 # 整条序列的staleness,一般为最大的token_staleness
Expand Down Expand Up @@ -139,3 +140,47 @@ def update_status_from_finish_reason(finish_reason: str | None) -> Status:
else:
logger.error(f"finish_reason '{finish_reason}' is unknown, setting status to FAILED.")
return Status.FAILED


def update_group_status(rollout_states: list[RolloutState]) -> Status:
"""Updates the group status based on the individual rollout states.

Group Status Logic:
-------------------------------------------------------------
| Individual Rollout States | Group Status (Output) |
| :----------------------------- | :----------------------- |
| All `Status.COMPLETED` | `Status.COMPLETED` |
| Any `Status.FAILED` | `Status.FAILED` |
| Any `Status.ABORTED` | `Status.ABORTED` |
| Any `Status.EXPIRED` | `Status.EXPIRED` |
| Any `Status.FILTERED` | `Status.FILTERED` |
| *Others* | *Determined by priority*|
-------------------------------------------------------------

Priority Order (from highest to lowest):
1. FAILED
2. ABORTED
3. EXPIRED
4. FILTERED
5. COMPLETED

Args:
rollout_states (list[RolloutState]): A list of individual rollout states.

Returns:
Status: The aggregated group status based on the individual states.
"""
if all(state.status == Status.COMPLETED for state in rollout_states):
return Status.COMPLETED
elif any(state.status == Status.FAILED for state in rollout_states):
return Status.FAILED
elif any(state.status == Status.ABORTED for state in rollout_states):
return Status.ABORTED
elif any(state.status == Status.EXPIRED for state in rollout_states):
return Status.EXPIRED
elif any(state.status == Status.FILTERED for state in rollout_states):
return Status.FILTERED
else:
# If there are other statuses, we can determine the group status based on a defined priority order.
# For now, we will default to COMPLETED if none of the above conditions are met.
return Status.COMPLETED
225 changes: 225 additions & 0 deletions xtuner/v1/rl/base/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass, field

from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status


@dataclass
class StorageIndices:
# 为不同存储后段提供统一的索引接口
Comment thread
YanhuiDua marked this conversation as resolved.
task_name: str | None = None
group_status: Status | None = None
tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8

Comment thread
YanhuiDua marked this conversation as resolved.

class Storage(ABC):
@abstractmethod
async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ...
@abstractmethod
async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ...
@abstractmethod
def __len__(self): ...


class NaiveStorage(Storage):
def __init__(self):
self._storage = defaultdict(list)

def _hash_storage_indices(self, indices: StorageIndices) -> tuple:
base = (indices.task_name, indices.group_status)

if indices.tags:
sorted_tags = tuple(sorted(indices.tags.items()))
return base + sorted_tags
return base
Comment thread
YanhuiDua marked this conversation as resolved.

async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
indices = self._hash_storage_indices(storage_indices)
self._storage[indices].extend(items)

async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]:
indices = self._hash_storage_indices(storage_indices)
target_list = self._storage[indices]
target_count = min(count, len(target_list))
result = target_list[:target_count]
self._storage[indices] = target_list[target_count:]
return result

def __len__(self):
return sum(len(v) for v in self._storage.values())


class PandasStorage(Storage):
def __init__(self, limit: int = 0):
raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.")
import pandas as pd
Comment thread
YanhuiDua marked this conversation as resolved.

self._df = pd.DataFrame(columns=["task_name", "group_status", "data"])

def __len__(self): ...
async def put(self, items: list[RolloutState], indices: StorageIndices):
import pandas as pd

new_rows = []
base_info = {"task_name": indices.task_name, "group_status": indices.group_status, **indices.tags}

for item in items:
row = base_info.copy()
row["data"] = item
new_rows.append(row)

new_df = pd.DataFrame(new_rows)
self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False)

async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:
if self._df.empty:
return []
mask = (self._df["task_name"] == indices.task_name) & (self._df["group_status"] == indices.group_status)
for key, value in indices.tags.items():
if key in self._df.columns:
mask &= self._df[key] == value
else:
return []
target_df = self._df[mask].head(count)
if target_df.empty:
return []
result = target_df["data"].tolist()
self._df.drop(target_df.index, inplace=True)
return result


class SQLStorage(Storage):
def __init__(self, db_path: str = ":memory:"):
raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.")
self.db_path = db_path
Comment thread
YanhuiDua marked this conversation as resolved.
self._init_db()

def _init_db(self): ...
def _serialize_item(self, item: RolloutState) -> bytes: ...
def _deserialize_item(self, blob: bytes) -> RolloutState: ...
def __len__(self): ...

async def put(self, items: list[RolloutState], indices: StorageIndices):
import json
import sqlite3

rows = []
tags_json = json.dumps(indices.tags)

for item in items:
data_blob = self._serialize_item(item)
rows.append((indices.task_name, indices.group_status, tags_json, data_blob))

with sqlite3.connect(self.db_path) as conn:
conn.executemany(
"INSERT INTO replay_buffer (task_name, group_status, tags, data) VALUES (?, ?, ?, ?)", rows
)

async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:
import sqlite3

# 构建动态查询
query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?"
params = [indices.task_name, indices.group_status]

# SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤)
# 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。
# 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列)
# 或者使用 JSON_EXTRACT (推荐)
for key, value in indices.tags.items():
# 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。
# $.key 取出对应的值
query += f" AND json_extract(tags, '$.{key}') = ?"
params.append(value)

query += f" LIMIT {count}"

Comment thread
YanhuiDua marked this conversation as resolved.
results = []
ids_to_delete = []

with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(query, params)
rows = cursor.fetchall()

for row_id, data_blob in rows:
results.append(self._deserialize_item(data_blob))
ids_to_delete.append(row_id)

if ids_to_delete:
placeholders = ",".join("?" for _ in ids_to_delete)
conn.execute(f"DELETE FROM replay_buffer WHERE id IN ({placeholders})", ids_to_delete)

return results


class FIFOBackend(NaiveStorage):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIFO/Staleness的逻辑与SQL/Pandas/Naive的逻辑并列,不是继承关系,应该解耦开

# 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了
def __init__(self, limit: int = 0):
self.limit = limit
self._storage = defaultdict(lambda: deque(maxlen=limit) if limit > 0 else deque())

async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
await super().put(items, storage_indices)

async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]:
indices = self._hash_storage_indices(storage_indices)
target_count = min(count, len(self._storage[indices]))
return [self._storage[indices].popleft() for _ in range(target_count)]


class StalenessBackend(NaiveStorage):
# xtuner v1的异步的replay buffer的实现,同样不持久保存
# TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成
def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0):
self.limit = limit
self.max_staleness = max_staleness
self.min_staleness = min_staleness
self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)})
self._bucket_counts = defaultdict(int)

async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
indices = self._hash_storage_indices(storage_indices)
group_seq_staleness = max([item.seq_staleness for item in items])
self._storage[indices][group_seq_staleness].extend(items)
self._bucket_counts[indices] += len(items)

async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]:
indices = self._hash_storage_indices(storage_indices)
if self._bucket_counts[indices] == 0:
return []

target_items = []
needed = count

for s in range(self.max_staleness, self.min_staleness - 1, -1):
if needed <= 0:
break
cur_bucket = self._storage[indices][s]
take = min(len(cur_bucket), needed)
for _ in range(take):
target_items.append(cur_bucket.popleft())
self._bucket_counts[indices] -= take
needed -= take
return target_items

def __len__(self):
return sum(count for count in self._bucket_counts.values())


class ReplayBuffer:
def __init__(self, storage_backend: Storage = None):
self._storage = FIFOBackend() if storage_backend is None else storage_backend
self._lock = asyncio.Lock()

async def put(self, items: list[RolloutState], task_name: str, **kwargs) -> None:
group_status = update_group_status(items)
indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
await self._storage.put(items, indices)

async def get(self, batch_size: int, task_name: str, group_status: Status, **kwargs) -> list[RolloutState]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_state 我们可以自己基于 items 算出来吗?我感觉让用户自己决定可能比较难,如果这是一套可以固定下来的逻辑,是不是我们自己处理就行。

indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
return await self._storage.get(batch_size, indices)
Loading