Add StepPO recipes and dataset docs#93
Conversation
There was a problem hiding this comment.
Code Review
This pull request significantly expands the Agent-R1 framework by integrating multiple new benchmarks (ALFWorld, HotpotQA, Paper Search, and WebShop) along with their respective data preparation scripts, agent flows, and reward functions. It also introduces several new reinforcement learning algorithms and advantage estimators, including token-level GAE, REINFORCE++, RLOO, and GiGPO, alongside a GSPO sequence-level policy loss. Feedback on the implementation highlights opportunities to prevent undefined behavior by initializing object arrays with np.full instead of np.empty, to fully vectorize group score normalization in PyTorch to avoid Python loop overhead, and to refine the critic value mask to prevent training on padded tokens.
|
|
||
| step_group_uids = np.empty(len(anchor_obs), dtype=object) | ||
| for prompt_idx in np.unique(index): | ||
| locs = np.where(index == prompt_idx)[0] |
There was a problem hiding this comment.
Using np.empty with dtype=object creates an array of uninitialized object pointers. If any elements are not explicitly overwritten, attempting to read them or compare them (e.g., step_group_uids == None) can lead to undefined behavior, silent bugs, or segmentation faults due to dereferencing garbage memory addresses.
It is highly recommended to use np.full(..., None, dtype=object) to safely initialize the array with None elements.
| locs = np.where(index == prompt_idx)[0] | |
| step_group_uids = np.full(len(anchor_obs), None, dtype=object) |
| group_uids: np.ndarray, | ||
| epsilon: float, | ||
| remove_std: bool, | ||
| single_mean_zero: bool = False, | ||
| ) -> torch.Tensor: | ||
| id2score = defaultdict(list) | ||
| id2mean: dict[object, torch.Tensor] = {} | ||
| id2std: dict[object, torch.Tensor] = {} | ||
|
|
||
| for i in range(scores.shape[0]): | ||
| id2score[group_uids[i]].append(scores[i]) | ||
|
|
||
| for group_uid, group_scores in id2score.items(): | ||
| stacked = torch.stack(group_scores) | ||
| if single_mean_zero and len(group_scores) == 1: | ||
| id2mean[group_uid] = scores.new_tensor(0.0) | ||
| else: | ||
| id2mean[group_uid] = torch.mean(stacked) | ||
| id2std[group_uid] = torch.std(stacked) if len(group_scores) > 1 else scores.new_tensor(1.0) | ||
|
|
||
| normalized = scores.clone() | ||
| for i in range(scores.shape[0]): | ||
| group_uid = group_uids[i] | ||
| if remove_std: | ||
| normalized[i] = scores[i] - id2mean[group_uid] | ||
| else: | ||
| normalized[i] = (scores[i] - id2mean[group_uid]) / (id2std[group_uid] + epsilon) | ||
| return normalized | ||
|
|
||
|
|
There was a problem hiding this comment.
The current implementation of _normalize_group_scores uses nested Python loops to group, stack, and compute the mean and standard deviation of scores for each group, and then loops again to normalize them. This introduces significant Python overhead and completely bypasses PyTorch's vectorized execution, which will severely bottleneck GPU training performance on larger batch sizes.
We can fully vectorize this operation using np.unique(..., return_inverse=True) to map group UIDs to unique integer indices, and then use PyTorch's scatter_add_ to compute group sums, counts, and standard deviations in a highly efficient, vectorized manner.
def _normalize_group_scores(
scores: torch.Tensor,
group_uids: np.ndarray,
epsilon: float,
remove_std: bool,
single_mean_zero: bool = False,
) -> torch.Tensor:
if scores.numel() == 0:
return scores.clone()
# Map group UIDs to unique integer indices
unique_groups, group_indices_np = np.unique(group_uids, return_inverse=True)
group_indices = torch.as_tensor(group_indices_np, dtype=torch.long, device=scores.device)
num_groups = len(unique_groups)
# Compute group sums and counts in a vectorized manner
group_sums = torch.zeros(num_groups, dtype=scores.dtype, device=scores.device)
group_sums.scatter_add_(0, group_indices, scores)
group_counts = torch.zeros(num_groups, dtype=scores.dtype, device=scores.device)
group_counts.scatter_add_(0, group_indices, torch.ones_like(scores))
group_means = group_sums / group_counts
if single_mean_zero:
group_means[group_counts == 1] = 0.0
if remove_std:
normalized = scores - group_means[group_indices]
else:
# Compute standard deviation: std = sqrt(sum((x - mean)^2) / (count - 1))
sq_diff = (scores - group_means[group_indices]) ** 2
group_sq_sum = torch.zeros(num_groups, dtype=scores.dtype, device=scores.device)
group_sq_sum.scatter_add_(0, group_indices, sq_diff)
group_stds = torch.ones(num_groups, dtype=scores.dtype, device=scores.device)
mask = group_counts > 1
group_stds[mask] = torch.sqrt(group_sq_sum[mask] / (group_counts[mask] - 1))
normalized = (scores - group_means[group_indices]) / (group_stds[group_indices] + epsilon)
return normalized| value_mask = torch.zeros_like(response_mask) | ||
| value_mask[:, 0] = 1 |
There was a problem hiding this comment.
Setting value_mask[:, 0] = 1 unconditionally forces the first token of every sequence to be active for the critic value loss. However, if a sequence is padded or has an empty response (where response_mask[:, 0] is 0), this will incorrectly train the critic on invalid/padded tokens.
A safer approach is to mask the first token based on its actual validity in response_mask by setting value_mask[:, 0] = response_mask[:, 0].
| value_mask = torch.zeros_like(response_mask) | |
| value_mask[:, 0] = 1 | |
| value_mask = torch.zeros_like(response_mask) | |
| value_mask[:, 0] = response_mask[:, 0] |
No description provided.