Skip to content

[Producer] Add Sampler and ProduceStrategy#1491

Merged
YanhuiDua merged 2 commits intoInternLM:rl_designfrom
YanhuiDua:dyh/dev_produce_strategy
Feb 26, 2026
Merged

[Producer] Add Sampler and ProduceStrategy#1491
YanhuiDua merged 2 commits intoInternLM:rl_designfrom
YanhuiDua:dyh/dev_produce_strategy

Conversation

@YanhuiDua
Copy link
Copy Markdown
Collaborator

@YanhuiDua YanhuiDua commented Feb 25, 2026

producer.py

负责一批数据从Prompts至Trajectory的整个流程,其中:

  1. 单条数据怎么生成由 AgentLoop.generate_group()AgentLoop.generate()generate_group 为生成一组数据,generate 为生成单条数据
  2. AgentLoop 的输入 RolloutStateSampler 生成,最简单的 Sampler为从dataloader中采样,若需要支持异步,则使用SamplerWithReplayBuffer,可从ReplayBuffer中采样中断、过期等状态的样本。

使用示例:

strategy.produce_batch(agent_loop, sampler, batch_size, task_name)

说明 @hhaAndroid
最终实现与原本设计不同的地方:Dataloader不是在ProduceStrategy初始化的时候传入,而是在调用produce_batch函数中传入Sampler实例,这样的好处的不会绑定数据集与生产的策略,当存在不同的数据集时,可通过以下代码来实现,并且prompt_repeat_k作为Sampler的配置

data_batch_1 = sync_strategy.produce_batch(agent_loop_1, sampler, 10, task_name)
data_batch_2 = async_strategy.produce_batch(agent_loop_2, sampler_with_buffer, 20, task_name)
data_batch_3 = sync_strategy.produce_batch(agent_loop_3, sampler_3, 30, task_name)

ProduceStrategy

class ProduceStrategy(ABC):
    def __init__(self, replay_buffer: ReplayBuffer):...
    @abstractmethod
    async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): ...

# 同步生产策略,生成满足batch_size数量的样本
class SyncProduceStrategy(ProduceStrategy):
    async def produce_batch(self, agent_loop: AgentLoop, sampler: Sampler, batch_size: int, task_name: str): ...

# 异步生产策略,在同步的基础上,支持超发与暂停逻辑
class AsyncProduceStrategy(ProduceStrategy):
    def __init__(self, replay_buffer, staleness_threshold, enable_partial_rollout, tail_batch_trigger_size, tail_batch_candidate_step): ...
    async def produce_batch(self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, task_name: str): 

Sampler

# 封装Dataloader的迭代逻辑,并支持resume
class Sampler:
    def __init__(self, dataloader_cfg: DataloaderConfig, tokenizer):...
    def sample(self) -> RolloutState:...

# 支持从ReplayBuffer中采样的逻辑
class SamplerWithReplayBuffer(Sampler):
    def __init__(self, task_name: str, dataloader_cfg: DataloaderConfig, tokenizer, replay_buffer):...
    def sample(self) -> list[RolloutState]:...

TODO

  • 用户把自己的数据集转换成rolloutstate的逻辑放在哪里
  • filter后是否补充数据的逻辑还没加
  • task_name, pause暂时写成了agent_loop的成员,需要再考虑下这两个功能怎么实现
  • 单测中目前只测试了超发的逻辑,异步其他的逻辑还没测试
  • 所有的模块ready后,需要看下用户的组合使用是否方便

Comment thread xtuner/v1/rl/base/producer.py Outdated
except Exception as e:
print(f"Error in generating trajectory: {e}")

if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoud be len(pending_tasks) + completed_sample_count < data_concurrency ?

Copy link
Copy Markdown
Collaborator Author

@YanhuiDua YanhuiDua Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,同步可以直接写成这样

Comment thread xtuner/v1/rl/base/producer.py Outdated
async def produce_batch(
self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int
):
data_concurrency = (1 + self.staleness_threshold) * batch_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ceil and cast to int

Comment thread xtuner/v1/rl/base/producer.py Outdated
task_name=agent_loop.task_name, group_status=Status.COMPLETED
)
completed_sample_count = init_completed_sample_count
while completed_sample_count < data_concurrency:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be completed_sample_count < batch_size ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

Comment thread xtuner/v1/rl/base/producer.py Outdated
self.dataloader = dataloader_cfg.build(
tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1
)
self.dataloader_iter = iter(self.dataloader)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要在初始化时 触发 iter dataloader,这会导致多进程 dataloader resume失败。这是因为:

  1. resume时会修改 dataset offset信息(比如 global consumed samples)
  2. 但是如果resume前已经在__init__时触发 iter dataloader,dataloader子进程中已经按主进程初始的dataset运行,resume时再修改主进程的dataset已经无效。

正确用法是在 sample中按需做 iter dataloader,维护一个局部的 dataloader_iter 变量。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Comment thread xtuner/v1/rl/base/producer.py Outdated
init_completed_sample_count = await self.replay_buffer.count(
task_name=agent_loop.task_name, group_status=Status.COMPLETED
)
completed_sample_count = init_completed_sample_count
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert completed_sample_count == 0 and clear replay_buffer's completed samples at end of this function?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯 同步的生成策略可以把这些都删掉

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同步情况下,不会生成多余的样本,所以应该不用clear吧

Comment thread xtuner/v1/rl/base/producer.py Outdated
data = next(self.dataloader_iter)[0]
return data

async def sample(self) -> list[RolloutState]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回类型应该和Sampler保持一致,是 RolloutState ?

Comment thread xtuner/v1/rl/base/producer.py Outdated
self.tail_batch_candidate_step = tail_batch_candidate_step

async def produce_batch(
self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt_k放到Sampler中

Comment thread xtuner/v1/rl/base/producer.py Outdated
pending_tasks.add(task)

completed_sample_count = await self.replay_buffer.count(
task_name=agent_loop.task_name, group_status=Status.COMPLETED
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update AgentLoop with new member task_name

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

task_name怎么用需要再想想,在agentloopmanager的代码中统一修改

@YanhuiDua YanhuiDua merged commit bad220a into InternLM:rl_design Feb 26, 2026
0 of 3 checks passed
@YanhuiDua YanhuiDua deleted the dyh/dev_produce_strategy branch March 3, 2026 11:35
YanhuiDua added a commit that referenced this pull request Apr 27, 2026
* [Producer] Add Sampler, SamplerWithBuffer, SyncProduceStrategy, AsyncProduceStrategy

* add tqdm in ProduceStrategy and fix comments on sampler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants