diff --git a/fxmcpserver/fxmcpservertest/server_test.go b/fxmcpserver/fxmcpservertest/server_test.go index 1deaacaf..79b301c5 100644 --- a/fxmcpserver/fxmcpservertest/server_test.go +++ b/fxmcpserver/fxmcpservertest/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/ankorstore/yokai/log/logtest" "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/trace" ) @@ -27,11 +28,13 @@ func TestMCPSSETestServer(t *testing.T) { tp := trace.NewTracerProvider() + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + lb := logtest.NewDefaultTestLogBuffer() lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) assert.NoError(t, err) - hdl := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, lg) + hdl := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, tmp, lg) mcpSrv := server.NewMCPServer("test-server", "1.0.0") diff --git a/fxmcpserver/module.go b/fxmcpserver/module.go index c89e9ce7..ecb5818d 100644 --- a/fxmcpserver/module.go +++ b/fxmcpserver/module.go @@ -12,6 +12,7 @@ import ( "github.com/ankorstore/yokai/log" "github.com/mark3labs/mcp-go/server" "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" "go.uber.org/fx" ) @@ -135,9 +136,15 @@ type ProvideDefaultMCPSSEContextHandlerParam struct { // ProvideDefaultMCPSSEServerContextHandler provides the default sse.MCPSSEServerContextHandler instance. func ProvideDefaultMCPSSEServerContextHandler(p ProvideDefaultMCPSSEContextHandlerParam) *sse.DefaultMCPSSEServerContextHandler { + textMapPropagator := propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ) + return sse.NewDefaultMCPSSEServerContextHandler( p.Generator, p.TracerProvider, + textMapPropagator, p.Logger, p.MCPSSEServerContextHooks..., ) diff --git a/fxmcpserver/server/sse/context.go b/fxmcpserver/server/sse/context.go index 42351291..c98d184d 100644 --- a/fxmcpserver/server/sse/context.go +++ b/fxmcpserver/server/sse/context.go @@ -11,6 +11,7 @@ import ( "github.com/ankorstore/yokai/trace" "github.com/mark3labs/mcp-go/server" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" ot "go.opentelemetry.io/otel/trace" ) @@ -28,24 +29,27 @@ type MCPSSEServerContextHandler interface { // DefaultMCPSSEServerContextHandler is the default MCPSSEServerContextHandler implementation. type DefaultMCPSSEServerContextHandler struct { - generator uuid.UuidGenerator - tracerProvider ot.TracerProvider - logger *log.Logger - contextHooks []MCPSSEServerContextHook + generator uuid.UuidGenerator + tracerProvider ot.TracerProvider + textMapPropagator propagation.TextMapPropagator + logger *log.Logger + contextHooks []MCPSSEServerContextHook } // NewDefaultMCPSSEServerContextHandler returns a new DefaultMCPSSEServerContextHandler instance. func NewDefaultMCPSSEServerContextHandler( generator uuid.UuidGenerator, tracerProvider ot.TracerProvider, + textMapPropagator propagation.TextMapPropagator, logger *log.Logger, contextHooks ...MCPSSEServerContextHook, ) *DefaultMCPSSEServerContextHandler { return &DefaultMCPSSEServerContextHandler{ - generator: generator, - tracerProvider: tracerProvider, - logger: logger, - contextHooks: contextHooks, + generator: generator, + tracerProvider: tracerProvider, + textMapPropagator: textMapPropagator, + logger: logger, + contextHooks: contextHooks, } } @@ -71,6 +75,8 @@ func (h *DefaultMCPSSEServerContextHandler) Handle() server.SSEContextFunc { ctx = fsc.WithRequestID(ctx, rID) // tracer propagation + ctx = h.textMapPropagator.Extract(ctx, propagation.HeaderCarrier(req.Header)) + ctx = trace.WithContext(ctx, h.tracerProvider) ctx, span := trace.CtxTracer(ctx).Start( diff --git a/fxmcpserver/server/sse/context_test.go b/fxmcpserver/server/sse/context_test.go index d19a550a..34bada89 100644 --- a/fxmcpserver/server/sse/context_test.go +++ b/fxmcpserver/server/sse/context_test.go @@ -13,6 +13,7 @@ import ( "github.com/ankorstore/yokai/log/logtest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/sdk/trace" ) @@ -36,11 +37,13 @@ func TestDefaultMCPSSEServerContextHandler_Handle(t *testing.T) { tp := trace.NewTracerProvider() + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + lb := logtest.NewDefaultTestLogBuffer() lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) assert.NoError(t, err) - handler := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, lg) + handler := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, tmp, lg) req := httptest.NewRequest(http.MethodGet, "/sse", nil) @@ -91,13 +94,15 @@ func TestDefaultMCPSSEServerContextHandler_Handle(t *testing.T) { tp := trace.NewTracerProvider() + tmp := propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}) + lb := logtest.NewDefaultTestLogBuffer() lg, err := log.NewDefaultLoggerFactory().Create(log.WithOutputWriter(lb)) assert.NoError(t, err) hk := hook.NewSimpleMCPSSEServerContextHook() - handler := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, lg, hk) + handler := sse.NewDefaultMCPSSEServerContextHandler(gm, tp, tmp, lg, hk) req := httptest.NewRequest(http.MethodGet, "/sse?sessionId=test-session-id", nil) req.Header.Set("X-Request-Id", "test-request-id")