In [None]:
import os
import importlib
from datetime import datetime
from src.pipeline.hyper_heuristics.random import RandomHyperHeuristic
from src.pipeline.hyper_heuristics.single import SingleHyperHeuristic
from src.pipeline.hyper_heuristics.llm_selection import LLMSelectionHyperHeuristic
from src.util.llm_client.get_llm_client import get_llm_client
from src.util.util import search_file

# --------------------------
# 在 Notebook 中设置参数（替代命令行参数）
# --------------------------
problem = "tsp"  # 替换为实际问题名称（如 tsp、mkp 等）
heuristic = "llm_hh"  # 可选值：具体启发式函数名、llm_hh、random_hh、or_solver
llm_config_file = os.path.join("output", "llm_config", "azure_gpt_4o.json")  # LLM配置文件路径
heuristic_dir = "basic_heuristics"  # 启发式函数目录
test_data = "test_data"  # 测试数据路径或目录
iterations_scale_factor = 2.0  # 迭代次数缩放因子
steps_per_selection = 5  # LLM模式下每次选择的执行步数
num_candidate_heuristics = 1  # LLM模式下考虑的候选启发式数量
rollout_budget = 0  # 蒙特卡洛评估次数
result_dir = "result"  # 结果保存目录

# --------------------------
# 核心执行逻辑（无需修改）
# --------------------------
def run_hyper_heuristic(
    problem,
    heuristic,
    llm_config_file,
    heuristic_dir,
    test_data,
    iterations_scale_factor,
    steps_per_selection,
    num_candidate_heuristics,
    rollout_budget,
    result_dir
):
    datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    heuristic_name = heuristic.split(os.sep)[-1].split(".")[0]
    
    # 获取启发式池
    heuristic_pool_path = os.path.join("src", "problems", problem, "heuristics", heuristic_dir)
    if not os.path.exists(heuristic_pool_path):
        raise FileNotFoundError(f"启发式目录不存在: {heuristic_pool_path}")
    heuristic_pool = os.listdir(heuristic_pool_path)

    # 基础输出目录
    base_output_dir = os.path.join(os.getenv("AMLT_OUTPUT_DIR"), "..", "..", "output") if os.getenv("AMLT_OUTPUT_DIR") else "output"

    # 初始化超启发式
    hyper_heuristic = None
    experiment_name = ""
    llm_client = None

    if heuristic_name == "llm_hh":
        prompt_dir = os.path.join("src", "problems", "base", "prompt")
        llm_client = get_llm_client(llm_config_file, prompt_dir, None)
        llm_name = os.path.splitext(os.path.basename(llm_config_file))[0]
        experiment_name = f"{heuristic_name}.{heuristic_dir}.{llm_name}.n{iterations_scale_factor}m{steps_per_selection}c{num_candidate_heuristics}b{rollout_budget}.{datetime_str}"
        hyper_heuristic = LLMSelectionHyperHeuristic(
            llm_client=llm_client,
            heuristic_pool=heuristic_pool,
            problem=problem,
            iterations_scale_factor=iterations_scale_factor,
            steps_per_selection=steps_per_selection,
            num_candidate_heuristics=num_candidate_heuristics,
            rollout_budget=rollout_budget,
        )
    elif heuristic_name == "random_hh":
        experiment_name = f"{heuristic_name}.{heuristic_dir}.{datetime_str}"
        hyper_heuristic = RandomHyperHeuristic(
            heuristic_pool=heuristic_pool,
            problem=problem,
            iterations_scale_factor=iterations_scale_factor
        )
    elif heuristic_name == "or_solver":
        experiment_name = "or_solver"
        module = importlib.import_module(f"src.problems.{problem}.or_solver")
        ORSolver = getattr(module, "ORSolver")
        hyper_heuristic = ORSolver(problem=problem)
    else:
        experiment_name = heuristic_name
        hyper_heuristic = SingleHyperHeuristic(heuristic=heuristic_name, problem=problem)

    # 导入环境类
    try:
        module = importlib.import_module(f"src.problems.{problem}.env")
        Env = getattr(module, "Env")
    except Exception as e:
        raise ImportError(f"无法导入问题 {problem} 的环境类: {str(e)}")

    # 处理测试数据
    if test_data == "test_data":
        test_data_dir = search_file("test_data", problem)
        test_data_list = os.listdir(test_data_dir)
    else:
        test_data_list = [test_data]

    # 运行每个测试数据
    for data_name in test_data_list:
        try:
            env = Env(data_name=data_name)
            output_dir = os.path.join(base_output_dir, problem, result_dir, env.data_ref_name, experiment_name)
            env.reset(output_dir)

            # 保存参数
            params = {
                "problem": problem,
                "heuristic": heuristic,
                "llm_config_file": llm_config_file,
                "heuristic_dir": heuristic_dir,
                "test_data": test_data,
                "iterations_scale_factor": iterations_scale_factor,
                "steps_per_selection": steps_per_selection,
                "num_candidate_heuristics": num_candidate_heuristics,
                "rollout_budget": rollout_budget,
                "result_dir": result_dir,
                "data_path": env.data_path
            }
            with open(os.path.join(env.output_dir, "parameters.txt"), 'w') as f:
                f.write('\n'.join(f'{k}={v}' for k, v in params.items()))

            # 运行超启发式
            if heuristic_name == "llm_hh":
                llm_client.reset(env.output_dir)
            validation_result = hyper_heuristic.run(env)

            # 输出结果
            if validation_result:
                env.dump_result()
                print(f"成功: {os.path.join(env.output_dir, 'result.txt')} | {heuristic_name} | {data_name} | {env.key_item}: {env.key_value}")
            else:
                print(f"失败: 无效解 | {heuristic_name} | {data_name}")
        except Exception as e:
            print(f"处理 {data_name} 时出错: {str(e)}")

# --------------------------
# 执行（直接运行该单元格即可）
# --------------------------
run_hyper_heuristic(
    problem=problem,
    heuristic=heuristic,
    llm_config_file=llm_config_file,
    heuristic_dir=heuristic_dir,
    test_data=test_data,
    iterations_scale_factor=iterations_scale_factor,
    steps_per_selection=steps_per_selection,
    num_candidate_heuristics=num_candidate_heuristics,
    rollout_budget=rollout_budget,
    result_dir=result_dir
)