From 192a38a6b1ee658664ba67bab46c2933b7cd5d36 Mon Sep 17 00:00:00 2001 From: William Martin Date: Fri, 25 Apr 2025 22:14:34 +0200 Subject: [PATCH 1/2] Split PR review creation, commenting, submission and deletion --- e2e/e2e_test.go | 544 ++++++++++ go.mod | 3 + go.sum | 6 + internal/ghmcp/server.go | 12 + pkg/github/pullrequests.go | 989 ++++++++++++------ pkg/github/pullrequests_test.go | 568 ---------- pkg/github/server.go | 59 +- pkg/github/server_test.go | 126 ++- pkg/github/tools.go | 13 +- third-party-licenses.darwin.md | 3 + third-party-licenses.linux.md | 3 + third-party-licenses.windows.md | 3 + .../github.com/shurcooL/githubv4/LICENSE | 21 + .../github.com/shurcooL/graphql/LICENSE | 21 + third-party/golang.org/x/oauth2/LICENSE | 27 + 15 files changed, 1477 insertions(+), 921 deletions(-) create mode 100644 third-party/github.com/shurcooL/githubv4/LICENSE create mode 100644 third-party/github.com/shurcooL/graphql/LICENSE create mode 100644 third-party/golang.org/x/oauth2/LICENSE diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 489681e9..c609c2c5 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -772,3 +772,547 @@ func TestDirectoryDeletion(t *testing.T) { require.Equal(t, "test-dir/test-file.txt", trimmedGetCommitText.Files[0].Filename, "expected filename to match") require.Equal(t, 1, trimmedGetCommitText.Files[0].Deletions, "expected one deletion") } + +func TestPullRequestAtomicCreateAndSubmit(t *testing.T) { + t.Parallel() + + mcpClient := setupMCPClient(t) + + ctx := context.Background() + + // First, who am I + getMeRequest := mcp.CallToolRequest{} + getMeRequest.Params.Name = "get_me" + + t.Log("Getting current user...") + resp, err := mcpClient.CallTool(ctx, getMeRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + require.False(t, resp.IsError, "expected result not to be an error") + require.Len(t, resp.Content, 1, "expected content to have one item") + + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedGetMeText struct { + Login string `json:"login"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText) + require.NoError(t, err, "expected to unmarshal text content successfully") + + currentOwner := trimmedGetMeText.Login + + // Then create a repository with a README (via autoInit) + repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli()) + createRepoRequest := mcp.CallToolRequest{} + createRepoRequest.Params.Name = "create_repository" + createRepoRequest.Params.Arguments = map[string]any{ + "name": repoName, + "private": true, + "autoInit": true, + } + + t.Logf("Creating repository %s/%s...", currentOwner, repoName) + _, err = mcpClient.CallTool(ctx, createRepoRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Cleanup the repository after the test + t.Cleanup(func() { + // MCP Server doesn't support deletions, but we can use the GitHub Client + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) + t.Logf("Deleting repository %s/%s...", currentOwner, repoName) + _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName) + require.NoError(t, err, "expected to delete repository successfully") + }) + + // Create a branch on which to create a new commit + createBranchRequest := mcp.CallToolRequest{} + createBranchRequest.Params.Name = "create_branch" + createBranchRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "branch": "test-branch", + "from_branch": "main", + } + + t.Logf("Creating branch in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createBranchRequest) + require.NoError(t, err, "expected to call 'create_branch' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create a commit with a new file + commitRequest := mcp.CallToolRequest{} + commitRequest.Params.Name = "create_or_update_file" + commitRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "path": "test-file.txt", + "content": fmt.Sprintf("Created by e2e test %s", t.Name()), + "message": "Add test file", + "branch": "test-branch", + } + + t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, commitRequest) + require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedCommitText struct { + SHA string `json:"sha"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText) + require.NoError(t, err, "expected to unmarshal text content successfully") + commitId := trimmedCommitText.SHA + + // Create a pull request + prRequest := mcp.CallToolRequest{} + prRequest.Params.Name = "create_pull_request" + prRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "title": "Test PR", + "body": "This is a test PR", + "head": "test-branch", + "base": "main", + "commitId": commitId, + } + + t.Logf("Creating pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, prRequest) + require.NoError(t, err, "expected to call 'create_pull_request' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create and submit a review + createAndSubmitReviewRequest := mcp.CallToolRequest{} + createAndSubmitReviewRequest.Params.Name = "create_and_submit_pull_request_review" + createAndSubmitReviewRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + "event": "COMMENT", // the only event we can use as the creator of the PR + "body": "Looks good if you like bad code I guess!", + "commitId": commitId, + } + + t.Logf("Creating and submitting review for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createAndSubmitReviewRequest) + require.NoError(t, err, "expected to call 'create_and_submit_pull_request_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Finally, get the list of reviews and see that our review has been submitted + getPullRequestsReview := mcp.CallToolRequest{} + getPullRequestsReview.Params.Name = "get_pull_request_reviews" + getPullRequestsReview.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, getPullRequestsReview) + require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var reviews []struct { + State string `json:"state"` + } + err = json.Unmarshal([]byte(textContent.Text), &reviews) + require.NoError(t, err, "expected to unmarshal text content successfully") + + // Check that there is one review + require.Len(t, reviews, 1, "expected to find one review") + require.Equal(t, "COMMENTED", reviews[0].State, "expected review state to be COMMENTED") +} + +func TestPullRequestReviewCommentSubmit(t *testing.T) { + t.Parallel() + + mcpClient := setupMCPClient(t) + + ctx := context.Background() + + // First, who am I + getMeRequest := mcp.CallToolRequest{} + getMeRequest.Params.Name = "get_me" + + t.Log("Getting current user...") + resp, err := mcpClient.CallTool(ctx, getMeRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + require.False(t, resp.IsError, "expected result not to be an error") + require.Len(t, resp.Content, 1, "expected content to have one item") + + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedGetMeText struct { + Login string `json:"login"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText) + require.NoError(t, err, "expected to unmarshal text content successfully") + + currentOwner := trimmedGetMeText.Login + + // Then create a repository with a README (via autoInit) + repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli()) + createRepoRequest := mcp.CallToolRequest{} + createRepoRequest.Params.Name = "create_repository" + createRepoRequest.Params.Arguments = map[string]any{ + "name": repoName, + "private": true, + "autoInit": true, + } + + t.Logf("Creating repository %s/%s...", currentOwner, repoName) + _, err = mcpClient.CallTool(ctx, createRepoRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Cleanup the repository after the test + t.Cleanup(func() { + // MCP Server doesn't support deletions, but we can use the GitHub Client + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) + t.Logf("Deleting repository %s/%s...", currentOwner, repoName) + _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName) + require.NoError(t, err, "expected to delete repository successfully") + }) + + // Create a branch on which to create a new commit + createBranchRequest := mcp.CallToolRequest{} + createBranchRequest.Params.Name = "create_branch" + createBranchRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "branch": "test-branch", + "from_branch": "main", + } + + t.Logf("Creating branch in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createBranchRequest) + require.NoError(t, err, "expected to call 'create_branch' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create a commit with a new file + commitRequest := mcp.CallToolRequest{} + commitRequest.Params.Name = "create_or_update_file" + commitRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "path": "test-file.txt", + "content": fmt.Sprintf("Created by e2e test %s", t.Name()), + "message": "Add test file", + "branch": "test-branch", + } + + t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, commitRequest) + require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedCommitText struct { + SHA string `json:"sha"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText) + require.NoError(t, err, "expected to unmarshal text content successfully") + commitId := trimmedCommitText.SHA + + // Create a pull request + prRequest := mcp.CallToolRequest{} + prRequest.Params.Name = "create_pull_request" + prRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "title": "Test PR", + "body": "This is a test PR", + "head": "test-branch", + "base": "main", + "commitId": commitId, + } + + t.Logf("Creating pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, prRequest) + require.NoError(t, err, "expected to call 'create_pull_request' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create a review for the pull request, but we can't approve it + // because the current owner also owns the PR. + createPendingPullRequestReviewRequest := mcp.CallToolRequest{} + createPendingPullRequestReviewRequest.Params.Name = "create_pending_pull_request_review" + createPendingPullRequestReviewRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Creating pending review for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createPendingPullRequestReviewRequest) + require.NoError(t, err, "expected to call 'create_pending_pull_request_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + require.Equal(t, "", textContent.Text, "expected content to be empty") + + // Add a review comment + addReviewCommentRequest := mcp.CallToolRequest{} + addReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review" + addReviewCommentRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + "path": "test-file.txt", + "subjectType": "LINE", + "body": "Very nice!", + "line": 1, + } + + t.Logf("Adding review comment to pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, addReviewCommentRequest) + require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Submit the review + submitReviewRequest := mcp.CallToolRequest{} + submitReviewRequest.Params.Name = "submit_pending_pull_request_review" + submitReviewRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + "event": "COMMENT", // the only event we can use as the creator of the PR + "body": "Looks good if you like bad code I guess!", + } + + t.Logf("Submitting review for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, submitReviewRequest) + require.NoError(t, err, "expected to call 'submit_pending_pull_request_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Finally, get the review and see that it has been created + getPullRequestsReview := mcp.CallToolRequest{} + getPullRequestsReview.Params.Name = "get_pull_request_reviews" + getPullRequestsReview.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, getPullRequestsReview) + require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var reviews []struct { + State string `json:"state"` + } + err = json.Unmarshal([]byte(textContent.Text), &reviews) + require.NoError(t, err, "expected to unmarshal text content successfully") + + // Check that there is one review + require.Len(t, reviews, 1, "expected to find one review") + require.Equal(t, "COMMENTED", reviews[0].State, "expected review state to be COMMENTED") +} + +func TestPullRequestReviewDeletion(t *testing.T) { + t.Parallel() + + mcpClient := setupMCPClient(t) + + ctx := context.Background() + + // First, who am I + getMeRequest := mcp.CallToolRequest{} + getMeRequest.Params.Name = "get_me" + + t.Log("Getting current user...") + resp, err := mcpClient.CallTool(ctx, getMeRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + require.False(t, resp.IsError, "expected result not to be an error") + require.Len(t, resp.Content, 1, "expected content to have one item") + + textContent, ok := resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedGetMeText struct { + Login string `json:"login"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedGetMeText) + require.NoError(t, err, "expected to unmarshal text content successfully") + + currentOwner := trimmedGetMeText.Login + + // Then create a repository with a README (via autoInit) + repoName := fmt.Sprintf("github-mcp-server-e2e-%s-%d", t.Name(), time.Now().UnixMilli()) + createRepoRequest := mcp.CallToolRequest{} + createRepoRequest.Params.Name = "create_repository" + createRepoRequest.Params.Arguments = map[string]any{ + "name": repoName, + "private": true, + "autoInit": true, + } + + t.Logf("Creating repository %s/%s...", currentOwner, repoName) + _, err = mcpClient.CallTool(ctx, createRepoRequest) + require.NoError(t, err, "expected to call 'get_me' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Cleanup the repository after the test + t.Cleanup(func() { + // MCP Server doesn't support deletions, but we can use the GitHub Client + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) + t.Logf("Deleting repository %s/%s...", currentOwner, repoName) + _, err := ghClient.Repositories.Delete(context.Background(), currentOwner, repoName) + require.NoError(t, err, "expected to delete repository successfully") + }) + + // Create a branch on which to create a new commit + createBranchRequest := mcp.CallToolRequest{} + createBranchRequest.Params.Name = "create_branch" + createBranchRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "branch": "test-branch", + "from_branch": "main", + } + + t.Logf("Creating branch in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createBranchRequest) + require.NoError(t, err, "expected to call 'create_branch' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create a commit with a new file + commitRequest := mcp.CallToolRequest{} + commitRequest.Params.Name = "create_or_update_file" + commitRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "path": "test-file.txt", + "content": fmt.Sprintf("Created by e2e test %s", t.Name()), + "message": "Add test file", + "branch": "test-branch", + } + + t.Logf("Creating commit with new file in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, commitRequest) + require.NoError(t, err, "expected to call 'create_or_update_file' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var trimmedCommitText struct { + SHA string `json:"sha"` + } + err = json.Unmarshal([]byte(textContent.Text), &trimmedCommitText) + require.NoError(t, err, "expected to unmarshal text content successfully") + commitId := trimmedCommitText.SHA + + // Create a pull request + prRequest := mcp.CallToolRequest{} + prRequest.Params.Name = "create_pull_request" + prRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "title": "Test PR", + "body": "This is a test PR", + "head": "test-branch", + "base": "main", + "commitId": commitId, + } + + t.Logf("Creating pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, prRequest) + require.NoError(t, err, "expected to call 'create_pull_request' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Create a review for the pull request, but we can't approve it + // because the current owner also owns the PR. + createPendingPullRequestReviewRequest := mcp.CallToolRequest{} + createPendingPullRequestReviewRequest.Params.Name = "create_pending_pull_request_review" + createPendingPullRequestReviewRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Creating pending review for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, createPendingPullRequestReviewRequest) + require.NoError(t, err, "expected to call 'create_pending_pull_request_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + require.Equal(t, "", textContent.Text, "expected content to be empty") + + // See that there is a pending review + getPullRequestsReview := mcp.CallToolRequest{} + getPullRequestsReview.Params.Name = "get_pull_request_reviews" + getPullRequestsReview.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, getPullRequestsReview) + require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var reviews []struct { + State string `json:"state"` + } + err = json.Unmarshal([]byte(textContent.Text), &reviews) + require.NoError(t, err, "expected to unmarshal text content successfully") + + // Check that there is one review + require.Len(t, reviews, 1, "expected to find one review") + require.Equal(t, "PENDING", reviews[0].State, "expected review state to be PENDING") + + // Delete the review + deleteReviewRequest := mcp.CallToolRequest{} + deleteReviewRequest.Params.Name = "delete_pending_pull_request_review" + deleteReviewRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + } + + t.Logf("Deleting review for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, deleteReviewRequest) + require.NoError(t, err, "expected to call 'delete_pending_pull_request_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // See that there are no reviews + t.Logf("Getting reviews for pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, getPullRequestsReview) + require.NoError(t, err, "expected to call 'get_pull_request_reviews' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + textContent, ok = resp.Content[0].(mcp.TextContent) + require.True(t, ok, "expected content to be of type TextContent") + + var noReviews []struct{} + err = json.Unmarshal([]byte(textContent.Text), &noReviews) + require.NoError(t, err, "expected to unmarshal text content successfully") + require.Len(t, noReviews, 0, "expected to find no reviews") + +} diff --git a/go.mod b/go.mod index 7b850829..1505656b 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,8 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/sagikazarmark/locafero v0.9.0 // indirect + github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 + github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.14.0 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -32,6 +34,7 @@ require ( github.com/subosito/gotenv v1.6.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/oauth2 v0.29.0 golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/go.sum b/go.sum index 8b960ad5..411dd957 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,10 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= +github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkvclm+pWm1Lk4YrREb4IOIb/YdFO0p2M= +github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= +github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= +github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466/go.mod h1:9dIRpgIY7hVhoqfe0/FcYp0bpInZaT7dc3BYOprrIUE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -69,6 +73,8 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/oauth2 v0.29.0 h1:WdYw2tdTK1S8olAzWHdgeqfy+Mtm9XNhv/xJsY65d98= +golang.org/x/oauth2 v0.29.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 3434d9cd..eb305f33 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -15,7 +15,9 @@ import ( gogithub "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/shurcooL/githubv4" "github.com/sirupsen/logrus" + "golang.org/x/oauth2" ) type MCPServerConfig struct { @@ -86,11 +88,21 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { return ghClient, nil // closing over client } + getGQLClient := func(_ context.Context) (*githubv4.Client, error) { + // TODO: Enterprise support + src := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: cfg.Token}, + ) + httpClient := oauth2.NewClient(context.Background(), src) + return githubv4.NewClient(httpClient), nil + } + // Create default toolsets toolsets, err := github.InitToolsets( enabledToolsets, cfg.ReadOnly, getClient, + getGQLClient, cfg.Translator, ) if err != nil { diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index f4470b7b..b04abaf3 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -11,10 +11,11 @@ import ( "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "github.com/shurcooL/githubv4" ) // GetPullRequest creates a tool to get details of a specific pull request. -func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DESCRIPTION", "Get details of a specific pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -75,8 +76,123 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) } } +// CreatePullRequest creates a tool to create a new pull request. +func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("create_pull_request", + mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_CREATE_PULL_REQUEST_USER_TITLE", "Open new pull request"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithString("title", + mcp.Required(), + mcp.Description("PR title"), + ), + mcp.WithString("body", + mcp.Description("PR description"), + ), + mcp.WithString("head", + mcp.Required(), + mcp.Description("Branch containing changes"), + ), + mcp.WithString("base", + mcp.Required(), + mcp.Description("Branch to merge into"), + ), + mcp.WithBoolean("draft", + mcp.Description("Create as draft PR"), + ), + mcp.WithBoolean("maintainer_can_modify", + mcp.Description("Allow maintainer edits"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + title, err := requiredParam[string](request, "title") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + head, err := requiredParam[string](request, "head") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + base, err := requiredParam[string](request, "base") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + body, err := OptionalParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + draft, err := OptionalParam[bool](request, "draft") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + newPR := &github.NewPullRequest{ + Title: github.Ptr(title), + Head: github.Ptr(head), + Base: github.Ptr(base), + } + + if body != "" { + newPR.Body = github.Ptr(body) + } + + newPR.Draft = github.Ptr(draft) + newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + if err != nil { + return nil, fmt.Errorf("failed to create pull request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil + } + + r, err := json.Marshal(pr) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -197,7 +313,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("list_pull_requests", mcp.WithDescription(t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List pull requests in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -306,7 +422,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("merge_pull_request", mcp.WithDescription(t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request in a GitHub repository.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -395,7 +511,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun } // GetPullRequestFiles creates a tool to get the list of files changed in a pull request. -func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_files", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_FILES_DESCRIPTION", "Get the files changed in a specific pull request.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -458,7 +574,7 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper } // GetPullRequestStatus creates a tool to get the combined status of all status checks for a pull request. -func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_status", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_STATUS_DESCRIPTION", "Get the status of a specific pull request.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -535,7 +651,7 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request_branch", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update the branch of a pull request with the latest changes from the base branch.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -613,7 +729,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe } // GetPullRequestComments creates a tool to get the review comments on a pull request. -func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_comments", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_COMMENTS_DESCRIPTION", "Get comments for a specific pull request.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ @@ -680,13 +796,13 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel } } -// AddPullRequestReviewComment creates a tool to add a review comment to a pull request. -func AddPullRequestReviewComment(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("add_pull_request_review_comment", - mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_DESCRIPTION", "Add a review comment to a pull request.")), +// GetPullRequestReviews creates a tool to get the reviews on a pull request. +func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("get_pull_request_reviews", + mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get reviews for a specific pull request.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_USER_TITLE", "Add review comment to pull request"), - ReadOnlyHint: toBoolPtr(false), + Title: t("TOOL_GET_PULL_REQUEST_REVIEWS_USER_TITLE", "Get pull request reviews"), + ReadOnlyHint: toBoolPtr(true), }), mcp.WithString("owner", mcp.Required(), @@ -696,41 +812,10 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati mcp.Required(), mcp.Description("Repository name"), ), - mcp.WithNumber("pull_number", + mcp.WithNumber("pullNumber", mcp.Required(), mcp.Description("Pull request number"), ), - mcp.WithString("body", - mcp.Required(), - mcp.Description("The text of the review comment"), - ), - mcp.WithString("commit_id", - mcp.Description("The SHA of the commit to comment on. Required unless in_reply_to is specified."), - ), - mcp.WithString("path", - mcp.Description("The relative path to the file that necessitates a comment. Required unless in_reply_to is specified."), - ), - mcp.WithString("subject_type", - mcp.Description("The level at which the comment is targeted"), - mcp.Enum("line", "file"), - ), - mcp.WithNumber("line", - mcp.Description("The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range"), - ), - mcp.WithString("side", - mcp.Description("The side of the diff to comment on"), - mcp.Enum("LEFT", "RIGHT"), - ), - mcp.WithNumber("start_line", - mcp.Description("For multi-line comments, the first line of the range that the comment applies to"), - ), - mcp.WithString("start_side", - mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to"), - mcp.Enum("LEFT", "RIGHT"), - ), - mcp.WithNumber("in_reply_to", - mcp.Description("The ID of the review comment to reply to. When specified, only body is required and all other parameters are ignored"), - ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") @@ -741,11 +826,7 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := RequiredInt(request, "pull_number") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - body, err := requiredParam[string](request, "body") + pullNumber, err := RequiredInt(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -754,114 +835,157 @@ func AddPullRequestReviewComment(getClient GetClientFn, t translations.Translati if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } + reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) + if err != nil { + return nil, fmt.Errorf("failed to get pull request reviews: %w", err) + } + defer func() { _ = resp.Body.Close() }() - // Check if this is a reply to an existing comment - if replyToFloat, ok := request.Params.Arguments["in_reply_to"].(float64); ok { - // Use the specialized method for reply comments due to inconsistency in underlying go-github library: https://github.com/google/go-github/pull/950 - commentID := int64(replyToFloat) - createdReply, resp, err := client.PullRequests.CreateCommentInReplyTo(ctx, owner, repo, pullNumber, body, commentID) + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to reply to pull request comment: %w", err) + return nil, fmt.Errorf("failed to read response body: %w", err) } - defer func() { _ = resp.Body.Close() }() + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil + } - if resp.StatusCode != http.StatusCreated { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to reply to pull request comment: %s", string(respBody))), nil - } + r, err := json.Marshal(reviews) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } - r, err := json.Marshal(createdReply) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) - } + return mcp.NewToolResultText(string(r)), nil + } +} - return mcp.NewToolResultText(string(r)), nil +func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("create_and_submit_pull_request_review", + mcp.WithDescription(t("TOOL_CREATE_AND_SUBMIT_PULL_REQUEST_REVIEW_DESCRIPTION", "Create and submit a review for a pull request without review comments.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_CREATE_AND_SUBMIT_PULL_REQUEST_REVIEW_USER_TITLE", "Create and submit a pull request review without comments"), + ReadOnlyHint: toBoolPtr(false), + }), + // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up. + // Since our other Pull Request tools are working with the REST Client, will handle the lookup + // internally for now. + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number"), + ), + mcp.WithString("body", + mcp.Required(), + mcp.Description("Review comment text"), + ), + mcp.WithString("event", + mcp.Required(), + mcp.Description("Review action to perform"), + mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"), + ), + mcp.WithString("commitId", + mcp.Description("SHA of commit to review"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // This is a new comment, not a reply - // Verify required parameters for a new comment - commitID, err := requiredParam[string](request, "commit_id") + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - path, err := requiredParam[string](request, "path") + + pullNumber, err := RequiredInt32Param(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - comment := &github.PullRequestComment{ - Body: github.Ptr(body), - CommitID: github.Ptr(commitID), - Path: github.Ptr(path), + body, err := requiredParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - subjectType, err := OptionalParam[string](request, "subject_type") + event, err := requiredParam[string](request, "event") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - if subjectType != "file" { - line, lineExists := request.Params.Arguments["line"].(float64) - startLine, startLineExists := request.Params.Arguments["start_line"].(float64) - side, sideExists := request.Params.Arguments["side"].(string) - startSide, startSideExists := request.Params.Arguments["start_side"].(string) - if !lineExists { - return mcp.NewToolResultError("line parameter is required unless using subject_type:file"), nil - } - - comment.Line = github.Ptr(int(line)) - if sideExists { - comment.Side = github.Ptr(side) - } - if startLineExists { - comment.StartLine = github.Ptr(int(startLine)) - } - if startSideExists { - comment.StartSide = github.Ptr(startSide) - } - - if startLineExists && !lineExists { - return mcp.NewToolResultError("if start_line is provided, line must also be provided"), nil - } - if startSideExists && !sideExists { - return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil - } + commitID, err := OptionalParam[string](request, "commitId") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - createdComment, resp, err := client.PullRequests.CreateComment(ctx, owner, repo, pullNumber, comment) + // Given our owner, repo and PR number, lookup the GQL ID of the PR. + client, err := getGQLClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to create pull request comment: %w", err) + return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request comment: %s", string(respBody))), nil + var getPullRequestQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + if err := client.Query(ctx, &getPullRequestQuery, map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), + }); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - r, err := json.Marshal(createdComment) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + // Now we have the GQL ID, we can create a review + var addPullRequestReviewMutation struct { + AddPullRequestReview struct { + PullRequestReview struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"addPullRequestReview(input: $input)"` } - return mcp.NewToolResultText(string(r)), nil + if err := client.Mutate( + ctx, + &addPullRequestReviewMutation, + githubv4.AddPullRequestReviewInput{ + PullRequestID: getPullRequestQuery.Repository.PullRequest.ID, + Body: newGQLStringlike[githubv4.String](body), + Event: newGQLStringlike[githubv4.PullRequestReviewEvent](event), + CommitOID: newGQLStringlike[githubv4.GitObjectID](commitID), + }, + nil, + ); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Return nothing, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return mcp.NewToolResultText(""), nil } } -// GetPullRequestReviews creates a tool to get the reviews on a pull request. -func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("get_pull_request_reviews", - mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get reviews for a specific pull request.")), +// CreatePendingPullRequestReview creates a tool to create a pending review on a pull request. +func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("create_pending_pull_request_review", + mcp.WithDescription(t("TOOL_CREATE_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a pending review for a pull request. Call this first before attempting to add comments to a pending review, and ultimately submitting it. A pending pull request review means a pull request review, it is pending because you create it first and submit it later, and the PR author will not see it until it is submitted.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: t("TOOL_GET_PULL_REQUEST_REVIEWS_USER_TITLE", "Get pull request reviews"), - ReadOnlyHint: toBoolPtr(true), + Title: t("TOOL_CREATE_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Create pending pull request review"), + ReadOnlyHint: toBoolPtr(false), }), + // Either we need the PR GQL Id directly, or we need owner, repo and PR number to look it up. + // Since our other Pull Request tools are working with the REST Client, will handle the lookup + // internally for now. mcp.WithString("owner", mcp.Required(), mcp.Description("Repository owner"), @@ -874,56 +998,99 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp mcp.Required(), mcp.Description("Pull request number"), ), + mcp.WithString("commitID", + mcp.Description("SHA of commit to review"), + ), + // Event is omitted here because we always want to create a pending review. + // Threads are omitted for the moment, and we'll see if the LLM can use the appropriate tool. ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { owner, err := requiredParam[string](request, "owner") if err != nil { return mcp.NewToolResultError(err.Error()), nil } + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := RequiredInt(request, "pullNumber") + + pullNumber, err := RequiredInt32Param(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + commitID, err := OptionalParam[string](request, "commitID") if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return mcp.NewToolResultError(err.Error()), nil } - reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) + + // Given our owner, repo and PR number, lookup the GQL ID of the PR. + client, err := getGQLClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to get pull request reviews: %w", err) + return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request reviews: %s", string(body))), nil + var getPullRequestQuery struct { + Repository struct { + PullRequest struct { + ID githubv4.ID + } `graphql:"pullRequest(number: $prNum)"` + } `graphql:"repository(owner: $owner, name: $repo)"` + } + if err := client.Query(ctx, &getPullRequestQuery, map[string]any{ + "owner": githubv4.String(owner), + "repo": githubv4.String(repo), + "prNum": githubv4.Int(pullNumber), + }); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - r, err := json.Marshal(reviews) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + // Now we have the GQL ID, we can create a pending review + var addPullRequestReviewMutation struct { + AddPullRequestReview struct { + PullRequestReview struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"addPullRequestReview(input: $input)"` } - return mcp.NewToolResultText(string(r)), nil + if err := client.Mutate( + ctx, + &addPullRequestReviewMutation, + githubv4.AddPullRequestReviewInput{ + PullRequestID: getPullRequestQuery.Repository.PullRequest.ID, + CommitOID: newGQLStringlike[githubv4.GitObjectID](commitID), + }, + nil, + ); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Return nothing, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return mcp.NewToolResultText(""), nil } } -// CreatePullRequestReview creates a tool to submit a review on a pull request. -func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_pull_request_review", - mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a review for a pull request.")), +// AddPullRequestReviewCommentToPendingReview creates a tool to add a comment to a pull request review. +func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("add_pull_request_review_comment_to_pending_review", + mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add a comment to the requester's latest pending pull request review, a pending review needs to already exist to call this (check with the user if not sure).")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: t("TOOL_CREATE_PULL_REQUEST_REVIEW_USER_TITLE", "Submit pull request review"), + Title: t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_USER_TITLE", "Add comment to the requester's latest pending pull request review"), ReadOnlyHint: toBoolPtr(false), }), + // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to + // add a new tool to get that ID for clients that aren't in the same context as the original pending review + // creation. So for now, we'll just accept the owner, repo and pull number and assume this is adding a comment + // the latest review from a user, since only one can be active at a time. It can later be extended with + // a pullRequestReviewID parameter if targeting other reviews is desired: + // mcp.WithString("pullRequestReviewID", + // mcp.Required(), + // mcp.Description("The ID of the pull request review to add a comment to"), + // ), mcp.WithString("owner", mcp.Required(), mcp.Description("Repository owner"), @@ -936,71 +1103,32 @@ func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHe mcp.Required(), mcp.Description("Pull request number"), ), + mcp.WithString("path", + mcp.Required(), + mcp.Description("The relative path to the file that necessitates a comment"), + ), mcp.WithString("body", - mcp.Description("Review comment text"), + mcp.Required(), + mcp.Description("The text of the review comment"), ), - mcp.WithString("event", + mcp.WithString("subjectType", mcp.Required(), - mcp.Description("Review action to perform"), - mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"), + mcp.Description("The level at which the comment is targeted"), + mcp.Enum("FILE", "LINE"), ), - mcp.WithString("commitId", - mcp.Description("SHA of commit to review"), + mcp.WithNumber("line", + mcp.Description("The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range"), + ), + mcp.WithString("side", + mcp.Description("The side of the diff to comment on"), + mcp.Enum("LEFT", "RIGHT"), ), - mcp.WithArray("comments", - mcp.Items( - map[string]interface{}{ - "type": "object", - "additionalProperties": false, - "required": []string{"path", "body", "position", "line", "side", "start_line", "start_side"}, - "properties": map[string]interface{}{ - "path": map[string]interface{}{ - "type": "string", - "description": "path to the file", - }, - "position": map[string]interface{}{ - "anyOf": []interface{}{ - map[string]string{"type": "number"}, - map[string]string{"type": "null"}, - }, - "description": "position of the comment in the diff", - }, - "line": map[string]interface{}{ - "anyOf": []interface{}{ - map[string]string{"type": "number"}, - map[string]string{"type": "null"}, - }, - "description": "line number in the file to comment on. For multi-line comments, the end of the line range", - }, - "side": map[string]interface{}{ - "anyOf": []interface{}{ - map[string]string{"type": "string"}, - map[string]string{"type": "null"}, - }, - "description": "The side of the diff on which the line resides. For multi-line comments, this is the side for the end of the line range. (LEFT or RIGHT)", - }, - "start_line": map[string]interface{}{ - "anyOf": []interface{}{ - map[string]string{"type": "number"}, - map[string]string{"type": "null"}, - }, - "description": "The first line of the range to which the comment refers. Required for multi-line comments.", - }, - "start_side": map[string]interface{}{ - "anyOf": []interface{}{ - map[string]string{"type": "string"}, - map[string]string{"type": "null"}, - }, - "description": "The side of the diff on which the start line resides for multi-line comments. (LEFT or RIGHT)", - }, - "body": map[string]interface{}{ - "type": "string", - "description": "comment body", - }, - }, - }, - ), - mcp.Description("Line-specific comments array of objects to place comments on pull request changes. Requires path and body. For line comments use line or position. For multi-line comments use start_line and line with optional side parameters."), + mcp.WithNumber("startLine", + mcp.Description("For multi-line comments, the first line of the range that the comment applies to"), + ), + mcp.WithString("startSide", + mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to"), + mcp.Enum("LEFT", "RIGHT"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -1008,138 +1136,151 @@ func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHe if err != nil { return mcp.NewToolResultError(err.Error()), nil } + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - pullNumber, err := RequiredInt(request, "pullNumber") + + pullNumber, err := RequiredInt32Param(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - event, err := requiredParam[string](request, "event") + + path, err := requiredParam[string](request, "path") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - // Create review request - reviewRequest := &github.PullRequestReviewRequest{ - Event: github.Ptr(event), + body, err := requiredParam[string](request, "body") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - // Add body if provided - body, err := OptionalParam[string](request, "body") + subjectType, err := requiredParam[string](request, "subjectType") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - if body != "" { - reviewRequest.Body = github.Ptr(body) - } - // Add commit ID if provided - commitID, err := OptionalParam[string](request, "commitId") + line, err := OptionalInt32Param(request, "line") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - if commitID != "" { - reviewRequest.CommitID = github.Ptr(commitID) - } - // Add comments if provided - if commentsObj, ok := request.Params.Arguments["comments"].([]interface{}); ok && len(commentsObj) > 0 { - comments := []*github.DraftReviewComment{} - - for _, c := range commentsObj { - commentMap, ok := c.(map[string]interface{}) - if !ok { - return mcp.NewToolResultError("each comment must be an object with path and body"), nil - } + side, err := OptionalParam[string](request, "side") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - path, ok := commentMap["path"].(string) - if !ok || path == "" { - return mcp.NewToolResultError("each comment must have a path"), nil - } + startLine, err := OptionalInt32Param(request, "startLine") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - body, ok := commentMap["body"].(string) - if !ok || body == "" { - return mcp.NewToolResultError("each comment must have a body"), nil - } + startSide, err := OptionalParam[string](request, "startSide") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - _, hasPosition := commentMap["position"].(float64) - _, hasLine := commentMap["line"].(float64) - _, hasSide := commentMap["side"].(string) - _, hasStartLine := commentMap["start_line"].(float64) - _, hasStartSide := commentMap["start_side"].(string) - - switch { - case !hasPosition && !hasLine: - return mcp.NewToolResultError("each comment must have either position or line"), nil - case hasPosition && (hasLine || hasSide || hasStartLine || hasStartSide): - return mcp.NewToolResultError("position cannot be combined with line, side, start_line, or start_side"), nil - case hasStartSide && !hasSide: - return mcp.NewToolResultError("if start_side is provided, side must also be provided"), nil - } + client, err := getGQLClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) + } - comment := &github.DraftReviewComment{ - Path: github.Ptr(path), - Body: github.Ptr(body), - } + // First we'll get the current user + var getViewerQuery struct { + Viewer struct { + Login githubv4.String + } + } - if positionFloat, ok := commentMap["position"].(float64); ok { - comment.Position = github.Ptr(int(positionFloat)) - } else if lineFloat, ok := commentMap["line"].(float64); ok { - comment.Line = github.Ptr(int(lineFloat)) - } - if side, ok := commentMap["side"].(string); ok { - comment.Side = github.Ptr(side) - } - if startLineFloat, ok := commentMap["start_line"].(float64); ok { - comment.StartLine = github.Ptr(int(startLineFloat)) - } - if startSide, ok := commentMap["start_side"].(string); ok { - comment.StartSide = github.Ptr(startSide) - } + if err := client.Query(ctx, &getViewerQuery, nil); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - comments = append(comments, comment) - } + // Then let's get the ID of the review (but maybe we should just get the ID of the review itself: TODO) + var getLatestReviewForViewerQuery struct { + Repository struct { + PullRequest struct { + Reviews struct { + Nodes []struct { + ID githubv4.ID + State githubv4.PullRequestReviewState + URL githubv4.URI + } + } `graphql:"reviews(first: 1, author: $author)"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } - reviewRequest.Comments = comments + vars := map[string]interface{}{ + "author": githubv4.String(getViewerQuery.Viewer.Login), + "owner": githubv4.String(owner), + "name": githubv4.String(repo), + "number": githubv4.Int(pullNumber), } - client, err := getClient(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest) - if err != nil { - return nil, fmt.Errorf("failed to create pull request review: %w", err) + + // Validate there is one review and the state is pending + if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { + return mcp.NewToolResultError("No pending review found for the viewer"), nil } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request review: %s", string(body))), nil + review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] + if review.State != githubv4.PullRequestReviewStatePending { + errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) + return mcp.NewToolResultError(errText), nil } - r, err := json.Marshal(review) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + // Then we can create a new review thread comment on the review. + var addPullRequestReviewThreadMutation struct { + AddPullRequestReviewThread struct { + Thread struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"addPullRequestReviewThread(input: $input)"` + } + + if err := client.Mutate( + ctx, + &addPullRequestReviewThreadMutation, + githubv4.AddPullRequestReviewThreadInput{ + Path: githubv4.String(path), + Body: githubv4.String(body), + SubjectType: newGQLStringlike[githubv4.PullRequestReviewThreadSubjectType](subjectType), + Line: githubv4.NewInt(githubv4.Int(line)), + Side: newGQLStringlike[githubv4.DiffSide](side), + StartLine: githubv4.NewInt(githubv4.Int(startLine)), + StartSide: newGQLStringlike[githubv4.DiffSide](startSide), + PullRequestReviewID: &review.ID, + }, + nil, + ); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(string(r)), nil + // Return nothing, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return mcp.NewToolResultText(""), nil } } -// CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { - return mcp.NewTool("create_pull_request", - mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository.")), +// SubmitPendingPullRequestReview creates a tool to submit a pull request review. +func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("submit_pending_pull_request_review", + mcp.WithDescription(t("TOOL_SUBMIT_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Submit the requester's latest pending pull request review, normally this is a final step after creating a pending review, adding comments first, unless you know that the user already did the first two steps, you should check before calling this.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ - Title: t("TOOL_CREATE_PULL_REQUEST_USER_TITLE", "Open new pull request"), + Title: t("TOOL_SUBMIT_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Submit the requester's latest pending pull request review"), ReadOnlyHint: toBoolPtr(false), }), + // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to + // add a new tool to get that ID for clients that aren't in the same context as the original pending review + // creation. So for now, we'll just accept the owner, repo and pull number and assume this is submitting + // the latest review from a user, since only one can be active at a time. mcp.WithString("owner", mcp.Required(), mcp.Description("Repository owner"), @@ -1148,26 +1289,17 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu mcp.Required(), mcp.Description("Repository name"), ), - mcp.WithString("title", - mcp.Required(), - mcp.Description("PR title"), - ), - mcp.WithString("body", - mcp.Description("PR description"), - ), - mcp.WithString("head", + mcp.WithNumber("pullNumber", mcp.Required(), - mcp.Description("Branch containing changes"), + mcp.Description("Pull request number"), ), - mcp.WithString("base", + mcp.WithString("event", mcp.Required(), - mcp.Description("Branch to merge into"), - ), - mcp.WithBoolean("draft", - mcp.Description("Create as draft PR"), + mcp.Description("The event to perform"), + mcp.Enum("APPROVE", "REQUEST_CHANGES", "COMMENT"), ), - mcp.WithBoolean("maintainer_can_modify", - mcp.Description("Allow maintainer edits"), + mcp.WithString("body", + mcp.Description("The text of the review comment"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -1175,74 +1307,251 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu if err != nil { return mcp.NewToolResultError(err.Error()), nil } + repo, err := requiredParam[string](request, "repo") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - title, err := requiredParam[string](request, "title") + + pullNumber, err := RequiredInt32Param(request, "pullNumber") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - head, err := requiredParam[string](request, "head") + + event, err := requiredParam[string](request, "event") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - base, err := requiredParam[string](request, "base") + + body, err := OptionalParam[string](request, "body") if err != nil { return mcp.NewToolResultError(err.Error()), nil } - body, err := OptionalParam[string](request, "body") + client, err := getGQLClient(ctx) if err != nil { + return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) + } + + // First we'll get the current user + var getViewerQuery struct { + Viewer struct { + Login githubv4.String + } + } + + if err := client.Query(ctx, &getViewerQuery, nil); err != nil { return mcp.NewToolResultError(err.Error()), nil } - draft, err := OptionalParam[bool](request, "draft") - if err != nil { + // Then let's get the ID of the review (but maybe we should just get the ID of the review itself: TODO) + var getLatestReviewForViewerQuery struct { + Repository struct { + PullRequest struct { + Reviews struct { + Nodes []struct { + ID githubv4.ID + Author struct { + Login githubv4.String + } + State githubv4.PullRequestReviewState + SubmittedAt githubv4.DateTime + Body githubv4.String + URL githubv4.URI + } + } `graphql:"reviews(first: 1, author: $author)"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + vars := map[string]interface{}{ + "author": githubv4.String(getViewerQuery.Viewer.Login), + "owner": githubv4.String(owner), + "name": githubv4.String(repo), + "number": githubv4.Int(pullNumber), + } + + if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { return mcp.NewToolResultError(err.Error()), nil } - maintainerCanModify, err := OptionalParam[bool](request, "maintainer_can_modify") - if err != nil { + // Validate there is one review and the state is pending + if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { + return mcp.NewToolResultError("No pending review found for the viewer"), nil + } + + review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] + if review.State != githubv4.PullRequestReviewStatePending { + errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) + return mcp.NewToolResultError(errText), nil + } + + // Prepare the mutation + var submitPullRequestReviewMutation struct { + SubmitPullRequestReview struct { + PullRequestReview struct { + State githubv4.PullRequestReviewState + SubmittedAt githubv4.DateTime + } + } `graphql:"submitPullRequestReview(input: $input)"` + } + + if err := client.Mutate( + ctx, + &submitPullRequestReviewMutation, + githubv4.SubmitPullRequestReviewInput{ + PullRequestReviewID: &review.ID, + Event: githubv4.PullRequestReviewEvent(event), + Body: newGQLStringlike[githubv4.String](body), + }, + nil, + ); err != nil { return mcp.NewToolResultError(err.Error()), nil } - newPR := &github.NewPullRequest{ - Title: github.Ptr(title), - Head: github.Ptr(head), - Base: github.Ptr(base), + // Return the state and submitted at time of the review as a receipt for the LLM. + r, err := json.Marshal(submitPullRequestReviewMutation.SubmitPullRequestReview.PullRequestReview) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) } - if body != "" { - newPR.Body = github.Ptr(body) + return mcp.NewToolResultText(string(r)), nil + } +} + +func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("delete_pending_pull_request_review", + mcp.WithDescription(t("TOOL_DELETE_PENDING_PULL_REQUEST_REVIEW_DESCRIPTION", "Delete the requester's latest pending pull request review. Use this after the user decides not to submit a pending review, if you don't know if they already created one then check first.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_DELETE_PENDING_PULL_REQUEST_REVIEW_USER_TITLE", "Delete the requester's latest pending pull request review"), + ReadOnlyHint: toBoolPtr(false), + }), + // Ideally, for performance sake this would just accept the pullRequestReviewID. However, we would need to + // add a new tool to get that ID for clients that aren't in the same context as the original pending review + // creation. So for now, we'll just accept the owner, repo and pull number and assume this is deleting + // the latest pending review from a user, since only one can be active at a time. + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil } - newPR.Draft = github.Ptr(draft) - newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } - client, err := getClient(ctx) + pullNumber, err := RequiredInt32Param(request, "pullNumber") if err != nil { - return nil, fmt.Errorf("failed to get GitHub client: %w", err) + return mcp.NewToolResultError(err.Error()), nil } - pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) + + client, err := getGQLClient(ctx) if err != nil { - return nil, fmt.Errorf("failed to create pull request: %w", err) + return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } - defer func() { _ = resp.Body.Close() }() - if resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + // First we'll get the current user + var getViewerQuery struct { + Viewer struct { + Login githubv4.String } - return mcp.NewToolResultError(fmt.Sprintf("failed to create pull request: %s", string(body))), nil } - r, err := json.Marshal(pr) - if err != nil { - return nil, fmt.Errorf("failed to marshal response: %w", err) + if err := client.Query(ctx, &getViewerQuery, nil); err != nil { + return mcp.NewToolResultError(err.Error()), nil } - return mcp.NewToolResultText(string(r)), nil + // Then let's get the ID of the review (but maybe we should just get the ID of the review itself: TODO) + var getLatestReviewForViewerQuery struct { + Repository struct { + PullRequest struct { + Reviews struct { + Nodes []struct { + ID githubv4.ID + Author struct { + Login githubv4.String + } + State githubv4.PullRequestReviewState + SubmittedAt githubv4.DateTime + Body githubv4.String + URL githubv4.URI + } + } `graphql:"reviews(first: 1, author: $author)"` + } `graphql:"pullRequest(number: $number)"` + } `graphql:"repository(owner: $owner, name: $name)"` + } + + vars := map[string]interface{}{ + "author": githubv4.String(getViewerQuery.Viewer.Login), + "owner": githubv4.String(owner), + "name": githubv4.String(repo), + "number": githubv4.Int(pullNumber), + } + + if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Validate there is one review and the state is pending + if len(getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes) == 0 { + return mcp.NewToolResultError("No pending review found for the viewer"), nil + } + + review := getLatestReviewForViewerQuery.Repository.PullRequest.Reviews.Nodes[0] + if review.State != githubv4.PullRequestReviewStatePending { + errText := fmt.Sprintf("The latest review, found at %s is not pending", review.URL) + return mcp.NewToolResultError(errText), nil + } + + // Prepare the mutation + var deletePullRequestReviewMutation struct { + DeletePullRequestReview struct { + PullRequestReview struct { + ID githubv4.ID // We don't need this, but a selector is required or GQL complains. + } + } `graphql:"deletePullRequestReview(input: $input)"` + } + + if err := client.Mutate( + ctx, + &deletePullRequestReviewMutation, + githubv4.DeletePullRequestReviewInput{ + PullRequestReviewID: &review.ID, + }, + nil, + ); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Return nothing, just indicate success for the time being. + // In future, we may want to return the review ID, but for the moment, we're not leaking + // API implementation details to the LLM. + return mcp.NewToolResultText(""), nil } } + +// newGQLString like takes something that approximates a string (of which there are many types in shurcooL/githubv4) +// and constructs a pointer to it, or nil if the string is empty. This is extremely useful because when we parse +// params from the MCP request, we need to convert them to types that are pointers of type def strings and it's +// not possible to take a pointer of an anonymous value e.g. &githubv4.String("foo"). +func newGQLStringlike[T ~string](s string) *T { + if s == "" { + return nil + } + stringlike := T(s) + return &stringlike +} diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index bb372624..dad1c226 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -1192,377 +1192,6 @@ func Test_GetPullRequestReviews(t *testing.T) { } } -func Test_CreatePullRequestReview(t *testing.T) { - // Verify tool definition once - mockClient := github.NewClient(nil) - tool, _ := CreatePullRequestReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "create_pull_request_review", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pullNumber") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "event") - assert.Contains(t, tool.InputSchema.Properties, "commitId") - assert.Contains(t, tool.InputSchema.Properties, "comments") - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pullNumber", "event"}) - - // Setup mock review for success case - mockReview := &github.PullRequestReview{ - ID: github.Ptr(int64(301)), - State: github.Ptr("APPROVED"), - Body: github.Ptr("Looks good!"), - HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42#pullrequestreview-301"), - User: &github.User{ - Login: github.Ptr("reviewer"), - }, - CommitID: github.Ptr("abcdef123456"), - SubmittedAt: &github.Timestamp{Time: time.Now()}, - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedReview *github.PullRequestReview - expectedErrMsg string - }{ - { - name: "successful review creation with body only", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "body": "Looks good!", - "event": "APPROVE", - }).andThen( - mockResponse(t, http.StatusOK, mockReview), - ), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Looks good!", - "event": "APPROVE", - }, - expectError: false, - expectedReview: mockReview, - }, - { - name: "successful review creation with commitId", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "body": "Looks good!", - "event": "APPROVE", - "commit_id": "abcdef123456", - }).andThen( - mockResponse(t, http.StatusOK, mockReview), - ), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Looks good!", - "event": "APPROVE", - "commitId": "abcdef123456", - }, - expectError: false, - expectedReview: mockReview, - }, - { - name: "successful review creation with comments", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "body": "Some issues to fix", - "event": "REQUEST_CHANGES", - "comments": []interface{}{ - map[string]interface{}{ - "path": "file1.go", - "position": float64(10), - "body": "This needs to be fixed", - }, - map[string]interface{}{ - "path": "file2.go", - "position": float64(20), - "body": "Consider a different approach here", - }, - }, - }).andThen( - mockResponse(t, http.StatusOK, mockReview), - ), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Some issues to fix", - "event": "REQUEST_CHANGES", - "comments": []interface{}{ - map[string]interface{}{ - "path": "file1.go", - "position": float64(10), - "body": "This needs to be fixed", - }, - map[string]interface{}{ - "path": "file2.go", - "position": float64(20), - "body": "Consider a different approach here", - }, - }, - }, - expectError: false, - expectedReview: mockReview, - }, - { - name: "invalid comment format", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnprocessableEntity) - _, _ = w.Write([]byte(`{"message": "Invalid comment format"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "event": "REQUEST_CHANGES", - "comments": []interface{}{ - map[string]interface{}{ - "path": "file1.go", - // missing position - "body": "This needs to be fixed", - }, - }, - }, - expectError: false, - expectedErrMsg: "each comment must have either position or line", - }, - { - name: "successful review creation with line parameter", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "body": "Code review comments", - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "line": float64(42), - "body": "Consider adding a comment here", - }, - }, - }).andThen( - mockResponse(t, http.StatusOK, mockReview), - ), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Code review comments", - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "line": float64(42), - "body": "Consider adding a comment here", - }, - }, - }, - expectError: false, - expectedReview: mockReview, - }, - { - name: "successful review creation with multi-line comment", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - expectRequestBody(t, map[string]interface{}{ - "body": "Multi-line comment review", - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "start_line": float64(10), - "line": float64(15), - "side": "RIGHT", - "body": "This entire block needs refactoring", - }, - }, - }).andThen( - mockResponse(t, http.StatusOK, mockReview), - ), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Multi-line comment review", - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "start_line": float64(10), - "line": float64(15), - "side": "RIGHT", - "body": "This entire block needs refactoring", - }, - }, - }, - expectError: false, - expectedReview: mockReview, - }, - { - name: "invalid multi-line comment - missing line parameter", - mockedClient: mock.NewMockedHTTPClient(), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "start_line": float64(10), - // missing line parameter - "body": "Invalid multi-line comment", - }, - }, - }, - expectError: false, - expectedErrMsg: "each comment must have either position or line", // Updated error message - }, - { - name: "invalid comment - mixing position with line parameters", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatch( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - mockReview, - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "position": float64(5), - "line": float64(42), - "body": "Invalid parameter combination", - }, - }, - }, - expectError: false, - expectedErrMsg: "position cannot be combined with line, side, start_line, or start_side", - }, - { - name: "invalid multi-line comment - missing side parameter", - mockedClient: mock.NewMockedHTTPClient(), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "event": "COMMENT", - "comments": []interface{}{ - map[string]interface{}{ - "path": "main.go", - "start_line": float64(10), - "line": float64(15), - "start_side": "LEFT", - // missing side parameter - "body": "Invalid multi-line comment", - }, - }, - }, - expectError: false, - expectedErrMsg: "if start_side is provided, side must also be provided", - }, - { - name: "review creation fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsReviewsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnprocessableEntity) - _, _ = w.Write([]byte(`{"message": "Invalid comment format"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pullNumber": float64(42), - "body": "Looks good!", - "event": "APPROVE", - }, - expectError: true, - expectedErrMsg: "failed to create pull request review", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - // Setup client with mock - client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequestReview(stubGetClientFn(client), translations.NullTranslationHelper) - - // Create call request - request := createMCPRequest(tc.requestArgs) - - // Call handler - result, err := handler(context.Background(), request) - - // Verify results - if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) - return - } - - require.NoError(t, err) - - // For error messages in the result - if tc.expectedErrMsg != "" { - textContent := getTextResult(t, result) - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - return - } - - // Parse the result and get the text content if no error - textContent := getTextResult(t, result) - - // Unmarshal and verify the result - var returnedReview github.PullRequestReview - err = json.Unmarshal([]byte(textContent.Text), &returnedReview) - require.NoError(t, err) - assert.Equal(t, *tc.expectedReview.ID, *returnedReview.ID) - assert.Equal(t, *tc.expectedReview.State, *returnedReview.State) - assert.Equal(t, *tc.expectedReview.Body, *returnedReview.Body) - assert.Equal(t, *tc.expectedReview.User.Login, *returnedReview.User.Login) - assert.Equal(t, *tc.expectedReview.HTMLURL, *returnedReview.HTMLURL) - }) - } -} - func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) @@ -1719,200 +1348,3 @@ func Test_CreatePullRequest(t *testing.T) { }) } } - -func Test_AddPullRequestReviewComment(t *testing.T) { - mockClient := github.NewClient(nil) - tool, _ := AddPullRequestReviewComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - assert.Equal(t, "add_pull_request_review_comment", tool.Name) - assert.NotEmpty(t, tool.Description) - assert.Contains(t, tool.InputSchema.Properties, "owner") - assert.Contains(t, tool.InputSchema.Properties, "repo") - assert.Contains(t, tool.InputSchema.Properties, "pull_number") - assert.Contains(t, tool.InputSchema.Properties, "body") - assert.Contains(t, tool.InputSchema.Properties, "commit_id") - assert.Contains(t, tool.InputSchema.Properties, "path") - // Since we've updated commit_id and path to be optional when using in_reply_to - assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "pull_number", "body"}) - - mockComment := &github.PullRequestComment{ - ID: github.Ptr(int64(123)), - Body: github.Ptr("Great stuff!"), - Path: github.Ptr("file1.txt"), - Line: github.Ptr(2), - Side: github.Ptr("RIGHT"), - } - - mockReply := &github.PullRequestComment{ - ID: github.Ptr(int64(456)), - Body: github.Ptr("Good point, will fix!"), - } - - tests := []struct { - name string - mockedClient *http.Client - requestArgs map[string]interface{} - expectError bool - expectedComment *github.PullRequestComment - expectedErrMsg string - }{ - { - name: "successful line comment creation", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsCommentsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusCreated) - err := json.NewEncoder(w).Encode(mockComment) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pull_number": float64(1), - "body": "Great stuff!", - "commit_id": "6dcb09b5b57875f334f61aebed695e2e4193db5e", - "path": "file1.txt", - "line": float64(2), - "side": "RIGHT", - }, - expectError: false, - expectedComment: mockComment, - }, - { - name: "successful reply using in_reply_to", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsCommentsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusCreated) - err := json.NewEncoder(w).Encode(mockReply) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pull_number": float64(1), - "body": "Good point, will fix!", - "in_reply_to": float64(123), - }, - expectError: false, - expectedComment: mockReply, - }, - { - name: "comment creation fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsCommentsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusUnprocessableEntity) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"message": "Validation Failed"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pull_number": float64(1), - "body": "Great stuff!", - "commit_id": "6dcb09b5b57875f334f61aebed695e2e4193db5e", - "path": "file1.txt", - "line": float64(2), - }, - expectError: true, - expectedErrMsg: "failed to create pull request comment", - }, - { - name: "reply creation fails", - mockedClient: mock.NewMockedHTTPClient( - mock.WithRequestMatchHandler( - mock.PostReposPullsCommentsByOwnerByRepoByPullNumber, - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNotFound) - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"message": "Comment not found"}`)) - }), - ), - ), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pull_number": float64(1), - "body": "Good point, will fix!", - "in_reply_to": float64(999), - }, - expectError: true, - expectedErrMsg: "failed to reply to pull request comment", - }, - { - name: "missing required parameters for comment", - mockedClient: mock.NewMockedHTTPClient(), - requestArgs: map[string]interface{}{ - "owner": "owner", - "repo": "repo", - "pull_number": float64(1), - "body": "Great stuff!", - // missing commit_id and path - }, - expectError: false, - expectedErrMsg: "missing required parameter: commit_id", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockClient := github.NewClient(tc.mockedClient) - - _, handler := AddPullRequestReviewComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) - - request := createMCPRequest(tc.requestArgs) - - result, err := handler(context.Background(), request) - - if tc.expectError { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.expectedErrMsg) - return - } - - require.NoError(t, err) - assert.NotNil(t, result) - require.Len(t, result.Content, 1) - - textContent := getTextResult(t, result) - if tc.expectedErrMsg != "" { - assert.Contains(t, textContent.Text, tc.expectedErrMsg) - return - } - - var returnedComment github.PullRequestComment - err = json.Unmarshal([]byte(getTextResult(t, result).Text), &returnedComment) - require.NoError(t, err) - - assert.Equal(t, *tc.expectedComment.ID, *returnedComment.ID) - assert.Equal(t, *tc.expectedComment.Body, *returnedComment.Body) - - // Only check Path, Line, and Side if they exist in the expected comment - if tc.expectedComment.Path != nil { - assert.Equal(t, *tc.expectedComment.Path, *returnedComment.Path) - } - if tc.expectedComment.Line != nil { - assert.Equal(t, *tc.expectedComment.Line, *returnedComment.Line) - } - if tc.expectedComment.Side != nil { - assert.Equal(t, *tc.expectedComment.Side, *returnedComment.Side) - } - }) - } -} diff --git a/pkg/github/server.go b/pkg/github/server.go index e4c24171..badfb1e5 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -3,6 +3,7 @@ package github import ( "errors" "fmt" + "math" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" @@ -68,21 +69,22 @@ func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { var zero T // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { + param, ok := r.Params.Arguments[p] + if !ok { return zero, fmt.Errorf("missing required parameter: %s", p) } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(T); !ok { + typedParam, ok := param.(T) + if !ok { return zero, fmt.Errorf("parameter %s is not of type %T", p, zero) } - if r.Params.Arguments[p].(T) == zero { + if typedParam == zero { return zero, fmt.Errorf("missing required parameter: %s", p) - } - return r.Params.Arguments[p].(T), nil + return typedParam, nil } // RequiredInt is a helper function that can be used to fetch a requested parameter from the request. @@ -98,6 +100,26 @@ func RequiredInt(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// RequiredInt32Param is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request. +// 2. Checks if the parameter is of the expected type. +// 3. Checks if the parameter is not empty, i.e: non-zero value +// 4. Checks if the parameter is within the int32 range +func RequiredInt32Param(r mcp.CallToolRequest, p string) (int32, error) { + v, err := RequiredInt(r, p) + if err != nil { + return 0, err + } + + // Check if the parameter is within the int32 range + if v < math.MinInt32 || v > math.MaxInt32 { + return 0, fmt.Errorf("parameter %s is out of int32 range", p) + } + + return int32(v), nil +} + // OptionalParam is a helper function that can be used to fetch a requested parameter from the request. // It does the following checks: // 1. Checks if the parameter is present in the request, if not, it returns its zero-value @@ -106,16 +128,18 @@ func OptionalParam[T any](r mcp.CallToolRequest, p string) (T, error) { var zero T // Check if the parameter is present in the request - if _, ok := r.Params.Arguments[p]; !ok { + param, ok := r.Params.Arguments[p] + if !ok { return zero, nil } // Check if the parameter is of the expected type - if _, ok := r.Params.Arguments[p].(T); !ok { + typedParam, ok := param.(T) + if !ok { return zero, fmt.Errorf("parameter %s is not of type %T, is %T", p, zero, r.Params.Arguments[p]) } - return r.Params.Arguments[p].(T), nil + return typedParam, nil } // OptionalIntParam is a helper function that can be used to fetch a requested parameter from the request. @@ -130,6 +154,25 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) { return int(v), nil } +// OptionalInt32Param is a helper function that can be used to fetch a requested parameter from the request. +// It does the following checks: +// 1. Checks if the parameter is present in the request, if not, it returns its zero-value +// 2. If it is present, it checks if the parameter is of the expected type +// 3. Checks if the parameter is within the int32 range +func OptionalInt32Param(r mcp.CallToolRequest, p string) (int32, error) { + v, err := OptionalIntParam(r, p) + if err != nil { + return 0, err + } + + // Check if the parameter is within the int32 range + if v < math.MinInt32 || v > math.MaxInt32 { + return 0, fmt.Errorf("parameter %s is out of int32 range", p) + } + + return int32(v), nil +} + // OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request // similar to optionalIntParam, but it also takes a default value. func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) { diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 58bcb9db..63a3c487 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -3,6 +3,7 @@ package github import ( "context" "fmt" + "math" "testing" "github.com/google/go-github/v69/github" @@ -157,7 +158,7 @@ func Test_OptionalStringParam(t *testing.T) { } } -func Test_RequiredNumberParam(t *testing.T) { +func Test_RequiredInt(t *testing.T) { tests := []struct { name string params map[string]interface{} @@ -203,7 +204,71 @@ func Test_RequiredNumberParam(t *testing.T) { } } -func Test_OptionalNumberParam(t *testing.T) { +func Test_RequiredInt32Param(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int32 + expectError bool + }{ + { + name: "valid int32 parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "too small parameter", + params: map[string]interface{}{ + "count": float64(math.MinInt32 - 1), + }, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "too large parameter", + params: map[string]interface{}{ + "count": float64(math.MaxInt32 + 1), + }, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "wrong type parameter", + params: map[string]interface{}{"count": "not-a-number"}, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := RequiredInt32Param(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +func Test_OptionalIntParam(t *testing.T) { tests := []struct { name string params map[string]interface{} @@ -256,6 +321,63 @@ func Test_OptionalNumberParam(t *testing.T) { } } +func Test_OptionalInt32Param(t *testing.T) { + tests := []struct { + name string + params map[string]interface{} + paramName string + expected int32 + expectError bool + }{ + { + name: "valid int32 parameter", + params: map[string]interface{}{"count": float64(42)}, + paramName: "count", + expected: 42, + expectError: false, + }, + { + name: "missing parameter", + params: map[string]interface{}{}, + paramName: "count", + expected: 0, + expectError: false, + }, + { + name: "too small parameter", + params: map[string]interface{}{ + "count": float64(math.MinInt32 - 1), + }, + paramName: "count", + expected: 0, + expectError: true, + }, + { + name: "too large parameter", + params: map[string]interface{}{ + "count": float64(math.MaxInt32 + 1), + }, + paramName: "count", + expected: 0, + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request := createMCPRequest(tc.params) + result, err := OptionalInt32Param(request, tc.paramName) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + func Test_OptionalNumberParamWithDefault(t *testing.T) { tests := []struct { name string diff --git a/pkg/github/tools.go b/pkg/github/tools.go index faef86ce..c18b8d74 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -7,13 +7,15 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/server" + "github.com/shurcooL/githubv4" ) type GetClientFn func(context.Context) (*github.Client, error) +type GetGQLClientFn func(context.Context) (*githubv4.Client, error) var DefaultTools = []string{"all"} -func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { +func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { // Create a new toolset group tsg := toolsets.NewToolsetGroup(readOnly) @@ -66,10 +68,15 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, AddWriteTools( toolsets.NewServerTool(MergePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequestBranch(getClient, t)), - toolsets.NewServerTool(CreatePullRequestReview(getClient, t)), toolsets.NewServerTool(CreatePullRequest(getClient, t)), toolsets.NewServerTool(UpdatePullRequest(getClient, t)), - toolsets.NewServerTool(AddPullRequestReviewComment(getClient, t)), + + // Reviews + toolsets.NewServerTool(CreateAndSubmitPullRequestReview(getGQLClient, t)), + toolsets.NewServerTool(CreatePendingPullRequestReview(getGQLClient, t)), + toolsets.NewServerTool(AddPullRequestReviewCommentToPendingReview(getGQLClient, t)), + toolsets.NewServerTool(SubmitPendingPullRequestReview(getGQLClient, t)), + toolsets.NewServerTool(DeletePendingPullRequestReview(getGQLClient, t)), ) codeSecurity := toolsets.NewToolset("code_security", "Code security related tools, such as GitHub Code Scanning"). AddReadTools( diff --git a/third-party-licenses.darwin.md b/third-party-licenses.darwin.md index 18c0379e..e8aab63a 100644 --- a/third-party-licenses.darwin.md +++ b/third-party-licenses.darwin.md @@ -16,6 +16,8 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) + - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) + - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE)) - [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE)) - [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE)) - [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt)) @@ -25,6 +27,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/spf13/viper](https://pkg.go.dev/github.com/spf13/viper) ([MIT](https://github.com/spf13/viper/blob/v1.20.1/LICENSE)) - [github.com/subosito/gotenv](https://pkg.go.dev/github.com/subosito/gotenv) ([MIT](https://github.com/subosito/gotenv/blob/v1.6.0/LICENSE)) - [github.com/yosida95/uritemplate/v3](https://pkg.go.dev/github.com/yosida95/uritemplate/v3) ([BSD-3-Clause](https://github.com/yosida95/uritemplate/blob/v3.0.2/LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.29.0:LICENSE)) - [golang.org/x/sys/unix](https://pkg.go.dev/golang.org/x/sys/unix) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.23.0:LICENSE)) - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) diff --git a/third-party-licenses.linux.md b/third-party-licenses.linux.md index 18c0379e..e8aab63a 100644 --- a/third-party-licenses.linux.md +++ b/third-party-licenses.linux.md @@ -16,6 +16,8 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) + - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) + - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE)) - [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE)) - [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE)) - [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt)) @@ -25,6 +27,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/spf13/viper](https://pkg.go.dev/github.com/spf13/viper) ([MIT](https://github.com/spf13/viper/blob/v1.20.1/LICENSE)) - [github.com/subosito/gotenv](https://pkg.go.dev/github.com/subosito/gotenv) ([MIT](https://github.com/subosito/gotenv/blob/v1.6.0/LICENSE)) - [github.com/yosida95/uritemplate/v3](https://pkg.go.dev/github.com/yosida95/uritemplate/v3) ([BSD-3-Clause](https://github.com/yosida95/uritemplate/blob/v3.0.2/LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.29.0:LICENSE)) - [golang.org/x/sys/unix](https://pkg.go.dev/golang.org/x/sys/unix) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.23.0:LICENSE)) - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) diff --git a/third-party-licenses.windows.md b/third-party-licenses.windows.md index 72f669db..931ef5bf 100644 --- a/third-party-licenses.windows.md +++ b/third-party-licenses.windows.md @@ -17,6 +17,8 @@ Some packages may only be included on certain architectures or operating systems - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.27.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) + - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) + - [github.com/shurcooL/graphql](https://pkg.go.dev/github.com/shurcooL/graphql) ([MIT](https://github.com/shurcooL/graphql/blob/ed46e5a46466/LICENSE)) - [github.com/sirupsen/logrus](https://pkg.go.dev/github.com/sirupsen/logrus) ([MIT](https://github.com/sirupsen/logrus/blob/v1.9.3/LICENSE)) - [github.com/sourcegraph/conc](https://pkg.go.dev/github.com/sourcegraph/conc) ([MIT](https://github.com/sourcegraph/conc/blob/v0.3.0/LICENSE)) - [github.com/spf13/afero](https://pkg.go.dev/github.com/spf13/afero) ([Apache-2.0](https://github.com/spf13/afero/blob/v1.14.0/LICENSE.txt)) @@ -26,6 +28,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/spf13/viper](https://pkg.go.dev/github.com/spf13/viper) ([MIT](https://github.com/spf13/viper/blob/v1.20.1/LICENSE)) - [github.com/subosito/gotenv](https://pkg.go.dev/github.com/subosito/gotenv) ([MIT](https://github.com/subosito/gotenv/blob/v1.6.0/LICENSE)) - [github.com/yosida95/uritemplate/v3](https://pkg.go.dev/github.com/yosida95/uritemplate/v3) ([BSD-3-Clause](https://github.com/yosida95/uritemplate/blob/v3.0.2/LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.29.0:LICENSE)) - [golang.org/x/sys/windows](https://pkg.go.dev/golang.org/x/sys/windows) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.23.0:LICENSE)) - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) diff --git a/third-party/github.com/shurcooL/githubv4/LICENSE b/third-party/github.com/shurcooL/githubv4/LICENSE new file mode 100644 index 00000000..ca4c7764 --- /dev/null +++ b/third-party/github.com/shurcooL/githubv4/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Dmitri Shuralyov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third-party/github.com/shurcooL/graphql/LICENSE b/third-party/github.com/shurcooL/graphql/LICENSE new file mode 100644 index 00000000..ca4c7764 --- /dev/null +++ b/third-party/github.com/shurcooL/graphql/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Dmitri Shuralyov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third-party/golang.org/x/oauth2/LICENSE b/third-party/golang.org/x/oauth2/LICENSE new file mode 100644 index 00000000..2a7cf70d --- /dev/null +++ b/third-party/golang.org/x/oauth2/LICENSE @@ -0,0 +1,27 @@ +Copyright 2009 The Go Authors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google LLC nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. From f0c9d525f5aa6bd83cc6b8d3a9b32ce17d0f9e41 Mon Sep 17 00:00:00 2001 From: William Martin Date: Wed, 14 May 2025 18:10:00 +0200 Subject: [PATCH 2/2] WIP --- e2e/e2e_test.go | 66 ++++++- go.mod | 2 +- pkg/github/pullrequests.go | 293 +++++++++++++++++++++++++------- pkg/github/pullrequests_test.go | 1 + pkg/github/tools.go | 1 + 5 files changed, 292 insertions(+), 71 deletions(-) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index c609c2c5..ac9bf106 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -1008,7 +1008,7 @@ func TestPullRequestReviewCommentSubmit(t *testing.T) { "owner": currentOwner, "repo": repoName, "path": "test-file.txt", - "content": fmt.Sprintf("Created by e2e test %s", t.Name()), + "content": fmt.Sprintf("Created by e2e test %s\nwith multiple lines", t.Name()), "message": "Add test file", "branch": "test-branch", } @@ -1065,21 +1065,62 @@ func TestPullRequestReviewCommentSubmit(t *testing.T) { require.True(t, ok, "expected content to be of type TextContent") require.Equal(t, "", textContent.Text, "expected content to be empty") - // Add a review comment - addReviewCommentRequest := mcp.CallToolRequest{} - addReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review" - addReviewCommentRequest.Params.Arguments = map[string]any{ + // Add a file review comment + addFileReviewCommentRequest := mcp.CallToolRequest{} + addFileReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review" + addFileReviewCommentRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + "path": "test-file.txt", + "subjectType": "FILE", + "body": "File review comment", + } + + t.Logf("Adding file review comment to pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, addFileReviewCommentRequest) + require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Add a single line review comment + addSingleLineReviewCommentRequest := mcp.CallToolRequest{} + addSingleLineReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review" + addSingleLineReviewCommentRequest.Params.Arguments = map[string]any{ "owner": currentOwner, "repo": repoName, "pullNumber": 1, "path": "test-file.txt", "subjectType": "LINE", - "body": "Very nice!", + "body": "Single line review comment", "line": 1, + "side": "RIGHT", + "commitId": commitId, } - t.Logf("Adding review comment to pull request in %s/%s...", currentOwner, repoName) - resp, err = mcpClient.CallTool(ctx, addReviewCommentRequest) + t.Logf("Adding single line review comment to pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, addSingleLineReviewCommentRequest) + require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully") + require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) + + // Add a multiline review comment + addMultilineReviewCommentRequest := mcp.CallToolRequest{} + addMultilineReviewCommentRequest.Params.Name = "add_pull_request_review_comment_to_pending_review" + addMultilineReviewCommentRequest.Params.Arguments = map[string]any{ + "owner": currentOwner, + "repo": repoName, + "pullNumber": 1, + "path": "test-file.txt", + "subjectType": "LINE", + "body": "Multiline review comment", + "startLine": 1, + "line": 2, + "startSide": "RIGHT", + "side": "RIGHT", + "commitId": commitId, + } + + t.Logf("Adding multi line review comment to pull request in %s/%s...", currentOwner, repoName) + resp, err = mcpClient.CallTool(ctx, addMultilineReviewCommentRequest) require.NoError(t, err, "expected to call 'add_pull_request_review_comment_to_pending_review' tool successfully") require.False(t, resp.IsError, fmt.Sprintf("expected result not to be an error: %+v", resp)) @@ -1117,6 +1158,7 @@ func TestPullRequestReviewCommentSubmit(t *testing.T) { require.True(t, ok, "expected content to be of type TextContent") var reviews []struct { + ID int `json:"id"` State string `json:"state"` } err = json.Unmarshal([]byte(textContent.Text), &reviews) @@ -1125,6 +1167,13 @@ func TestPullRequestReviewCommentSubmit(t *testing.T) { // Check that there is one review require.Len(t, reviews, 1, "expected to find one review") require.Equal(t, "COMMENTED", reviews[0].State, "expected review state to be COMMENTED") + + // Check that there are three review comments + // MCP Server doesn't support this, but we can use the GitHub Client + ghClient := gogithub.NewClient(nil).WithAuthToken(getE2EToken(t)) + comments, _, err := ghClient.PullRequests.ListReviewComments(context.Background(), currentOwner, repoName, 1, int64(reviews[0].ID), nil) + require.NoError(t, err, "expected to list review comments successfully") + require.Equal(t, 3, len(comments), "expected to find three review comments") } func TestPullRequestReviewDeletion(t *testing.T) { @@ -1314,5 +1363,4 @@ func TestPullRequestReviewDeletion(t *testing.T) { err = json.Unmarshal([]byte(textContent.Text), &noReviews) require.NoError(t, err, "expected to unmarshal text content successfully") require.Len(t, noReviews, 0, "expected to find no reviews") - } diff --git a/go.mod b/go.mod index 1505656b..0d5736fb 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/fsnotify/fsnotify v1.8.0 // indirect - github.com/go-viper/mapstructure/v2 v2.2.1 // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/go-github/v71 v71.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index b04abaf3..9a71e839 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -5,9 +5,12 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" + "reflect" "github.com/github/github-mcp-server/pkg/translations" + "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v69/github" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -1074,10 +1077,148 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } } +// Don't look too closely at this, I'm trying out a pattern for more complex param parsing. +type addPullRequestReviewCommentToPendingReviewParams struct { + Owner string + Repo string + PullNumber int32 + Path string + Body string + SubjectType string + Line *int32 + Side *string + StartLine *int32 + StartSide *string +} + +func (p *addPullRequestReviewCommentToPendingReviewParams) parse(args map[string]any) error { + requiredFields := []string{"owner", "repo", "pullNumber", "path", "body", "subjectType"} + for _, field := range requiredFields { + if _, ok := args[field]; !ok { + return fmt.Errorf("missing required field: %s", field) + } + } + + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + TagName: "json", + Result: p, + DecodeHook: composeHooks(requireWholeNumberHook, int32BoundsHook), + }) + if err != nil { + return fmt.Errorf("error creating decoder: %w", err) + } + + if err := decoder.Decode(args); err != nil { + return fmt.Errorf("error decoding input: %w", err) + } + + if p.PullNumber < 0 { + return fmt.Errorf("pull number must be positive") + } + + hasLineInfo := p.Line != nil || p.StartLine != nil || p.Side != nil || p.StartSide != nil + if p.SubjectType != "LINE" && hasLineInfo { + return fmt.Errorf("line numbers or sides can only be used with LINE subject type") + } + + hasNoLineNumberOrSide := p.Line == nil || p.Side == nil + if p.SubjectType == "LINE" && hasNoLineNumberOrSide { + return fmt.Errorf("at least a line number and side must be provided for LINE subject type") + } + + if p.Line != nil && *p.Line < 0 { + return fmt.Errorf("line number must be positive") + } + + if p.StartLine != nil && *p.StartLine < 0 { + return fmt.Errorf("start line number must be positive") + } + + return nil +} + +var requireWholeNumberHook mapstructure.DecodeHookFuncType = func(from reflect.Type, to reflect.Type, data any) (any, error) { + // Only care about source float64s + if from.Kind() != reflect.Float64 { + return data, nil + } + + // Unwrap pointers if necessary + target := to + if to.Kind() == reflect.Ptr { + target = to.Elem() + } + + switch target.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v := data.(float64) + if v != float64(int64(v)) { + return nil, fmt.Errorf("value %v is not a whole number", v) + } + } + + return data, nil +} + +var int32BoundsHook mapstructure.DecodeHookFuncType = func(_ reflect.Type, to reflect.Type, data any) (any, error) { + var isPtr bool + var targetType reflect.Type + + if to.Kind() == reflect.Ptr { + isPtr = true + targetType = to.Elem() + } else { + targetType = to + } + + // Only handle int32 (or *int32) targets + if targetType != reflect.TypeOf(int32(0)) { + return data, nil + } + + var val float64 + switch v := data.(type) { + case float64: + if v != float64(int64(v)) { + return nil, fmt.Errorf("value %v is not a whole number", v) + } + val = v + case int: + val = float64(v) + default: + return data, nil // Not a numeric type we handle + } + + if val < math.MinInt32 || val > math.MaxInt32 { + return nil, fmt.Errorf("value %v is out of int32 bounds", val) + } + + // Convert and optionally return pointer + i := int32(val) + if isPtr { + return &i, nil + } + return i, nil +} + +func composeHooks(hooks ...mapstructure.DecodeHookFuncType) mapstructure.DecodeHookFunc { + return func(f reflect.Type, t reflect.Type, data any) (any, error) { + var err error + for _, hook := range hooks { + data, err = hook(f, t, data) + if err != nil { + return nil, err + } + } + return data, nil + } +} + // AddPullRequestReviewCommentToPendingReview creates a tool to add a comment to a pull request review. func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { return mcp.NewTool("add_pull_request_review_comment_to_pending_review", - mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add a comment to the requester's latest pending pull request review, a pending review needs to already exist to call this (check with the user if not sure).")), + mcp.WithDescription(t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_DESCRIPTION", "Add a comment to the requester's latest pending pull request review, a pending review needs to already exist to call this (check with the user if not sure). If you are using the LINE subjectType, consider getting a diff of the Pull Request to be certain of line numbers.")), mcp.WithToolAnnotation(mcp.ToolAnnotation{ Title: t("TOOL_ADD_PULL_REQUEST_REVIEW_COMMENT_TO_PENDING_REVIEW_USER_TITLE", "Add comment to the requester's latest pending pull request review"), ReadOnlyHint: toBoolPtr(false), @@ -1120,65 +1261,20 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t mcp.Description("The line of the blob in the pull request diff that the comment applies to. For multi-line comments, the last line of the range"), ), mcp.WithString("side", - mcp.Description("The side of the diff to comment on"), + mcp.Description("The side of the diff to comment on. LEFT indicates the previous state, RIGHT indicates the new state"), mcp.Enum("LEFT", "RIGHT"), ), mcp.WithNumber("startLine", mcp.Description("For multi-line comments, the first line of the range that the comment applies to"), ), mcp.WithString("startSide", - mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to"), + mcp.Description("For multi-line comments, the starting side of the diff that the comment applies to. LEFT indicates the previous state, RIGHT indicates the new state"), mcp.Enum("LEFT", "RIGHT"), ), ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - owner, err := requiredParam[string](request, "owner") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - repo, err := requiredParam[string](request, "repo") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - pullNumber, err := RequiredInt32Param(request, "pullNumber") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - path, err := requiredParam[string](request, "path") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - body, err := requiredParam[string](request, "body") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - subjectType, err := requiredParam[string](request, "subjectType") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - line, err := OptionalInt32Param(request, "line") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - side, err := OptionalParam[string](request, "side") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - startLine, err := OptionalInt32Param(request, "startLine") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil - } - - startSide, err := OptionalParam[string](request, "startSide") - if err != nil { + var params addPullRequestReviewCommentToPendingReviewParams + if err := params.parse(request.Params.Arguments); err != nil { return mcp.NewToolResultError(err.Error()), nil } @@ -1213,11 +1309,11 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t } `graphql:"repository(owner: $owner, name: $name)"` } - vars := map[string]interface{}{ + vars := map[string]any{ "author": githubv4.String(getViewerQuery.Viewer.Login), - "owner": githubv4.String(owner), - "name": githubv4.String(repo), - "number": githubv4.Int(pullNumber), + "owner": githubv4.String(params.Owner), + "name": githubv4.String(params.Repo), + "number": githubv4.Int(params.PullNumber), } if err := client.Query(context.Background(), &getLatestReviewForViewerQuery, vars); err != nil { @@ -1248,13 +1344,13 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t ctx, &addPullRequestReviewThreadMutation, githubv4.AddPullRequestReviewThreadInput{ - Path: githubv4.String(path), - Body: githubv4.String(body), - SubjectType: newGQLStringlike[githubv4.PullRequestReviewThreadSubjectType](subjectType), - Line: githubv4.NewInt(githubv4.Int(line)), - Side: newGQLStringlike[githubv4.DiffSide](side), - StartLine: githubv4.NewInt(githubv4.Int(startLine)), - StartSide: newGQLStringlike[githubv4.DiffSide](startSide), + Path: githubv4.String(params.Path), + Body: githubv4.String(params.Body), + SubjectType: newGQLStringlikePtr[githubv4.PullRequestReviewThreadSubjectType](¶ms.SubjectType), + Line: newGQLIntPtr(params.Line), + Side: newGQLStringlikePtr[githubv4.DiffSide](params.Side), + StartLine: newGQLIntPtr(params.StartLine), + StartSide: newGQLStringlikePtr[githubv4.DiffSide](params.StartSide), PullRequestReviewID: &review.ID, }, nil, @@ -1544,6 +1640,65 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } } +func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, server.ToolHandlerFunc) { + return mcp.NewTool("get_pull_request_diff", + mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DIFF_DESCRIPTION", "Get the diff of a pull request.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_GET_PULL_REQUEST_DIFF_USER_TITLE", "Get pull request diff"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("Repository owner"), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("Repository name"), + ), + mcp.WithNumber("pullNumber", + mcp.Required(), + mcp.Description("Pull request number"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pullNumber, err := RequiredInt(request, "pullNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + raw, resp, err := client.PullRequests.GetRaw(ctx, owner, repo, pullNumber, github.RawOptions{Type: github.Diff}) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to get pull request diff: %s", string(body))), nil + } + + defer func() { _ = resp.Body.Close() }() + + // Return the raw response + return mcp.NewToolResultText(string(raw)), nil + } +} + // newGQLString like takes something that approximates a string (of which there are many types in shurcooL/githubv4) // and constructs a pointer to it, or nil if the string is empty. This is extremely useful because when we parse // params from the MCP request, we need to convert them to types that are pointers of type def strings and it's @@ -1555,3 +1710,19 @@ func newGQLStringlike[T ~string](s string) *T { stringlike := T(s) return &stringlike } + +func newGQLStringlikePtr[T ~string](s *string) *T { + if s == nil { + return nil + } + stringlike := T(*s) + return &stringlike +} + +func newGQLIntPtr(i *int32) *githubv4.Int { + if i == nil { + return nil + } + gi := githubv4.Int(*i) + return &gi +} diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index dad1c226..17cc343a 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -9,6 +9,7 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/pkg/github/tools.go b/pkg/github/tools.go index c18b8d74..06bdbe58 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -64,6 +64,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, toolsets.NewServerTool(GetPullRequestStatus(getClient, t)), toolsets.NewServerTool(GetPullRequestComments(getClient, t)), toolsets.NewServerTool(GetPullRequestReviews(getClient, t)), + toolsets.NewServerTool(GetPullRequestDiff(getClient, t)), ). AddWriteTools( toolsets.NewServerTool(MergePullRequest(getClient, t)),