Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions evalbench/generators/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .gemini_cli import GeminiCliGenerator
from .claude_code import ClaudeCodeGenerator
from .codex_cli import CodexCliGenerator
from .data_engineering_agent import DataEngineeringAgentGenerator
from util.config import load_yaml_config


Expand Down Expand Up @@ -42,6 +43,8 @@ def get_generator(global_models, model_config_path: str, db: DB = None):
model = ClaudeCodeGenerator(config)
if config["generator"] == "codex_cli":
model = CodexCliGenerator(config)
if config["generator"] == "data_engineering_agent":
model = DataEngineeringAgentGenerator(config)
if not model:
raise ValueError(f"Unknown Generator {config['generator']}")

Expand Down
109 changes: 109 additions & 0 deletions evalbench/generators/models/data_engineering_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import asyncio
import logging
from typing import Any

from a2a.client import ClientCallContext
from a2a.client.auth import AuthInterceptor, CredentialService
import google.auth
from google.auth.exceptions import DefaultCredentialsError, RefreshError
from google.auth.transport.requests import Request

from .generator import QueryGenerator


class GcpAdcCredentialService(CredentialService):
"""GCP Application Default Credentials (ADC) service for A2A SDK.

This provider is concurrency-safe, non-blocking, and dynamically caches
tokens to prevent redundant refreshes on the hot path. It intentionally
only services OAuth/OAuth2 schemes.
"""

def __init__(self):
self.logger = logging.getLogger(__name__)
self.credentials = None
self._lock = None

async def get_credentials(
self,
security_scheme_name: str,
context: ClientCallContext | None,
) -> str:
if security_scheme_name.lower() not in ("oauth", "oauth2"):
raise ValueError(
f"GcpAdcCredentialService only services 'oauth' or 'oauth2' "
f"schemes, got '{security_scheme_name}'"
)

if self._lock is None:
self._lock = asyncio.Lock()

try:
async with self._lock:
if self.credentials is None:
credentials, _ = await asyncio.to_thread(
google.auth.default,
scopes=[
"https://www.googleapis.com/auth/cloud-platform"
]
)
self.credentials = credentials

if not self.credentials.valid:
await asyncio.to_thread(
self.credentials.refresh, Request()
)

self.logger.debug("Retrieved GCP ADC token successfully.")
return self.credentials.token

except (DefaultCredentialsError, RefreshError) as e:
self.logger.error(
"Failed to retrieve or refresh GCP Application Default "
"Credentials: %s",
e,
)
raise
except Exception as e:
self.logger.exception(
"Unexpected error while fetching GCP ADC credentials: %s", e
)
raise


class DataEngineeringAgentGenerator(QueryGenerator):
"""Data Engineering Agent (DEA) Query Generator using the A2A SDK."""

def __init__(self, querygenerator_config: dict[str, Any]):
super().__init__(querygenerator_config)
self.name = "data_engineering_agent"
self.endpoint = querygenerator_config.get("endpoint", "")
self.target_workspace = querygenerator_config.get(
"target_workspace", ""
)

if not self.endpoint:
raise ValueError(
"Configuration key 'endpoint' is required for "
"DataEngineeringAgentGenerator."
)
if not self.target_workspace:
raise ValueError(
"Configuration key 'target_workspace' is required for "
"DataEngineeringAgentGenerator."
)

self.logger = logging.getLogger(__name__)

self.auth_interceptor = AuthInterceptor(GcpAdcCredentialService())
self.logger.info(
"A2A AuthInterceptor successfully configured with "
"GcpAdcCredentialService."
)

def generate_internal(self, prompt: str) -> Any:
"""Stubbed messaging logic for WIP scaffolding (Task 1.3)."""
raise NotImplementedError(
"Task 1.3 DEA A2A messaging logic in generate_internal is "
"not yet implemented."
)
102 changes: 102 additions & 0 deletions evalbench/test/data_engineering_agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
from google.auth.exceptions import DefaultCredentialsError, RefreshError

# Add generators path to system path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from generators.models import get_generator # noqa: E402
from generators.models.data_engineering_agent import ( # noqa: E402
DataEngineeringAgentGenerator,
GcpAdcCredentialService,
)


def test_data_engineering_agent_generator_setup():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add tests for the bad paths, i.e. when endpoint and/or target_workspace are missing and ensure the value error is being raised

similarly, let's also try to cover the bad paths in GcpAdcCredentialService.get_credentials

config = {
"generator": "data_engineering_agent",
"endpoint": (
"https://geminidataanalytics.googleapis.com/v1/a2a/"
"projects/test/locations/us-west4/agents/"
"dataengineeringagent"
),
"target_workspace": (
"projects/test/locations/us-west4/repositories/"
"test-repo/workspaces/test-workspace"
),
}

# Mock google.auth.default during initialization
with patch("google.auth.default") as mock_auth_default:
mock_creds = MagicMock()
mock_creds.valid = True
mock_auth_default.return_value = (mock_creds, "test-project")

generator = DataEngineeringAgentGenerator(config)

assert generator.name == "data_engineering_agent"
assert generator.endpoint == config["endpoint"]
assert generator.target_workspace == config["target_workspace"]
assert generator.auth_interceptor is not None


@pytest.mark.anyio
async def test_get_credentials_invalid_scheme():
service = GcpAdcCredentialService()

with pytest.raises(ValueError) as excinfo:
await service.get_credentials("basic", None)

assert "only services 'oauth' or 'oauth2'" in str(excinfo.value)


def test_generator_setup_missing_endpoint():
config = {
"generator": "data_engineering_agent",
"target_workspace": "projects/test-workspace",
}
with pytest.raises(ValueError) as excinfo:
DataEngineeringAgentGenerator(config)
assert "endpoint' is required" in str(excinfo.value)


def test_generator_setup_missing_workspace():
config = {
"generator": "data_engineering_agent",
"endpoint": (
"https://geminidataanalytics.googleapis.com/v1/a2a/"
"projects/test/locations/us-west4/agents/"
"dataengineeringagent"
),
}
with pytest.raises(ValueError) as excinfo:
DataEngineeringAgentGenerator(config)
assert "target_workspace' is required" in str(excinfo.value)


@pytest.mark.anyio
@patch("google.auth.default")
async def test_get_credentials_error_resiliency_default(mock_auth_default):
mock_auth_default.side_effect = DefaultCredentialsError(
"Credentials missing."
)
service = GcpAdcCredentialService()

with pytest.raises(DefaultCredentialsError):
await service.get_credentials("oauth", None)


@pytest.mark.anyio
@patch("google.auth.default")
async def test_get_credentials_error_resiliency_refresh(mock_auth_default):
mock_creds = MagicMock()
mock_creds.valid = False
mock_creds.refresh.side_effect = RefreshError("Network timed out.")
mock_auth_default.return_value = (mock_creds, "test-project")

service = GcpAdcCredentialService()

with pytest.raises(RefreshError):
await service.get_credentials("oauth", None)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"dbt-core",
"dbt-bigquery",
"dbt-postgres",
"a2a-sdk>=1.0.3",
]

[tool.uv.sources]
Expand Down
Loading