[Producer] Add Sampler and ProduceStrategy#1491
[Producer] Add Sampler and ProduceStrategy#1491YanhuiDua merged 2 commits intoInternLM:rl_designfrom
Conversation
| except Exception as e: | ||
| print(f"Error in generating trajectory: {e}") | ||
|
|
||
| if len(pending_tasks) + completed_sample_count < data_concurrency + init_completed_sample_count: |
There was a problem hiding this comment.
Shoud be len(pending_tasks) + completed_sample_count < data_concurrency ?
| async def produce_batch( | ||
| self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int | ||
| ): | ||
| data_concurrency = (1 + self.staleness_threshold) * batch_size |
| task_name=agent_loop.task_name, group_status=Status.COMPLETED | ||
| ) | ||
| completed_sample_count = init_completed_sample_count | ||
| while completed_sample_count < data_concurrency: |
There was a problem hiding this comment.
should be completed_sample_count < batch_size ?
| self.dataloader = dataloader_cfg.build( | ||
| tokenizer=self.tokenizer, dp_mesh=None, global_batch_size=1, micro_batch_size=1, seed=1 | ||
| ) | ||
| self.dataloader_iter = iter(self.dataloader) |
There was a problem hiding this comment.
不要在初始化时 触发 iter dataloader,这会导致多进程 dataloader resume失败。这是因为:
- resume时会修改 dataset offset信息(比如 global consumed samples)
- 但是如果resume前已经在__init__时触发 iter dataloader,dataloader子进程中已经按主进程初始的dataset运行,resume时再修改主进程的dataset已经无效。
正确用法是在 sample中按需做 iter dataloader,维护一个局部的 dataloader_iter 变量。
| init_completed_sample_count = await self.replay_buffer.count( | ||
| task_name=agent_loop.task_name, group_status=Status.COMPLETED | ||
| ) | ||
| completed_sample_count = init_completed_sample_count |
There was a problem hiding this comment.
assert completed_sample_count == 0 and clear replay_buffer's completed samples at end of this function?
There was a problem hiding this comment.
嗯嗯 同步的生成策略可以把这些都删掉
There was a problem hiding this comment.
同步情况下,不会生成多余的样本,所以应该不用clear吧
| data = next(self.dataloader_iter)[0] | ||
| return data | ||
|
|
||
| async def sample(self) -> list[RolloutState]: |
There was a problem hiding this comment.
返回类型应该和Sampler保持一致,是 RolloutState ?
| self.tail_batch_candidate_step = tail_batch_candidate_step | ||
|
|
||
| async def produce_batch( | ||
| self, agent_loop: AgentLoop, sampler: SamplerWithReplayBuffer, batch_size: int, prompt_k: int |
| pending_tasks.add(task) | ||
|
|
||
| completed_sample_count = await self.replay_buffer.count( | ||
| task_name=agent_loop.task_name, group_status=Status.COMPLETED |
There was a problem hiding this comment.
Update AgentLoop with new member task_name
There was a problem hiding this comment.
task_name怎么用需要再想想,在agentloopmanager的代码中统一修改
* [Producer] Add Sampler, SamplerWithBuffer, SyncProduceStrategy, AsyncProduceStrategy * add tqdm in ProduceStrategy and fix comments on sampler
producer.py
负责一批数据从Prompts至Trajectory的整个流程,其中:
AgentLoop.generate_group()和AgentLoop.generate(),generate_group为生成一组数据,generate为生成单条数据AgentLoop的输入RolloutState由Sampler生成,最简单的Sampler为从dataloader中采样,若需要支持异步,则使用SamplerWithReplayBuffer,可从ReplayBuffer中采样中断、过期等状态的样本。使用示例:
说明 @hhaAndroid
最终实现与原本设计不同的地方:Dataloader不是在ProduceStrategy初始化的时候传入,而是在调用produce_batch函数中传入Sampler实例,这样的好处的不会绑定数据集与生产的策略,当存在不同的数据集时,可通过以下代码来实现,并且prompt_repeat_k作为Sampler的配置
ProduceStrategy
Sampler
TODO