Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions cmd/mcp-http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"os/signal"
"regexp"
"slices"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -173,13 +174,41 @@ func tracerMiddleware(resources config.Resources, next http.Handler) http.Handle
}

func authMiddleware(resources config.Resources, next http.Handler) http.Handler {
whitelistEndpoints := map[string][]string{
// health checks don't require authentication
"/api/health": {http.MethodGet, http.MethodOptions},

// allow some protocol methods to bypass authentication
//
// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates
// https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts
"/": {http.MethodPost},
"/tools/list": {http.MethodPost},
"/resources/list": {http.MethodPost},
"/resources/templates/list": {http.MethodPost},
"/prompts/list": {http.MethodPost},
}

whitelistPrefixEndpoints := map[string][]string{
// OAuth2 endpoints cannot require authentication
"/.well-known": {"GET", "OPTIONS"},
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// some endpoints don't require auth
if (r.URL.Path == "/api/health" || strings.HasPrefix(r.URL.Path, "/.well-known")) &&
(r.Method == http.MethodGet || r.Method == http.MethodOptions) {
if methods, ok := whitelistEndpoints[r.URL.Path]; ok && slices.Contains(methods, r.Method) {
next.ServeHTTP(w, r)
return
}
for prefix, methods := range whitelistPrefixEndpoints {
if strings.HasPrefix(r.URL.Path, prefix) && slices.Contains(methods, r.Method) {
next.ServeHTTP(w, r)
return
}
}

requestLogger := resources.Logger().With(
slog.String("method", r.Method),
Expand Down
103 changes: 80 additions & 23 deletions cmd/mcp-stdio/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"os"
"slices"
"strings"

"github.com/mark3labs/mcp-go/mcp"
Expand All @@ -20,7 +21,23 @@ import (
)

var (
methods = methodsInput([]toolsets.Method{toolsets.MethodAll})
methods = methodsInput([]toolsets.Method{toolsets.MethodAll})
methodsWhitelist = []string{
// allow some protocol methods to bypass authentication
//
// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#listing-tools
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#listing-resources
// https://modelcontextprotocol.io/specification/2025-06-18/server/resources#resource-templates
// https://modelcontextprotocol.io/specification/2025-06-18/server/prompts#listing-prompts
"initialize",
"notifications/initialized",
"logging/setLevel",
"tools/list",
"resources/list",
"resources/templates/list",
"prompts/list",
}
readOnly bool
)

Expand All @@ -34,33 +51,31 @@ func main() {
flag.BoolVar(&readOnly, "read-only", false, "Restrict the server to read-only operations")
flag.Parse()

if resources.Info.BearerToken == "" {
mcpError(resources, errors.New("TW_MCP_BEARER_TOKEN environment variable is not set"), mcp.INVALID_PARAMS)
exit(exitCodeSetupFailure)
}

ctx := context.Background()

// detect the installation from the bearer token
info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken)
if err != nil {
mcpError(resources, fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS)
exit(exitCodeSetupFailure)
}
if resources.Info.BearerToken != "" {
// detect the installation from the bearer token
info, err := auth.GetBearerInfo(ctx, resources, resources.Info.BearerToken)
if err != nil {
mcpError(resources.Logger(), fmt.Errorf("failed to authenticate: %s", err), mcp.INVALID_PARAMS)
exit(exitCodeSetupFailure)
}

// inject customer URL in the context
ctx = config.WithCustomerURL(ctx, info.URL)
// inject bearer token in the context
ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL))
// inject customer URL in the context
ctx = config.WithCustomerURL(ctx, info.URL)
// inject bearer token in the context
ctx = session.WithBearerTokenContext(ctx, session.NewBearerToken(resources.Info.BearerToken, info.URL))
}

mcpServer, err := newMCPServer(resources)
if err != nil {
mcpError(resources, fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR)
mcpError(resources.Logger(), fmt.Errorf("failed to create MCP server: %s", err), mcp.INTERNAL_ERROR)
exit(exitCodeSetupFailure)
}
mcpSTDIOServer := server.NewStdioServer(mcpServer)
if err := mcpSTDIOServer.Listen(ctx, os.Stdin, os.Stdout); err != nil {
mcpError(resources, fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR)
stdinWrapper := newStdinWrapper(resources.Logger(), resources.Info.BearerToken != "", methodsWhitelist)
if err := mcpSTDIOServer.Listen(ctx, stdinWrapper, os.Stdout); err != nil {
mcpError(resources.Logger(), fmt.Errorf("failed to serve: %s", err), mcp.INTERNAL_ERROR)
exit(exitCodeSetupFailure)
}
}
Expand All @@ -73,14 +88,16 @@ func newMCPServer(resources config.Resources) (*server.MCPServer, error) {
return config.NewMCPServer(resources, group), nil
}

func mcpError(resources config.Resources, err error, code int) {
func mcpError(logger *slog.Logger, err error, code int) {
mcpError := mcp.NewJSONRPCError(mcp.NewRequestId("startup"), code, err.Error(), nil)
encoder := json.NewEncoder(os.Stdout)
if err := encoder.Encode(mcpError); err != nil {
resources.Logger().Error("failed to encode error",
encoded, err := json.Marshal(mcpError)
if err != nil {
logger.Error("failed to encode error",
slog.String("error", err.Error()),
)
return
}
fmt.Printf("%s\n", string(encoded))
}

type methodsInput []toolsets.Method
Expand Down Expand Up @@ -110,6 +127,46 @@ func (t *methodsInput) Set(value string) error {
return errs
}

type stdinWrapper struct {
logger *slog.Logger
authenticated bool
methodsWhitelist []string
}

func newStdinWrapper(logger *slog.Logger, authenticated bool, methods []string) stdinWrapper {
return stdinWrapper{
logger: logger,
authenticated: authenticated,
methodsWhitelist: methods,
}
}

func (s stdinWrapper) Read(p []byte) (n int, err error) {
if s.authenticated {
return os.Stdin.Read(p)
}
buffer := make([]byte, len(p))
n, err = os.Stdin.Read(buffer)
if err != nil {
return n, err
}
content := buffer[:n]
if len(content) == 0 {
return n, err
}
var baseMessage struct {
Method string `json:"method"`
}
if err := json.Unmarshal(content, &baseMessage); err != nil {
return 0, errors.New("parse error")
}
if !slices.Contains(s.methodsWhitelist, baseMessage.Method) {
return 0, errors.New("not authenticated")
}
copy(p, buffer)
return n, err
}

type exitCode int

const (
Expand Down