From e0cde718c1f9e4f303fde05b8992382824a3e1c9 Mon Sep 17 00:00:00 2001
From: Martina Jireckova <martinajir@github.com>
Date: Wed, 16 Apr 2025 08:21:45 +0000
Subject: [PATCH 1/4] WIP: List notifications tool

---
 pkg/github/notifications.go      |  92 ++++++++++++++++++++++
 pkg/github/notifications_test.go | 127 +++++++++++++++++++++++++++++++
 pkg/github/server.go             |  14 ++++
 pkg/github/tools.go              |   6 ++
 4 files changed, 239 insertions(+)
 create mode 100644 pkg/github/notifications.go
 create mode 100644 pkg/github/notifications_test.go

diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go
new file mode 100644
index 00000000..eb664b9b
--- /dev/null
+++ b/pkg/github/notifications.go
@@ -0,0 +1,92 @@
+package github
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/github/github-mcp-server/pkg/translations"
+	"github.com/google/go-github/v69/github"
+	"github.com/mark3labs/mcp-go/mcp"
+	"github.com/mark3labs/mcp-go/server"
+)
+
+// ListNotifications creates a tool to list notifications for a GitHub user.
+func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+	return mcp.NewTool("list_notifications",
+			mcp.WithDescription(t("TOOL_LIST_NOTIFICATIONS_DESCRIPTION", "List notifications for a GitHub user")),
+			mcp.WithNumber("page",
+				mcp.Description("Page number"),
+			),
+			mcp.WithNumber("per_page",
+				mcp.Description("Number of records per page"),
+			),
+			mcp.WithBoolean("all",
+				mcp.Description("Whether to fetch all notifications, including read ones"),
+			),
+		),
+		func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+			page, err := OptionalIntParamWithDefault(request, "page", 1)
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+			perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+			all, err := OptionalBoolParamWithDefault(request, "all", false) // Default to false unless specified
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+
+			if request.Params.Arguments["all"] == true {
+				all = true // Set to true if user explicitly asks for all notifications
+			}
+
+			opts := &github.NotificationListOptions{
+				ListOptions: github.ListOptions{
+					Page:    page,
+					PerPage: perPage,
+				},
+				All: all, // Include all notifications, even those already read.
+			}
+
+			client, err := getClient(ctx)
+			if err != nil {
+				return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+			}
+			notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
+			if err != nil {
+				return nil, fmt.Errorf("failed to list notifications: %w", err)
+			}
+			defer func() { _ = resp.Body.Close() }()
+
+			if resp.StatusCode != http.StatusOK {
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					return nil, fmt.Errorf("failed to read response body: %w", err)
+				}
+				return mcp.NewToolResultError(fmt.Sprintf("failed to list notifications: %s", string(body))), nil
+			}
+
+			// Extract the notification title in addition to reason, url, and timestamp.
+			var extractedNotifications []map[string]interface{}
+			for _, notification := range notifications {
+				extractedNotifications = append(extractedNotifications, map[string]interface{}{
+					"title":     notification.GetSubject().GetTitle(),
+					"reason":    notification.GetReason(),
+					"url":       notification.GetURL(),
+					"timestamp": notification.GetUpdatedAt(),
+				})
+			}
+
+			r, err := json.Marshal(extractedNotifications)
+			if err != nil {
+				return nil, fmt.Errorf("failed to marshal notifications: %w", err)
+			}
+
+			return mcp.NewToolResultText(string(r)), nil
+		}
+}
diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go
new file mode 100644
index 00000000..2d663c7a
--- /dev/null
+++ b/pkg/github/notifications_test.go
@@ -0,0 +1,127 @@
+package github
+
+import (
+	"context"
+	"encoding/json"
+	"net/http"
+	"testing"
+	"time"
+
+	"github.com/github/github-mcp-server/pkg/translations"
+	"github.com/google/go-github/v69/github"
+	"github.com/migueleliasweb/go-github-mock/src/mock"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func Test_ListNotifications(t *testing.T) {
+	// Verify tool definition
+	mockClient := github.NewClient(nil)
+	tool, _ := ListNotifications(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+
+	assert.Equal(t, "list_notifications", tool.Name)
+	assert.NotEmpty(t, tool.Description)
+	assert.Contains(t, tool.InputSchema.Properties, "page")
+	assert.Contains(t, tool.InputSchema.Properties, "per_page")
+	assert.Contains(t, tool.InputSchema.Properties, "all")
+
+	// Setup mock notifications
+	mockNotifications := []*github.Notification{
+		{
+			ID:     github.String("1"),
+			Reason: github.String("mention"),
+			Subject: &github.NotificationSubject{
+				Title: github.String("Test Notification 1"),
+			},
+			UpdatedAt: &github.Timestamp{Time: time.Now()},
+			URL:       github.String("https://example.com/notifications/threads/1"),
+		},
+		{
+			ID:     github.String("2"),
+			Reason: github.String("team_mention"),
+			Subject: &github.NotificationSubject{
+				Title: github.String("Test Notification 2"),
+			},
+			UpdatedAt: &github.Timestamp{Time: time.Now()},
+			URL:       github.String("https://example.com/notifications/threads/1"),
+		},
+	}
+
+	tests := []struct {
+		name             string
+		mockedClient     *http.Client
+		requestArgs      map[string]interface{}
+		expectError      bool
+		expectedResponse []*github.Notification
+		expectedErrMsg   string
+	}{
+		{
+			name: "list all notifications",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatch(
+					mock.GetNotifications,
+					mockNotifications,
+				),
+			),
+			requestArgs: map[string]interface{}{
+				"all": true,
+			},
+			expectError:      false,
+			expectedResponse: mockNotifications,
+		},
+		{
+			name: "list unread notifications",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatch(
+					mock.GetNotifications,
+					mockNotifications[:1], // Only the first notification
+				),
+			),
+			requestArgs: map[string]interface{}{
+				"all": false,
+			},
+			expectError:      false,
+			expectedResponse: mockNotifications[:1],
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			// Setup client with mock
+			client := github.NewClient(tc.mockedClient)
+			_, handler := ListNotifications(stubGetClientFn(client), translations.NullTranslationHelper)
+
+			// Create call request
+			request := createMCPRequest(tc.requestArgs)
+			// Call handler
+			result, err := handler(context.Background(), request)
+
+			// Verify results
+			if tc.expectError {
+				require.Error(t, err)
+				assert.Contains(t, err.Error(), tc.expectedErrMsg)
+				return
+			}
+
+			require.NoError(t, err)
+			textContent := getTextResult(t, result)
+
+			// Unmarshal and verify the result
+			var returnedNotifications []*github.Notification
+			err = json.Unmarshal([]byte(textContent.Text), &returnedNotifications)
+			require.NoError(t, err)
+			assert.Equal(t, len(tc.expectedResponse), len(returnedNotifications))
+			for i, notification := range returnedNotifications {
+				// Ensure all required fields are mocked
+				assert.NotNil(t, notification.Subject, "Subject should not be nil")
+				assert.NotNil(t, notification.Subject.Title, "Title should not be nil")
+				assert.NotNil(t, notification.Reason, "Reason should not be nil")
+				assert.NotNil(t, notification.URL, "URL should not be nil")
+				assert.NotNil(t, notification.UpdatedAt, "UpdatedAt should not be nil")
+				// assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID)
+				assert.Equal(t, *tc.expectedResponse[i].Reason, *notification.Reason)
+				// assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title)
+			}
+		})
+	}
+}
diff --git a/pkg/github/server.go b/pkg/github/server.go
index e4c24171..b2413f54 100644
--- a/pkg/github/server.go
+++ b/pkg/github/server.go
@@ -130,6 +130,20 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
 	return int(v), nil
 }
 
+// OptionalBoolParamWithDefault is a helper function that retrieves a boolean parameter from the request.
+// If the parameter is not present, it returns the provided default value. If the parameter is present,
+// it validates its type and returns the value.
+func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) (bool, error) {
+	v, err := OptionalParam[bool](request, s)
+	if err != nil {
+		return false, err
+	}
+	if b == false {
+		return b, nil
+	}
+	return v, nil
+}
+
 // OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
 // similar to optionalIntParam, but it also takes a default value.
 func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
diff --git a/pkg/github/tools.go b/pkg/github/tools.go
index ce10c4ad..4ac5d3bc 100644
--- a/pkg/github/tools.go
+++ b/pkg/github/tools.go
@@ -76,6 +76,11 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
 	// Keep experiments alive so the system doesn't error out when it's always enabled
 	experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet")
 
+	notifications := toolsets.NewToolset("notifications", "GitHub Notifications related tools").
+		AddReadTools(
+			toolsets.NewServerTool(ListNotifications(getClient, t)),
+		)
+
 	// Add toolsets to the group
 	tsg.AddToolset(repos)
 	tsg.AddToolset(issues)
@@ -83,6 +88,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
 	tsg.AddToolset(pullRequests)
 	tsg.AddToolset(codeSecurity)
 	tsg.AddToolset(experiments)
+	tsg.AddToolset(notifications)
 	// Enable the requested features
 
 	if err := tsg.EnableToolsets(passedToolsets); err != nil {

From 2e52386ba5c8d9234de80c6caf6610b3d41bf8ef Mon Sep 17 00:00:00 2001
From: Martina Jireckova <martinajir@github.com>
Date: Tue, 22 Apr 2025 13:01:04 +0000
Subject: [PATCH 2/4] Improve testing and mapping code in request and response

---
 README.md                        |  8 ++++++++
 pkg/github/notifications.go      | 17 +----------------
 pkg/github/notifications_test.go | 26 ++++++++++----------------
 3 files changed, 19 insertions(+), 32 deletions(-)

diff --git a/README.md b/README.md
index 6bfc6ab5..4d04f6fc 100644
--- a/README.md
+++ b/README.md
@@ -437,6 +437,14 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description
   - `state`: Alert state (string, optional)
   - `severity`: Alert severity (string, optional)
 
+### Notifications
+
+- **list_notifications** - List notifications for a GitHub user
+
+  - `page`: Page number (number, optional, default: 1)
+  - `per_page`: Number of records per page (number, optional, default: 30)
+  - `all`: Whether to fetch all notifications, including read ones (boolean, optional, default: false)
+
 ## Resources
 
 ### Repository Content
diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go
index eb664b9b..10a98be2 100644
--- a/pkg/github/notifications.go
+++ b/pkg/github/notifications.go
@@ -41,10 +41,6 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
 				return mcp.NewToolResultError(err.Error()), nil
 			}
 
-			if request.Params.Arguments["all"] == true {
-				all = true // Set to true if user explicitly asks for all notifications
-			}
-
 			opts := &github.NotificationListOptions{
 				ListOptions: github.ListOptions{
 					Page:    page,
@@ -71,18 +67,7 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
 				return mcp.NewToolResultError(fmt.Sprintf("failed to list notifications: %s", string(body))), nil
 			}
 
-			// Extract the notification title in addition to reason, url, and timestamp.
-			var extractedNotifications []map[string]interface{}
-			for _, notification := range notifications {
-				extractedNotifications = append(extractedNotifications, map[string]interface{}{
-					"title":     notification.GetSubject().GetTitle(),
-					"reason":    notification.GetReason(),
-					"url":       notification.GetURL(),
-					"timestamp": notification.GetUpdatedAt(),
-				})
-			}
-
-			r, err := json.Marshal(extractedNotifications)
+			r, err := json.Marshal(notifications)
 			if err != nil {
 				return nil, fmt.Errorf("failed to marshal notifications: %w", err)
 			}
diff --git a/pkg/github/notifications_test.go b/pkg/github/notifications_test.go
index 2d663c7a..20c7967b 100644
--- a/pkg/github/notifications_test.go
+++ b/pkg/github/notifications_test.go
@@ -28,22 +28,22 @@ func Test_ListNotifications(t *testing.T) {
 	// Setup mock notifications
 	mockNotifications := []*github.Notification{
 		{
-			ID:     github.String("1"),
-			Reason: github.String("mention"),
+			ID:     github.Ptr("1"),
+			Reason: github.Ptr("mention"),
 			Subject: &github.NotificationSubject{
-				Title: github.String("Test Notification 1"),
+				Title: github.Ptr("Test Notification 1"),
 			},
 			UpdatedAt: &github.Timestamp{Time: time.Now()},
-			URL:       github.String("https://example.com/notifications/threads/1"),
+			URL:       github.Ptr("https://example.com/notifications/threads/1"),
 		},
 		{
-			ID:     github.String("2"),
-			Reason: github.String("team_mention"),
+			ID:     github.Ptr("2"),
+			Reason: github.Ptr("team_mention"),
 			Subject: &github.NotificationSubject{
-				Title: github.String("Test Notification 2"),
+				Title: github.Ptr("Test Notification 2"),
 			},
 			UpdatedAt: &github.Timestamp{Time: time.Now()},
-			URL:       github.String("https://example.com/notifications/threads/1"),
+			URL:       github.Ptr("https://example.com/notifications/threads/1"),
 		},
 	}
 
@@ -112,15 +112,9 @@ func Test_ListNotifications(t *testing.T) {
 			require.NoError(t, err)
 			assert.Equal(t, len(tc.expectedResponse), len(returnedNotifications))
 			for i, notification := range returnedNotifications {
-				// Ensure all required fields are mocked
-				assert.NotNil(t, notification.Subject, "Subject should not be nil")
-				assert.NotNil(t, notification.Subject.Title, "Title should not be nil")
-				assert.NotNil(t, notification.Reason, "Reason should not be nil")
-				assert.NotNil(t, notification.URL, "URL should not be nil")
-				assert.NotNil(t, notification.UpdatedAt, "UpdatedAt should not be nil")
-				// assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID)
+				assert.Equal(t, *tc.expectedResponse[i].ID, *notification.ID)
 				assert.Equal(t, *tc.expectedResponse[i].Reason, *notification.Reason)
-				// assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title)
+				assert.Equal(t, *tc.expectedResponse[i].Subject.Title, *notification.Subject.Title)
 			}
 		})
 	}

From ea562ae82a21128777ee70e3d8e06f76d7836505 Mon Sep 17 00:00:00 2001
From: Martina Jireckova <martinajir@github.com>
Date: Tue, 22 Apr 2025 13:31:28 +0000
Subject: [PATCH 3/4] Fix lint

---
 pkg/github/server.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pkg/github/server.go b/pkg/github/server.go
index b2413f54..91328499 100644
--- a/pkg/github/server.go
+++ b/pkg/github/server.go
@@ -138,7 +138,7 @@ func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool)
 	if err != nil {
 		return false, err
 	}
-	if b == false {
+	if !b {
 		return b, nil
 	}
 	return v, nil

From 81f093e7be6c95f38530f7cf36188f425b4a33ca Mon Sep 17 00:00:00 2001
From: Martina Jireckova <martinajir@github.com>
Date: Tue, 22 Apr 2025 15:36:43 +0000
Subject: [PATCH 4/4] Inline optional bool function for all notifications

---
 pkg/github/notifications.go |  6 +++---
 pkg/github/server.go        | 14 --------------
 2 files changed, 3 insertions(+), 17 deletions(-)

diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go
index 10a98be2..512452d2 100644
--- a/pkg/github/notifications.go
+++ b/pkg/github/notifications.go
@@ -36,9 +36,9 @@ func ListNotifications(getClient GetClientFn, t translations.TranslationHelperFu
 			if err != nil {
 				return mcp.NewToolResultError(err.Error()), nil
 			}
-			all, err := OptionalBoolParamWithDefault(request, "all", false) // Default to false unless specified
-			if err != nil {
-				return mcp.NewToolResultError(err.Error()), nil
+			all := false
+			if val, err := OptionalParam[bool](request, "all"); err == nil {
+				all = val
 			}
 
 			opts := &github.NotificationListOptions{
diff --git a/pkg/github/server.go b/pkg/github/server.go
index 91328499..e4c24171 100644
--- a/pkg/github/server.go
+++ b/pkg/github/server.go
@@ -130,20 +130,6 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
 	return int(v), nil
 }
 
-// OptionalBoolParamWithDefault is a helper function that retrieves a boolean parameter from the request.
-// If the parameter is not present, it returns the provided default value. If the parameter is present,
-// it validates its type and returns the value.
-func OptionalBoolParamWithDefault(request mcp.CallToolRequest, s string, b bool) (bool, error) {
-	v, err := OptionalParam[bool](request, s)
-	if err != nil {
-		return false, err
-	}
-	if !b {
-		return b, nil
-	}
-	return v, nil
-}
-
 // OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
 // similar to optionalIntParam, but it also takes a default value.
 func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {