Skip to content

add gsm8k_with_tool agent_loop as example#1543

Merged
jayhenry merged 4 commits intoInternLM:rl_designfrom
YanhuiDua:dev_tool_agent_loop
Mar 9, 2026
Merged

add gsm8k_with_tool agent_loop as example#1543
jayhenry merged 4 commits intoInternLM:rl_designfrom
YanhuiDua:dev_tool_agent_loop

Conversation

@YanhuiDua
Copy link
Copy Markdown
Collaborator

No description provided.

@YanhuiDua YanhuiDua force-pushed the dev_tool_agent_loop branch 2 times, most recently from fe3009e to 0c2c2e9 Compare March 6, 2026 10:11
@YanhuiDua YanhuiDua force-pushed the dev_tool_agent_loop branch from 0c2c2e9 to 4cf124b Compare March 6, 2026 10:41
@YanhuiDua YanhuiDua requested review from hhaAndroid and jayhenry March 6, 2026 10:44
@YanhuiDua
Copy link
Copy Markdown
Collaborator Author

@claude review

Comment thread xtuner/v1/data_proto/rl_data.py
Comment thread xtuner/v1/ray/rollout/lmdeploy.py
Comment on lines +27 to +33
},
"required": ["answer"],
},
},
},
}
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — "required" is nested inside "properties" instead of being a sibling

The JSON Schema for this tool definition is malformed. "required" should be at the same level as "properties", not nested inside it. This may cause the LLM to not understand the tool schema correctly.

Suggested change
},
"required": ["answer"],
},
},
},
}
]
"properties": {
"answer": {
"type": "string",
"description": "The model's answer to the GSM8K math problem, must be a digits",
},
},
"required": ["answer"],

Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py
Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py Outdated
Comment on lines +141 to +176

_, function_calls = self.extract_tool_calls(rollout_state)
if not function_calls:
break

tool_messages = []
for function_call in function_calls:
tool_name = function_call.name
tool_args = function_call.arguments
if tool_name == "calc_gsm8k_reward":
answer = tool_args
ground_truth = cast(dict, rollout_state.reward_model).get("ground_truth", "")
function_results = self.calc_gsm8k_reward(answer, ground_truth)
tool_message = {
"role": "tool",
"content": json.dumps({"result": function_results}, ensure_ascii=False),
}
tool_messages.append(tool_message)

# 处理工具调用的输出
tools_response_ids = self.tokenizer.apply_chat_template(tool_messages, remove_system_prompt=True)
cur_turn_tokens.extend(tools_response_ids)
final_response_ids.extend(tools_response_ids)
final_logprobs.extend([0.0] * len(tools_response_ids))
final_response_mask.extend([0] * len(tools_response_ids))

tokens = cast(list[int], rollout_state.tokens)
raw_prompt_ids = len(tokens) - len(final_response_mask)
final_tokens = tokens[:raw_prompt_ids]
final_response_ids = tokens[raw_prompt_ids:]
rollout_state.tokens = final_tokens
rollout_state.response_ids = final_response_ids
rollout_state.response_mask = final_response_mask
rollout_state.logprobs = final_logprobs
rollout_state.response = self.tokenizer.decode(rollout_state.response_ids)
assert len(rollout_state.response_ids) == len(rollout_state.response_mask) == len(rollout_state.logprobs), (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — final_response_ids is built up incrementally then overwritten

Lines 141 and 167-168 build final_response_ids by extending it across turns, but then line 174 overwrites it entirely with tokens[raw_prompt_ids:]. This means the carefully accumulated final_response_ids is discarded. This works because tokens is also being extended in lock-step, but the intermediate final_response_ids variable is misleading — it's only used for the length assertion at line 180. Consider simplifying by not accumulating final_response_ids at all, or adding a comment explaining the intent.

Comment thread xtuner/v1/train/rl_colocate_trainer.py Outdated
Comment on lines +466 to +472
response_labels.append(-100)
else:
response_labels.append(response_ids[idx])

# 根据 response_mask 计算新的 advantages
actual_advantages: list[float] = [advantages[i].item()] * len(prompt_ids)
for mask in response_mask:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Critical — actual_advantages includes prompt tokens, causing length mismatch with packing

actual_advantages is constructed with len(prompt_ids) entries for the prompt plus len(response_mask) entries for the response, then sliced with [:-1]. This means each element in advantages_list for a sample has length len(prompt_ids) + len(response_mask) - 1 = len(input_ids).

However, in _packing (controller.py), after this PR, advantage_list collects these per-sample lists and for padding appends pad_advantages = [-100] * pad_len. Then it flattens everything. The old code used repeat_interleave with cu_seq_lens_q to expand per-sample scalars to per-token values. The new code assumes each sample's advantage list is already per-token, but the packing code's pad_advantages uses pad_len (number of padding tokens) while the real samples use len(input_ids) tokens.

Please verify that the flattened advantage_flat length matches shifted_labels length (1, pack_max_length) in all cases. The old repeat_interleave approach guaranteed alignment; the new approach requires manual length tracking.

Comment thread xtuner/v1/train/rl_colocate_trainer.py Outdated
continue

prompt_ids = group[0].prompt_ids
prompt_ids = group[0].tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — prompt_ids = group[0].tokens semantic change

The old code used group[0].prompt_ids (which was presumably just the prompt portion). The new code uses group[0].tokens, which in the agent loop is set to the full token sequence minus response. Please verify this is semantically equivalent in all code paths, especially for the non-tool (single-turn) agent loop where tokens might still include response tokens at this point.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 6, 2026

Claude: ## Summary

This PR adds a GSM8K with tool-calling agent loop for RL training, introducing multi-turn tool interaction support with response_mask for selective loss masking, and refactors advantage handling from per-sample scalars to per-token lists.

Issues

Critical

  • [lmdeploy.py:116-117] tools and tool_choice are referenced but undefined in _get_request_payload — will raise NameError at runtime.
  • [rl_colocate_trainer.py:466-472] The advantages refactor (per-token lists + flattening in packing) needs verification that the flattened length always matches pack_max_length. The old repeat_interleave approach guaranteed alignment; the new manual approach is error-prone.

Warning

  • [gsm8k_with_tool.py:27-33] Tool JSON Schema is malformed — "required" is nested inside "properties" instead of being a sibling.
  • [gsm8k_with_tool.py:93] No-op assignment actual_answer = actual_answer.
  • [gsm8k_with_tool.py:75] self.cur_turn instance variable set but never used (local cur_turn used instead).
  • [gsm8k_with_tool.py:141-176] final_response_ids is accumulated then fully overwritten — misleading intermediate state.
  • [rl_colocate_trainer.py:412] group[0].prompt_idsgroup[0].tokens semantic change needs verification for non-tool agent loops.
  • [rl_data.py] loss_maskresponse_mask rename is a breaking change if anything else on rl_design reads .loss_mask.

Nit

  • [gsm8k_with_tool.py:111] Use logger.warning() instead of print().
  • [test_tool_agent_loop.py] Uses single quotes, has trailing whitespace, uses unittest instead of pytest, and has commented-out code.

Verdict

REQUEST_CHANGES

@YanhuiDua YanhuiDua force-pushed the dev_tool_agent_loop branch from 6ca55dd to dcfc945 Compare March 9, 2026 02:47
Comment thread examples/v1/config/rl_grpo_gsm8k_with_tool.py Outdated
Comment thread examples/v1/config/rl_grpo_gsm8k_with_tool.py
)

# 3. judger
judger_config = GSM8KRouterJudgerConfig(judger_name="openai/gsm8k")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续 pr 改成 router 是通用类,而不是固定给某个 judge 用的

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py Outdated
Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py Outdated
Comment thread xtuner/v1/rl/agent_loop/gsm8k_with_tool.py Outdated
Comment thread xtuner/v1/train/rl_colocate_trainer.py Outdated
Comment thread xtuner/v1/train/rl_colocate_trainer.py Outdated
Comment thread xtuner/v1/train/rl_colocate_trainer.py Outdated
@YanhuiDua YanhuiDua force-pushed the dev_tool_agent_loop branch from 515e725 to 92970be Compare March 9, 2026 04:16
return content, function_calls

async def generate_sample(self, rollout_state: RolloutState) -> RolloutState:
# NOTE: 使用过程中发现很容易忘了给rollout_state传sample_params
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_group里已经传了

@jayhenry jayhenry merged commit 4d1e29d into InternLM:rl_design Mar 9, 2026
3 of 6 checks passed
@YanhuiDua YanhuiDua deleted the dev_tool_agent_loop branch March 17, 2026 09:41
YanhuiDua added a commit that referenced this pull request Apr 27, 2026
* add gsm8k_with_tool agent_loop as example

* fix claude comments

* fix haian comments

* add data_preprocess for gsm8k_with_tool
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants