# FoldAgent LocalSearch Test (bc_test_emh.parquet)

This notebook tests your FoldAgent loop in the LocalSearch environment:
- Starts a standalone rollout server (vLLM by default).
- Reads a sample from `/opt/tiger/verl_context_folding/bc_test_emh.parquet`.
- Uses `LOCAL_SEARCH_URL` (set below) to hit your local search server.
- Feeds one sample to `FoldAgentLoop` and prints a compact summary.

Note: Replace `LOCAL_SEARCH_URL` with your actual local search server endpoint (e.g., `http://127.0.0.1:8000`).

In [1]:
import os
import asyncio
import numpy as np
import pandas as pd

import ray
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf

import verl
from verl import DataProto
from verl.experimental.agent_loop import AgentLoopWorker
from verl.experimental.agent_loop.FoldAgent import FoldAgentLoop  # Ensures @register("fold_agent") runs
from verl.workers.rollout.replica import get_rollout_replica_class
from huggingface_hub import snapshot_download

# Fast Ray init; tweak as needed
ray.init(runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, ignore_reinit_error=True)
verl_config_dir = os.path.join(os.path.dirname(verl.__file__), "trainer/config")

rollout_name = "vllm"  # or "sglang"

# Download a small-ish instruct model (adjust if you already have one).
model_path = os.path.expanduser("~/Qwen/Qwen3-4B")
snapshot_download(repo_id="Qwen/Qwen3-4B", repo_type="model", local_dir=model_path)

with initialize_config_dir(config_dir=verl_config_dir):
    config = compose(
        config_name="ppo_trainer",
        overrides=[
            # rollout engine
            "actor_rollout_ref.rollout.name=" + rollout_name,
            "actor_rollout_ref.rollout.mode=async",
            "actor_rollout_ref.rollout.tensor_model_parallel_size=1",
            "actor_rollout_ref.rollout.data_parallel_size=1",
            "actor_rollout_ref.rollout.pipeline_model_parallel_size=1",
            "actor_rollout_ref.rollout.skip_tokenizer_init=False",
            "actor_rollout_ref.rollout.prompt_length=4096",
            "actor_rollout_ref.rollout.response_length=32768",
            # model
            "actor_rollout_ref.model.path=" + model_path,
            # agent loop: use our FoldAgent
            "actor_rollout_ref.rollout.agent.default_agent_loop=fold_agent",
            "actor_rollout_ref.rollout.agent.num_workers=1",
            # trainer sizing
            "trainer.n_gpus_per_node=8",
            "trainer.nnodes=1",
            "trainer.logger=['console']",
            "trainer.project_name=verl",
            "trainer.experiment_name=" + os.path.basename(model_path)
        ],
    )

# Make a safe copy of the trainer config to attach plugin without affecting rollout server instantiation
trainer_config_with_plugin = OmegaConf.create(OmegaConf.to_container(config, resolve=False))
OmegaConf.set_struct(trainer_config_with_plugin.actor_rollout_ref.rollout, False)

# Inject FoldAgent plugin fields on the copied config
trainer_config_with_plugin.actor_rollout_ref.rollout.plugin = OmegaConf.create({
    "workflow": "search",
    "max_turn": 20,
    "retry_cjk": 0,
    "turn_max_new_tokens": 2048,
    "max_session": 3,
    "val_max_session": 3,
    "session_timeout": 3600,
    "enable_summary": False,
    "branch_len": 256,
    "process_reward": "flat,scope",
    "max_traj": 4,
    "must_finish": False,
    "double_check": False,
    "must_search": True,
    "val_max_turn": 32,
    "val_response_length": 1024,
})

print("Plugin config:", OmegaConf.to_container(trainer_config_with_plugin.actor_rollout_ref.rollout.plugin, resolve=True))

print("Model:", config.actor_rollout_ref.model.path)
print("Rollout:", config.actor_rollout_ref.rollout.name)
print("Agent loop:", config.actor_rollout_ref.rollout.agent.default_agent_loop)

[2025-12-30 08:57:54,529 I 1467692 1467692] gcs_rpc_client.h:648: successful connect gcs: 10.122.253.153:59548
[2025-12-30 08:57:54,817 I 1467692 1467692] gcs_rpc_client.h:648: successful connect gcs: 10.122.253.153:59548
[2025-12-30 08:57:57,306 I 1467692 1467692] gcs_rpc_client.h:648: successful connect gcs: 10.122.253.153:59548
[2025-12-30 08:57:57,308 I 1467692 1467692] gcs_rpc_client.h:648: successful connect gcs: 10.122.253.153:59548
2025-12-30 08:57:57,309	INFO worker.py:1887 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m
[2025-12-30 08:57:57,311 I 1467692 1467692] gcs_rpc_client.h:648: successful connect gcs: 10.122.253.153:59548


Logs are printed to python-core-driver-01000000ffffffffffffffffffffffffffffffffffffffffffffffff_1467692.log


Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize_config_dir(config_dir=verl_config_dir):


Plugin config: {'workflow': 'search', 'max_turn': 20, 'retry_cjk': 0, 'turn_max_new_tokens': 2048, 'max_session': 3, 'val_max_session': 3, 'session_timeout': 3600, 'enable_summary': False, 'branch_len': 256, 'process_reward': 'flat,scope', 'max_traj': 4, 'must_finish': False, 'double_check': False, 'must_search': True, 'val_max_turn': 32, 'val_response_length': 1024}
Model: /home/tiger/Qwen/Qwen3-4B
Rollout: vllm
Agent loop: fold_agent


## Start standalone rollout server
Initializes a single-node rollout and exposes `server_handle` for token generation.

In [2]:
rollout_server_class = get_rollout_replica_class(config.actor_rollout_ref.rollout.name)
rollout_server = rollout_server_class(
    replica_rank=0,
    config=config.actor_rollout_ref.rollout,
    model_config=config.actor_rollout_ref.model,
    gpus_per_node=config.trainer.n_gpus_per_node,
)
await rollout_server.init_standalone()
print("Rollout server address:", rollout_server.server_address)

INFO 12-30 08:58:25 [__init__.py:235] Automatically detected platform cuda.




(pid=1479125, ip=10.122.253.153) INFO 12-30 08:58:38 [__init__.py:235] Automatically detected platform cuda.
(pid=1479328, ip=10.122.253.153) INFO 12-30 08:58:48 [__init__.py:235] Automatically detected platform cuda.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) ['serve',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '/home/tiger/Qwen/Qwen3-4B',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--dtype',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  'bfloat16',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--load_format',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  'auto',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--max_model_len',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '36864',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--max_num_seqs',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '1024',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--enable_chunked_prefill',
(vLLMHttpServer pid=1479328, ip=10.122.253.153)  '--max_num_batched_tokens'

(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO:2025-12-30 08:58:54,206:vLLMHttpServer, replica_rank: 0, master address: 10.122.253.153, master port: 44735, data parallel master port: 40587
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO:2025-12-30 08:58:54,213:override_generation_config: {'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'repetition_penalty': 1.0, 'max_new_tokens': 32768}
(vLLMHttpServer pid=1479328, ip=10.122.253.153) Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO:2025-12-30 08:58:54,925:replica_rank=0, node_rank=0, nnodes=1, get worker zmq addresses: ['ipc:///tmp/verl_vllm_zmq_1479125_tiger.ipc']


(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:00 [config.py:1604] Using max model len 36864
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:00 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=8192.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:06 [__init__.py:235] Automatically detected platform cuda.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:11 [core.py:572] Waiting for init message from front-end.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:11 [core.py:71] Initializing a V1 LLM engine (v0.10.0) with config: model='/home/tiger/Qwen/Qwen3-4B', speculative_config=None, tokenizer='/home/tiger/Qwen/Qwen3-4B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=36864, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_pa

Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  33% Completed | 1/3 [00:00<00:01,  1.39it/s]
Loading safetensors checkpoint shards:  67% Completed | 2/3 [00:01<00:00,  1.24it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  1.97it/s]
Loading safetensors checkpoint shards: 100% Completed | 3/3 [00:01<00:00,  1.73it/s]
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) 


(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:17 [default_loader.py:262] Loading weights took 1.74 seconds
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:18 [gpu_model_runner.py:1892] Model loading took 7.5552 GiB and 1.902710 seconds
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:28 [backends.py:530] Using cache directory: /home/tiger/.cache/vllm/torch_compile_cache/bd66293048/rank_0_0/backbone for vLLM's torch.compile
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:28 [backends.py:541] Dynamo bytecode transform time: 9.64 s
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:36 [backends.py:161] Directly load the compiled graph(s) for dynamic shape from the cache, took 7.685 s
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:39 [monitor.py:34] torch.compile takes 9.64 s in total
(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:40 [gpu_worker.py:255] Avail

Capturing CUDA graph shapes:   0%|          | 0/67 [00:00<?, ?it/s]
Capturing CUDA graph shapes:   4%|▍         | 3/67 [00:00<00:02, 23.87it/s]
Capturing CUDA graph shapes:   9%|▉         | 6/67 [00:00<00:02, 24.15it/s]
Capturing CUDA graph shapes:  13%|█▎        | 9/67 [00:00<00:02, 24.19it/s]
Capturing CUDA graph shapes:  18%|█▊        | 12/67 [00:00<00:02, 24.11it/s]
Capturing CUDA graph shapes:  22%|██▏       | 15/67 [00:00<00:02, 23.97it/s]
Capturing CUDA graph shapes:  27%|██▋       | 18/67 [00:00<00:02, 24.11it/s]
Capturing CUDA graph shapes:  31%|███▏      | 21/67 [00:00<00:01, 23.80it/s]
Capturing CUDA graph shapes:  36%|███▌      | 24/67 [00:01<00:01, 23.93it/s]
Capturing CUDA graph shapes:  40%|████      | 27/67 [00:01<00:01, 24.05it/s]
Capturing CUDA graph shapes:  45%|████▍     | 30/67 [00:01<00:01, 23.68it/s]
Capturing CUDA graph shapes:  49%|████▉     | 33/67 [00:01<00:01, 24.09it/s]
Capturing CUDA graph shapes:  54%|█████▎    | 36/67 [00:01<00:01, 24.78it/s]
Capturing C

(vLLMAsyncRollout pid=1479125, ip=10.122.253.153) INFO 12-30 08:59:44 [gpu_model_runner.py:2485] Graph capturing finished in 3 secs, took 0.61 GiB
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:44 [core.py:193] init engine (profile, create kv cache, warmup model) took 25.95 seconds
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:44 [serving_responses.py:89] Using default chat sampling params from model: {'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'max_tokens': 32768}
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:44 [serving_chat.py:122] Using default chat sampling params from model: {'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'max_tokens': 32768}
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 08:59:44 [serving_completion.py:77] Using default completion sampling params from model: {'repetition_penalty': 1.0, 'temperature': 1.0, 'top_k': -1, 'top_p': 1, 'max_tokens'

(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO:2025-12-30 08:59:44,732:Initializing a V1 LLM engine with config: model='/home/tiger/Qwen/Qwen3-4B', speculative_config=None, tokenizer='/home/tiger/Qwen/Qwen3-4B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=36864, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/home/tiger/Qwen/Qwen3-4B, num_scheduler_steps=1, multi_step_stre

Rollout server address: 10.122.253.153:45141


## Read dataset and build one test sample
Extracts `query` and `answer` from `bc_test_emh.parquet`, sets LocalSearch URL, and builds `DataProto`.

In [8]:
dataset_path = "/opt/tiger/verl_context_folding/bc_test_emh.parquet"
assert os.path.exists(dataset_path), f"Dataset not found: {dataset_path}"
df = pd.read_parquet(dataset_path)
# Choose a sample
sample_idx = 0

extra_info = df.iloc[sample_idx]['extra_info']  
uid = f"bc_test_emh:{sample_idx}"
reward_model = "default"
raw_prompt = "DUMMY RAW PROMPR (SHOULD BE AFTER CHAT TEMPLATE)"

batch = DataProto.from_dict(
    tensors={},
    non_tensors={
        "raw_prompt": np.array([raw_prompt], dtype=object),
        "extra_info": np.array([extra_info], dtype=object),
        "uid": np.array([uid], dtype=object),
        "reward_model": np.array([reward_model], dtype=object),
        "ability": np.array(["LocalSearch"], dtype=object),
        "agent_name": np.array(["fold_agent"], dtype=object),
        "index": np.array([0], dtype=object),
    },
    meta_info={"validate": False, "global_steps": 0},
)
batch.non_tensor_batch['extra_info'][0]

{'answer': 'Emmanuel Kwesi Danso Arthur Junior ',
 'evidence_docs': array([{'docid': '8354', 'url': 'https://en.wikipedia.org/wiki/Kwesi_Arthur'},
        {'docid': '81842', 'url': 'https://en.wikipedia.org/wiki/Ghana_Independence_Act_1957'},
        {'docid': '38083', 'url': 'https://music.apple.com/gb/album/live-from-nkrumah-krom-ep/1229290143'},
        {'docid': '32920', 'url': 'https://en.wikipedia.org/wiki/Sapphire_Jubilee_of_Elizabeth_II'},
        {'docid': '7675', 'url': 'https://www.myjoyonline.com/kwesi-arthur-releases-much-anticipated-son-of-jacob-album/'},
        {'docid': '75392', 'url': 'https://pan-african-music.com/en/kwesi-arthur-got-something-to-say-on-live-from-nkrumah-krom-vol-2/'},
        {'docid': '61482', 'url': 'https://www.newwavemagazine.com/single-post/kwesi-arthur-an-artist-inspiring-african-youth-one-story-at-a-time-interview'},
        {'docid': '86319', 'url': 'https://www.viberate.com/artist/kwesi-arthur/'},
        {'docid': '79213', 'url': 'https://

## Run the agent loop and summarize output
Creates a worker bound to the rollout server, runs `generate_sequences`, and prints a concise summary.

In [9]:
# Inject FoldAgent plugin fields on the copied config
trainer_config_with_plugin.actor_rollout_ref.rollout.plugin = OmegaConf.create({
    "workflow": "search_branch",
    "max_turn": 10,
    "retry_cjk": 0,
    "turn_max_new_tokens": 2048,
    "max_session": 2,
    "val_max_session": 3,
    "session_timeout": 3600,
    "enable_summary": False,
    "branch_len": 256,
    "process_reward": "flat,scope",
    "max_traj": 3,
    "must_finish": False,
    "double_check": False,
    "must_search": True,
    "val_max_turn": 32,
    "val_response_length": 1024,
})


In [10]:
# LOCAL_SEARCH_URL = os.environ.get("LOCAL_SEARCH_URL", "http://[2605:340:cd51:7700:3900:9815:f3ac:c6d2]:8000").rstrip("/")
LOCAL_SEARCH_URL = "http://[2605:340:cd51:7700:912f:284d:9dd7:367f]:8000"
print("Using LOCAL_SEARCH_URL:", LOCAL_SEARCH_URL)

alm_worker = AgentLoopWorker.options(
    name="fold_agent_local_search_worker_branch_no_summary_2",
    runtime_env={"env_vars": {"LOCAL_SEARCH_URL": LOCAL_SEARCH_URL}},
).remote(
    trainer_config_with_plugin,
    [rollout_server.server_handle],
    None,
)
output = ray.get(alm_worker.generate_sequences.remote(batch))

print("Reward score tensor present?", 'rm_scores' in output.batch)
print("Trajectories:", output.batch['responses'].shape[0])

Using LOCAL_SEARCH_URL: http://[2605:340:cd51:7700:912f:284d:9dd7:367f]:8000


(AgentLoopWorker pid=1483041, ip=10.122.253.153) INFO 12-30 09:04:25 [__init__.py:235] Automatically detected platform cuda.
(vLLMHttpServer pid=1479328, ip=10.122.253.153) INFO 12-30 09:04:27 [async_llm.py:269] Added request 6608905b23d84ae69d16796ec7319270.
(AgentLoopWorker pid=1483041, ip=10.122.253.153) {'extra_info': [{'answer': 'Emmanuel Kwesi Danso Arthur Junior ', 'evidence_docs': array([{'docid': '8354', 'url': 'https://en.wikipedia.org/wiki/Kwesi_Arthur'},
(AgentLoopWorker pid=1483041, ip=10.122.253.153)        {'docid': '81842', 'url': 'https://en.wikipedia.org/wiki/Ghana_Independence_Act_1957'},
(AgentLoopWorker pid=1483041, ip=10.122.253.153)        {'docid': '38083', 'url': 'https://music.apple.com/gb/album/live-from-nkrumah-krom-ep/1229290143'},
(AgentLoopWorker pid=1483041, ip=10.122.253.153)        {'docid': '32920', 'url': 'https://en.wikipedia.org/wiki/Sapphire_Jubilee_of_Elizabeth_II'},
(AgentLoopWorker pid=1483041, ip=10.122.253.153)        {'docid': '7675', 'url':

(AgentLoopWorker pid=1483041, ip=10.122.253.153) You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
*** SIGTERM received at time=1767057578 on cpu 82 ***
PC: @     0x7f9407299ee6  (unknown)  epoll_wait
    @     0x7f94071cd050  (unknown)  (unknown)
[2025-12-30 09:19:38,368 E 1467692 1467692] logging.cc:367: *** SIGTERM received at time=1767057578 on cpu 82 ***
[2025-12-30 09:19:38,368 E 1467692 1467692] logging.cc:367: PC: @     0x7f9407299ee6  (unknown)  epoll_wait
[2025-12-30 09:19:38,368 E 1467692 1467692] logging.cc:367:     @     0x7f94071cd050  (unknown)  (unknown)


In [6]:

def summarize_fold_output(dp):
    print("\n=== Summary ===")
    n = dp.batch['responses'].shape[0]
    for i in range(n):
        print(f"-- trajectory {i} --")
        env_stats = dp.non_tensor_batch.get("env_stats", [None])[i]
        agent_name = dp.non_tensor_batch.get("agent_name", [None])[i]
        mask_rollout = dp.non_tensor_batch.get("mask_rollout", [None])[i]
        is_finish = dp.non_tensor_batch.get("is_finish", [None])[i]
        branch_names = dp.non_tensor_batch.get("branch_names", [None])[i]
        if isinstance(env_stats, dict):
            print("env_stats keys:", list(env_stats.keys()))
            print(env_stats)
        print("agent_name:", agent_name, "mask_rollout:", mask_rollout, "is_finish:", is_finish)
        if branch_names:
            print("branch_names:", branch_names)
        messages = dp.non_tensor_batch.get("messages", [None])[i]
        if isinstance(messages, list):
            print("Transcript preview:")
            for m in messages:
                role = m.get("role")
                content = m.get("content")
                print(f"- {role}: {content}")
    print("=== End Summary ===")

summarize_fold_output(output)


=== Summary ===
-- trajectory 0 --
env_stats keys: ['finish', 'search', 'open_page', 'change_answer', 'is_search', 'is_open', 'is_finish', 'visit_pages', 'action', 'session_time', 'get_final_score', 'traj_num', 'main_len', 'total_token', 'main_turn', 'is_branch', 'branch_success', 'use_all_branch']
Counter({'total_token': 28914, 'main_len': 17750, 'session_time': 176.4272162914276, 'visit_pages': 20, 'main_turn': 8, 'action': 3, 'search': 2, 'open_page': 1, 'is_search': 1, 'is_open': 1, 'traj_num': 1, 'finish': 0, 'change_answer': 0, 'is_finish': 0, 'get_final_score': 0, 'is_branch': 0, 'branch_success': 0, 'use_all_branch': 0})
agent_name: main mask_rollout: False is_finish: False
Transcript preview:
- system: You are a **Multi-Role Research Agent**, an advanced AI designed to conduct comprehensive, multi-step research. Your purpose is to deliver a thorough, accurate, and well-supported report in response to a user's query.

You operate in one of two modes: **MAIN** or **BRANCH**. Yo

In [None]:
# Choose a sample
sample_idx = 0
raw_prompt = df.iloc[sample_idx]["prompt"]  # raw prompt
# query = str(df.iloc[sample_idx]["prompt"])  # question text
answer = str(df.iloc[sample_idx]["answer"])  # ground-truth label

print(answer)

extra_info = {
    "workflow": "search_branch",
    "raw_prompt": raw_prompt,
    "answer": answer,
}
uid = f"bc_test_emh:{sample_idx}"
reward_model = "default"

batch = DataProto.from_dict(
    tensors={},
    non_tensors={
        "raw_prompt": np.array([raw_prompt], dtype=object),
        "extra_info": np.array([extra_info], dtype=object),
        "uid": np.array([uid], dtype=object),
        "reward_model": np.array([reward_model], dtype=object),
        "ability": np.array(["LocalSearch"], dtype=object),
        "agent_name": np.array(["fold_agent"], dtype=object),
        "index": np.array([0], dtype=object),
    },
    meta_info={"validate": False, "global_steps": 0},
)
batch.non_tensor_batch['extra_info'][0]['answer']

In [None]:
output = ray.get(alm_worker.generate_sequences.remote(batch))

print("Reward score tensor present?", 'rm_scores' in output.batch)
print("Trajectories:", output.batch['responses'].shape[0])

In [None]:

def summarize_fold_output(dp):
    print("\n=== Summary ===")
    n = dp.batch['responses'].shape[0]
    for i in range(n):
        print(f"-- trajectory {i} --")
        env_stats = dp.non_tensor_batch.get("env_stats", [None])[i]
        agent_name = dp.non_tensor_batch.get("agent_name", [None])[i]
        mask_rollout = dp.non_tensor_batch.get("mask_rollout", [None])[i]
        is_finish = dp.non_tensor_batch.get("is_finish", [None])[i]
        branch_names = dp.non_tensor_batch.get("branch_names", [None])[i]
        if isinstance(env_stats, dict):
            print("env_stats keys:", list(env_stats.keys()))
        print("agent_name:", agent_name, "mask_rollout:", mask_rollout, "is_finish:", is_finish)
        if branch_names:
            print("branch_names:", branch_names)
        messages = dp.non_tensor_batch.get("messages", [None])[i]
        if isinstance(messages, list):
            print("Transcript preview:")
            for m in messages:
                role = m.get("role")
                content = m.get("content")
                # snippet = content.replace("\n", " ") if isinstance(content, str) else str(content)[:200]
                print(f"- {role}: {content}")
    print("=== End Summary ===")

summarize_fold_output(output)