diff --git a/README.md b/README.md index ec8018a0..c869d191 100644 --- a/README.md +++ b/README.md @@ -408,6 +408,16 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description - `page`: Page number (number, optional) - `perPage`: Results per page (number, optional) +### Organizations + +- **list_repositories** - Get list of repositories in a GitHub organization. + - `org`: Organization name (string, required) + - `type`: Type of repositories to list (string, optional) + - `sort`: Sort field (string, optional) + - `direction`: Sort order (string, optional) + - `page`: Page number (number, optional) + - `perPage`: Results per page (number, optional) + ### Code Scanning - **get_code_scanning_alert** - Get a code scanning alert diff --git a/pkg/github/organizations.go b/pkg/github/organizations.go new file mode 100644 index 00000000..23679746 --- /dev/null +++ b/pkg/github/organizations.go @@ -0,0 +1,104 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "io" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// ListCommits creates a tool to get commits of a branch in a repository. +func ListRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_repositories", + mcp.WithDescription(t("TOOL_LIST_REPOSITORIES_DESCRIPTION", "Get list of repositories in a GitHub organization")), + mcp.WithString("org", + mcp.Required(), + mcp.Description("Organization name"), + ), + mcp.WithString("type", + mcp.Description("Type of repositories to list."), + mcp.Enum("all", "public", "private", "forks", "sources", "member"), + mcp.DefaultString("all"), + ), + mcp.WithString("sort", + mcp.Description("How to sort the repository list."), + mcp.Enum("created", "updated", "pushed", "full_name"), + mcp.DefaultString("created"), + ), + mcp.WithString("direction", + mcp.Description("Direction in which to sort repositories. Default when using full_name: asc; otherwise desc."), + mcp.Enum("asc", "desc"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + org, err := requiredParam[string](request, "org") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + pagination, err := OptionalPaginationParams(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + opts := &github.RepositoryListByOrgOptions{ + ListOptions: github.ListOptions{ + Page: pagination.page, + PerPage: pagination.perPage, + }, + } + + repoType, err := OptionalParam[string](request, "type") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if repoType != "" { + opts.Type = repoType + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if sort != "" { + opts.Sort = sort + } + direction, err := OptionalParam[string](request, "direction") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + if direction != "" { + opts.Direction = direction + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + repos, resp, err := client.Repositories.ListByOrg(ctx, org, opts) + if err != nil { + return nil, fmt.Errorf("failed to list repositories: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list repositories: %s", string(body))), nil + } + + r, err := json.Marshal(repos) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/organizations_test.go b/pkg/github/organizations_test.go new file mode 100644 index 00000000..d3351f1c --- /dev/null +++ b/pkg/github/organizations_test.go @@ -0,0 +1,209 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v69/github" + "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ListRepositories(t *testing.T) { + // Verify tool definition once + mockClient := github.NewClient(nil) + tool, _ := ListRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_repositories", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "org") + assert.Contains(t, tool.InputSchema.Properties, "type") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "direction") + assert.Contains(t, tool.InputSchema.Properties, "perPage") + assert.Contains(t, tool.InputSchema.Properties, "page") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"}) + + // Setup mock repos for success case + mockRepos := []*github.Repository{ + { + ID: github.Ptr(int64(1001)), + Name: github.Ptr("repo1"), + FullName: github.Ptr("testorg/repo1"), + Description: github.Ptr("Test repo 1"), + HTMLURL: github.Ptr("https://github.com/testorg/repo1"), + Private: github.Ptr(false), + Fork: github.Ptr(false), + }, + { + ID: github.Ptr(int64(1002)), + Name: github.Ptr("repo2"), + FullName: github.Ptr("testorg/repo2"), + Description: github.Ptr("Test repo 2"), + HTMLURL: github.Ptr("https://github.com/testorg/repo2"), + Private: github.Ptr(true), + Fork: github.Ptr(false), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedRepos []*github.Repository + expectedErrMsg string + }{ + { + name: "successful repositories listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsReposByOrg, + expectQueryParams(t, map[string]string{ + "type": "all", + "sort": "created", + "direction": "desc", + "per_page": "30", + "page": "1", + }).andThen( + mockResponse(t, http.StatusOK, mockRepos), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "testorg", + "type": "all", + "sort": "created", + "direction": "desc", + "perPage": float64(30), + "page": float64(1), + }, + expectError: false, + expectedRepos: mockRepos, + }, + { + name: "successful repos listing with defaults", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsReposByOrg, + expectQueryParams(t, map[string]string{ + "per_page": "30", + "page": "1", + }).andThen( + mockResponse(t, http.StatusOK, mockRepos), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "testorg", + // Using defaults for other parameters + }, + expectError: false, + expectedRepos: mockRepos, + }, + { + name: "custom pagination and filtering", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsReposByOrg, + expectQueryParams(t, map[string]string{ + "type": "public", + "sort": "updated", + "direction": "asc", + "per_page": "10", + "page": "2", + }).andThen( + mockResponse(t, http.StatusOK, mockRepos), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "testorg", + "type": "public", + "sort": "updated", + "direction": "asc", + "perPage": float64(10), + "page": float64(2), + }, + expectError: false, + expectedRepos: mockRepos, + }, + { + name: "API error response", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsReposByOrg, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message": "Not Found"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "org": "nonexistentorg", + }, + expectError: true, + expectedErrMsg: "failed to list repositories", + }, + { + name: "rate limit exceeded", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsReposByOrg, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message": "API rate limit exceeded"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "org": "testorg", + }, + expectError: true, + expectedErrMsg: "failed to list repositories", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Setup client with mock + client := github.NewClient(tc.mockedClient) + _, handler := ListRepositories(stubGetClientFn(client), translations.NullTranslationHelper) + + // Create call request + request := createMCPRequest(tc.requestArgs) + + // Call handler + result, err := handler(context.Background(), request) + + // Verify results + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + + // Parse the result and get the text content if no error + textContent := getTextResult(t, result) + + // Unmarshal and verify the result + var returnedRepos []*github.Repository + err = json.Unmarshal([]byte(textContent.Text), &returnedRepos) + require.NoError(t, err) + assert.Len(t, returnedRepos, len(tc.expectedRepos)) + for i, repo := range returnedRepos { + assert.Equal(t, *tc.expectedRepos[i].ID, *repo.ID) + assert.Equal(t, *tc.expectedRepos[i].Name, *repo.Name) + assert.Equal(t, *tc.expectedRepos[i].FullName, *repo.FullName) + assert.Equal(t, *tc.expectedRepos[i].Private, *repo.Private) + assert.Equal(t, *tc.expectedRepos[i].HTMLURL, *repo.HTMLURL) + } + }) + } +} diff --git a/pkg/github/server.go b/pkg/github/server.go index da916b98..a4aefa53 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -80,6 +80,9 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati s.AddTool(PushFiles(getClient, t)) } + // Add GitHub tools - Organizations + s.AddTool(ListRepositories(getClient, t)) + // Add GitHub tools - Search s.AddTool(SearchCode(getClient, t)) s.AddTool(SearchUsers(getClient, t)) @@ -179,7 +182,6 @@ func requiredParam[T comparable](r mcp.CallToolRequest, p string) (T, error) { if r.Params.Arguments[p].(T) == zero { return zero, fmt.Errorf("missing required parameter: %s", p) - } return r.Params.Arguments[p].(T), nil