Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions graph_net/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from Hugging Face models.
"""

from graph_net.agent.graph_net_agent import GraphNetAgent
from graph_net.agent.graph_net_agent import ExtractionStatus, GraphNetAgent

__all__ = ["GraphNetAgent"]
__all__ = ["GraphNetAgent", "ExtractionStatus"]
6 changes: 3 additions & 3 deletions graph_net/agent/graph_extractor/subprocess_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
class SubprocessGraphExtractor(BaseGraphExtractor):
"""Extractor that runs script in subprocess"""

def __init__(self, workspace: str, timeout: int = DEFAULT_TIMEOUT):
def __init__(self, workspace: str, timeout: int | None = None):
"""
Args:
workspace: Workspace root directory
timeout: Timeout in seconds for script execution
timeout: Timeout in seconds for script execution (default 1000s)
"""
self.workspace = Path(workspace)
self.timeout = timeout
self.timeout = timeout if timeout is not None else DEFAULT_TIMEOUT
self.logger = logging.getLogger(self.__class__.__name__)

def extract(self, code_path: Path, model_id: str) -> Path:
Expand Down
53 changes: 38 additions & 15 deletions graph_net/agent/graph_net_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
from enum import Enum
from pathlib import Path
from typing import Optional

Expand All @@ -23,6 +24,15 @@
from graph_net.agent.sample_verifier import ForwardVerifier


class ExtractionStatus(str, Enum):
"""Extraction result status for a single model."""

OK = "ok"
VERIFY_FAILED = "verify_failed"
EXTRACT_FAILED = "extract_failed"
ERROR = "error"


class GraphNetAgent:
"""GraphNet automatic sample extraction agent"""

Expand All @@ -31,16 +41,22 @@ def __init__(
workspace: Optional[str] = None,
hf_token: Optional[str] = None,
llm_retry: bool = True,
extract_timeout: Optional[int] = None,
verify_timeout: Optional[int] = None,
):
"""
Initialize GraphNet Agent

Args:
workspace: Workspace root directory. Defaults to
$GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace.
hf_token: HuggingFace API token (optional)
llm_retry: If True and ducc/claude CLI is available, retry failed
extractions up to 2 times with LLM-fixed scripts.
workspace: Workspace root directory. Defaults to
$GRAPH_NET_EXTRACT_WORKSPACE or ~/graphnet_workspace.
hf_token: HuggingFace API token (optional)
llm_retry: If True and ducc/claude CLI is available, retry failed
extractions up to 2 times with LLM-fixed scripts.
extract_timeout: Timeout in seconds for graph extraction subprocess
(default None -> 1000s).
verify_timeout: Timeout in seconds for forward verification subprocess
(default None -> 300s).
"""
if workspace is None:
workspace = os.environ.get(
Expand All @@ -63,14 +79,15 @@ def __init__(
self.metadata_analyzer = ConfigMetadataAnalyzer()
self.code_generator = TemplateCodeGenerator()
self.graph_extractor = SubprocessGraphExtractor(
workspace=str(self.workspace.workspace_root)
workspace=str(self.workspace.workspace_root),
timeout=extract_timeout,
)
self.sample_verifier = ForwardVerifier()
self.sample_verifier = ForwardVerifier(timeout=verify_timeout)

# LLM fixer — only created when llm_retry is requested
self.llm_fixer: Optional[LLMCodeFixer] = LLMCodeFixer() if llm_retry else None

def extract_sample(self, model_id: str) -> bool:
def extract_sample(self, model_id: str) -> ExtractionStatus:
"""
Execute complete sample extraction pipeline from HuggingFace model ID.

Expand All @@ -82,7 +99,10 @@ def extract_sample(self, model_id: str) -> bool:
model_id: HuggingFace model ID (e.g., "bert-base-uncased")

Returns:
True if sample extraction succeeded, False otherwise
ExtractionStatus.OK – extraction and verification both passed
ExtractionStatus.VERIFY_FAILED – extraction succeeded but verification failed
ExtractionStatus.EXTRACT_FAILED – extraction (or pre-extraction) failed
ExtractionStatus.ERROR – unexpected error
"""
try:
self.logger.info(f"Starting extraction for model: {model_id}")
Expand All @@ -104,21 +124,24 @@ def extract_sample(self, model_id: str) -> bool:

if self.is_duplicate_sample(sample_dir):
self.logger.info("Duplicate sample detected, skipping verification")
return True
return ExtractionStatus.OK

if not self.sample_verifier.verify(sample_dir):
self.logger.error("Sample verification failed")
return False
return ExtractionStatus.VERIFY_FAILED

self.logger.info(f"Successfully extracted sample for {model_id}")
return True
return ExtractionStatus.OK

except (AnalysisError, CodeGenError, ExtractionError, VerificationError) as e:
except VerificationError as e:
self.logger.error(f"Extraction failed for {model_id}: {e}")
return False
return ExtractionStatus.VERIFY_FAILED
except (AnalysisError, CodeGenError, ExtractionError) as e:
self.logger.error(f"Extraction failed for {model_id}: {e}")
return ExtractionStatus.EXTRACT_FAILED
except Exception as e:
self.logger.error(f"Unexpected error for {model_id}: {e}", exc_info=True)
return False
return ExtractionStatus.ERROR

def _llm_retry(
self,
Expand Down
Loading
Loading