Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 28 additions & 14 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from verl.utils.metric import reduce_metrics
from verl.utils.rollout_skip import RolloutSkip
from verl.utils.seqlen_balancing import (
calculate_workload,
get_seqlen_balanced_partitions,
log_seqlen_unbalance,
)
Expand Down Expand Up @@ -678,7 +679,6 @@ def _validate(self):
data_fields=["input_ids", "uid", "reward_model"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="get_data",
)
)
Expand All @@ -697,7 +697,6 @@ def _validate(self):
data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="generate_sequences",
)
)
Expand Down Expand Up @@ -727,7 +726,6 @@ def _validate(self):
data_fields=["responses"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_response",
)
)
Expand Down Expand Up @@ -756,7 +754,6 @@ def _validate(self):
data_fields=compute_reward_fields,
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}",
get_n_samples=False,
task_name="compute_reward",
)
)
Expand All @@ -780,7 +777,6 @@ def _validate(self):
data_fields=["__num_turns__"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_num_turns",
)
)
Expand All @@ -794,7 +790,6 @@ def _validate(self):
data_fields=["data_source"],
batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n,
partition_id=f"val_{self.global_steps - 1}", # self.global_steps start from 1
get_n_samples=False,
task_name="get_data_source",
)
)
Expand Down Expand Up @@ -1098,17 +1093,39 @@ def _stop_profiling(self, do_profile: bool) -> None:
if self.use_rm:
self.rm_wg.stop_profile()

def _balance_batch(self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen"):
def _balance_batch(
self, batch: BatchMeta, data_system_client, metrics, logging_prefix="global_seqlen", keep_minibatch=False
):
"""Reorder the batchmeta on single controller such that each dp rank gets similar total tokens"""
data = asyncio.run(data_system_client.async_get_data(batch))

attention_mask = data["attention_mask"]
batch_size = attention_mask.shape[0]
global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1).tolist() # (train_batch_size,)
global_seqlen_lst = data["attention_mask"].view(batch_size, -1).sum(-1) # (train_batch_size,)
global_seqlen_lst = calculate_workload(global_seqlen_lst)
world_size = self.actor_rollout_wg.world_size
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
if keep_minibatch:
# Decouple the DP balancing and mini-batching.
minibatch_size = self.config.actor_rollout_ref.actor.get("ppo_mini_batch_size")
minibatch_num = len(global_seqlen_lst) // minibatch_size
global_partition_lst = [[] for _ in range(world_size)]
for i in range(minibatch_num):
rearrange_minibatch_lst = get_seqlen_balanced_partitions(
global_seqlen_lst[i * minibatch_size : (i + 1) * minibatch_size],
k_partitions=world_size,
equal_size=True,
)
for j, part in enumerate(rearrange_minibatch_lst):
global_partition_lst[j].extend([x + minibatch_size * i for x in part])
else:
global_partition_lst = get_seqlen_balanced_partitions(
global_seqlen_lst, k_partitions=world_size, equal_size=True
)
# Place smaller micro-batches at both ends to reduce the bubbles in pipeline parallel.
for idx, partition in enumerate(global_partition_lst):
partition.sort(key=lambda x: (global_seqlen_lst[x], x))
ordered_partition = partition[::2] + partition[1::2][::-1]
global_partition_lst[idx] = ordered_partition
# reorder based on index. The data will be automatically equally partitioned by dispatch function
global_idx = [j for partition in global_partition_lst for j in partition]
global_balance_stats = log_seqlen_unbalance(
Expand Down Expand Up @@ -1248,7 +1265,6 @@ def fit(self):
base_get_meta_kwargs = dict(
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
partition_id=f"train_{self.global_steps - 1}", # self.global_steps starts from 1
get_n_samples=False,
)

with marked_timer("start_profile", timing_raw):
Expand Down Expand Up @@ -1646,7 +1662,6 @@ def fit(self):
batch_size=self.config.data.train_batch_size
* self.config.actor_rollout_ref.rollout.n,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="update_actor",
)
)
Expand All @@ -1672,7 +1687,6 @@ def fit(self):
data_fields=data_fields,
batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n,
partition_id=f"train_{self.global_steps - 1}",
get_n_samples=False,
task_name="log_rollout",
)
)
Expand Down
2 changes: 1 addition & 1 deletion requirements_transferqueue.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# requirements.txt records the full set of dependencies for development
git+https://github.com/TransferQueue/TransferQueue.git@862b74a
transferqueue==0.1.1.dev2