In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/"

In [2]:
args = ["--config", "train_config.yaml"]

In [3]:
import asyncio
import os
import sys
import uuid
import json
import gc
import torch
import torch.distributed as dist
import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from tensordict import TensorDict
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import PreTrainedTokenizerFast

from dataclasses import dataclass, field

from areal.api.cli_args import (
    GenerationHyperparameters,
    GRPOConfig,
    load_expr_config,
)
from areal.api.io_struct import (
    FinetuneSpec,
    ModelRequest,
    WeightUpdateMeta,
)
from areal.api.workflow_api import RolloutWorkflow
from areal.api.cli_args import GRPOConfig
from areal.engine.ppo.actor import FSDPPPOActor
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.utils.data import concat_padded_tensors
from areal.utils.device import log_gpu_stats
from areal.utils.saver import Saver
from areal.utils.stats_logger import StatsLogger
from realhf.api.core.data_api import load_hf_tokenizer
from realhf.base import logging, seeding, stats_tracker

logger = logging.getLogger("TIR")

  from .autonotebook import tqdm as notebook_tqdm
2025-09-02 14:59:50,136	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [4]:
@dataclass
class AgentRLConfig(GRPOConfig):
    max_tokens_per_traj: int = field(
        default=32000,
        metadata={
            "help": "maximum number of tokens per trajectory"
        }
    )
    max_tokens: int = field(
        default=32000,
        metadata={
            "help": "maximum number of tokens (including input and output) for the model"
        }
    )

    max_turns: int = field(
        default=128,
        metadata={
            "help": "maximum number of turns for search agent"
        }
    )
    n_trajs: int = field(
        default=1,
        metadata={
            "help": "We could collect multiple trajectories for a single query. By default n_trajs=1."
        }
    )
    executor_url: str = field(
        default="http://localhost:1451",
        metadata={
            "help": "URL of the code executor service"
        }
    )

    dump_dir: str = field(
        default="./dump",
        metadata={
            "help": "directory to dump the trajectories"
        }
    )
    verbose: bool = field(
        default=True,
        metadata={
            "help": "whether to print verbose information"
        }
    )
    recover_start_step: int = field(
        default=0,
        metadata={
            "help": "step to start recovering from, useful for resuming training"
        }
    )

In [5]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()
config, _ = load_expr_config(args, AgentRLConfig)
config: AgentRLConfig


config.dump_dir = os.path.join(
    StatsLogger.get_log_path(config.stats_logger), "generated"
)

config.dump_dir

'/home/liangchengwei/lcw/ZERO-TIR-RL/experiments/logs/liangchengwei/tir-grpo/trial0/generated'

In [6]:
from areal.utils.network import find_free_ports

SGLANG_PORT, MASTER_PORT = 11451, 14514

SGLANG_HOST = "127.0.0.1"

# Environment variables used by inference/train engines
import os

os.environ["AREAL_LLM_SERVER_ADDRS"] = f"{SGLANG_HOST}:{SGLANG_PORT}"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(MASTER_PORT)
os.environ["RANK"] = str(0)
os.environ["WORLD_SIZE"] = str(1)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["LOCAL_RANK"] = str(0)

In [7]:
# config.sglang.tokenizer_path = config.sglang.model_path
config.sglang.model_path = "/home/liangchengwei/lcw/ZERO-TIR-RL/experiments/checkpoints/liangchengwei/tir/debug/default/epoch0epochstep399globalstep399"
config.sglang.skip_tokenizer_init = True

In [8]:
import subprocess
import sys
import threading
import time

# 启动sglang server
from areal.api.cli_args import SGLangConfig
from areal.utils.network import find_free_ports

config.sglang.log_level = "info"
config.sglang.decode_log_interval = 10
sglang_cmd = SGLangConfig.build_cmd(
    config.sglang,
    tp_size=1,
    base_gpu_id=1,
    host=SGLANG_HOST,
    port=SGLANG_PORT,
)

def read_pipe(pipe, prefix):
    """实时读取管道输出并在notebook中显示"""
    for line in iter(pipe.readline, b''):
        try:
            line_str = line.decode('utf-8').rstrip()
            if line_str:
                print(f"[{prefix}] {line_str}")
        except UnicodeDecodeError:
            print(f"[{prefix}] <binary output>")
    pipe.close()

sglang_process = subprocess.Popen(
    sglang_cmd,
    shell=True,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    bufsize=1,
    universal_newlines=False
)

# 获取并打印进程号
print(f"SGLang服务器已启动，进程号(PID): {sglang_process.pid}")

# 启动线程实时读取stdout和stderr
stdout_thread = threading.Thread(target=read_pipe, args=(sglang_process.stdout, "STDOUT"))
stderr_thread = threading.Thread(target=read_pipe, args=(sglang_process.stderr, "STDERR"))
stdout_thread.daemon = True
stderr_thread.daemon = True
stdout_thread.start()
stderr_thread.start()

print("SGLang服务器启动中，请等待初始化完成...")

SGLang服务器已启动，进程号(PID): 2373777
SGLang服务器启动中，请等待初始化完成...


  self.stdout = io.open(c2pread, 'rb', bufsize)
  self.stderr = io.open(errread, 'rb', bufsize)


[STDERR] [2025-09-02 15:00:07] Init torch distributed begin.
[STDERR] [2025-09-02 15:00:07] Init torch distributed ends. mem usage=0.00 GB
[STDERR] [2025-09-02 15:00:08] Load weight begin. avail mem=48.49 GB
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.75it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.75it/s]
[STDERR] [2025-09-02 15:00:09] Load weight end. type=Qwen2ForCausalLM, dtype=torch.bfloat16, avail mem=45.42 GB, mem usage=3.07 GB.
[STDERR] [2025-09-02 15:00:09] KV Cache is allocated. #tokens: 1337086, K size: 17.85 GB, V size: 17.85 GB
[STDERR] [2025-09-02 15:00:09] Memory pool end. avail mem=9.12 GB
[STDERR] [2025-09-02 15:00:09] Capture cuda graph begin. This can take up to several minutes. avail mem=9.10 GB
[STDERR] [2025-09-02 15:00:09] Capture cuda graph bs [1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128