diff --git a/examples/ablation/run_agent_loop_benchmarks.py b/examples/ablation/run_agent_loop_benchmarks.py index 6c75ac0..4add631 100644 --- a/examples/ablation/run_agent_loop_benchmarks.py +++ b/examples/ablation/run_agent_loop_benchmarks.py @@ -29,6 +29,12 @@ REPO_ROOT = Path(__file__).resolve().parents[2] OUT_DIR = Path(__file__).parent / "diagnostics" DEFAULT_MSMARCO_PATH = REPO_ROOT / "tests" / "benchmark" / "data" / "msmarco_passage.json" +DEFAULT_LLM_BASE_URL = "http://localhost:8012/v1" +DEFAULT_LLM_MODEL = "Qwen3.6-27B" +DEFAULT_API_KEY_ENV = "OPENAI_API_KEY" +DEEPSEEK_LLM_BASE_URL = "https://api.deepseek.com/v1" +DEEPSEEK_LLM_MODEL = "deepseek-v4-flash" +DEEPSEEK_API_KEY_ENV = "DEEPSEEK_API_KEY" _AGENT_LOOP_EXTRA_CONTEXT = """Benchmark context: - You are evaluating retrieval, not general knowledge. @@ -66,6 +72,27 @@ def _load_local_env(paths: list[Path] | None = None) -> None: os.environ[key] = value.strip().strip("\"'") +def _resolve_llm_settings( + *, + preset: str, + llm_base_url: str | None, + model: str | None, + api_key_env: str | None, +) -> tuple[str, str, str]: + """Resolve provider defaults while preserving explicit CLI overrides.""" + if preset == "deepseek": + return ( + llm_base_url or DEEPSEEK_LLM_BASE_URL, + model or DEEPSEEK_LLM_MODEL, + api_key_env or DEEPSEEK_API_KEY_ENV, + ) + return ( + llm_base_url or DEFAULT_LLM_BASE_URL, + model or DEFAULT_LLM_MODEL, + api_key_env or DEFAULT_API_KEY_ENV, + ) + + @dataclass(slots=True) class AgentLoopRow: qid: str @@ -493,9 +520,18 @@ async def amain(argv: list[str] | None = None) -> int: parser.add_argument("--sqlite-db-path", type=Path, required=True) parser.add_argument("--subset", type=int, default=20) parser.add_argument("--corpus-limit", type=int, default=0) - parser.add_argument("--llm-base-url", default="http://localhost:8012/v1") - parser.add_argument("--model", default="Qwen3.6-27B") - parser.add_argument("--api-key-env", default="OPENAI_API_KEY") + parser.add_argument( + "--llm-preset", + choices=("local", "deepseek"), + default="local", + help=( + "Provider preset for omitted LLM settings. " + "deepseek => api.deepseek.com/v1, deepseek-v4-flash, DEEPSEEK_API_KEY." + ), + ) + parser.add_argument("--llm-base-url", default=None) + parser.add_argument("--model", default=None) + parser.add_argument("--api-key-env", default=None) parser.add_argument("--max-turns", type=int, default=5) parser.add_argument( "--llm-timeout", @@ -543,6 +579,12 @@ async def amain(argv: list[str] | None = None) -> int: raise SystemExit("--preflight-timeout must be positive") if args.resume and args.out_jsonl is None: raise SystemExit("--resume requires --out-jsonl") + args.llm_base_url, args.model, args.api_key_env = _resolve_llm_settings( + preset=args.llm_preset, + llm_base_url=args.llm_base_url, + model=args.model, + api_key_env=args.api_key_env, + ) if not args.msmarco_path.exists(): raise SystemExit(f"{args.msmarco_path} does not exist") if not args.sqlite_db_path.exists(): diff --git a/tests/test_agent_search_benchmarks.py b/tests/test_agent_search_benchmarks.py index bdaef45..02d39f6 100644 --- a/tests/test_agent_search_benchmarks.py +++ b/tests/test_agent_search_benchmarks.py @@ -173,6 +173,45 @@ def test_agent_loop_load_local_env_without_overriding_shell_env( assert os.environ["DEEPSEEK_API_KEY"] == "from_shell" +def test_agent_loop_deepseek_preset_resolves_provider_defaults() -> None: + assert loop_runner._resolve_llm_settings( + preset="deepseek", + llm_base_url=None, + model=None, + api_key_env=None, + ) == ( + "https://api.deepseek.com/v1", + "deepseek-v4-flash", + "DEEPSEEK_API_KEY", + ) + + +def test_agent_loop_llm_preset_preserves_explicit_overrides() -> None: + assert loop_runner._resolve_llm_settings( + preset="deepseek", + llm_base_url="https://example.test/v1", + model="custom-model", + api_key_env="CUSTOM_KEY", + ) == ( + "https://example.test/v1", + "custom-model", + "CUSTOM_KEY", + ) + + +def test_agent_loop_local_preset_keeps_existing_defaults() -> None: + assert loop_runner._resolve_llm_settings( + preset="local", + llm_base_url=None, + model=None, + api_key_env=None, + ) == ( + "http://localhost:8012/v1", + "Qwen3.6-27B", + "OPENAI_API_KEY", + ) + + def test_llm_preflight_error_message_names_endpoint_and_skip_hint() -> None: msg = loop_runner._llm_preflight_error_message( "http://127.0.0.1:18012/v1",