diff --git a/pyproject.toml b/pyproject.toml index 24630e1..f6f0da5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,10 @@ docstring-code-format = true [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] -addopts = "-v --tb=short" +addopts = "-v --tb=short -m 'not e2e'" +markers = [ + "e2e: marks tests as end-to-end (deselect with '-m \"not e2e\"')", +] filterwarnings = [ "ignore::DeprecationWarning", ] diff --git a/src/coding/proxy/vendors/antigravity.py b/src/coding/proxy/vendors/antigravity.py index b9bbfb5..b4d7199 100644 --- a/src/coding/proxy/vendors/antigravity.py +++ b/src/coding/proxy/vendors/antigravity.py @@ -141,7 +141,14 @@ def __init__( config.refresh_token, ) TokenBackendMixin.__init__(self, token_manager) - BaseVendor.__init__(self, config.base_url, config.timeout_ms, failover_config) + # v1internal 模式:base_url 需要去除 /v1internal 路径后缀, + # 因为 endpoint 使用完整路径 /v1internal:generateContent(冒号格式)。 + # httpx 会将 base_url path 与 endpoint path 拼接, + # 如果 base_url 含 /v1internal 会导致路径重复。 + init_base_url = config.base_url + if init_base_url.rstrip("/").endswith("/v1internal"): + init_base_url = init_base_url.rstrip("/").removesuffix("/v1internal") + BaseVendor.__init__(self, init_base_url, config.timeout_ms, failover_config) self._model_endpoint = config.model_endpoint self._model_mapper = model_mapper self._default_model = config.model_endpoint.removeprefix("models/") @@ -149,6 +156,7 @@ def __init__( self._safety_settings = config.safety_settings # v1internal 协议字段 self._project_id: str = config.project_id + self._v1internal_enabled: bool = "v1internal" in config.base_url self._session_id: str = uuid.uuid4().hex[:16] self._message_count: int = 0 # project_id 自动发现状态 @@ -159,8 +167,11 @@ def get_name(self) -> str: return "antigravity" def _is_v1internal_mode(self) -> bool: - """检测是否启用 v1internal 协议模式(与 Antigravity-Manager 对齐).""" - return bool(self._effective_project_id) and "v1internal" in self._base_url + """检测是否启用 v1internal 协议模式(与 Antigravity-Manager 对齐). + + v1internal 协议由原始配置的 base_url 路径或 project_id 自动发现触发。 + """ + return self._v1internal_enabled @property def _effective_project_id(self) -> str: @@ -229,7 +240,11 @@ async def _discover_project_id(self, access_token: str) -> str: return "" # 发现成功:原子性切换到 v1internal 模式 - self._base_url = _V1INTERNAL_BASE_URL + # base_url 只保留域名部分(去除 /v1internal 路径后缀) + self._base_url = _V1INTERNAL_BASE_URL.rstrip("/").removesuffix( + "/v1internal" + ) + self._v1internal_enabled = True self._project_id_discovered = project_id # 重建 HTTP 客户端(base_url 是初始化参数) @@ -339,8 +354,13 @@ async def _prepare_request( self._last_request_adaptations = converted.adaptations token = await self._token_manager.get_token() - # 懒加载:未配置 project_id 时自动发现并切换 v1internal 模式 - if not self._project_id and not self._project_discovery_attempted: + # 懒加载:未配置 project_id 时尝试自动发现(仅标准 GLA 模式需要) + # v1internal 模式不依赖 project_id,跳过发现 + if ( + not self._project_id + and not self._project_discovery_attempted + and not self._v1internal_enabled + ): discovered = await self._discover_project_id(token) if discovered: logger.info( @@ -450,11 +470,11 @@ async def send_message( body, prepared_headers = await self._prepare_request(request_body, headers) client = self._get_client() resolved_model = self._last_resolved_model - endpoint = ( - ":generateContent" - if self._is_v1internal_mode() - else f"/models/{resolved_model}:generateContent" - ) + if self._is_v1internal_mode(): + # v1internal 端点需要完整路径(冒号格式)覆盖 base_url 的 path 部分 + endpoint = "/v1internal:generateContent" + else: + endpoint = f"/models/{resolved_model}:generateContent" logger.debug("send_message: POST %s", endpoint) response = await client.post(endpoint, json=body, headers=prepared_headers) @@ -496,11 +516,10 @@ async def send_message_stream( body, prepared_headers = await self._prepare_request(request_body, headers) client = self._get_client() resolved_model = self._last_resolved_model - endpoint = ( - ":streamGenerateContent?alt=sse" - if self._is_v1internal_mode() - else f"/models/{resolved_model}:streamGenerateContent?alt=sse" - ) + if self._is_v1internal_mode(): + endpoint = "/v1internal:streamGenerateContent?alt=sse" + else: + endpoint = f"/models/{resolved_model}:streamGenerateContent?alt=sse" logger.debug("send_message_stream: POST %s", endpoint) diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py new file mode 100644 index 0000000..cf41f45 --- /dev/null +++ b/tests/e2e/conftest.py @@ -0,0 +1,199 @@ +"""E2E 集成测试共享 fixtures — Antigravity 真实凭证加载与测试对象构建.""" + +from __future__ import annotations + +import os +from typing import Any + +import pytest + +# ── 模块级门控:未设置环境变量时跳过整个 e2e 包 ── + +_SKIP_REASON = "Set RUN_ANTIGRAVITY_E2E=1 to enable Antigravity E2E tests" + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line( + "markers", "e2e: End-to-end tests requiring real Antigravity credentials" + ) + + +def _load_real_credentials() -> dict[str, str] | None: + """从 ~/.coding-proxy/ 加载真实的 Google OAuth 凭证.""" + from coding.proxy.auth.providers.google import ( + _DEFAULT_CLIENT_ID, + _DEFAULT_CLIENT_SECRET, + ) + from coding.proxy.auth.store import TokenStoreManager + from coding.proxy.config.loader import load_config + + try: + token_store = TokenStoreManager() + token_store.load() + google_tokens = token_store.get("google") + if not google_tokens.refresh_token: + return None + + config = load_config() + + # 从 vendors 列表查找 antigravity 配置 + client_id = "" + client_secret = "" + base_url = "" + model_endpoint = "models/claude-sonnet-4-20250514" + project_id = "" + + for vc in config.vendors: + if vc.vendor == "antigravity": + client_id = vc.client_id or _DEFAULT_CLIENT_ID + client_secret = vc.client_secret or _DEFAULT_CLIENT_SECRET + base_url = ( + vc.base_url or "https://generativelanguage.googleapis.com/v1beta" + ) + model_endpoint = vc.model_endpoint or model_endpoint + break + + # 优先使用 config.yaml 中的 refresh_token,否则使用 token store + refresh_token = "" + for vc in config.vendors: + if vc.vendor == "antigravity" and vc.refresh_token: + refresh_token = vc.refresh_token + break + if not refresh_token: + refresh_token = google_tokens.refresh_token + + return { + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + "base_url": base_url, + "model_endpoint": model_endpoint, + "project_id": project_id, + } + except Exception: + return None + + +# ── Fixtures ── + + +@pytest.fixture(scope="session") +def e2e_credentials() -> dict[str, str]: + """加载真实 Antigravity OAuth 凭证,失败则跳过.""" + if os.environ.get("RUN_ANTIGRAVITY_E2E") != "1": + pytest.skip(_SKIP_REASON) + creds = _load_real_credentials() + if creds is None: + pytest.skip("No valid Antigravity credentials found in ~/.coding-proxy/") + return creds + + +@pytest.fixture(scope="session") +def antigravity_config(e2e_credentials: dict[str, str]) -> Any: + """构建标准 GLA 模式的 AntigravityConfig.""" + from coding.proxy.config.vendors import AntigravityConfig + + return AntigravityConfig( + enabled=True, + client_id=e2e_credentials["client_id"], + client_secret=e2e_credentials["client_secret"], + refresh_token=e2e_credentials["refresh_token"], + base_url=e2e_credentials["base_url"], + model_endpoint=e2e_credentials["model_endpoint"], + timeout_ms=60000, + ) + + +@pytest.fixture(scope="session") +def antigravity_config_v1internal(e2e_credentials: dict[str, str]) -> Any: + """构建 v1internal 模式的 AntigravityConfig(无 project_id,触发自动发现).""" + from coding.proxy.config.vendors import AntigravityConfig + + return AntigravityConfig( + enabled=True, + client_id=e2e_credentials["client_id"], + client_secret=e2e_credentials["client_secret"], + refresh_token=e2e_credentials["refresh_token"], + base_url="https://cloudcode-pa.googleapis.com/v1internal", + model_endpoint=e2e_credentials["model_endpoint"], + timeout_ms=60000, + ) + + +@pytest.fixture +async def antigravity_vendor(antigravity_config: Any) -> Any: + """构建标准 GLA 模式的 AntigravityVendor(function scope,每次测试独立).""" + from coding.proxy.config.schema import FailoverConfig + from coding.proxy.routing.model_mapper import ModelMapper + from coding.proxy.vendors.antigravity import AntigravityVendor + + vendor = AntigravityVendor(antigravity_config, FailoverConfig(), ModelMapper([])) + yield vendor + await vendor.close() + + +@pytest.fixture +async def antigravity_vendor_v1internal(antigravity_config_v1internal: Any) -> Any: + """构建 v1internal 模式的 AntigravityVendor.""" + from coding.proxy.config.schema import FailoverConfig + from coding.proxy.routing.model_mapper import ModelMapper + from coding.proxy.vendors.antigravity import AntigravityVendor + + vendor = AntigravityVendor( + antigravity_config_v1internal, FailoverConfig(), ModelMapper([]) + ) + yield vendor + await vendor.close() + + +@pytest.fixture +def minimal_request_body() -> dict[str, Any]: + """最小 Anthropic 格式请求体(用于最小化 token 消耗).""" + return { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Say exactly: pong"}], + "max_tokens": 32, + } + + +@pytest.fixture(scope="session") +def e2e_app(e2e_credentials: dict[str, str]) -> Any: + """构建仅启用 Antigravity 的 FastAPI 应用(临时 DB).""" + import tempfile + + from coding.proxy.config.schema import ProxyConfig + from coding.proxy.server.app import create_app + + tmpdir = tempfile.mkdtemp(prefix="e2e-antigravity-") + db_path = os.path.join(tmpdir, "usage.db") + compat_path = os.path.join(tmpdir, "compat.db") + + config = ProxyConfig( + vendors=[ + { + "vendor": "antigravity", + "enabled": True, + "client_id": e2e_credentials["client_id"], + "client_secret": e2e_credentials["client_secret"], + "refresh_token": e2e_credentials["refresh_token"], + "base_url": "https://cloudcode-pa.googleapis.com/v1internal", + "model_endpoint": e2e_credentials["model_endpoint"], + "timeout_ms": 60000, + }, + ], + tiers=["antigravity"], + database={"path": db_path, "compat_state_path": compat_path}, + ) + return create_app(config) + + +@pytest.fixture +async def e2e_client(e2e_app: Any) -> Any: + """构建异步 HTTP 客户端(支持 SSE 流式测试).""" + import httpx + + transport = httpx.ASGITransport(app=e2e_app) + async with httpx.AsyncClient( + transport=transport, base_url="http://test", timeout=60.0 + ) as client: + yield client diff --git a/tests/e2e/test_e2e_http.py b/tests/e2e/test_e2e_http.py new file mode 100644 index 0000000..fe84db5 --- /dev/null +++ b/tests/e2e/test_e2e_http.py @@ -0,0 +1,263 @@ +"""Level 3 E2E: 完整 HTTP 端到端 — 模拟 Claude Code 通过 coding-proxy 使用 Antigravity.""" + +from __future__ import annotations + +import json + +import pytest + +# Claude Code 发送的典型 headers +CLAUDE_CODE_HEADERS = { + "anthropic-version": "2023-06-01", + "content-type": "application/json", + "x-api-key": "sk-ant-placeholder", +} + + +def _is_quota_exhausted(response: object) -> bool: + """检查响应是否为配额耗尽 (429).""" + if response.status_code != 429: + return False + try: + body = response.json() + err = body.get("error", {}) + msg = err.get("message", "").lower() + return "resource" in msg or "quota" in msg or "exhausted" in msg + except Exception: + return False + + +def _is_scope_error(response: object) -> bool: + """检查响应是否为 scope 不足 (403).""" + if response.status_code != 403: + return False + try: + body = response.json() + err = body.get("error", {}) + return "scope" in json.dumps(err).lower() + except Exception: + return False + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_non_streaming( + e2e_client: object, + minimal_request_body: dict, +) -> None: + """POST /v1/messages 非流式 → 验证协议对接正确.""" + response = await e2e_client.post( + "/v1/messages", + json=minimal_request_body, + headers=CLAUDE_CODE_HEADERS, + ) + + if _is_scope_error(response): + pytest.skip("GLA 端点 scope 不足,需要 v1internal 模式") + if _is_quota_exhausted(response): + print("\n[E2E] HTTP non-streaming: 协议对接正确,但配额已耗尽 (429)") + return + + assert response.status_code == 200, ( + f"预期 200,实际 {response.status_code}: {response.text[:300]}" + ) + + body = response.json() + assert body["type"] == "message", f"预期 type=message,实际: {body.get('type')}" + assert body["role"] == "assistant" + assert len(body["content"]) > 0, "content 为空" + assert body["content"][0]["type"] == "text" + assert body["usage"]["input_tokens"] > 0, "input_tokens 应 > 0" + + print( + f"\n[E2E] HTTP non-streaming 成功: model={body.get('model')}, " + f"input={body['usage']['input_tokens']}, output={body['usage']['output_tokens']}" + ) + print(f" content: {body['content'][0].get('text', '')[:100]}") + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_streaming(e2e_client: object) -> None: + """POST /v1/messages (stream=true) → 验证 SSE 协议.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Say exactly: pong"}], + "max_tokens": 32, + "stream": True, + } + + events: list[str] = [] + content_chunks: list[str] = [] + + try: + async with e2e_client.stream( + "POST", "/v1/messages", json=body, headers=CLAUDE_CODE_HEADERS + ) as response: + if response.status_code == 429: + print("\n[E2E] HTTP streaming: 协议对接正确,但配额已耗尽 (429)") + return + + assert response.status_code == 200, f"预期 200,实际 {response.status_code}" + + async for line in response.aiter_lines(): + line = line.strip() + if not line: + continue + if line.startswith("event:"): + events.append(line[6:].strip()) + elif line.startswith("data:"): + payload = line[5:].strip() + if payload == "[DONE]": + continue + try: + data = json.loads(payload) + if data.get("type") == "content_block_delta": + delta = data.get("delta", {}) + if delta.get("type") == "text_delta": + content_chunks.append(delta.get("text", "")) + except json.JSONDecodeError: + pass + + assert "message_start" in events, f"缺少 message_start,实际: {events[:10]}" + assert "content_block_delta" in events, "缺少 content_block_delta" + assert "message_stop" in events, "缺少 message_stop" + + full_text = "".join(content_chunks) + print( + f"\n[E2E] HTTP streaming 成功: events={len(events)}, content='{full_text[:100]}'" + ) + except Exception as exc: + error_str = str(exc) + if "429" in error_str or "exhausted" in error_str.lower(): + print("\n[E2E] HTTP streaming: 协议对接正确,但配额已耗尽 (429)") + return + raise + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_with_tools(e2e_client: object) -> None: + """POST /v1/messages 带 tools 定义 → 请求正常往返.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + {"role": "user", "content": "What is 2+2? Reply with just the number."} + ], + "max_tokens": 128, + "tools": [ + { + "name": "calculator", + "description": "Performs arithmetic", + "input_schema": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + } + ], + } + response = await e2e_client.post( + "/v1/messages", json=body, headers=CLAUDE_CODE_HEADERS + ) + + if _is_scope_error(response): + pytest.skip("GLA 端点 scope 不足") + if _is_quota_exhausted(response): + print("\n[E2E] HTTP with tools: 协议对接正确,配额耗尽") + return + + assert response.status_code == 200, ( + f"预期 200,实际 {response.status_code}: {response.text[:300]}" + ) + + resp_body = response.json() + assert resp_body["type"] == "message" + assert len(resp_body["content"]) > 0 + content_types = [b["type"] for b in resp_body["content"]] + print(f"\n[E2E] HTTP with tools 成功: content_types={content_types}") + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_health_probe(e2e_client: object) -> None: + """HEAD / 和 GET /health → 200(Claude Code 连通性探测).""" + head_resp = await e2e_client.head("/") + assert head_resp.status_code == 200, ( + f"HEAD / 预期 200,实际 {head_resp.status_code}" + ) + + get_resp = await e2e_client.get("/") + assert get_resp.status_code == 200, f"GET / 预期 200,实际 {get_resp.status_code}" + + health_resp = await e2e_client.get("/health") + assert health_resp.status_code == 200 + assert health_resp.json() == {"status": "ok"} + + print("\n[E2E] HTTP health probe 成功: HEAD /=200, GET /=200, /health=ok") + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_status_diagnostics(e2e_client: object) -> None: + """GET /api/status → 包含 antigravity tier 诊断信息.""" + response = await e2e_client.get("/api/status") + assert response.status_code == 200 + + data = response.json() + assert "tiers" in data + antigravity_tiers = [t for t in data["tiers"] if t["name"] == "antigravity"] + assert len(antigravity_tiers) == 1, ( + f"预期 1 个 antigravity tier,实际: {len(antigravity_tiers)}" + ) + + tier = antigravity_tiers[0] + assert "diagnostics" in tier, "缺少 diagnostics" + + diag = tier["diagnostics"] + print("\n[E2E] status diagnostics:") + for k, v in diag.items(): + if isinstance(v, dict): + print(f" {k}: {json.dumps(v, ensure_ascii=False)[:200]}") + else: + print(f" {k}: {v}") + + # token_manager 诊断可能为空(若未发生错误),仅验证其存在性 + if "token_manager" in diag: + print(" token_manager diagnostics present") + else: + print(" (token_manager diagnostics empty — no token errors)") + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_http_claude_code_headers(e2e_client: object) -> None: + """带完整 Claude Code headers 的请求正常(验证 x-api-key 不干扰 Antigravity).""" + headers = { + "anthropic-version": "2023-06-01", + "content-type": "application/json", + "x-api-key": "sk-ant-api03-fake-key-for-testing", + "accept": "application/json", + } + body = { + "model": "claude-sonnet-4-20250514", + "messages": [{"role": "user", "content": "Say: ok"}], + "max_tokens": 16, + } + response = await e2e_client.post("/v1/messages", json=body, headers=headers) + + if _is_quota_exhausted(response): + print("\n[E2E] Claude Code headers: 协议对接正确,配额耗尽") + return + + assert response.status_code == 200, ( + f"预期 200,实际 {response.status_code}: {response.text[:300]}" + ) + + resp_body = response.json() + assert resp_body["type"] == "message" + assert len(resp_body["content"]) > 0 + + print( + f"\n[E2E] Claude Code headers 成功: content='{resp_body['content'][0].get('text', '')[:80]}'" + ) diff --git a/tests/e2e/test_e2e_token.py b/tests/e2e/test_e2e_token.py new file mode 100644 index 0000000..dd3bb7b --- /dev/null +++ b/tests/e2e/test_e2e_token.py @@ -0,0 +1,93 @@ +"""Level 1 E2E: Google OAuth2 Token 刷新 — 验证真实凭证链路.""" + +from __future__ import annotations + +import pytest + +from coding.proxy.vendors.antigravity import GoogleOAuthTokenManager +from coding.proxy.vendors.token_manager import TokenAcquireError, TokenErrorKind + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_real_token_refresh(e2e_credentials: dict[str, str]) -> None: + """真实 refresh_token 应返回有效的 access_token(ya29. 前缀).""" + tm = GoogleOAuthTokenManager( + e2e_credentials["client_id"], + e2e_credentials["client_secret"], + e2e_credentials["refresh_token"], + ) + try: + token = await tm.get_token() + assert token, "access_token 为空" + assert token.startswith("ya29."), f"access_token 前缀异常: {token[:10]}..." + print(f"[E2E DIAG] access_token={token[:10]}... (len={len(token)})") + finally: + await tm.close() + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_real_token_caching(e2e_credentials: dict[str, str]) -> None: + """连续调用 get_token() 应返回缓存的同一 token.""" + tm = GoogleOAuthTokenManager( + e2e_credentials["client_id"], + e2e_credentials["client_secret"], + e2e_credentials["refresh_token"], + ) + try: + token1 = await tm.get_token() + token2 = await tm.get_token() + assert token1 == token2, "缓存未生效,两次返回不同 token" + assert tm._expires_at > 0, "expires_at 未被设置" + print(f"[E2E DIAG] caching OK: expires_at={tm._expires_at}") + finally: + await tm.close() + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_invalid_refresh_token_raises(e2e_credentials: dict[str, str]) -> None: + """错误的 refresh_token 应抛出 TokenAcquireError(INVALID_CREDENTIALS).""" + tm = GoogleOAuthTokenManager( + e2e_credentials["client_id"], + e2e_credentials["client_secret"], + "1//invalid_token_for_e2e_test_00000000", + ) + try: + with pytest.raises(TokenAcquireError) as exc_info: + await tm.get_token() + assert exc_info.value.kind == TokenErrorKind.INVALID_CREDENTIALS, ( + f"预期 INVALID_CREDENTIALS,实际: {exc_info.value.kind}" + ) + assert exc_info.value.needs_reauth is True + print(f"[E2E DIAG] invalid_grant 正确捕获: {exc_info.value}") + finally: + await tm.close() + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_token_invalidation_triggers_refresh( + e2e_credentials: dict[str, str], +) -> None: + """invalidate() 后重新获取应成功.""" + tm = GoogleOAuthTokenManager( + e2e_credentials["client_id"], + e2e_credentials["client_secret"], + e2e_credentials["refresh_token"], + ) + try: + token1 = await tm.get_token() + assert token1, "首次获取失败" + + tm.invalidate() + assert tm._expires_at == 0.0, "invalidate 后 expires_at 应为 0" + + token2 = await tm.get_token() + assert token2, "invalidate 后重新获取失败" + print( + f"[E2E DIAG] invalidation OK: token1={token1[:10]}... token2={token2[:10]}..." + ) + finally: + await tm.close() diff --git a/tests/e2e/test_e2e_vendor.py b/tests/e2e/test_e2e_vendor.py new file mode 100644 index 0000000..1781235 --- /dev/null +++ b/tests/e2e/test_e2e_vendor.py @@ -0,0 +1,327 @@ +"""Level 2 E2E: AntigravityVendor 直接调用 — 验证 GLA 和 v1internal 协议端到端.""" + +from __future__ import annotations + +import json + +import pytest + + +def _print_diagnostics(vendor: object, label: str) -> None: + diag = vendor.get_diagnostics() + print(f"\n[E2E DIAG] {label}:") + for k, v in diag.items(): + if isinstance(v, dict): + print(f" {k}: {json.dumps(v, ensure_ascii=False)[:200]}") + else: + print(f" {k}: {v}") + + +def _is_quota_exhausted(resp: object) -> bool: + """检查响应是否为配额耗尽(429 RESOURCE_EXHAUSTED). + + 429 表示协议对接正确但配额已用完,测试应标记为预期行为。 + """ + if resp.status_code != 429: + return False + error_msg = (resp.error_message or "").lower() + return "resource" in error_msg or "quota" in error_msg or "exhausted" in error_msg + + +def _is_scope_error(resp: object) -> bool: + """检查响应是否为 scope 不足错误.""" + if resp.status_code != 403: + return False + return "scope" in (resp.error_message or "").lower() + + +# ── 标准 GLA 模式 ── + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_gla_non_streaming_text( + antigravity_vendor: object, + minimal_request_body: dict, +) -> None: + """GLA 模式非流式请求 — 验证协议对接正确.""" + resp = await antigravity_vendor.send_message(minimal_request_body, {}) + _print_diagnostics(antigravity_vendor, "GLA non-streaming") + + # 403 scope 不足说明 GLA 端点不适用于当前凭证(正常,需要 v1internal) + if _is_scope_error(resp): + pytest.skip("GLA 端点 scope 不足,需要 v1internal 模式") + + # 429 配额耗尽 = 协议对接正确,仅配额问题 + if _is_quota_exhausted(resp): + print("\n[E2E] GLA non-streaming: 协议对接正确,但配额已耗尽 (429)") + return + + assert resp.status_code == 200, ( + f"预期 200,实际 {resp.status_code}: {resp.error_message}" + ) + + body = json.loads(resp.raw_body) + assert body["type"] == "message", f"预期 type=message,实际: {body.get('type')}" + assert body["role"] == "assistant" + assert len(body["content"]) > 0, "content 为空" + assert body["content"][0]["type"] == "text" + assert body["stop_reason"] in ("end_turn", "max_tokens") + assert body["usage"]["input_tokens"] > 0, "input_tokens 应 > 0" + + print( + f"\n[E2E] GLA non-streaming 成功: model={body.get('model')}, " + f"input={body['usage']['input_tokens']}, output={body['usage']['output_tokens']}, " + f"stop_reason={body['stop_reason']}" + ) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_gla_streaming_text( + antigravity_vendor: object, + minimal_request_body: dict, +) -> None: + """GLA 模式流式请求 — 验证 SSE 协议对接.""" + minimal_request_body["stream"] = True + + events: list[str] = [] + content_chunks: list[str] = [] + quota_exhausted = False + + try: + async for chunk in antigravity_vendor.send_message_stream( + minimal_request_body, {} + ): + text = chunk.decode("utf-8", errors="replace") + for line in text.split("\n"): + line = line.strip() + if line.startswith("event:"): + events.append(line[6:].strip()) + elif line.startswith("data:"): + try: + data = json.loads(line[5:].strip()) + if data.get("type") == "content_block_delta": + delta = data.get("delta", {}) + if delta.get("type") == "text_delta": + content_chunks.append(delta.get("text", "")) + except json.JSONDecodeError: + pass + except Exception as exc: + error_str = str(exc).lower() + if "403" in error_str and "scope" in error_str: + pytest.skip("GLA 端点 scope 不足,需要 v1internal 模式") + if "429" in error_str or "quota" in error_str or "exhausted" in error_str: + quota_exhausted = True + print("\n[E2E] GLA streaming: 协议对接正确,但配额已耗尽 (429)") + else: + raise + + if not quota_exhausted: + _print_diagnostics(antigravity_vendor, "GLA streaming") + assert "message_start" in events, ( + f"缺少 message_start 事件,实际事件: {events[:10]}" + ) + assert "content_block_delta" in events, "缺少 content_block_delta 事件" + assert "message_stop" in events, "缺少 message_stop 事件" + + full_text = "".join(content_chunks) + print( + f"\n[E2E] GLA streaming 成功: events={len(events)}, content='{full_text[:100]}'" + ) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_gla_with_system_prompt( + antigravity_vendor: object, + minimal_request_body: dict, +) -> None: + """GLA 模式带 system prompt 的请求正常.""" + minimal_request_body["system"] = ( + "You are a test assistant. Always respond with exactly one word." + ) + resp = await antigravity_vendor.send_message(minimal_request_body, {}) + + if _is_scope_error(resp): + pytest.skip("GLA 端点 scope 不足") + if _is_quota_exhausted(resp): + print("\n[E2E] GLA with system prompt: 协议对接正确,配额耗尽") + return + + assert resp.status_code == 200, ( + f"预期 200,实际 {resp.status_code}: {resp.error_message}" + ) + body = json.loads(resp.raw_body) + assert body["type"] == "message" + assert len(body["content"]) > 0 + + print( + f"\n[E2E] GLA with system prompt 成功: content='{body['content'][0].get('text', '')[:80]}'" + ) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_gla_with_tools( + antigravity_vendor: object, + minimal_request_body: dict, +) -> None: + """GLA 模式带 tools 定义的请求正常往返.""" + minimal_request_body["tools"] = [ + { + "name": "calculator", + "description": "Performs arithmetic", + "input_schema": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + "required": ["expression"], + }, + } + ] + minimal_request_body["messages"] = [ + {"role": "user", "content": "What is 2+2? Reply with just the number."} + ] + resp = await antigravity_vendor.send_message(minimal_request_body, {}) + + if _is_scope_error(resp): + pytest.skip("GLA 端点 scope 不足") + if _is_quota_exhausted(resp): + print("\n[E2E] GLA with tools: 协议对接正确,配额耗尽") + return + + assert resp.status_code == 200, ( + f"预期 200,实际 {resp.status_code}: {resp.error_message}" + ) + body = json.loads(resp.raw_body) + assert body["type"] == "message" + assert len(body["content"]) > 0 + + _print_diagnostics(antigravity_vendor, "GLA with tools") + print( + f"\n[E2E] GLA with tools 成功: content_types={[b['type'] for b in body['content']]}" + ) + + +# ── v1internal 模式 ── + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_v1internal_non_streaming( + antigravity_vendor_v1internal: object, + minimal_request_body: dict, +) -> None: + """v1internal 模式非流式请求 — 验证协议对接.""" + resp = await antigravity_vendor_v1internal.send_message(minimal_request_body, {}) + + _print_diagnostics(antigravity_vendor_v1internal, "v1internal non-streaming") + + # 429 = 协议对接正确,仅配额问题 + if _is_quota_exhausted(resp): + diag = antigravity_vendor_v1internal.get_diagnostics() + print( + f"\n[E2E] v1internal non-streaming: 协议对接正确 (is_v1internal={diag.get('is_v1internal_mode')}),但配额已耗尽 (429)" + ) + return + + assert resp.status_code == 200, ( + f"预期 200,实际 {resp.status_code}: {resp.error_message}" + ) + body = json.loads(resp.raw_body) + assert body["type"] == "message" + assert body["role"] == "assistant" + assert len(body["content"]) > 0 + + diag = antigravity_vendor_v1internal.get_diagnostics() + print( + f"\n[E2E] v1internal non-streaming 成功: " + f"is_v1internal={diag.get('is_v1internal_mode')}, " + f"project_id_source={diag.get('project_id_source')}, " + f"input={body['usage']['input_tokens']}, output={body['usage']['output_tokens']}" + ) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_v1internal_streaming( + antigravity_vendor_v1internal: object, + minimal_request_body: dict, +) -> None: + """v1internal 模式流式请求 — 验证 SSE 协议.""" + minimal_request_body["stream"] = True + + events: list[str] = [] + content_chunks: list[str] = [] + quota_exhausted = False + + try: + async for chunk in antigravity_vendor_v1internal.send_message_stream( + minimal_request_body, {} + ): + text = chunk.decode("utf-8", errors="replace") + for line in text.split("\n"): + line = line.strip() + if line.startswith("event:"): + events.append(line[6:].strip()) + elif line.startswith("data:"): + try: + data = json.loads(line[5:].strip()) + if data.get("type") == "content_block_delta": + delta = data.get("delta", {}) + if delta.get("type") == "text_delta": + content_chunks.append(delta.get("text", "")) + except json.JSONDecodeError: + pass + except Exception as exc: + error_str = str(exc) + if "429" in error_str: + quota_exhausted = True + print("\n[E2E] v1internal streaming: 协议对接正确,但配额已耗尽 (429)") + else: + raise + + if not quota_exhausted: + _print_diagnostics(antigravity_vendor_v1internal, "v1internal streaming") + assert "message_start" in events, "缺少 message_start" + assert "content_block_delta" in events, "缺少 content_block_delta" + assert "message_stop" in events, "缺少 message_stop" + + full_text = "".join(content_chunks) + print( + f"\n[E2E] v1internal streaming 成功: events={len(events)}, content='{full_text[:100]}'" + ) + + +@pytest.mark.e2e +@pytest.mark.asyncio +async def test_project_id_auto_discovery( + antigravity_vendor_v1internal: object, + minimal_request_body: dict, +) -> None: + """首次请求后 v1internal 模式状态和 project_id 发现结果.""" + resp = await antigravity_vendor_v1internal.send_message(minimal_request_body, {}) + + diag = antigravity_vendor_v1internal.get_diagnostics() + source = diag.get("project_id_source", "unknown") + is_v1 = diag.get("is_v1internal_mode", False) + + print(f"\n[E2E] project_id discovery: source={source}, is_v1internal={is_v1}") + + # v1internal 模式应已启用(由 base_url 配置驱动) + assert is_v1 is True, "v1internal 模式应已启用" + assert source in ("discovered", "none", "configured"), ( + f"未知的 project_id_source: {source}" + ) + + # 请求应到达了 API 端点(429 配额耗尽或 200 成功都说明协议对接正确) + assert resp.status_code in (200, 429), ( + f"预期 200/429,实际 {resp.status_code}: {resp.error_message[:200]}" + ) + + if resp.status_code == 429: + print(" 配额已耗尽 (429),但协议对接验证正确") + elif source == "discovered": + print(f" discovered_project_id={diag.get('discovered_project_id')}") + elif source == "none": + print(" 未发现 project_id,v1internal 无需 project_id") diff --git a/tests/test_antigravity.py b/tests/test_antigravity.py index 6256bfb..cc93127 100644 --- a/tests/test_antigravity.py +++ b/tests/test_antigravity.py @@ -384,12 +384,12 @@ def test_is_v1internal_mode_with_project_id_and_v1internal_url(): def test_is_v1internal_mode_without_project_id(): - """未配置 project_id 时即使 URL 含 v1internal 也不启用.""" + """v1internal 模式由 base_url 驱动,无需 project_id(与参考项目对齐).""" config = AntigravityConfig( base_url="https://cloudcode-pa.googleapis.com/v1internal", ) vendor = AntigravityVendor(config, FailoverConfig(), ModelMapper([])) - assert vendor._is_v1internal_mode() is False + assert vendor._is_v1internal_mode() is True def test_is_v1internal_mode_standard_gla_url(): @@ -527,7 +527,7 @@ async def test_discover_project_id_single_active_project(): assert result == "my-gcp-123" assert vendor._project_id_discovered == "my-gcp-123" - assert vendor._base_url == "https://cloudcode-pa.googleapis.com/v1internal" + assert vendor._base_url == "https://cloudcode-pa.googleapis.com" assert vendor._is_v1internal_mode() is True @@ -743,20 +743,19 @@ async def mock_discover(token): def test_is_v1internal_mode_uses_effective_project_id(): - """_is_v1internal_mode 应基于 _effective_project_id 判断.""" + """_is_v1internal_mode 应基于 base_url 判断(不再依赖 project_id).""" config = AntigravityConfig(base_url=_V1INTERNAL_BASE_URL) vendor = AntigravityVendor(config, FailoverConfig(), ModelMapper([])) - # 未配置、未发现 → False - assert vendor._is_v1internal_mode() is False + # base_url 含 v1internal → True(即使无 project_id) + assert vendor._is_v1internal_mode() is True - # 发现后 → True + # 发现 project_id 不影响 v1internal 模式判断 vendor._project_id_discovered = "found-it" assert vendor._is_v1internal_mode() is True - # 配置值覆盖发现值 + # 清除发现值也不影响 vendor._project_id_discovered = "" - vendor._project_id = "manual" assert vendor._is_v1internal_mode() is True