Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed Diffguard.zip
Binary file not shown.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ services:
context: ./services/gateway
dockerfile: Dockerfile
ports:
- "8090:8080" # webhook server
- "8080:8080" # webhook server
- "9090:9090" # tool server
- "9091:9091" # metrics (prometheus)
environment:
Expand Down
2 changes: 1 addition & 1 deletion services/agent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
COPY pyproject.toml ./

# Install dependencies
RUN uv pip install --system -i https://pypi.tuna.tsinghua.edu.cn/simple -r pyproject.toml
RUN uv pip install --system -r pyproject.toml

# Copy application code
COPY app/ ./app/
Expand Down
2 changes: 1 addition & 1 deletion services/agent/app/agent/multi_agent_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from app.agent.base import AgentReviewResult
from app.agent.memory import AgentMemory
from app.agent.utils import create_llm as _create_llm
from app.agent.pipeline_orchestrator import _create_llm
from app.agent.registry import AgentRegistry
from app.agent.strategy_planner import StrategyPlanner
from app.models.schemas import (
Expand Down
2 changes: 1 addition & 1 deletion services/agent/app/agent/pipeline/stages/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from app.models.schemas import IssuePayload
from app.agent.pipeline.stages.base import PipelineContext, PipelineStage
from app.agent.utils import load_prompt as _load_prompt
from app.agent.pipeline_orchestrator import _load_prompt

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion services/agent/app/agent/pipeline/stages/reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from app.models.schemas import IssuePayload
from app.agent.pipeline.stages.base import PipelineContext, PipelineStage
from app.agent.utils import load_prompt as _load_prompt
from app.agent.pipeline_orchestrator import _load_prompt
from app.tools.definitions import (
make_call_graph_tool,
make_diff_context_tool,
Expand Down
2 changes: 1 addition & 1 deletion services/agent/app/agent/pipeline/stages/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field

from app.agent.pipeline.stages.base import PipelineContext, PipelineStage
from app.agent.utils import load_prompt as _load_prompt
from app.agent.pipeline_orchestrator import _load_prompt

logger = logging.getLogger(__name__)

Expand Down
55 changes: 54 additions & 1 deletion services/agent/app/agent/pipeline_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import logging
from typing import Any

from app.agent.utils import create_llm as _create_llm, load_prompt as _load_prompt
from langchain_core.language_models.chat_models import BaseChatModel

from app.agent.pipeline.stages.base import PipelineContext, PipelineStage
from app.agent.pipeline.stages.summary import SummaryStage
from app.agent.pipeline.stages.reviewer import ReviewerStage
Expand All @@ -24,6 +25,55 @@

logger = logging.getLogger(__name__)

_PROMPTS_DIR = "app/prompts"


# --- Shared utilities (used by stages and agents) ---


def _create_llm(config: Any) -> BaseChatModel:
"""Create a LangChain ChatModel from the LLM config in the request."""
llm_cfg = config.llm_config
if llm_cfg.provider == "claude":
from langchain_anthropic import ChatAnthropic

kwargs: dict[str, Any] = {
"model": llm_cfg.model,
"max_tokens": llm_cfg.max_tokens,
"temperature": llm_cfg.temperature,
"timeout": llm_cfg.timeout_seconds,
}
if llm_cfg.api_key:
kwargs["api_key"] = llm_cfg.api_key
if llm_cfg.base_url:
kwargs["anthropic_api_url"] = llm_cfg.base_url
return ChatAnthropic(**kwargs)
else:
from langchain_openai import ChatOpenAI

kwargs = {
"model": llm_cfg.model,
"max_tokens": llm_cfg.max_tokens,
"temperature": llm_cfg.temperature,
"timeout": llm_cfg.timeout_seconds,
}
if llm_cfg.api_key:
kwargs["api_key"] = llm_cfg.api_key
if llm_cfg.base_url:
kwargs["base_url"] = llm_cfg.base_url
return ChatOpenAI(**kwargs)


def _load_prompt(name: str) -> str:
"""Load a prompt template from the prompts directory."""
from pathlib import Path

path = Path(_PROMPTS_DIR) / name
return path.read_text(encoding="utf-8")


# --- Default pipeline builder ---


def build_default_pipeline() -> list[PipelineStage]:
"""Build the standard 3-stage review pipeline."""
Expand All @@ -34,6 +84,9 @@ def build_default_pipeline() -> list[PipelineStage]:
]


# --- Orchestrator ---


class PipelineOrchestrator:
"""Composable multi-stage pipeline orchestrator.

Expand Down
49 changes: 0 additions & 49 deletions services/agent/app/agent/utils.py

This file was deleted.

1 change: 0 additions & 1 deletion services/gateway/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ COPY target/diffguard-*.jar diffguard.jar
EXPOSE 8080 9090

ENTRYPOINT ["java", "-jar", "diffguard.jar"]
CMD ["server"]
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void run() {
*/
public static int execute(String[] args) {
DiffGuardMain main = new DiffGuardMain();
int exitCode = new CommandLine(main).execute(args);
return exitCode != 0 ? exitCode : main.getExitCode();
new CommandLine(main).execute(args);
return main.getExitCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ public ReviewResult executeBatch(List<PromptBuilder.PromptContent> prompts,
Function<PromptBuilder.PromptContent, LlmResponse> singlePromptRunner) throws LlmApiException {
ReviewResult result = new ReviewResult();

if (prompts.isEmpty()) {
return result;
}

List<Future<LlmResponse>> futures = new ArrayList<>();
for (int i = 0; i < prompts.size(); i++) {
final int idx = i;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,7 @@ public static ReviewTaskMessage fromJson(byte[] data) {
*/
public static String extractTaskId(byte[] data) {
try {
if (data == null || data.length == 0) {
return null;
}
String value = MAPPER.readTree(data).path("task_id").asText(null);
return value != null && !value.isEmpty() ? value : null;
return MAPPER.readTree(data).path("task_id").asText();
} catch (Exception e) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ private PromptLoader() {}
* @return 模板内容
*/
public static String load(String resourcePath, String fallback) {
if (resourcePath == null || resourcePath.isEmpty()) {
return fallback;
}
try (InputStream is = PromptLoader.class.getResourceAsStream(resourcePath)) {
if (is == null) {
log.warn("Prompt 模板未找到: {},使用回退内容", resourcePath);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void doesNotRemoveNonDiffGuardHooks() throws IOException {
// Create a non-DiffGuard hook
Path hooksDir = gitRepo.resolve(".git/hooks");
Path customHook = hooksDir.resolve("pre-commit");
String customContent = "#!/bin/sh\n# Custom hook - my own script\necho custom";
String customContent = "#!/bin/sh\n# Custom hook - not DiffGuard\necho custom";
Files.writeString(customHook, customContent);

GitHookInstaller.uninstall(gitRepo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void findsCallers() {
ToolResult result = tool.execute("callers process", context);

assertTrue(result.isSuccess());
assertTrue(result.getOutput().contains("handle"));
assertTrue(result.getOutput().contains("Controller.handle"));
}

@Test
Expand Down Expand Up @@ -152,7 +152,7 @@ void computesImpactForClassDotMethod() {
ToolResult result = tool.execute("impact Service.process", context);

assertTrue(result.isSuccess());
assertTrue(result.getOutput().contains("handle"));
assertTrue(result.getOutput().contains("Controller.handle"));
}

@Test
Expand All @@ -161,7 +161,7 @@ void computesImpactForMethodName() {
ToolResult result = tool.execute("impact save", context);

assertTrue(result.isSuccess());
assertTrue(result.getOutput().contains("process"));
assertTrue(result.getOutput().contains("Service.process"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class RegisterProvider {
@Test
@DisplayName("注册自定义 Provider 后可获取")
void registerAndRetrieve() {
// Create a mock provider for Rust
LanguageASTProvider rustProvider = new LanguageASTProvider() {
// Create a mock provider for Python
LanguageASTProvider pythonProvider = new LanguageASTProvider() {
@Override
public Language language() { return Language.RUST; }
public Language language() { return Language.PYTHON; }

@Override
public List<ASTNodeInfo> parse(String sourceCode, String filePath) {
Expand All @@ -127,8 +127,8 @@ public List<CallEdgeInfo> extractCallEdges(String sourceCode, String filePath) {
}
};

ASTProviderRegistry.register(rustProvider);
Optional<LanguageASTProvider> found = ASTProviderRegistry.getProvider(Language.RUST);
ASTProviderRegistry.register(pythonProvider);
Optional<LanguageASTProvider> found = ASTProviderRegistry.getProvider(Language.PYTHON);

assertTrue(found.isPresent());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ class ReviewTest {
@Test
@DisplayName("数据库未启用时抛出异常")
void databaseNotEnabledThrows() {
ReviewConfig.LlmConfig llmConfig = new ReviewConfig.LlmConfig();
llmConfig.setTimeoutSeconds(30);
when(config.getLlm()).thenReturn(llmConfig);

ReviewConfig.DatabaseConfigHolder dbConfig = mock(ReviewConfig.DatabaseConfigHolder.class);
when(dbConfig.isEnabled()).thenReturn(false);
when(config.getDatabase()).thenReturn(dbConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ void emptyContentReturnsEmpty() {
@DisplayName("多条规则同时触发")
void multipleRulesTriggered() {
String content = "Runtime.getRuntime().exec(cmd);\n" +
"String query = \" + userId + \" FROM users WHERE id = 1;";
"String query = \"SELECT * FROM users WHERE id = \" + userId;";
List<ReviewIssue> issues = engine.scan(List.of(entry("Service.java", content)));

assertTrue(issues.size() >= 2, "Should have at least 2 issues from different rules");
Expand All @@ -279,7 +279,7 @@ void multipleRulesTriggered() {
@Test
@DisplayName("多个文件分别扫描")
void multipleFilesScanned() {
String sqlContent = "String query = \" + id + \" FROM t WHERE id = 1;";
String sqlContent = "String query = \"SELECT * FROM t WHERE id = \" + id;";
String safeContent = "int x = 1;";
List<ReviewIssue> issues = engine.scan(List.of(
entry("A.java", sqlContent),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ void returnsNonNullConfig() {

assertNotNull(config);
assertNotNull(config.getLlm());
// application.yml 中配置的是 claude provider
assertEquals("claude", config.getLlm().getProvider());
// application.yml 中配置的是 openai provider
assertEquals("openai", config.getLlm().getProvider());
}

@Test
Expand Down Expand Up @@ -119,8 +119,8 @@ void fallsBackToDefaults() throws Exception {
ReviewConfig config = ConfigLoader.load(tempDir);

assertNotNull(config);
// application.yml 中默认 provider 是 claude
assertEquals("claude", config.getLlm().getProvider());
// application.yml 中默认 provider 是 openai
assertEquals("openai", config.getLlm().getProvider());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ void saveResultTruncatesLongFilePath() throws SQLException {

repository.saveResult("task-5", result);

verify(preparedStatement).setString(eq(4), argThat(s -> s != null && s.length() <= 500));
verify(preparedStatement).setString(4, argThat(s -> s != null && s.length() <= 500));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void updateErrorTruncatesLongMessage() throws SQLException {
String longMessage = "x".repeat(5000);
repository.updateError("task-11", longMessage);

verify(preparedStatement).setString(eq(1), argThat(msg ->
verify(preparedStatement).setString(1, argThat(msg ->
msg != null && msg.length() <= 4000));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void rulesSectionPopulated() {
String userPrompt = prompts.get(0).getUserPrompt();
// 默认配置启用 security, bug-risk, code-style, performance
assertTrue(userPrompt.contains("安全"));
assertTrue(userPrompt.contains("逻辑"));
assertTrue(userPrompt.contains("逻辑错误"));
assertTrue(userPrompt.contains("代码质量"));
assertTrue(userPrompt.contains("性能"));
}
Expand All @@ -89,7 +89,7 @@ void languageSettingPopulated() {
DiffFileEntry entry = makeEntry("A.java", "content");
List<PromptBuilder.PromptContent> prompts = builder.buildPrompts(List.of(entry));

assertFalse(prompts.get(0).getUserPrompt().isEmpty());
assertTrue(prompts.get(0).getUserPrompt().contains("zh"));
}

@Test
Expand Down Expand Up @@ -122,6 +122,7 @@ void systemPromptNotEmpty() {

String sys = prompts.get(0).getSystemPrompt();
assertFalse(sys.isBlank());
assertTrue(sys.contains("JSON"));
}
}

Expand Down
Loading
Loading