# Multi-step GRPO (vLLM) for Display Semiconductor Q&A

This notebook runs a multi-step GRPO-style loop:
- Step-wise rollouts with experience injection (per-problem)
- Grading per attempt → experience generation per problem
- Artifacts saved under `workspace/semiconductor_grpo_vllm/step_{k}`


In [None]:
import os
import json
import random
import shutil
import glob
import hashlib
from collections import defaultdict
from typing import List, Dict, Set

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from training_free_grpo.semiconductor.prompts import (
    PROBLEM_WITH_EXPERIENCE_TEMPLATE,
    SINGLE_ROLLOUT_GRADING_TEMPLATE,
    SINGLE_QUERY_CRITIQUE_TEMPLATE,
)

from training_free_grpo.semiconductor.utils import (
    load_local_dataset,
    load_experiences_for_prompt,
    save_experiences_for_problem,
    safe_json_obj,
    safe_json_array,
    compute_problem_key,
    extract_final_answer,
    to_qwen_thinking_chat,
    run_search,
    format_experiences_for_prompt,
    format_requirements_block,
    build_or_load_index,
)
from training_free_grpo.semiconductor.embeddings import build_model as build_embed_engine


In [None]:
# ---- Config ----
MODEL_PATH = os.environ.get("VLLM_MODEL", "/mnt/storage/models/Qwen3/Qwen3-32B")
GRADING_MODEL_PATH = os.environ.get("VLLM_MODEL_GRADING", "/mnt/storage/models/Qwen3/Qwen3-Next-80B-A3B-Thinking")
DATASET_PATH = os.environ.get("SEMI_DATASET", "/mnt/workspace/MLLM/zz/training_free_grpo/semiconductor/data/3.5k_filtered_processed_data.json")
EXPERIMENT_DIR = os.environ.get("SEMI_EXP_DIR", "/mnt/workspace/MLLM/zz/training_free_grpo/semiconductor")

BATCH_SIZE = 512
GRPO_N = 5
TEMPERATURE = 0.7
MAX_NEW_TOKENS = 16384
NUM_STEPS = 3

# Randomization
RANDOM_SEED = int(os.environ.get("SEED", "42"))
SHUFFLE_EACH_STEP = os.environ.get("SHUFFLE_EACH_STEP", "True").lower() in ["1", "true", "yes"]

# Grading/critique generation params
GRADING_TEMPERATURE = 0.1
GRADING_MAX_NEW_TOKENS = 16384

# GPU/engine configurations
TENSOR_PARALLEL_SIZE = int(os.environ.get("TP_SIZE", "8"))
MAX_MODEL_LEN = int(os.environ.get("MAX_MODEL_LEN", "32768"))
DTYPE = os.environ.get("DTYPE", "auto")  # "auto", "float16", "bfloat16"

# Optional separate grading engine configs (fallback to main if unset)
GRADING_TP_SIZE = int(os.environ.get("GRADING_TP_SIZE", str(TENSOR_PARALLEL_SIZE)))
GRADING_MAX_MODEL_LEN = int(os.environ.get("GRADING_MAX_MODEL_LEN", str(MAX_MODEL_LEN)))
GRADING_DTYPE = os.environ.get("GRADING_DTYPE", DTYPE)

# Per-problem experiences base directory (single path)
# Pass via env EXPERIENCES_ROOT to set the base directory. We'll create two subfolders: old/ and new/
EXPERIENCES_BASE = os.environ.get("EXPERIENCES_ROOT", os.path.join(EXPERIMENT_DIR, "experiences"))
EXPERIENCES_OLD_ROOT = os.path.join(EXPERIENCES_BASE, "old")
EXPERIENCES_NEW_ROOT = os.path.join(EXPERIENCES_BASE, "new")

# Retrieval config
TOP_K_EXPERIENCES = int(os.environ.get("TOP_K_EXPERIENCES", "5"))
EMBED_MODEL_ID = os.environ.get("EMBED_MODEL_ID", "Qwen/Qwen3-Embedding-8B")
EMBED_DEVICE = os.environ.get("EMBED_DEVICE")  # e.g., "cuda" or None
EMBED_BATCH_SIZE = int(os.environ.get("EMBED_BATCH_SIZE", "256"))
EMBED_USE_FA2 = os.environ.get("EMBED_USE_FA2", "false").lower() in ["1", "true", "yes"]
EMBED_INSTRUCTION = os.environ.get("EMBED_INSTRUCTION", "").strip() or None
EXPERIENCES_INDEX_DIR = os.environ.get("EXPERIENCES_INDEX_DIR", os.path.join(EXPERIENCES_BASE, "index"))
EMBED_REBUILD_INDEX = os.environ.get("EMBED_REBUILD_INDEX", "false").lower() in ["1", "true", "yes"]

# Scoring outputs
SCORES_DIR = os.path.join(EXPERIMENT_DIR, "scores")

# Apply CUDA visibility
os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2, 3, 4, 5, 6, 7'

os.makedirs(EXPERIMENT_DIR, exist_ok=True)
os.makedirs(EXPERIENCES_OLD_ROOT, exist_ok=True)
os.makedirs(EXPERIENCES_NEW_ROOT, exist_ok=True)
os.makedirs(SCORES_DIR, exist_ok=True)


In [None]:
# ---- Load dataset ----
data: List[Dict] = load_local_dataset(DATASET_PATH)
print(f"Loaded {len(data)} records")

# Build index list and (optionally) shuffle once here
random.seed(RANDOM_SEED)
data_indices = list(range(len(data)))
if not SHUFFLE_EACH_STEP:
    random.shuffle(data_indices)
    print("Shuffled once at start.")
else:
    print("Will reshuffle indices each step.")

num_batches = (len(data) + BATCH_SIZE - 1) // BATCH_SIZE
print(f"Batching with BATCH_SIZE={BATCH_SIZE} → {num_batches} batches")


In [None]:
# ---- Build generators ----
print("Building vLLM engines ...")
if MODEL_PATH == GRADING_MODEL_PATH:
    gen = LLM(
        model=GRADING_MODEL_PATH,
        trust_remote_code=True,
        tensor_parallel_size=GRADING_TP_SIZE,
        max_model_len=GRADING_MAX_MODEL_LEN,
        dtype=GRADING_DTYPE,
        gpu_memory_utilization=0.7,
    )
    grade_gen = gen
else:
    gen = LLM(
        model=MODEL_PATH,
        trust_remote_code=True,
        tensor_parallel_size=TENSOR_PARALLEL_SIZE,
        max_model_len=MAX_MODEL_LEN,
        dtype=DTYPE,
        gpu_memory_utilization=0.2,
    )
    grade_gen = LLM(
        model=GRADING_MODEL_PATH,
        trust_remote_code=True,
        tensor_parallel_size=GRADING_TP_SIZE,
        max_model_len=GRADING_MAX_MODEL_LEN,
        dtype=GRADING_DTYPE,
        gpu_memory_utilization=0.7,
    )

# Preload tokenizers to pass directly
gen_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
grade_tokenizer = AutoTokenizer.from_pretrained(GRADING_MODEL_PATH)

# Build embedding engine (once, before main loop)
print("Building embedding engine ...")
embed_engine = build_embed_engine(EMBED_MODEL_ID)
print(f"Embedding engine ready: {EMBED_MODEL_ID}")

In [None]:
# ---- Multi-step rollout & experience update ----
for step in range(NUM_STEPS):
    # Precompute embeddings for NEW experiences at the beginning of this step
    _records, _emb = build_or_load_index(
        experiences_dir=EXPERIENCES_OLD_ROOT,
        index_dir=EXPERIENCES_INDEX_DIR,
        model=embed_engine,
        batch_size=EMBED_BATCH_SIZE,
        instruction=EMBED_INSTRUCTION
    )
    print(f"[Step {step}] New experiences embeddings built: {_emb.shape if _emb is not None else 'n/a'}")

    # Prepare per-step state
    step_total = 0.0
    ops_dir = os.path.join(EXPERIMENT_DIR, "ops", f"step_{step}")
    # Recreate ops dir for this step
    shutil.rmtree(ops_dir, ignore_errors=True)
    os.makedirs(ops_dir, exist_ok=True)

    # Scoring state
    cur_problem_totals: Dict[str, float] = {}

    # Prepare per-step randomized indices
    if SHUFFLE_EACH_STEP:
        random.seed(RANDOM_SEED + step)
        step_indices = list(data_indices)
        random.shuffle(step_indices)
    else:
        step_indices = data_indices

    for batch_idx in range(num_batches):
        start = batch_idx * BATCH_SIZE
        end = min(len(data), (batch_idx + 1) * BATCH_SIZE)
        idx_slice = step_indices[start:end]
        batch_data = [data[i] for i in idx_slice]

        # 1) Build prompts with top-k retrieved experiences (exclude same-problem experiences)
        formatted_batch: List[Dict] = []
        problems = [s['problem'] for s in batch_data]
        keys = [compute_problem_key(p) for p in problems]
        # Over-fetch then filter to ensure we still have K after exclusion
        retrieve_k = max(TOP_K_EXPERIENCES * 3, TOP_K_EXPERIENCES)
        retrieved_lists = run_search(
            experiences_dir=EXPERIENCES_OLD_ROOT,
            query=problems,
            top_k=retrieve_k,
            device=None,
            batch_size=EMBED_BATCH_SIZE,
            use_flash_attention_2=False,
            instruction=EMBED_INSTRUCTION,
            index_dir=EXPERIENCES_INDEX_DIR,
            rebuild_index=False,
            engine=embed_engine,
        )
        for sample, problem_text, current_key, retrieved in zip(batch_data, problems, keys, retrieved_lists or []):
            # Exclude experiences belonging to this problem
            filtered = [it for it in (retrieved or []) if it.get('problem_key') != current_key]
            top_items = filtered[:TOP_K_EXPERIENCES]
            formatted_experiences = format_experiences_for_prompt(top_items)
            prompt = PROBLEM_WITH_EXPERIENCE_TEMPLATE.format(
                experiences=formatted_experiences,
                problem=problem_text,
            )
            formatted_batch.append({
                "prompt": prompt,
                **sample,
            })

        # 2) GRPO duplicate
        formatted_batch = formatted_batch * GRPO_N
        prompts = [x["prompt"] for x in formatted_batch]
        print(f"[Step {step}][Batch {batch_idx+1}/{num_batches}] Generating for {len(prompts)} prompts ...")
        

        # 3) Generate
        chat_prompts = [to_qwen_thinking_chat(p, gen_tokenizer) for p in prompts]
        sampling_params = SamplingParams(
            temperature=TEMPERATURE,
            max_tokens=MAX_NEW_TOKENS,
            top_p=0.95,
            top_k=20,
        )
        outputs = gen.generate(chat_prompts, sampling_params)
        gen_texts = [extract_final_answer(o.outputs[0].text) for o in outputs]

        # 4) Attach outputs (defer rewards until grading)
        rollouts: List[Dict] = []
        for item, out in zip(formatted_batch, gen_texts):
            r = dict(item)
            r["response"] = out
            r["reward"] = 0.0
            rollouts.append(r)

        # 5) Group by problem
        problem_to_rollouts: Dict[str, List[Dict]] = {}
        for r in rollouts:
            problem_to_rollouts.setdefault(r["problem"], []).append(r)

        # 6) Batched grading for all attempts in this batch
        grading_prompts: List[str] = []
        grading_refs: List[tuple[str, int, Dict]] = []  # (problem, attempt_idx, rollout_ref)
        for problem, rs in problem_to_rollouts.items():
            req_block = rs[0].get("keypoints")
            req_text = format_requirements_block(req_block) if isinstance(req_block, list) else ""
            for i, each in enumerate(rs):
                grading_prompts.append(
                    SINGLE_ROLLOUT_GRADING_TEMPLATE.format(
                        problem=problem,
                        response=str(each.get("response", "")),
                        requirements=req_text,
                    )
                )
                grading_refs.append((problem, i, each))

        grading_chat_prompts = [to_qwen_thinking_chat(p, grade_tokenizer) for p in grading_prompts]
        g_params = SamplingParams(
            temperature=GRADING_TEMPERATURE,
            max_tokens=GRADING_MAX_NEW_TOKENS,
            top_p=0.95,
            top_k=20,
        )
        g_outs = grade_gen.generate(grading_chat_prompts, g_params)

        problem_to_grading_pieces: Dict[str, List[str]] = {}
        for (problem, i, each), gout in zip(grading_refs, g_outs):
            gtxt = extract_final_answer(gout.outputs[0].text) if gout and gout.outputs else ""
            gjson = safe_json_obj(gtxt)
            # Reward = normalized sum of numeric grades
            total_grade = 0.0
            try:
                for v in (gjson or {}).values():
                    g = v.get("grade") if isinstance(v, dict) else None
                    if isinstance(g, (int, float)):
                        total_grade += float(g)
                    elif isinstance(g, str):
                        digits = "".join(ch for ch in g if ch.isdigit())
                        if digits:
                            total_grade += float(int(digits))
            except Exception:
                pass
            # Normalize by 4 * number of keypoints (exclude summary line)
            kp = each.get("keypoints") or []
            if isinstance(kp, list):
                num_points = max(len(kp) - 1, 1)
            else:
                num_points = 1
            denom = 4.0 * float(num_points)
            each["reward"] = (total_grade / denom) if denom > 0 else 0.0
            piece = json.dumps({"attempt": i, "grading": gjson}, ensure_ascii=False)
            problem_to_grading_pieces.setdefault(problem, []).append(piece)

        problem_to_grading_text: Dict[str, str] = {
            p: "\n".join(pieces) for p, pieces in problem_to_grading_pieces.items()
        }

        # Update per-problem totals for this batch
        for p, rs in problem_to_rollouts.items():
            pkey = compute_problem_key(p)
            cur_problem_totals[pkey] = cur_problem_totals.get(pkey, 0.0) + sum(x["reward"] for x in rs)

        # after grading, report avg reward for batch and accumulate step total
        batch_total = sum(x["reward"] for x in rollouts)
        step_total += batch_total
        avg_reward = batch_total / max(1, len(rollouts))
        print(f"[Step {step}][Batch {batch_idx+1}/{num_batches}] Avg reward (grading-sum): {avg_reward:.4f}; Batch total: {batch_total:.4f}")

        # 7) Batched critique prompts per problem and persist per-problem experiences
        critique_prompts: List[str] = []
        critique_problems: List[str] = []
        for problem, rs in problem_to_rollouts.items():
            existing_for_problem = load_experiences_for_prompt(problem, EXPERIENCES_NEW_ROOT)
            grading_block = problem_to_grading_text.get(problem, "[]")
            existing_serialized = json.dumps(existing_for_problem or {}, ensure_ascii=False)
            critique_prompts.append(
                SINGLE_QUERY_CRITIQUE_TEMPLATE.format(
                    problem=problem,
                    grading=grading_block,
                    answer=rs[0].get("groundtruth", ""),
                    experiences=existing_serialized,
                )
            )
            critique_problems.append(problem)

        critique_chat_prompts = [to_qwen_thinking_chat(p, grade_tokenizer) for p in critique_prompts]
        c_params = SamplingParams(
            temperature=TEMPERATURE,
            max_tokens=MAX_NEW_TOKENS,
            top_p=0.95,
            top_k=20,
        )
        c_outs = grade_gen.generate(critique_chat_prompts, c_params)

        for problem, cout in zip(critique_problems, c_outs):
            ops_text = extract_final_answer(cout.outputs[0].text) if cout and cout.outputs else "[]"
            ops = safe_json_array(ops_text)
            if ops:
                try:
                    pkey = compute_problem_key(problem)
                    fp = os.path.join(ops_dir, f"{pkey}.json")
                    payload = {"problem": problem, "ops": []}
                    if os.path.exists(fp):
                        try:
                            payload = json.load(open(fp, 'r', encoding='utf-8'))
                            if not isinstance(payload, dict):
                                payload = {"problem": problem, "ops": []}
                        except Exception:
                            payload = {"problem": problem, "ops": []}
                    if not payload.get("problem"):
                        payload["problem"] = problem
                    cur_ops = payload.get("ops") or []
                    if not isinstance(cur_ops, list):
                        cur_ops = []
                    cur_ops.extend(list(ops))
                    payload["ops"] = cur_ops
                    with open(fp, 'w', encoding='utf-8') as f:
                        json.dump(payload, f, ensure_ascii=False, indent=2)
                except Exception as _e:
                    print(f"[Step {step}] WARNING: failed to write ops for problem: {problem}: {_e}")

        # End-of-step processing on the last batch
        if (batch_idx + 1) == num_batches:
            # 1) Confirm ops directory and count files
            if os.path.isdir(ops_dir):
                n_ops_files = len(glob.glob(os.path.join(ops_dir, "*.json")))
                print(f"[Step {step}] Ops directory ready: {ops_dir} ({n_ops_files} files)")
            else:
                print(f"[Step {step}] WARNING: ops dir missing: {ops_dir}")

            # 4) Rebuild NEW from baseline (updated OLD if adopted, else existing OLD), then apply ops
            baseline_dir = EXPERIENCES_OLD_ROOT

            try:
                shutil.rmtree(EXPERIENCES_NEW_ROOT)
            except Exception:
                pass
            os.makedirs(EXPERIENCES_NEW_ROOT, exist_ok=True)
            # Clone baseline -> NEW
            for src in glob.glob(os.path.join(baseline_dir, "*.json")):
                dst = os.path.join(EXPERIENCES_NEW_ROOT, os.path.basename(src))
                shutil.copy2(src, dst)
            # Apply ops on top of cloned NEW by streaming per-problem JSON files
            try:
                for fp in glob.glob(os.path.join(ops_dir, "*.json")):
                    try:
                        rec = json.load(open(fp, 'r', encoding='utf-8'))
                    except Exception:
                        continue
                    problem = rec.get("problem")
                    ops = rec.get("ops") or []
                    if not problem or not isinstance(ops, list) or not ops:
                        continue
                    # Load current NEW experiences for this problem (already cloned from baseline)
                    current = load_experiences_for_prompt(problem, EXPERIENCES_NEW_ROOT)
                    updated = dict(current)
                    next_id = 0
                    if updated:
                        try:
                            next_id = max([int(''.join([c for c in k if c.isdigit()]) or -1) for k in updated.keys()]) + 1
                        except Exception:
                            next_id = len(updated)
                    for op in ops:
                        try:
                            opt = str(op.get("option", "")).lower()
                            content = str(op.get("experience", "")).strip()
                            if opt == "add" and content:
                                updated[f"{next_id}"] = content
                                next_id += 1
                            elif opt == "modify":
                                src_id = str(op.get("modified_from", "")).strip()
                                if src_id in updated and content:
                                    updated[src_id] = content
                        except Exception:
                            continue
                    save_experiences_for_problem(problem, updated, EXPERIENCES_NEW_ROOT)
            except Exception as _e:
                print(f"[Step {step}] WARNING: failed applying ops from ops dir: {_e}")
            print(f"[Step {step}] NEW experiences rebuilt from baseline with ops applied.")

                        # === Evaluate NEW vs OLD per problem; update OLD only when delta >= 0 ===
            try:
                # 1) Evaluate NEW using full retrieved experience sets (no globals)
                eval_problem_totals = {}
                GRPO_N_EVAL = 1
                EVAL_TEMPERATURE = 0.1

                for _batch_idx in range(num_batches):
                    _start = _batch_idx * BATCH_SIZE
                    _end = min(len(data), (_batch_idx + 1) * BATCH_SIZE)
                    _idx_slice = step_indices[_start:_end]
                    _batch_data = [data[i] for i in _idx_slice]

                    _problems = [s['problem'] for s in _batch_data]
                    _keys = [compute_problem_key(p) for p in _problems]

                    _retrieve_k = max(TOP_K_EXPERIENCES * 3, TOP_K_EXPERIENCES)
                    _retrieved_lists = run_search(
                        experiences_dir=EXPERIENCES_NEW_ROOT,
                        query=_problems,
                        top_k=_retrieve_k,
                        device=None,
                        batch_size=EMBED_BATCH_SIZE,
                        use_flash_attention_2=False,
                        instruction=EMBED_INSTRUCTION,
                        index_dir=EXPERIENCES_INDEX_DIR,
                        rebuild_index=False,
                        engine=embed_engine,
                    )

                    _formatted_batch = []
                    for _sample, _problem_text, _current_key, _retrieved in zip(_batch_data, _problems, _keys, _retrieved_lists or []):
                        _filtered = [it for it in (_retrieved or []) if it.get('problem_key') != _current_key]
                        _top_items = _filtered[:TOP_K_EXPERIENCES]
                        _prompt = PROBLEM_WITH_EXPERIENCE_TEMPLATE.format(
                            experiences=format_experiences_for_prompt(_top_items),
                            problem=_problem_text,
                        )
                        _formatted_batch.append({"prompt": _prompt, **_sample})

                    _formatted_batch = _formatted_batch * GRPO_N_EVAL
                    _prompts = [x["prompt"] for x in _formatted_batch]
                    _chat_prompts = [to_qwen_thinking_chat(p, gen_tokenizer) for p in _prompts]
                    _eval_sampling_params = SamplingParams(
                        temperature=EVAL_TEMPERATURE,
                        max_tokens=MAX_NEW_TOKENS,
                        top_p=0.95,
                        top_k=20,
                    )
                    _eval_outs = gen.generate(_chat_prompts, _eval_sampling_params)
                    _eval_texts = [extract_final_answer(o.outputs[0].text) for o in _eval_outs]

                    _problem_to_rollouts_eval = {}
                    for _item, _out in zip(_formatted_batch, _eval_texts):
                        _r = dict(_item); _r["response"] = _out; _r["reward"] = 0.0
                        _problem_to_rollouts_eval.setdefault(_r["problem"], []).append(_r)

                    _grading_prompts_eval, _grading_refs_eval = [], []
                    for _problem, _rs in _problem_to_rollouts_eval.items():
                        _req_block = _rs[0].get("keypoints")
                        _req_text = format_requirements_block(_req_block) if isinstance(_req_block, list) else ""
                        for _i, _each in enumerate(_rs):
                            _grading_prompts_eval.append(
                                SINGLE_ROLLOUT_GRADING_TEMPLATE.format(
                                    problem=_problem, response=str(_each.get("response", "")), requirements=_req_text
                                )
                            )
                            _grading_refs_eval.append((_problem, _i, _each))

                    _grading_chat_prompts_eval = [to_qwen_thinking_chat(p, grade_tokenizer) for p in _grading_prompts_eval]
                    _g_params_eval = SamplingParams(
                        temperature=GRADING_TEMPERATURE,
                        max_tokens=GRADING_MAX_NEW_TOKENS,
                        top_p=0.95,
                        top_k=20,
                    )
                    _g_outs_eval = grade_gen.generate(_grading_chat_prompts_eval, _g_params_eval)

                    for (_problem, _i, _each), _gout in zip(_grading_refs_eval, _g_outs_eval):
                        _gtxt = extract_final_answer(_gout.outputs[0].text) if _gout and _gout.outputs else ""
                        _gjson = safe_json_obj(_gtxt)
                        _total_grade = 0.0
                        try:
                            for _v in (_gjson or {}).values():
                                _g = _v.get("grade") if isinstance(_v, dict) else None
                                if isinstance(_g, (int, float)):
                                    _total_grade += float(_g)
                                elif isinstance(_g, str):
                                    _digits = "".join(ch for ch in _g if ch.isdigit())
                                    if _digits:
                                        _total_grade += float(int(_digits))
                        except Exception:
                            pass
                        _kp = _each.get("keypoints") or []
                        _num_points = max(len(_kp) - 1, 1) if isinstance(_kp, list) else 1
                        _denom = 4.0 * float(_num_points)
                        _each["reward"] = (_total_grade / _denom) if _denom > 0 else 0.0

                    for _p, _rs in _problem_to_rollouts_eval.items():
                        _pkey = compute_problem_key(_p)
                        eval_problem_totals[_pkey] = eval_problem_totals.get(_pkey, 0.0) + sum(x["reward"] for x in _rs)

                # 2) Compute per-problem delta (NEW - OLD)
                per_problem_delta = {}
                for _pkey in set(list(eval_problem_totals.keys()) + list(cur_problem_totals.keys())):
                    _new_total = float(eval_problem_totals.get(_pkey, 0.0))
                    _old_total = float(cur_problem_totals.get(_pkey, 0.0))
                    per_problem_delta[_pkey] = _new_total - _old_total

                # 3) Selectively update OLD with NEW only when delta >= 0 for that problem
                updated, skipped = 0, 0
                for _src in sorted(glob.glob(os.path.join(EXPERIENCES_NEW_ROOT, '*.json'))):
                    _pkey = os.path.splitext(os.path.basename(_src))[0]
                    _delta = float(per_problem_delta.get(_pkey, 0.0))
                    if _delta >= 0.0:
                        _dst = os.path.join(EXPERIENCES_OLD_ROOT, os.path.basename(_src))
                        os.makedirs(os.path.dirname(_dst), exist_ok=True)
                        from shutil import copy2
                        copy2(_src, _dst)
                        updated += 1
                    else:
                        skipped += 1
                print(f"[Step {step}] OLD selectively updated from NEW: updated={updated}, skipped={skipped}")

            except Exception as __e:
                print(f"[Step {step}] WARNING: NEW evaluation/selective update failed: {__e}")
            except Exception as e:
                print(f"[Step {step}] WARNING: failed rebuilding NEW: {e}")


        print(f"[Step {step}][Batch {batch_idx+1}/{num_batches}] Experiences updated.")
