diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index 1924a1ae268..19c63a98512 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -197,6 +197,14 @@ func (p *THttpClient) Flush(ctx context.Context) error { // Close any previous response body to avoid leaking connections. p.closeResponse() + // Request might not have been fully read by http client. + // Reset so we don't send the remains on next call. + defer func() { + if p.requestBuffer != nil { + p.requestBuffer.Reset() + } + }() + req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) if err != nil { return NewTTransportExceptionFromError(err) diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go index a7977a38595..eba366815ca 100644 --- a/lib/go/thrift/http_client_test.go +++ b/lib/go/thrift/http_client_test.go @@ -20,6 +20,8 @@ package thrift import ( + "bytes" + "context" "net/http" "testing" ) @@ -32,14 +34,14 @@ func TestHttpClient(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportTest(t, trans, trans) t.Run("nilBuffer", func(t *testing.T) { _ = trans.Close() if _, err = trans.Write([]byte{1, 2, 3, 4}); err == nil { - t.Fatalf("writing to a closed transport did not result in an error") + t.Fatal("writing to a closed transport did not result in an error") } }) } @@ -52,7 +54,7 @@ func TestHttpClientHeaders(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) } @@ -72,7 +74,7 @@ func TestHttpCustomClient(t *testing.T) { }) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) @@ -94,7 +96,7 @@ func TestHttpCustomClientPackageScope(t *testing.T) { trans, err := NewTHttpPostClient("http://" + addr.String()) if err != nil { l.Close() - t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) } TransportHeaderTest(t, trans, trans) @@ -103,6 +105,54 @@ func TestHttpCustomClientPackageScope(t *testing.T) { } } +func TestHTTPClientFlushesRequestBufferOnErrors(t *testing.T) { + var ( + write1 = []byte("write 1") + write2 = []byte("write 2") + ) + + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + t.Fatalf("Unable to connect to %s: %v", addr.String(), err) + } + defer trans.Close() + + _, err = trans.Write(write1) + if err != nil { + t.Fatalf("Failed to write to transport: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err = trans.Flush(ctx) + if err == nil { + t.Fatal("Expected flush error") + } + + _, err = trans.Write(write2) + if err != nil { + t.Fatalf("Failed to write to transport: %v", err) + } + err = trans.Flush(context.Background()) + if err != nil { + t.Fatalf("Failed to flush: %v", err) + } + + data := make([]byte, 1024) + n, err := trans.Read(data) + if err != nil { + t.Fatalf("Failed to read: %v", err) + } + + data = data[:n] + if !bytes.Equal(data, write2) { + t.Fatalf("Received unexpected data: %q, expected: %q", data, write2) + } +} + type customHttpTransport struct { hit bool }