From 5f1ce2bc3d1c9a026e60b6eb202eb42b8110cf95 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Mon, 17 Nov 2025 10:37:18 +0800 Subject: [PATCH 1/3] [recipe, TransderQueue] fix: remove unused param get_n_samples & update _balance_batch func --- recipe/transfer_queue/ray_trainer.py | 42 ++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 83b10d8c467..0200c1cc5aa 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -81,6 +81,7 @@ from verl.utils.metric import reduce_metrics from verl.utils.rollout_skip import RolloutSkip from verl.utils.seqlen_balancing import ( + calculate_workload, get_seqlen_balanced_partitions, log_seqlen_unbalance, ) @@ -678,7 +679,6 @@ def _validate(self): data_fields=["input_ids", "uid", "reward_model"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", - get_n_samples=False, task_name="get_data", ) ) @@ -697,7 +697,6 @@ def _validate(self): data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 - get_n_samples=False, task_name="generate_sequences", ) ) @@ -727,7 +726,6 @@ def _validate(self): data_fields=["responses"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 - get_n_samples=False, task_name="get_response", ) ) @@ -756,7 +754,6 @@ def _validate(self): data_fields=compute_reward_fields, batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", - get_n_samples=False, task_name="compute_reward", ) ) @@ -780,7 +777,6 @@ def _validate(self): data_fields=["__num_turns__"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 - get_n_samples=False, task_name="get_num_turns", ) ) @@ -794,7 +790,6 @@ def _validate(self): data_fields=["data_source"], batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1 - get_n_samples=False, task_name="get_data_source", ) ) @@ -1098,17 +1093,39 @@ def _stop_profiling(self, do_profile: bool) -> None: if self.use_rm: self.rm_wg.stop_profile() - def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"): + def _balance_batch( + self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False + ): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch)) attention_mask = data["attention_mask"] batch_size = attention_mask.shape[0] - global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,) + global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,) + global_seqlen_lst = calculate_workload(global_seqlen_lst) world_size = self.actor_rollout_wg.world_size - global_partition_lst = get_seqlen_balanced_partitions( - global_seqlen_lst, k_partitions=world_size, equal_size=True - ) + if keep_minibatch: + # Decouple the DP balancing and mini-batching. + minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size") + minibatch_num = len(global_seqlen_lst) // minibatch_size + global_partition_lst = [[] for _ in range(world_size)] + for i in range(minibatch_num): + rearrange_minibatch_lst = get_seqlen_balanced_partitions( + global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size], + k_partitions=world_size, + equal_size=True, + ) + for j, part in enumerate(rearrange_minibatch_lst): + global_partition_lst[j].extend([x + minibatch_size * i for x in part]) + else: + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=world_size, equal_size=True + ) + # Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel. + for idx, partition in enumerate(global_partition_lst): + partition.sort(key=lambda x: (global_seqlen_lst[x], x)) + ordered_partition = partition[::2] + partition[1::2][::-1] + global_partition_lst[idx] = ordered_partition # reorder based on index. The data will be automatically equally partitioned by dispatch function global_idx = [j for partition in global_partition_lst for j in partition] global_balance_stats = log_seqlen_unbalance( @@ -1248,7 +1265,6 @@ def fit(self): base_get_meta_kwargs = dict( batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1 - get_n_samples=False, ) with marked_timer("start_profile", timing_raw): @@ -1646,7 +1662,6 @@ def fit(self): batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, partition_id=f"train_{self.global_steps - 1}", - get_n_samples=False, task_name="update_actor", ) ) @@ -1672,7 +1687,6 @@ def fit(self): data_fields=data_fields, batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, partition_id=f"train_{self.global_steps - 1}", - get_n_samples=False, task_name="log_rollout", ) ) From 4f8fafb2034f4b098981c750720638521ab83cce Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Mon, 17 Nov 2025 10:42:36 +0800 Subject: [PATCH 2/3] update requirements_transferqueue --- requirements_transferqueue.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt index 387a61a456f..621682abbf7 100644 --- a/requirements_transferqueue.txt +++ b/requirements_transferqueue.txt @@ -1,2 +1,2 @@ # requirements.txt records the full set of dependencies for development -git+https://github.com/TransferQueue/TransferQueue.git@862b74a +transferqueue==0.1.1.dev2 From 528449b235af243df204456dc7b703bdb9c09b67 Mon Sep 17 00:00:00 2001 From: liuximeng <13073314+liuximeng18772102439@user.noreply.gitee.com> Date: Mon, 17 Nov 2025 10:47:43 +0800 Subject: [PATCH 3/3] fix codecheck --- recipe/transfer_queue/ray_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 0200c1cc5aa..a04e6edcc53 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1094,7 +1094,7 @@ def _stop_profiling(self, do_profile: bool) -> None: self.rm_wg.stop_profile() def _balance_batch( - self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False + self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False ): """Reorder the batchmeta on single controller such that each dp rank gets similar total tokens""" data = asyncio.run(data_system_client.async_get_data(batch))