Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ docstring-code-format = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
addopts = "-v --tb=short"
addopts = "-v --tb=short -m 'not e2e'"
markers = [
"e2e: marks tests as end-to-end (deselect with '-m \"not e2e\"')",
]
filterwarnings = [
"ignore::DeprecationWarning",
]
51 changes: 35 additions & 16 deletions src/coding/proxy/vendors/antigravity.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,22 @@ def __init__(
config.refresh_token,
)
TokenBackendMixin.__init__(self, token_manager)
BaseVendor.__init__(self, config.base_url, config.timeout_ms, failover_config)
# v1internal 模式:base_url 需要去除 /v1internal 路径后缀,
# 因为 endpoint 使用完整路径 /v1internal:generateContent(冒号格式)。
# httpx 会将 base_url path 与 endpoint path 拼接,
# 如果 base_url 含 /v1internal 会导致路径重复。
init_base_url = config.base_url
if init_base_url.rstrip("/").endswith("/v1internal"):
init_base_url = init_base_url.rstrip("/").removesuffix("/v1internal")
BaseVendor.__init__(self, init_base_url, config.timeout_ms, failover_config)
self._model_endpoint = config.model_endpoint
self._model_mapper = model_mapper
self._default_model = config.model_endpoint.removeprefix("models/")
self._last_request_adaptations: list[str] = []
self._safety_settings = config.safety_settings
# v1internal 协议字段
self._project_id: str = config.project_id
self._v1internal_enabled: bool = "v1internal" in config.base_url
self._session_id: str = uuid.uuid4().hex[:16]
self._message_count: int = 0
# project_id 自动发现状态
Expand All @@ -159,8 +167,11 @@ def get_name(self) -> str:
return "antigravity"

def _is_v1internal_mode(self) -> bool:
"""检测是否启用 v1internal 协议模式(与 Antigravity-Manager 对齐)."""
return bool(self._effective_project_id) and "v1internal" in self._base_url
"""检测是否启用 v1internal 协议模式(与 Antigravity-Manager 对齐).

v1internal 协议由原始配置的 base_url 路径或 project_id 自动发现触发。
"""
return self._v1internal_enabled

@property
def _effective_project_id(self) -> str:
Expand Down Expand Up @@ -229,7 +240,11 @@ async def _discover_project_id(self, access_token: str) -> str:
return ""

# 发现成功:原子性切换到 v1internal 模式
self._base_url = _V1INTERNAL_BASE_URL
# base_url 只保留域名部分(去除 /v1internal 路径后缀)
self._base_url = _V1INTERNAL_BASE_URL.rstrip("/").removesuffix(
"/v1internal"
)
self._v1internal_enabled = True
self._project_id_discovered = project_id

# 重建 HTTP 客户端(base_url 是初始化参数)
Expand Down Expand Up @@ -339,8 +354,13 @@ async def _prepare_request(
self._last_request_adaptations = converted.adaptations
token = await self._token_manager.get_token()

# 懒加载:未配置 project_id 时自动发现并切换 v1internal 模式
if not self._project_id and not self._project_discovery_attempted:
# 懒加载:未配置 project_id 时尝试自动发现(仅标准 GLA 模式需要)
# v1internal 模式不依赖 project_id,跳过发现
if (
not self._project_id
and not self._project_discovery_attempted
and not self._v1internal_enabled
):
discovered = await self._discover_project_id(token)
if discovered:
logger.info(
Expand Down Expand Up @@ -450,11 +470,11 @@ async def send_message(
body, prepared_headers = await self._prepare_request(request_body, headers)
client = self._get_client()
resolved_model = self._last_resolved_model
endpoint = (
":generateContent"
if self._is_v1internal_mode()
else f"/models/{resolved_model}:generateContent"
)
if self._is_v1internal_mode():
# v1internal 端点需要完整路径(冒号格式)覆盖 base_url 的 path 部分
endpoint = "/v1internal:generateContent"
else:
endpoint = f"/models/{resolved_model}:generateContent"

logger.debug("send_message: POST %s", endpoint)
response = await client.post(endpoint, json=body, headers=prepared_headers)
Expand Down Expand Up @@ -496,11 +516,10 @@ async def send_message_stream(
body, prepared_headers = await self._prepare_request(request_body, headers)
client = self._get_client()
resolved_model = self._last_resolved_model
endpoint = (
":streamGenerateContent?alt=sse"
if self._is_v1internal_mode()
else f"/models/{resolved_model}:streamGenerateContent?alt=sse"
)
if self._is_v1internal_mode():
endpoint = "/v1internal:streamGenerateContent?alt=sse"
else:
endpoint = f"/models/{resolved_model}:streamGenerateContent?alt=sse"

logger.debug("send_message_stream: POST %s", endpoint)

Expand Down
Empty file added tests/e2e/__init__.py
Empty file.
199 changes: 199 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""E2E 集成测试共享 fixtures — Antigravity 真实凭证加载与测试对象构建."""

from __future__ import annotations

import os
from typing import Any

import pytest

# ── 模块级门控:未设置环境变量时跳过整个 e2e 包 ──

_SKIP_REASON = "Set RUN_ANTIGRAVITY_E2E=1 to enable Antigravity E2E tests"


def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers", "e2e: End-to-end tests requiring real Antigravity credentials"
)


def _load_real_credentials() -> dict[str, str] | None:
"""从 ~/.coding-proxy/ 加载真实的 Google OAuth 凭证."""
from coding.proxy.auth.providers.google import (
_DEFAULT_CLIENT_ID,
_DEFAULT_CLIENT_SECRET,
)
from coding.proxy.auth.store import TokenStoreManager
from coding.proxy.config.loader import load_config

try:
token_store = TokenStoreManager()
token_store.load()
google_tokens = token_store.get("google")
if not google_tokens.refresh_token:
return None

config = load_config()

# 从 vendors 列表查找 antigravity 配置
client_id = ""
client_secret = ""
base_url = ""
model_endpoint = "models/claude-sonnet-4-20250514"
project_id = ""

for vc in config.vendors:
if vc.vendor == "antigravity":
client_id = vc.client_id or _DEFAULT_CLIENT_ID
client_secret = vc.client_secret or _DEFAULT_CLIENT_SECRET
base_url = (
vc.base_url or "https://generativelanguage.googleapis.com/v1beta"
)
model_endpoint = vc.model_endpoint or model_endpoint
break

# 优先使用 config.yaml 中的 refresh_token,否则使用 token store
refresh_token = ""
for vc in config.vendors:
if vc.vendor == "antigravity" and vc.refresh_token:
refresh_token = vc.refresh_token
break
if not refresh_token:
refresh_token = google_tokens.refresh_token

return {
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
"base_url": base_url,
"model_endpoint": model_endpoint,
"project_id": project_id,
}
except Exception:
return None


# ── Fixtures ──


@pytest.fixture(scope="session")
def e2e_credentials() -> dict[str, str]:
"""加载真实 Antigravity OAuth 凭证,失败则跳过."""
if os.environ.get("RUN_ANTIGRAVITY_E2E") != "1":
pytest.skip(_SKIP_REASON)
creds = _load_real_credentials()
if creds is None:
pytest.skip("No valid Antigravity credentials found in ~/.coding-proxy/")
return creds


@pytest.fixture(scope="session")
def antigravity_config(e2e_credentials: dict[str, str]) -> Any:
"""构建标准 GLA 模式的 AntigravityConfig."""
from coding.proxy.config.vendors import AntigravityConfig

return AntigravityConfig(
enabled=True,
client_id=e2e_credentials["client_id"],
client_secret=e2e_credentials["client_secret"],
refresh_token=e2e_credentials["refresh_token"],
base_url=e2e_credentials["base_url"],
model_endpoint=e2e_credentials["model_endpoint"],
timeout_ms=60000,
)


@pytest.fixture(scope="session")
def antigravity_config_v1internal(e2e_credentials: dict[str, str]) -> Any:
"""构建 v1internal 模式的 AntigravityConfig(无 project_id,触发自动发现)."""
from coding.proxy.config.vendors import AntigravityConfig

return AntigravityConfig(
enabled=True,
client_id=e2e_credentials["client_id"],
client_secret=e2e_credentials["client_secret"],
refresh_token=e2e_credentials["refresh_token"],
base_url="https://cloudcode-pa.googleapis.com/v1internal",
model_endpoint=e2e_credentials["model_endpoint"],
timeout_ms=60000,
)


@pytest.fixture
async def antigravity_vendor(antigravity_config: Any) -> Any:
"""构建标准 GLA 模式的 AntigravityVendor(function scope,每次测试独立)."""
from coding.proxy.config.schema import FailoverConfig
from coding.proxy.routing.model_mapper import ModelMapper
from coding.proxy.vendors.antigravity import AntigravityVendor

vendor = AntigravityVendor(antigravity_config, FailoverConfig(), ModelMapper([]))
yield vendor
await vendor.close()


@pytest.fixture
async def antigravity_vendor_v1internal(antigravity_config_v1internal: Any) -> Any:
"""构建 v1internal 模式的 AntigravityVendor."""
from coding.proxy.config.schema import FailoverConfig
from coding.proxy.routing.model_mapper import ModelMapper
from coding.proxy.vendors.antigravity import AntigravityVendor

vendor = AntigravityVendor(
antigravity_config_v1internal, FailoverConfig(), ModelMapper([])
)
yield vendor
await vendor.close()


@pytest.fixture
def minimal_request_body() -> dict[str, Any]:
"""最小 Anthropic 格式请求体(用于最小化 token 消耗)."""
return {
"model": "claude-sonnet-4-20250514",
"messages": [{"role": "user", "content": "Say exactly: pong"}],
"max_tokens": 32,
}


@pytest.fixture(scope="session")
def e2e_app(e2e_credentials: dict[str, str]) -> Any:
"""构建仅启用 Antigravity 的 FastAPI 应用(临时 DB)."""
import tempfile

from coding.proxy.config.schema import ProxyConfig
from coding.proxy.server.app import create_app

tmpdir = tempfile.mkdtemp(prefix="e2e-antigravity-")
db_path = os.path.join(tmpdir, "usage.db")
compat_path = os.path.join(tmpdir, "compat.db")

config = ProxyConfig(
vendors=[
{
"vendor": "antigravity",
"enabled": True,
"client_id": e2e_credentials["client_id"],
"client_secret": e2e_credentials["client_secret"],
"refresh_token": e2e_credentials["refresh_token"],
"base_url": "https://cloudcode-pa.googleapis.com/v1internal",
"model_endpoint": e2e_credentials["model_endpoint"],
"timeout_ms": 60000,
},
],
tiers=["antigravity"],
database={"path": db_path, "compat_state_path": compat_path},
)
return create_app(config)


@pytest.fixture
async def e2e_client(e2e_app: Any) -> Any:
"""构建异步 HTTP 客户端(支持 SSE 流式测试)."""
import httpx

transport = httpx.ASGITransport(app=e2e_app)
async with httpx.AsyncClient(
transport=transport, base_url="http://test", timeout=60.0
) as client:
yield client
Loading
Loading