-
Notifications
You must be signed in to change notification settings - Fork 19
Set up generator class and auth for Data Engineering Agent #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
minzznguyen
wants to merge
11
commits into
GoogleCloudPlatform:main
Choose a base branch
from
minzznguyen:data_engineering_agent
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
aee3cd5
feat(dea): define DataEngineeringAgentGenerator and integrate A2A SDK…
f9eb773
Merge remote-tracking branch 'origin/main' into data_engineering_agent
008abf9
feat(dea): integrate A2A SDK dependency and add setup unit tests
cbad0bc
refactor(dea): remove hardcoded defaults and add config validation
e1bc087
style: remove comments from data_engineering_agent.py
61b7009
test(dea): verify credential scheme check and clean up imports
b4bcb2a
fix(dea): revert relative import of QueryGenerator in factory
d752578
feat(dea): throw auth errors and add error resilience tests
3c4e234
style: resolve all pycodestyle violations in tests
5c338fe
Merge branch 'main' into data_engineering_agent
minzznguyen 74f2933
feat(dea): make GCP ADC credential retrieval thread-safe and non-bloc…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| import asyncio | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| from a2a.client import ClientCallContext | ||
| from a2a.client.auth import AuthInterceptor, CredentialService | ||
| import google.auth | ||
| from google.auth.exceptions import DefaultCredentialsError, RefreshError | ||
| from google.auth.transport.requests import Request | ||
|
|
||
| from .generator import QueryGenerator | ||
|
|
||
|
|
||
| class GcpAdcCredentialService(CredentialService): | ||
| """GCP Application Default Credentials (ADC) service for A2A SDK. | ||
|
|
||
| This provider is concurrency-safe, non-blocking, and dynamically caches | ||
| tokens to prevent redundant refreshes on the hot path. It intentionally | ||
| only services OAuth/OAuth2 schemes. | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| self.logger = logging.getLogger(__name__) | ||
| self.credentials = None | ||
| self._lock = None | ||
|
|
||
| async def get_credentials( | ||
| self, | ||
| security_scheme_name: str, | ||
| context: ClientCallContext | None, | ||
| ) -> str: | ||
| if security_scheme_name.lower() not in ("oauth", "oauth2"): | ||
| raise ValueError( | ||
| f"GcpAdcCredentialService only services 'oauth' or 'oauth2' " | ||
| f"schemes, got '{security_scheme_name}'" | ||
| ) | ||
|
|
||
| if self._lock is None: | ||
| self._lock = asyncio.Lock() | ||
|
|
||
| try: | ||
| async with self._lock: | ||
| if self.credentials is None: | ||
| credentials, _ = await asyncio.to_thread( | ||
| google.auth.default, | ||
| scopes=[ | ||
| "https://www.googleapis.com/auth/cloud-platform" | ||
| ] | ||
| ) | ||
| self.credentials = credentials | ||
|
|
||
| if not self.credentials.valid: | ||
| await asyncio.to_thread( | ||
| self.credentials.refresh, Request() | ||
| ) | ||
|
|
||
| self.logger.debug("Retrieved GCP ADC token successfully.") | ||
| return self.credentials.token | ||
|
|
||
| except (DefaultCredentialsError, RefreshError) as e: | ||
| self.logger.error( | ||
| "Failed to retrieve or refresh GCP Application Default " | ||
| "Credentials: %s", | ||
| e, | ||
| ) | ||
| raise | ||
| except Exception as e: | ||
| self.logger.exception( | ||
| "Unexpected error while fetching GCP ADC credentials: %s", e | ||
| ) | ||
| raise | ||
|
|
||
|
|
||
| class DataEngineeringAgentGenerator(QueryGenerator): | ||
| """Data Engineering Agent (DEA) Query Generator using the A2A SDK.""" | ||
|
|
||
| def __init__(self, querygenerator_config: dict[str, Any]): | ||
| super().__init__(querygenerator_config) | ||
| self.name = "data_engineering_agent" | ||
| self.endpoint = querygenerator_config.get("endpoint", "") | ||
| self.target_workspace = querygenerator_config.get( | ||
| "target_workspace", "" | ||
| ) | ||
|
|
||
| if not self.endpoint: | ||
| raise ValueError( | ||
| "Configuration key 'endpoint' is required for " | ||
| "DataEngineeringAgentGenerator." | ||
| ) | ||
| if not self.target_workspace: | ||
| raise ValueError( | ||
| "Configuration key 'target_workspace' is required for " | ||
| "DataEngineeringAgentGenerator." | ||
| ) | ||
|
|
||
| self.logger = logging.getLogger(__name__) | ||
|
|
||
| self.auth_interceptor = AuthInterceptor(GcpAdcCredentialService()) | ||
| self.logger.info( | ||
| "A2A AuthInterceptor successfully configured with " | ||
| "GcpAdcCredentialService." | ||
| ) | ||
|
|
||
| def generate_internal(self, prompt: str) -> Any: | ||
| """Stubbed messaging logic for WIP scaffolding (Task 1.3).""" | ||
| raise NotImplementedError( | ||
| "Task 1.3 DEA A2A messaging logic in generate_internal is " | ||
| "not yet implemented." | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| import os | ||
| import sys | ||
| from unittest.mock import MagicMock, patch | ||
| import pytest | ||
| from google.auth.exceptions import DefaultCredentialsError, RefreshError | ||
|
|
||
| # Add generators path to system path | ||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
|
|
||
| from generators.models import get_generator # noqa: E402 | ||
| from generators.models.data_engineering_agent import ( # noqa: E402 | ||
| DataEngineeringAgentGenerator, | ||
| GcpAdcCredentialService, | ||
| ) | ||
|
|
||
|
|
||
| def test_data_engineering_agent_generator_setup(): | ||
| config = { | ||
| "generator": "data_engineering_agent", | ||
| "endpoint": ( | ||
| "https://geminidataanalytics.googleapis.com/v1/a2a/" | ||
| "projects/test/locations/us-west4/agents/" | ||
| "dataengineeringagent" | ||
| ), | ||
| "target_workspace": ( | ||
| "projects/test/locations/us-west4/repositories/" | ||
| "test-repo/workspaces/test-workspace" | ||
| ), | ||
| } | ||
|
|
||
| # Mock google.auth.default during initialization | ||
| with patch("google.auth.default") as mock_auth_default: | ||
| mock_creds = MagicMock() | ||
| mock_creds.valid = True | ||
| mock_auth_default.return_value = (mock_creds, "test-project") | ||
|
|
||
| generator = DataEngineeringAgentGenerator(config) | ||
|
|
||
| assert generator.name == "data_engineering_agent" | ||
| assert generator.endpoint == config["endpoint"] | ||
| assert generator.target_workspace == config["target_workspace"] | ||
| assert generator.auth_interceptor is not None | ||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| async def test_get_credentials_invalid_scheme(): | ||
| service = GcpAdcCredentialService() | ||
|
|
||
| with pytest.raises(ValueError) as excinfo: | ||
| await service.get_credentials("basic", None) | ||
|
|
||
| assert "only services 'oauth' or 'oauth2'" in str(excinfo.value) | ||
|
|
||
|
|
||
| def test_generator_setup_missing_endpoint(): | ||
| config = { | ||
| "generator": "data_engineering_agent", | ||
| "target_workspace": "projects/test-workspace", | ||
| } | ||
| with pytest.raises(ValueError) as excinfo: | ||
| DataEngineeringAgentGenerator(config) | ||
| assert "endpoint' is required" in str(excinfo.value) | ||
|
|
||
|
|
||
| def test_generator_setup_missing_workspace(): | ||
| config = { | ||
| "generator": "data_engineering_agent", | ||
| "endpoint": ( | ||
| "https://geminidataanalytics.googleapis.com/v1/a2a/" | ||
| "projects/test/locations/us-west4/agents/" | ||
| "dataengineeringagent" | ||
| ), | ||
| } | ||
| with pytest.raises(ValueError) as excinfo: | ||
| DataEngineeringAgentGenerator(config) | ||
| assert "target_workspace' is required" in str(excinfo.value) | ||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| @patch("google.auth.default") | ||
| async def test_get_credentials_error_resiliency_default(mock_auth_default): | ||
| mock_auth_default.side_effect = DefaultCredentialsError( | ||
| "Credentials missing." | ||
| ) | ||
| service = GcpAdcCredentialService() | ||
|
|
||
| with pytest.raises(DefaultCredentialsError): | ||
| await service.get_credentials("oauth", None) | ||
|
|
||
|
|
||
| @pytest.mark.anyio | ||
| @patch("google.auth.default") | ||
| async def test_get_credentials_error_resiliency_refresh(mock_auth_default): | ||
| mock_creds = MagicMock() | ||
| mock_creds.valid = False | ||
| mock_creds.refresh.side_effect = RefreshError("Network timed out.") | ||
| mock_auth_default.return_value = (mock_creds, "test-project") | ||
|
|
||
| service = GcpAdcCredentialService() | ||
|
|
||
| with pytest.raises(RefreshError): | ||
| await service.get_credentials("oauth", None) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,6 +50,7 @@ dependencies = [ | |
| "dbt-core", | ||
| "dbt-bigquery", | ||
| "dbt-postgres", | ||
| "a2a-sdk>=1.0.3", | ||
| ] | ||
|
|
||
| [tool.uv.sources] | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also add tests for the bad paths, i.e. when endpoint and/or target_workspace are missing and ensure the value error is being raised
similarly, let's also try to cover the bad paths in GcpAdcCredentialService.get_credentials