From e66dbded3a57df8be17c678956abc305345305fc Mon Sep 17 00:00:00 2001 From: "D. Can Celasun" Date: Fri, 2 Jun 2017 14:33:32 +0200 Subject: [PATCH] THRIFT-4215 Make transport factories return errors This commit changes the signature of TTransportFactory.GetTransport from GetTransport(trans TTransport) TTransport to GetTransport(trans TTransport) (TTransport, error) so the factory can pass any underlying error to the caller (previously such errors were ignored). This is a backwards incompatible change for anyone implementing custom transports, but it shouldn't effect anyone using the ones in this library. Fixes THRIFT-4215. Fixes THRIFT-4216. --- lib/go/test/tests/protocols_test.go | 5 +- lib/go/thrift/buffered_transport.go | 4 +- lib/go/thrift/compact_protocol_test.go | 48 ++--- lib/go/thrift/framed_transport.go | 11 +- lib/go/thrift/http_client.go | 18 +- lib/go/thrift/iostream_transport.go | 25 ++- lib/go/thrift/lowlevel_benchmarks_test.go | 240 +++++++++++++++++----- lib/go/thrift/memory_buffer.go | 6 +- lib/go/thrift/protocol.go | 6 +- lib/go/thrift/protocol_exception.go | 3 +- lib/go/thrift/protocol_test.go | 54 ++++- lib/go/thrift/rich_transport.go | 1 - lib/go/thrift/rich_transport_test.go | 6 +- lib/go/thrift/serializer_types_test.go | 2 +- lib/go/thrift/simple_server.go | 10 +- lib/go/thrift/socket.go | 3 +- lib/go/thrift/ssl_socket.go | 3 +- lib/go/thrift/transport.go | 3 - lib/go/thrift/transport_factory.go | 6 +- lib/go/thrift/zlib_transport.go | 5 +- tutorial/go/src/client.go | 5 +- 21 files changed, 325 insertions(+), 139 deletions(-) diff --git a/lib/go/test/tests/protocols_test.go b/lib/go/test/tests/protocols_test.go index 422b5c889c6..1580678ebc9 100644 --- a/lib/go/test/tests/protocols_test.go +++ b/lib/go/test/tests/protocols_test.go @@ -42,7 +42,10 @@ func RunSocketTestSuite(t *testing.T, protocolFactory thrift.TProtocolFactory, // client var transport thrift.TTransport = thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT) - transport = transportFactory.GetTransport(transport) + transport, err = transportFactory.GetTransport(transport) + if err != nil { + t.Fatal(err) + } var protocol thrift.TProtocol = protocolFactory.GetProtocol(transport) thriftTestClient := thrifttest.NewThriftTestClientProtocol(transport, protocol, protocol) err = transport.Open() diff --git a/lib/go/thrift/buffered_transport.go b/lib/go/thrift/buffered_transport.go index f73a98b6c7e..b754f925d08 100644 --- a/lib/go/thrift/buffered_transport.go +++ b/lib/go/thrift/buffered_transport.go @@ -32,8 +32,8 @@ type TBufferedTransport struct { tp TTransport } -func (p *TBufferedTransportFactory) GetTransport(trans TTransport) TTransport { - return NewTBufferedTransport(trans, p.size) +func (p *TBufferedTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return NewTBufferedTransport(trans, p.size), nil } func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory { diff --git a/lib/go/thrift/compact_protocol_test.go b/lib/go/thrift/compact_protocol_test.go index 72812f9cb0d..f940b4e15a8 100644 --- a/lib/go/thrift/compact_protocol_test.go +++ b/lib/go/thrift/compact_protocol_test.go @@ -26,28 +26,28 @@ import ( func TestReadWriteCompactProtocol(t *testing.T) { ReadWriteProtocolTest(t, NewTCompactProtocolFactory()) - transports := []TTransport{ - NewTMemoryBuffer(), - NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))), - NewTFramedTransport(NewTMemoryBuffer()), - } - for _, trans := range transports { - p := NewTCompactProtocol(trans); - ReadWriteBool(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteByte(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteI16(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteI32(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteI64(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteDouble(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteString(t, p, trans); - p = NewTCompactProtocol(trans); - ReadWriteBinary(t, p, trans); - trans.Close(); - } + transports := []TTransport{ + NewTMemoryBuffer(), + NewStreamTransportRW(bytes.NewBuffer(make([]byte, 0, 16384))), + NewTFramedTransport(NewTMemoryBuffer()), + } + for _, trans := range transports { + p := NewTCompactProtocol(trans) + ReadWriteBool(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteByte(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI16(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI32(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteI64(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteDouble(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteString(t, p, trans) + p = NewTCompactProtocol(trans) + ReadWriteBinary(t, p, trans) + trans.Close() + } } diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go index d0bae21bc25..4ae14257e54 100644 --- a/lib/go/thrift/framed_transport.go +++ b/lib/go/thrift/framed_transport.go @@ -48,11 +48,15 @@ func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { } func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory { - return &tFramedTransportFactory{factory: factory, maxLength: maxLength} + return &tFramedTransportFactory{factory: factory, maxLength: maxLength} } -func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport { - return NewTFramedTransportMaxLength(p.factory.GetTransport(base), p.maxLength) +func (p *tFramedTransportFactory) GetTransport(base TTransport) (TTransport, error) { + tt, err := p.factory.GetTransport(base) + if err != nil { + return nil, err + } + return NewTFramedTransportMaxLength(tt, p.maxLength), nil } func NewTFramedTransport(transport TTransport) *TFramedTransport { @@ -164,4 +168,3 @@ func (p *TFramedTransport) readFrameHeader() (uint32, error) { func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) { return uint64(p.frameSize) } - diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index 88eb2c12873..b7ca2746767 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -49,24 +49,20 @@ type THttpClientTransportFactory struct { isPost bool } -func (p *THttpClientTransportFactory) GetTransport(trans TTransport) TTransport { +func (p *THttpClientTransportFactory) GetTransport(trans TTransport) (TTransport, error) { if trans != nil { t, ok := trans.(*THttpClient) if ok && t.url != nil { if t.requestBuffer != nil { - t2, _ := NewTHttpPostClientWithOptions(t.url.String(), p.options) - return t2 + return NewTHttpPostClientWithOptions(t.url.String(), p.options) } - t2, _ := NewTHttpClientWithOptions(t.url.String(), p.options) - return t2 + return NewTHttpClientWithOptions(t.url.String(), p.options) } } if p.isPost { - s, _ := NewTHttpPostClientWithOptions(p.url, p.options) - return s + return NewTHttpPostClientWithOptions(p.url, p.options) } - s, _ := NewTHttpClientWithOptions(p.url, p.options) - return s + return NewTHttpClientWithOptions(p.url, p.options) } type THttpClientOptions struct { @@ -103,7 +99,7 @@ func NewTHttpClientWithOptions(urlstr string, options THttpClientOptions) (TTran if client == nil { client = DefaultHttpClient } - httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}} + httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}} return &THttpClient{client: client, response: response, url: parsedURL, header: httpHeader}, nil } @@ -121,7 +117,7 @@ func NewTHttpPostClientWithOptions(urlstr string, options THttpClientOptions) (T if client == nil { client = DefaultHttpClient } - httpHeader := map[string][]string{"Content-Type": []string{"application/x-thrift"}} + httpHeader := map[string][]string{"Content-Type": {"application/x-thrift"}} return &THttpClient{client: client, url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: httpHeader}, nil } diff --git a/lib/go/thrift/iostream_transport.go b/lib/go/thrift/iostream_transport.go index 794872ff126..b18be81c46f 100644 --- a/lib/go/thrift/iostream_transport.go +++ b/lib/go/thrift/iostream_transport.go @@ -38,38 +38,38 @@ type StreamTransportFactory struct { isReadWriter bool } -func (p *StreamTransportFactory) GetTransport(trans TTransport) TTransport { +func (p *StreamTransportFactory) GetTransport(trans TTransport) (TTransport, error) { if trans != nil { t, ok := trans.(*StreamTransport) if ok { if t.isReadWriter { - return NewStreamTransportRW(t.Reader.(io.ReadWriter)) + return NewStreamTransportRW(t.Reader.(io.ReadWriter)), nil } if t.Reader != nil && t.Writer != nil { - return NewStreamTransport(t.Reader, t.Writer) + return NewStreamTransport(t.Reader, t.Writer), nil } if t.Reader != nil && t.Writer == nil { - return NewStreamTransportR(t.Reader) + return NewStreamTransportR(t.Reader), nil } if t.Reader == nil && t.Writer != nil { - return NewStreamTransportW(t.Writer) + return NewStreamTransportW(t.Writer), nil } - return &StreamTransport{} + return &StreamTransport{}, nil } } if p.isReadWriter { - return NewStreamTransportRW(p.Reader.(io.ReadWriter)) + return NewStreamTransportRW(p.Reader.(io.ReadWriter)), nil } if p.Reader != nil && p.Writer != nil { - return NewStreamTransport(p.Reader, p.Writer) + return NewStreamTransport(p.Reader, p.Writer), nil } if p.Reader != nil && p.Writer == nil { - return NewStreamTransportR(p.Reader) + return NewStreamTransportR(p.Reader), nil } if p.Reader == nil && p.Writer != nil { - return NewStreamTransportW(p.Writer) + return NewStreamTransportW(p.Writer), nil } - return &StreamTransport{} + return &StreamTransport{}, nil } func NewStreamTransportFactory(reader io.Reader, writer io.Writer, isReadWriter bool) *StreamTransportFactory { @@ -209,6 +209,5 @@ func (p *StreamTransport) WriteString(s string) (n int, err error) { func (p *StreamTransport) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/lib/go/thrift/lowlevel_benchmarks_test.go b/lib/go/thrift/lowlevel_benchmarks_test.go index a5094ae97ce..e1736557b14 100644 --- a/lib/go/thrift/lowlevel_benchmarks_test.go +++ b/lib/go/thrift/lowlevel_benchmarks_test.go @@ -36,7 +36,10 @@ var tfv = []TTransportFactory{ } func BenchmarkBinaryBool_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -44,7 +47,10 @@ func BenchmarkBinaryBool_0(b *testing.B) { } func BenchmarkBinaryByte_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -52,7 +58,10 @@ func BenchmarkBinaryByte_0(b *testing.B) { } func BenchmarkBinaryI16_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -60,35 +69,50 @@ func BenchmarkBinaryI16_0(b *testing.B) { } func BenchmarkBinaryI32_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) @@ -96,7 +120,10 @@ func BenchmarkBinaryBinary_0(b *testing.B) { } func BenchmarkBinaryBool_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -104,7 +131,10 @@ func BenchmarkBinaryBool_1(b *testing.B) { } func BenchmarkBinaryByte_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -112,7 +142,10 @@ func BenchmarkBinaryByte_1(b *testing.B) { } func BenchmarkBinaryI16_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -120,35 +153,50 @@ func BenchmarkBinaryI16_1(b *testing.B) { } func BenchmarkBinaryI32_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) @@ -156,7 +204,10 @@ func BenchmarkBinaryBinary_1(b *testing.B) { } func BenchmarkBinaryBool_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -164,7 +215,10 @@ func BenchmarkBinaryBool_2(b *testing.B) { } func BenchmarkBinaryByte_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -172,7 +226,10 @@ func BenchmarkBinaryByte_2(b *testing.B) { } func BenchmarkBinaryI16_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -180,35 +237,50 @@ func BenchmarkBinaryI16_2(b *testing.B) { } func BenchmarkBinaryI32_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkBinaryI64_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkBinaryDouble_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkBinaryString_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkBinaryBinary_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := binaryProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) @@ -216,7 +288,10 @@ func BenchmarkBinaryBinary_2(b *testing.B) { } func BenchmarkCompactBool_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -224,7 +299,10 @@ func BenchmarkCompactBool_0(b *testing.B) { } func BenchmarkCompactByte_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -232,7 +310,10 @@ func BenchmarkCompactByte_0(b *testing.B) { } func BenchmarkCompactI16_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -240,35 +321,50 @@ func BenchmarkCompactI16_0(b *testing.B) { } func BenchmarkCompactI32_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary0(b *testing.B) { - trans := tfv[0].GetTransport(nil) + trans, err := tfv[0].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) @@ -276,7 +372,10 @@ func BenchmarkCompactBinary0(b *testing.B) { } func BenchmarkCompactBool_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -284,7 +383,10 @@ func BenchmarkCompactBool_1(b *testing.B) { } func BenchmarkCompactByte_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -292,7 +394,10 @@ func BenchmarkCompactByte_1(b *testing.B) { } func BenchmarkCompactI16_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -300,35 +405,50 @@ func BenchmarkCompactI16_1(b *testing.B) { } func BenchmarkCompactI32_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary1(b *testing.B) { - trans := tfv[1].GetTransport(nil) + trans, err := tfv[1].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) @@ -336,7 +456,10 @@ func BenchmarkCompactBinary1(b *testing.B) { } func BenchmarkCompactBool_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBool(b, p, trans) @@ -344,7 +467,10 @@ func BenchmarkCompactBool_2(b *testing.B) { } func BenchmarkCompactByte_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteByte(b, p, trans) @@ -352,7 +478,10 @@ func BenchmarkCompactByte_2(b *testing.B) { } func BenchmarkCompactI16_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI16(b, p, trans) @@ -360,35 +489,50 @@ func BenchmarkCompactI16_2(b *testing.B) { } func BenchmarkCompactI32_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI32(b, p, trans) } } func BenchmarkCompactI64_2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteI64(b, p, trans) } } func BenchmarkCompactDouble2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteDouble(b, p, trans) } } func BenchmarkCompactString2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteString(b, p, trans) } } func BenchmarkCompactBinary2(b *testing.B) { - trans := tfv[2].GetTransport(nil) + trans, err := tfv[2].GetTransport(nil) + if err != nil { + b.Fatal(err) + } p := compactProtoF.GetProtocol(trans) for i := 0; i < b.N; i++ { ReadWriteBinary(b, p, trans) diff --git a/lib/go/thrift/memory_buffer.go b/lib/go/thrift/memory_buffer.go index b62fd56f063..97a4edfa5db 100644 --- a/lib/go/thrift/memory_buffer.go +++ b/lib/go/thrift/memory_buffer.go @@ -33,14 +33,14 @@ type TMemoryBufferTransportFactory struct { size int } -func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) TTransport { +func (p *TMemoryBufferTransportFactory) GetTransport(trans TTransport) (TTransport, error) { if trans != nil { t, ok := trans.(*TMemoryBuffer) if ok && t.size > 0 { - return NewTMemoryBufferLen(t.size) + return NewTMemoryBufferLen(t.size), nil } } - return NewTMemoryBufferLen(p.size) + return NewTMemoryBufferLen(p.size), nil } func NewTMemoryBufferTransportFactory(size int) *TMemoryBufferTransportFactory { diff --git a/lib/go/thrift/protocol.go b/lib/go/thrift/protocol.go index 32bb7b31745..25e6d24b904 100644 --- a/lib/go/thrift/protocol.go +++ b/lib/go/thrift/protocol.go @@ -89,9 +89,9 @@ func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { // Skips over the next data element from the provided input TProtocol object. func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { - - if maxDepth <= 0 { - return NewTProtocolExceptionWithType( DEPTH_LIMIT, errors.New("Depth limit exceeded")) + + if maxDepth <= 0 { + return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded")) } switch fieldType { diff --git a/lib/go/thrift/protocol_exception.go b/lib/go/thrift/protocol_exception.go index 6e357ee890d..29ab75d9215 100644 --- a/lib/go/thrift/protocol_exception.go +++ b/lib/go/thrift/protocol_exception.go @@ -60,7 +60,7 @@ func NewTProtocolException(err error) TProtocolException { if err == nil { return nil } - if e,ok := err.(TProtocolException); ok { + if e, ok := err.(TProtocolException); ok { return e } if _, ok := err.(base64.CorruptInputError); ok { @@ -75,4 +75,3 @@ func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { } return &tProtocolException{errType, err.Error()} } - diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go index 613eae6bc88..2573312d1d3 100644 --- a/lib/go/thrift/protocol_test.go +++ b/lib/go/thrift/protocol_test.go @@ -123,55 +123,91 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { NewTHttpPostClientTransportFactory("http://" + addr.String()), } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteBool(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteByte(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteI16(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteI32(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteI64(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteDouble(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteString(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteBinary(t, p, trans) trans.Close() } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } p := protocolFactory.GetProtocol(trans) ReadWriteI64(t, p, trans) ReadWriteDouble(t, p, trans) diff --git a/lib/go/thrift/rich_transport.go b/lib/go/thrift/rich_transport.go index 8e296a99b5f..4025bebeaa4 100644 --- a/lib/go/thrift/rich_transport.go +++ b/lib/go/thrift/rich_transport.go @@ -66,4 +66,3 @@ func writeByte(w io.Writer, c byte) error { _, err := w.Write(v[0:1]) return err } - diff --git a/lib/go/thrift/rich_transport_test.go b/lib/go/thrift/rich_transport_test.go index 41513f812b7..25c3fd5aa82 100644 --- a/lib/go/thrift/rich_transport_test.go +++ b/lib/go/thrift/rich_transport_test.go @@ -37,7 +37,11 @@ func TestEnsureTransportsAreRich(t *testing.T) { NewTHttpPostClientTransportFactory("http://127.0.0.1"), } for _, tf := range transports { - trans := tf.GetTransport(nil) + trans, err := tf.GetTransport(nil) + if err != nil { + t.Error(err) + continue + } _, ok := trans.(TRichTransport) if !ok { t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans)) diff --git a/lib/go/thrift/serializer_types_test.go b/lib/go/thrift/serializer_types_test.go index 38ab8d6d600..c8e3b3be466 100644 --- a/lib/go/thrift/serializer_types_test.go +++ b/lib/go/thrift/serializer_types_test.go @@ -598,7 +598,7 @@ func (p *MyTestStruct) writeField11(oprot TProtocol) (err error) { if err := oprot.WriteSetBegin(STRING, len(p.StringSet)); err != nil { return PrependError("error writing set begin: ", err) } - for v, _ := range p.StringSet { + for v := range p.StringSet { if err := oprot.WriteString(string(v)); err != nil { return PrependError(fmt.Sprintf("%T. (0) field write error: ", p), err) } diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 5c848f22636..be031e8db28 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -171,8 +171,14 @@ func (p *TSimpleServer) processRequests(client TTransport) error { defer p.Done() processor := p.processorFactory.GetProcessor(client) - inputTransport := p.inputTransportFactory.GetTransport(client) - outputTransport := p.outputTransportFactory.GetTransport(client) + inputTransport, err := p.inputTransportFactory.GetTransport(client) + if err != nil { + return err + } + outputTransport, err := p.outputTransportFactory.GetTransport(client) + if err != nil { + return err + } inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport) outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport) defer func() { diff --git a/lib/go/thrift/socket.go b/lib/go/thrift/socket.go index 82e28b4b182..383b1fe3e97 100644 --- a/lib/go/thrift/socket.go +++ b/lib/go/thrift/socket.go @@ -161,6 +161,5 @@ func (p *TSocket) Interrupt() error { func (p *TSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go index 2395986aab5..8272703cf83 100644 --- a/lib/go/thrift/ssl_socket.go +++ b/lib/go/thrift/ssl_socket.go @@ -169,6 +169,5 @@ func (p *TSSLSocket) Interrupt() error { func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/lib/go/thrift/transport.go b/lib/go/thrift/transport.go index 453899651fc..70a85a84895 100644 --- a/lib/go/thrift/transport.go +++ b/lib/go/thrift/transport.go @@ -34,7 +34,6 @@ type ReadSizeProvider interface { RemainingBytes() (num_bytes uint64) } - // Encapsulates the I/O layer type TTransport interface { io.ReadWriteCloser @@ -52,7 +51,6 @@ type stringWriter interface { WriteString(s string) (n int, err error) } - // This is "enchanced" transport with extra capabilities. You need to use one of these // to construct protocol. // Notably, TSocket does not implement this interface, and it is always a mistake to use @@ -65,4 +63,3 @@ type TRichTransport interface { Flusher ReadSizeProvider } - diff --git a/lib/go/thrift/transport_factory.go b/lib/go/thrift/transport_factory.go index 533d1b43753..c805807940a 100644 --- a/lib/go/thrift/transport_factory.go +++ b/lib/go/thrift/transport_factory.go @@ -24,14 +24,14 @@ package thrift // a ServerTransport and then may want to mutate them (i.e. create // a BufferedTransport from the underlying base transport) type TTransportFactory interface { - GetTransport(trans TTransport) TTransport + GetTransport(trans TTransport) (TTransport, error) } type tTransportFactory struct{} // Return a wrapped instance of the base Transport. -func (p *tTransportFactory) GetTransport(trans TTransport) TTransport { - return trans +func (p *tTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return trans, nil } func NewTTransportFactory() TTransportFactory { diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go index e47455fe163..6f477ca1df0 100644 --- a/lib/go/thrift/zlib_transport.go +++ b/lib/go/thrift/zlib_transport.go @@ -38,9 +38,8 @@ type TZlibTransport struct { } // GetTransport constructs a new instance of NewTZlibTransport -func (p *TZlibTransportFactory) GetTransport(trans TTransport) TTransport { - t, _ := NewTZlibTransport(trans, p.level) - return t +func (p *TZlibTransportFactory) GetTransport(trans TTransport) (TTransport, error) { + return NewTZlibTransport(trans, p.level) } // NewTZlibTransportFactory constructs a new instance of NewTZlibTransportFactory diff --git a/tutorial/go/src/client.go b/tutorial/go/src/client.go index a497d7f8b19..9106ac9487c 100644 --- a/tutorial/go/src/client.go +++ b/tutorial/go/src/client.go @@ -90,7 +90,10 @@ func runClient(transportFactory thrift.TTransportFactory, protocolFactory thrift fmt.Println("Error opening socket:", err) return err } - transport = transportFactory.GetTransport(transport) + transport, err = transportFactory.GetTransport(transport) + if err != nil { + return err + } defer transport.Close() if err := transport.Open(); err != nil { return err