diff --git a/interceptor/opencensus/opencensus.go b/interceptor/opencensus/opencensus.go index 3c473dea..b04d719a 100644 --- a/interceptor/opencensus/opencensus.go +++ b/interceptor/opencensus/opencensus.go @@ -16,7 +16,6 @@ package ocinterceptor import ( "errors" - "fmt" "time" "google.golang.org/api/support/bundler" @@ -58,28 +57,19 @@ type spansAndNode struct { node *commonpb.Node } +var errTraceExportProtocolViolation = errors.New("protocol violation: Export's first message must have a Node") + // Export is the gRPC method that receives streamed traces from // OpenCensus-traceproto compatible libraries/applications. func (oci *OCInterceptor) Export(tes agenttracepb.TraceService_ExportServer) error { - // Firstly we MUST receive the node identifier to initiate - // the service and start accepting exported spans. - const maxTraceInitRetries = 15 // Arbitrary value - - var initiatingNode *commonpb.Node - for i := 0; i < maxTraceInitRetries; i++ { - recv, err := tes.Recv() - if err != nil { - return err - } - - if nd := recv.Node; nd != nil { - initiatingNode = nd - break - } + // The first message MUST have a non-nil Node. + firstMessage, err := tes.Recv() + if err != nil { + return err } - if initiatingNode == nil { - return fmt.Errorf("failed to receive a non-nil initiating Node even after %d retries", maxTraceInitRetries) + if firstMessage.Node == nil { + return errTraceExportProtocolViolation } // Now that we've got the node, we can start to receive streamed up spans. @@ -98,7 +88,14 @@ func (oci *OCInterceptor) Export(tes agenttracepb.TraceService_ExportServer) err traceBundler.DelayThreshold = spanBufferPeriod traceBundler.BundleCountThreshold = spanBufferCount - var lastNonNilNode *commonpb.Node = initiatingNode + var lastNonNilNode *commonpb.Node = firstMessage.Node + + // If the firstMessage has spans, we MUST add them + // See https://github.com/census-instrumentation/opencensus-service/issues/51 + if len(firstMessage.Spans) > 0 { + firstPayload := &spansAndNode{node: lastNonNilNode, spans: firstMessage.Spans} + traceBundler.Add(firstPayload, len(firstPayload.spans)) + } for { recv, err := tes.Recv() diff --git a/interceptor/opencensus/opencensus_test.go b/interceptor/opencensus/opencensus_test.go index 893cc334..e0356676 100644 --- a/interceptor/opencensus/opencensus_test.go +++ b/interceptor/opencensus/opencensus_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "reflect" "strconv" @@ -170,20 +171,13 @@ func TestExportMultiplexing(t *testing.T) { _, port, doneFn := ocInterceptorOnGRPCServer(t, spanSink, ocinterceptor.WithSpanBufferPeriod(90*time.Millisecond)) defer doneFn() - addr := fmt.Sprintf(":%d", port) - cc, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithBlock()) - if err != nil { - t.Fatalf("Failed to create the gRPC client connection: %v", err) - } - defer cc.Close() - - svc := agenttracepb.NewTraceServiceClient(cc) - traceClient, err := svc.Export(context.Background()) + traceClient, traceClientDoneFn, err := makeTraceServiceClient(port) if err != nil { - t.Fatalf("Failed to create the traceClient: %v", err) + t.Fatalf("Failed to create the gRPC TraceService_ExportClient: %v", err) } + defer traceClientDoneFn() - // Step 1) The intiation + // Step 1) The initiation. initiatingNode := &commonpb.Node{ Identifier: &commonpb.ProcessIdentifier{ Pid: 1, @@ -288,6 +282,152 @@ func TestExportMultiplexing(t *testing.T) { } } +// The first message without a Node MUST be rejected and teardown the connection. +// See https://github.com/census-instrumentation/opencensus-service/issues/53 +func TestExportProtocolViolations_nodelessFirstMessage(t *testing.T) { + spanSink := newSpanAppender() + + _, port, doneFn := ocInterceptorOnGRPCServer(t, spanSink, ocinterceptor.WithSpanBufferPeriod(90*time.Millisecond)) + defer doneFn() + + traceClient, traceClientDoneFn, err := makeTraceServiceClient(port) + if err != nil { + t.Fatalf("Failed to create the gRPC TraceService_ExportClient: %v", err) + } + defer traceClientDoneFn() + + // Send a Nodeless first message + if err := traceClient.Send(&agenttracepb.ExportTraceServiceRequest{Node: nil}); err != nil { + t.Fatalf("Unexpectedly failed to send the first message: %v", err) + } + + longDuration := 2 * time.Second + testDone := make(chan bool, 1) + go func() { + // Our insurance policy to ensure that this test doesn't hang + // forever and should quickly report if/when we regress. + select { + case <-testDone: + t.Log("Test ended early enough") + case <-time.After(longDuration): + traceClientDoneFn() + t.Errorf("Test took too long (%s) and is likely still hanging so this is a regression", longDuration) + } + }() + + // Now the response should return an error and should have been torn down + // regardless of the number of times after invocation below, or any attempt + // to send the proper/corrective data should be rejected. + for i := 0; i < 10; i++ { + recv, err := traceClient.Recv() + if recv != nil { + t.Errorf("Iteration #%d: Unexpectedly got back a response: %#v", i, recv) + } + if err == nil { + t.Errorf("Iteration #%d: Unexpectedly got back a nil error", i) + continue + } + + wantSubStr := "protocol violation: Export's first message must have a Node" + if g := err.Error(); !strings.Contains(g, wantSubStr) { + t.Errorf("Iteration #%d: Got error:\n\t%s\nWant substring:\n\t%s\n", i, g, wantSubStr) + } + + // The connection should be invalid at this point and + // no attempt to send corrections should succeeed. + n1 := &commonpb.Node{ + Identifier: &commonpb.ProcessIdentifier{Pid: 9489, HostName: "nodejs-host"}, + LibraryInfo: &commonpb.LibraryInfo{Language: commonpb.LibraryInfo_NODE_JS}, + } + if err = traceClient.Send(&agenttracepb.ExportTraceServiceRequest{Node: n1}); err == nil { + t.Errorf("Iteration #%d: Unexpectedly succeeded in sending a message upstream. Connection must be in terminal state", i) + } else if g, w := err, io.EOF; g != w { + t.Errorf("Iteration #%d:\nGot error %q\nWant error %q", i, g, w) + } + } + + close(testDone) +} + +// If the first message is valid (has a non-nil Node) and has spans, those +// spans should be received and NEVER discarded. +// See https://github.com/census-instrumentation/opencensus-service/issues/51 +func TestExportProtocolConformation_spansInFirstMessage(t *testing.T) { + spanSink := newSpanAppender() + + _, port, doneFn := ocInterceptorOnGRPCServer(t, spanSink, ocinterceptor.WithSpanBufferPeriod(70*time.Millisecond)) + defer doneFn() + + traceClient, traceClientDoneFn, err := makeTraceServiceClient(port) + if err != nil { + t.Fatalf("Failed to create the gRPC TraceService_ExportClient: %v", err) + } + defer traceClientDoneFn() + + sLi := []*tracepb.Span{{TraceId: []byte("1234567890abcde")}, {TraceId: []byte("XXXXXXXXXXabcde")}} + ni := &commonpb.Node{ + Identifier: &commonpb.ProcessIdentifier{Pid: 1}, + LibraryInfo: &commonpb.LibraryInfo{Language: commonpb.LibraryInfo_JAVA}, + } + if err := traceClient.Send(&agenttracepb.ExportTraceServiceRequest{Node: ni, Spans: sLi}); err != nil { + t.Fatalf("Failed to send the first message: %v", err) + } + + // Give it time to be sent over the wire, then exported. + <-time.After(100 * time.Millisecond) + + // Examination time! + resultsMapping := make(map[string][]*tracepb.Span) + spanSink.forEachEntry(func(node *commonpb.Node, spans []*tracepb.Span) { + resultsMapping[nodeToKey(node)] = spans + }) + + if g, w := len(resultsMapping), 1; g != w { + t.Errorf("Results mapping: Got len(keys) %d Want %d", g, w) + } + + // Check for the keys + wantLengths := map[string]int{ + nodeToKey(ni): 2, + } + for key, wantLength := range wantLengths { + gotLength := len(resultsMapping[key]) + if gotLength != wantLength { + t.Errorf("Exported spans:: Key: %s\nGot length %d\nWant length %d", key, gotLength, wantLength) + } + } + + // And finally ensure that the protos' serializations are equivalent to the expected + wantContents := map[string][]*tracepb.Span{ + nodeToKey(ni): sLi, + } + + gotBlob, _ := json.Marshal(resultsMapping) + wantBlob, _ := json.Marshal(wantContents) + if !bytes.Equal(gotBlob, wantBlob) { + t.Errorf("Unequal serialization results\nGot:\n\t%s\nWant:\n\t%s\n", gotBlob, wantBlob) + } +} + +// Helper functions from here on below +func makeTraceServiceClient(port int) (agenttracepb.TraceService_ExportClient, func(), error) { + addr := fmt.Sprintf(":%d", port) + cc, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + return nil, nil, err + } + + svc := agenttracepb.NewTraceServiceClient(cc) + traceClient, err := svc.Export(context.Background()) + if err != nil { + _ = cc.Close() + return nil, nil, err + } + + doneFn := func() { _ = cc.Close() } + return traceClient, doneFn, nil +} + func nodeToKey(n *commonpb.Node) string { blob, _ := proto.Marshal(n) return string(blob)