In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"

In [2]:
args = ["--config", "train_config.yaml"]

In [3]:
import asyncio
import os
import sys
import uuid
import json
import gc
import torch
import torch.distributed as dist
import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast

from dataclasses import dataclass, field

from areal.api.cli_args import (
    GenerationHyperparameters,
    GRPOConfig,
    load_expr_config,
)
from areal.api.io_struct import (
    FinetuneSpec,
    ModelRequest,
    WeightUpdateMeta,
)
from areal.api.workflow_api import RolloutWorkflow
from areal.api.cli_args import GRPOConfig
from areal.engine.ppo.actor import FSDPPPOActor
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.utils.data import concat_padded_tensors
from areal.utils.device import log_gpu_stats
from areal.utils.saver import Saver
from areal.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, seeding, stats_tracker

logger = logging.getLogger("TIR")

  from .autonotebook import tqdm as notebook_tqdm
2025-09-02 15:16:33,085	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:

@dataclass
class AgentRLConfig(GRPOConfig):
    max_tokens_per_traj: int = field(
        default=32000,
        metadata={
            "help": "maximum number of tokens per trajectory"
        }
    )
    max_tokens: int = field(
        default=32000,
        metadata={
            "help": "maximum number of tokens (including input and output) for the model"
        }
    )

    max_turns: int = field(
        default=128,
        metadata={
            "help": "maximum number of turns for search agent"
        }
    )
    n_trajs: int = field(
        default=1,
        metadata={
            "help": "We could collect multiple trajectories for a single query. By default n_trajs=1."
        }
    )
    executor_url: str = field(
        default="http://localhost:1451",
        metadata={
            "help": "URL of the code executor service"
        }
    )

    dump_dir: str = field(
        default="./dump",
        metadata={
            "help": "directory to dump the trajectories"
        }
    )
    verbose: bool = field(
        default=True,
        metadata={
            "help": "whether to print verbose information"
        }
    )
    recover_start_step: int = field(
        default=0,
        metadata={
            "help": "step to start recovering from, useful for resuming training"
        }
    )

In [5]:
config, _ = load_expr_config(args, AgentRLConfig)
config: AgentRLConfig


config.dump_dir = os.path.join(
    StatsLogger.get_log_path(config.stats_logger), "generated"
)

config.dump_dir

'/home/liangchengwei/lcw/ZERO-TIR-RL/experiments/logs/liangchengwei/tir-grpo/trial0/generated'

In [6]:
from areal.utils.network import find_free_ports

SGLANG_PORT, MASTER_PORT = 11451, 14514

SGLANG_HOST = "127.0.0.1"

# Environment variables used by inference/train engines
import os

os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{SGLANG_HOST}:{SGLANG_PORT}"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["LOCAL_RANK"] = str(0)

In [7]:
# import subprocess
# import sys
# import threading
# import time

# # 启动sglang server
# from areal.api.cli_args import SGLangConfig
# from areal.utils.network import find_free_ports

# config.sglang.log_level = "info"
# config.sglang.decode_log_interval = 10
# sglang_cmd = SGLangConfig.build_cmd(
#     config.sglang,
#     tp_size=1,
#     base_gpu_id=1,
#     host=SGLANG_HOST,
#     port=SGLANG_PORT,
# )

# def read_pipe(pipe, prefix):
#     """实时读取管道输出并在notebook中显示"""
#     for line in iter(pipe.readline, b''):
#         try:
#             line_str = line.decode('utf-8').rstrip()
#             if line_str:
#                 print(f"[{prefix}] {line_str}")
#         except UnicodeDecodeError:
#             print(f"[{prefix}] <binary output>")
#     pipe.close()

# sglang_process = subprocess.Popen(
#     sglang_cmd,
#     shell=True,
#     stdout=subprocess.PIPE,
#     stderr=subprocess.PIPE,
#     bufsize=1,
#     universal_newlines=False
# )

# # 获取并打印进程号
# print(f"SGLang服务器已启动，进程号(PID): {sglang_process.pid}")

# # 启动线程实时读取stdout和stderr
# stdout_thread = threading.Thread(target=read_pipe, args=(sglang_process.stdout, "STDOUT"))
# stderr_thread = threading.Thread(target=read_pipe, args=(sglang_process.stderr, "STDERR"))
# stdout_thread.daemon = True
# stderr_thread.daemon = True
# stdout_thread.start()
# stderr_thread.start()

# print("SGLang服务器启动中，请等待初始化完成...")

In [8]:
import json

with open('orz_math_57k_collected.json', 'r') as f:
    raw_data = json.load(f)

print(">>>", len(raw_data)) # 56878
print(">>>", raw_data[2]) # [{'from': 'human', 'value': 'Consider all 1000-element subsets of the set $\\{1, 2, 3, ... , 2015\\}$.  From each such subset choose the least element.  The arithmetic mean of all of these least elements is $\\frac{p}{q}$, where $p$ and $q$ are relatively prime positive integers.  Find $p + q$.'}, {'from': 'assistant', 'ground_truth': {'value': '431'}}]

def process_raw_data(item):
    return {
        "question": item[0]['value'],
        "answer": item[1]['ground_truth']['value'],
    }

dataset = [process_raw_data(item) for item in raw_data]

print(">>>", dataset[2])

>>> 56878
>>> [{'from': 'human', 'value': 'Consider all 1000-element subsets of the set $\\{1, 2, 3, ... , 2015\\}$.  From each such subset choose the least element.  The arithmetic mean of all of these least elements is $\\frac{p}{q}$, where $p$ and $q$ are relatively prime positive integers.  Find $p + q$.'}, {'from': 'assistant', 'ground_truth': {'value': '431'}}]
>>> {'question': 'Consider all 1000-element subsets of the set $\\{1, 2, 3, ... , 2015\\}$.  From each such subset choose the least element.  The arithmetic mean of all of these least elements is $\\frac{p}{q}$, where $p$ and $q$ are relatively prime positive integers.  Find $p + q$.', 'answer': '431'}


In [9]:
dataloader = StatefulDataLoader(
    dataset,
    batch_size=config.train_dataset.batch_size,
    shuffle=config.train_dataset.shuffle,
    num_workers=config.train_dataset.num_workers,
    collate_fn=lambda x: x,
    drop_last=config.train_dataset.drop_last,
)
from itertools import cycle

data_generator = cycle(dataloader)

ft_spec = FinetuneSpec(
    total_train_epochs=config.total_train_epochs,
    dataset_size=len(dataloader) * config.train_dataset.batch_size,
    train_batch_size=config.train_dataset.batch_size,
)

example_batch = next(data_generator)
print(">>>", len(example_batch))
print(">>>", example_batch[0])

>>> 1
>>> {'question': 'Find $a+b+c$ if the graph of the equation $y=ax^2+bx+c$ is a parabola with vertex $(5,3)$, vertical axis of symmetry, and contains the point $(2,0)$.', 'answer': '-\\frac73'}


In [10]:
from concurrent.futures import ProcessPoolExecutor

rw_executor = ProcessPoolExecutor(max_workers=4)

from realhf.impl.dataset.math_parser import extract_answer, math_equal

REWARD_TIMEOUT_SECONDS = 15


def reward_fn(generated, answer):
    try:
        x = extract_answer(generated, "math", use_last_number=True)
        y = extract_answer(answer, "math", use_last_number=True)

        if x is None or x.strip() in ["None", "none", ""]:
            return 0.0
        elif y is None or y.strip() in ["None", "none", ""]:
            return 0.0
        return float(math_equal(x, y, timeout=False))
    except:
        return 0.0


# TODO: examine reward function
reward_fn(
    "\boxed{72}",
    "Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72",
)

1.0

In [11]:
import asyncio
import functools
import os
import time
import uuid

import colorama
import torch
from tensordict import TensorDict
from transformers import AutoTokenizer, PreTrainedTokenizerFast

from areal.api.cli_args import GenerationHyperparameters
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
    AllocationMode,
    FinetuneSpec,
    ModelRequest,
    WeightUpdateMeta,
)
from areal.api.workflow_api import RolloutWorkflow
from areal.engine.ppo.actor import FSDPPPOActor
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.utils.data import concat_padded_tensors
from areal.utils.device import log_gpu_stats

tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

if tokenizer.pad_token_id not in config.gconfig.stop_token_ids:
    config.gconfig.stop_token_ids.append(tokenizer.pad_token_id)
if tokenizer.eos_token_id not in config.gconfig.stop_token_ids:
    config.gconfig.stop_token_ids.append(tokenizer.eos_token_id)

In [14]:
import re

answer_test_str = r"""
<answer>72
<answer>82</answer>
<answer>line1
line2</answer>
"""

answer_matches = re.findall(r'<answer>(.*?)</answer>', answer_test_str, re.DOTALL)

print(">>>", answer_matches)

code_test_str = r"""
```python
import numpy as np

result = 2**31 - 1
print(result)
```

```python
def foo(x):
    return x + 1
print(foo(1))
```

```python
print(123)
print(foo(1))
```

```python
with open("test.txt", "w") as f:
    f.write("1 2 3")

with open("test.txt", "r") as f:
    content = f.read()
    print(content)
```
"""

code_matches = re.findall(r'```python(.*?)```', code_test_str, re.DOTALL)

print(">>>", code_matches)

>>> ['72\n<answer>82', 'line1\nline2']
>>> ['\nimport numpy as np\n\nresult = 2**31 - 1\nprint(result)\n', '\ndef foo(x):\n    return x + 1\nprint(foo(1))\n', '\nprint(123)\nprint(foo(1))\n', '\nwith open("test.txt", "w") as f:\n    f.write("1 2 3")\n\nwith open("test.txt", "r") as f:\n    content = f.read()\n    print(content)\n']


In [15]:
from code_executor import execute_code
for code in code_matches:
    print(">>>", execute_code(code, timeout=5))

>>> 2147483647

>>> 2

>>> 123
Traceback (most recent call last):
  File "/tmp/tmpl3oyv9xp/user_script.py", line 6, in <module>
    print(foo(1))
          ^^^
NameError: name 'foo' is not defined

>>> 1 2 3



In [None]:
# import aiohttp
# from typing import Dict, Any
# class CodeExecutorClient:
#     def __init__(self, server_url: str, timeout = 10, max_retries: int = 3):
#         self.server_url = server_url
#         self.timeout = timeout
#         self.max_retries = max_retries
#         self.session = None
    
#     async def __aenter__(self):
#         self.session = aiohttp.ClientSession()
#         return self
    
#     async def __aexit__(self, exc_type, exc_val, exc_tb):
#         if self.session:
#             await self.session.close()
    
#     async def execute_code(self, code: str, timeout: int = 10, traj_rid=None) -> Dict[str, Any]:
#         if not self.session:
#             self.session = aiohttp.ClientSession()
#         for _ in range(self.max_retries):
#             try:
#                 async with self.session.post(
#                     f"{self.server_url}/execute",
#                     json={"code": code, "timeout": timeout, "traj_rid": traj_rid},
#                     timeout=aiohttp.ClientTimeout(total=timeout + 5)
#                 ) as response:
#                     result = await response.json()

#             except Exception as e:
#                 result = {
#                     "success": False,
#                     "stdout": "",
#                     "stderr": "",
#                     "error": {
#                         "type": type(e).__name__,
#                         "message": str(e),
#                         "traceback": ""
#                     }
#                 }
#             if result.get("success", False):
#                 return result
#         return result
    
#     async def health_check(self) -> bool:
#         """
#         检查服务器是否健康
        
#         Returns:
#             服务器是否可用
#         """
#         if not self.session:
#             self.session = aiohttp.ClientSession()
        
#         try:
#             async with self.session.get(
#                 f"{self.server_url}/health",
#                 timeout=aiohttp.ClientTimeout(total=5)
#             ) as response:
#                 if response.status == 200:
#                     result = await response.json()
#                     return result.get("status") == "healthy"
#                 return False
#         except:
#             return False

# code_executor = CodeExecutorClient(config.executor_url)

In [None]:
# print(">>>", await code_executor.health_check())
# for code in code_matches:
#     print(">>>", await code_executor.execute_code(code))

>>> True
>>> {'success': True, 'content': '2147483647\n'}
>>> {'success': True, 'content': '2\n'}
>>> {'success': True, 'content': "123\nTraceback (most recent call last):\n   line 3, in <module>\nNameError: name 'foo' is not defined\n"}


In [17]:
SYSTEM_PROMPT = """
You are a helpful assistant. The User asks a question, and the Assistant solves it. 
The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. And your final answer will be extracted automatically by the \\boxed{{}} tag.
In your reasoning-process, You can use python-code to solve your problem. Put the code within ```python and ``` tags. The script will be executed immediately and output will be returned.
"""

def get_prompt(tokenizer, query):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": query}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    text += "<think>\n"
    return text

example_prompt = get_prompt(tokenizer, "What is 1+1?")
print(">>>", example_prompt)

>>> <|im_start|>system

You are a helpful assistant. The User asks a question, and the Assistant solves it. 
The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. And your final answer will be extracted automatically by the \boxed{{}} tag.
In your reasoning-process, You can use python-code to solve your problem. Put the code within ```python and ``` tags. The script will be executed immediately and output will be returned.
<|im_end|>
<|im_start|>user
What is 1+1?<|im_end|>
<|im_start|>assistant
<think>



In [None]:
import re
import threading
from code_executor_client import execute_python_code

class TIRWorkflow(RolloutWorkflow):
    def __init__(
        self, 
        config: AgentRLConfig, 
        tokenizer: PreTrainedTokenizerFast,
        code_executor,
    ):
        self.config = config
        self.gconfig = config.gconfig
        self.tokenizer = tokenizer
        self._qid_locks = {}
        self._locks_lock = threading.Lock()
        self.code_executor = code_executor
        self.current_trajs = 0

    def _get_qid_lock(self, qid):
        with self._locks_lock:
            if qid not in self._qid_locks:
                self._qid_locks[qid] = threading.Lock()
            return self._qid_locks[qid]

    async def collect_agent_trajectory(self, qid, prompt, answer, engine):
        traj_rid = uuid.uuid4().hex
        loop = asyncio.get_event_loop()
        result = None
        reward = 0.0

        num_turns = 0
        input_str = prompt
        input_ids = self.tokenizer.encode(input_str, add_special_tokens=False)
        logprobs = [0.0] * len(input_ids)
        loss_mask = [0] * len(input_ids)
        stops = ["```python", "</answer>"]
        total_gen_time = 0
        total_exec_time = 0
        start_time = time.time()
        while num_turns < self.config.max_turns:
            req = ModelRequest(
                rid=traj_rid,
                input_ids=input_ids,
                gconfig=self.gconfig.new(n_samples=1),
            )
            req.gconfig.stop = stops
            if len(input_ids) + self.gconfig.max_new_tokens >= self.config.max_tokens_per_traj:
                break
            
            gen_start_time = time.time()
            resp = await engine.agenerate(req)
            gen_time = time.time() - gen_start_time
            total_gen_time += gen_time
            completion_str = self.tokenizer.decode(resp.output_tokens)

            input_str += completion_str
            input_ids += resp.output_tokens
            logprobs += resp.output_logprobs
            loss_mask += [1] * len(resp.output_tokens)

            if "</answer>" in completion_str:
                matches = re.findall(r"<answer>(.*?)</answer>", completion_str, re.DOTALL)
                if matches:
                    result = matches[-1]
                    reward = await loop.run_in_executor(
                        rw_executor,
                        functools.partial(reward_fn, result, answer)
                    )
                    break
            elif stops[0] == "```python" and "```python" in completion_str:
                stops[0] = "```"
            elif stops[0] == "```" and "```" in completion_str:
                matches = re.findall(r'```python(.*?)```', input_str, re.DOTALL)
                if matches:
                    code = matches[-1]
                    exec_start_time = time.time()
                    exec_result = await self.code_executor.execute_code(code, traj_rid=traj_rid)
                    exec_time = time.time() - exec_start_time
                    total_exec_time += exec_time
                    
                    if exec_result["success"]:
                        execution_output = exec_result["content"]
                    else:
                        # 服务端错误
                        logger.error(f"Code execution failed: {exec_result['error']['message']}")
                        execution_output = "代码执行失败。"
                    
                    num_turns += 1
                    execution_output = "\n```output\n" + execution_output + "\n```\n"
                    input_str += execution_output
                    exec_tokens = self.tokenizer.encode(execution_output, add_special_tokens=False)
                    if len(input_ids) + len(exec_tokens) >= self.config.max_tokens_per_traj:
                        exec_tokens = exec_tokens[:self.config.max_tokens_per_traj - len(input_ids) - 1]
                    input_ids += exec_tokens
                    logprobs += [0.0] * len(exec_tokens)
                    loss_mask += [0] * len(exec_tokens)
                stops[0] = "```python"
            
            if resp.output_tokens[-1] in [self.tokenizer.eos_token_id, self.tokenizer.pad_token_id]:
                break

        total_time = time.time() - start_time

        if len(input_ids) > self.config.max_tokens_per_traj:
            assert False, f"Trajectory {traj_rid} exceeds max tokens {self.config.max_tokens_per_traj} with {len(input_ids)} tokens."
        
        res = dict(
            input_ids=torch.tensor(input_ids),
            logprobs=torch.tensor(logprobs),
            loss_mask=torch.tensor(loss_mask, dtype=torch.bool),
            rewards=torch.tensor(float(reward)),
            code_reward=torch.tensor(float(num_turns>0)),
            code_in_correct=torch.tensor(float(num_turns>0 and reward>0)),
            attention_mask=torch.ones(len(input_ids), dtype=torch.bool),
        )

        res_dump = {k: v.tolist() for k, v in res.items() if k != 'attention_mask' and k != 'rewards'}
        res_dump['input_str'] = input_str
        res_dump['metadata'] = {
            "reward": reward,
            "traj_rid": traj_rid,
            "num_turns": num_turns,
            "length": len(input_ids),
            "total_time": f"{total_time:.2f}s",
            "gen_time_ratio": f"{total_gen_time / total_time:.2f}" if total_time > 0 else "0.00",
            "exec_time_ratio": f"{total_exec_time / total_time:.2f}" if total_time > 0 else "0.00",
            "answer": answer,
            "result": result,
        }

        res = {k: v.unsqueeze(0) for k, v in res.items()}
        return TensorDict(res, batch_size=[1])

    async def arun_episode(self, engine, data):
        qid = uuid.uuid4().hex

        # prompt = PROMPT_TEMPLATE.format(query=data["question"])
        prompt = get_prompt(self.tokenizer, data["question"])

        trajs = await asyncio.gather(*[
            self.collect_agent_trajectory(qid, prompt, data["answer"], engine)
            for _ in range(self.config.n_trajs)
        ])


        
        return concat_padded_tensors(trajs)

workflow = TIRWorkflow(
    config=config,
    tokenizer=tokenizer,
    code_executor=code_executor
)

In [20]:
rollout = RemoteSGLangEngine(config.rollout)
rollout.initialize(None, None)
try:
    sample_data = next(data_generator)[:2]
    res = rollout.rollout_batch(sample_data, workflow=workflow)
    print(">>>", res)
finally:
    # rollout.destroy()
    pass

[37m20250830-17:58:13.117 areal.engine.sglang_remote INFO: Waiting for server ready...[0m
[37m20250830-17:58:13.124 areal.engine.sglang_remote INFO: Servers are all ready![0m
[31m20250830-17:58:18.685 TIR ERROR: Code execution failed: Task <Task pending name='Task-55' coro=<TCPConnector._resolve_host_with_throttle() running at /home/liangchengwei/miniconda3/envs/test/lib/python3.12/site-packages/aiohttp/connector.py:1179>> got Future <Future pending cb=[_chain_future.<locals>._call_check_cancel() at /home/liangchengwei/miniconda3/envs/test/lib/python3.12/asyncio/futures.py:389]> attached to a different loop[0m
>>> TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([16, 4000]), device=cpu, dtype=torch.bool, is_shared=False),
        code_in_correct: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False),
        code_reward: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False),
        input_ids: Tensor(sh

In [None]:
for i in range(len(res[0]['input_ids'])):
    print(f"id: {res[0]['input_ids'][i]} token: {tokenizer.decode(res[0]['input_ids'][i])}                  logprob: {res[0]['logprobs'][i]:.4f} loss_mask: {res[0]['loss_mask'][i]} attention_mask: {res[0]['attention_mask'][i]}")

id: 151644 token: <|im_start|>                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 8948 token: system                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 271 token: 

                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 2610 token: You                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 525 token:  are                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 264 token:  a                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 10950 token:  helpful                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 17847 token:  assistant                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 13 token: .                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 576 token:  The                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 2657 token:  User                  logprob: 0.0000 loss_mask: Fals

In [12]:
with open("debug_batch.pkl", "rb") as f:
    import pickle
    debug_batch = pickle.load(f)

debug_batch

TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([16, 706]), device=cpu, dtype=torch.bool, is_shared=False),
        code_in_correct: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False),
        code_reward: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False),
        input_ids: Tensor(shape=torch.Size([16, 706]), device=cpu, dtype=torch.int64, is_shared=False),
        logprobs: Tensor(shape=torch.Size([16, 706]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_mask: Tensor(shape=torch.Size([16, 706]), device=cpu, dtype=torch.bool, is_shared=False),
        rewards: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False),
        score: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([16]),
    device=None,
    is_shared=False)

In [13]:
for i in range(len(debug_batch)):
    print(">>>", tokenizer.decode(debug_batch[i]['input_ids']))
    print(">>>", debug_batch[i]['rewards'])
    print(">>>", debug_batch[i]['score'])

>>> <|im_start|>system

You are a helpful assistant. The User asks a question, and the Assistant solves it. 
The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. And your final answer will be extracted automatically by the \boxed{{}} tag.
In your reasoning-process, You can use python-code to solve your problem. Put the code within ```python and ``` tags. The script will be executed immediately and output will be returned.
<|im_end|>
<|im_start|>user
Suppose that we have a right triangle $DEF$ with the right angle at $E$ such that $DF = \sqrt{85}$ and $DE = 7$. A circle is drawn with its center on $DE$ such that the circle is tangent to $DF$ and $EF$. If $Q$ is the point where the circle and side $DF$ meet, then what is $FQ$?<|i

In [14]:
idx = 1
for i in range(len(debug_batch[idx]['input_ids'])):
    print(f"id: {debug_batch[idx]['input_ids'][i]} token: {tokenizer.decode(debug_batch[idx]['input_ids'][i])}                  logprob: {debug_batch[idx]['logprobs'][i]:.4f} loss_mask: {debug_batch[idx]['loss_mask'][i]} attention_mask: {debug_batch[idx]['attention_mask'][i]}")

id: 151644 token: <|im_start|>                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 8948 token: system                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 271 token: 

                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 2610 token: You                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 525 token:  are                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 264 token:  a                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 10950 token:  helpful                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 17847 token:  assistant                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 13 token: .                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 576 token:  The                  logprob: 0.0000 loss_mask: False attention_mask: True
id: 2657 token:  User                  logprob: 0.0000 loss_mask: Fals

In [17]:
from sympy import symbols, Eq, solve

x = symbols('x')
eq = Eq(x**3 + 1/x**3, -52)

solution = solve(eq, x)

x_value = solution[0] # solution of the equation
expr = x_value + 1/x_value

print(expr)

-2 - sqrt(3) + 1/(-2 - sqrt(3))


In [18]:
rank = int(os.getenv("RANK"))
world_size = int(os.getenv("WORLD_SIZE"))
assert world_size == 1, "This script is designed to run in a single process environment."
seeding.set_random_seed(config.seed, key=f"trainer{rank}")

worker_batch_size = config.train_dataset.batch_size

actor = FSDPPPOActor(config=config.actor)
actor.initialize(None, ft_spec)
ref = None

weight_update_meta = WeightUpdateMeta.from_disk(
    experiment_name=config.saver.experiment_name,
    trial_name=config.saver.trial_name,
    file_root=config.saver.fileroot,
)

saver = Saver(config.saver, ft_spec)
stat_logger = StatsLogger(config.stats_logger, ft_spec)

total_epochs = config.total_train_epochs
steps_per_epoch = len(dataloader)
max_steps = total_epochs * steps_per_epoch
start_step = config.recover_start_step or 0

[37m20250827-19:00:50.282 Base HF Engine INFO: Model creation and loading time: 2.626308770850301[0m
[37m20250827-19:00:50.390 FSDPEngine INFO: Applying FSDP2 time: 0.1044454537332058[0m
[37m20250827-19:00:50.393 Base HF Engine INFO: Create optimizer time: 0.0011926889419555664[0m




In [19]:
global_step = 0
epoch = global_step // steps_per_epoch
step = global_step % steps_per_epoch
print(f"Epoch {epoch}. Step: {step}/{steps_per_epoch}")

with stats_tracker.record_timing("rollout"):
    if config.async_training:
        batch = rollout.prepare_batch(dataloader, workflow=workflow)
    else:
        try:
            data = next(data_generator)
        except StopIteration:
            data_generator = iter(dataloader)
            data = next(data_generator)
        batch = rollout.rollout_batch(data, workflow=workflow)
    
print(batch[0])

batch = batch.to(actor.device)

print(len(batch))

Epoch 0. Step: 0/28439
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([491]), device=cpu, dtype=torch.bool, is_shared=False),
        input_ids: Tensor(shape=torch.Size([491]), device=cpu, dtype=torch.int64, is_shared=False),
        logprobs: Tensor(shape=torch.Size([491]), device=cpu, dtype=torch.float32, is_shared=False),
        loss_mask: Tensor(shape=torch.Size([491]), device=cpu, dtype=torch.bool, is_shared=False),
        rewards: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
32


In [26]:
print(len(batch[2]["input_ids"]))

491


In [20]:
# with torch.no_grad():
#     actor.actor.engine.eval()
#     print(">>> set eval mode")
#     assert hasattr(actor.actor.engine, "forward") and callable(getattr(actor.actor.engine, "forward")), "actor.actor.engine does not have a callable forward method"
#     print(">>> actor.actor.engine has forward")
#     print(">>>", actor.actor.engine.forward)

In [None]:
# if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
#     logp = actor.compute_logp(batch)
#     batch["prox_logp"] = logp

In [None]:


for global_step in range(start_step, max_steps):
    epoch = global_step // steps_per_epoch
    step = global_step % steps_per_epoch
    print(f"Epoch {epoch}. Step: {step}/{steps_per_epoch}")

    with stats_tracker.record_timing("rollout"):
        if config.async_training:
            batch = rollout.prepare_batch(dataloader, workflow=workflow)
        else:
            try:
                data = next(data_generator)
            except StopIteration:
                data_generator = iter(dataloader)
                data = next(data_generator)
            batch = rollout.rollout_batch(data, workflow=workflow)

    batch = batch.to(actor.device)

    if config.actor.recompute_logprob or config.actor.use_decoupled_loss:
        with stats_tracker.record_timing("recompute_logp"):
            logp = actor.compute_logp(batch)
            batch["prox_logp"] = logp
            log_gpu_stats("recompute logp")
    
    if ref is not None:
        with stats_tracker.record_timing("ref_logp"):
            batch["ref_logp"] = ref.compute_logp(batch)
            log_gpu_stats("ref logp")

    with stats_tracker.record_timing("compute_advantage"):
        actor.compute_advantages(batch)
        log_gpu_stats("compute advantages")
    
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()

    with (
        stats_tracker.record_timing("train_step"),
        stats_tracker.scope("grpo_actor"),
    ):
        stats = actor.ppo_update(batch)
        actor.step_lr_scheduler()
        log_gpu_stats("ppo update")
    
    with stats_tracker.record_timing("update_weights"):
        rollout.pause()
        future = rollout.update_weights(weight_update_meta)
        actor.upload_weights(weight_update_meta)
        future.result()
        dist.barrier(device_ids=[actor.device.index])
        torch.cuda.synchronize()
        rollout.resume()
        actor.set_version(global_step + 1)
        rollout.set_version(global_step + 1)
    
    stat_logger.commit(epoch, step, global_step, stats)

stat_logger.close()
rollout.destroy()
if ref is not None:
    ref.destroy()
actor.destroy()