From f1740ed8faf767a486d2ab38be2401dcb2288cee Mon Sep 17 00:00:00 2001 From: Tyler Treat Date: Fri, 11 Mar 2016 15:15:17 -0600 Subject: [PATCH] THRIFT-3735 JSON protocol left in incorrect state when an exception is thrown during read or write operations The JSON context stack may be left in an incorrect state when an exception is thrown during read or write operations. This leads to further errors while writing/reading the NEXT message, because incorrect characters may be written or expected. This is related to THRIFT-1473, but there is an additional issue in that the bufio.Writer needs to be reset if an error occurs. --- lib/go/thrift/json_protocol.go | 7 +++-- lib/go/thrift/simple_json_protocol.go | 37 +++++++++++++++++++-------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go index ec549b764ee..442fa9144d4 100644 --- a/lib/go/thrift/json_protocol.go +++ b/lib/go/thrift/json_protocol.go @@ -60,6 +60,7 @@ func NewTJSONProtocolFactory() *TJSONProtocolFactory { } func (p *TJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } @@ -202,17 +203,18 @@ func (p *TJSONProtocol) WriteBinary(v []byte) error { if e := p.OutputPreValue(); e != nil { return e } - if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } writer := base64.NewEncoder(base64.StdEncoding, p.writer) if _, e := writer.Write(v); e != nil { + p.writer.Reset(p.trans) // THRIFT-3735 return NewTProtocolException(e) } if e := writer.Close(); e != nil { return NewTProtocolException(e) } - if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() @@ -220,6 +222,7 @@ func (p *TJSONProtocol) WriteBinary(v []byte) error { // Reading methods. func (p *TJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go index e739be9e870..73533223182 100644 --- a/lib/go/thrift/simple_json_protocol.go +++ b/lib/go/thrift/simple_json_protocol.go @@ -156,6 +156,7 @@ func mismatch(expected, actual string) error { } func (p *TSimpleJSONProtocol) WriteMessageBegin(name string, typeId TMessageType, seqId int32) error { + p.resetContextStack() // THRIFT-3735 if e := p.OutputListBegin(); e != nil { return e } @@ -269,17 +270,18 @@ func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { if e := p.OutputPreValue(); e != nil { return e } - if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } writer := base64.NewEncoder(base64.StdEncoding, p.writer) if _, e := writer.Write(v); e != nil { + p.writer.Reset(p.trans) // THRIFT-3735 return NewTProtocolException(e) } if e := writer.Close(); e != nil { return NewTProtocolException(e) } - if _, e := p.writer.Write(JSON_QUOTE_BYTES); e != nil { + if _, e := p.write(JSON_QUOTE_BYTES); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() @@ -287,6 +289,7 @@ func (p *TSimpleJSONProtocol) WriteBinary(v []byte) error { // Reading methods. func (p *TSimpleJSONProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqId int32, err error) { + p.resetContextStack() // THRIFT-3735 if isNull, err := p.ParseListBegin(); isNull || err != nil { return name, typeId, seqId, err } @@ -565,12 +568,12 @@ func (p *TSimpleJSONProtocol) OutputPreValue() error { cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1]) switch cxt { case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY: - if _, e := p.writer.Write(JSON_COMMA); e != nil { + if _, e := p.write(JSON_COMMA); e != nil { return NewTProtocolException(e) } break case _CONTEXT_IN_OBJECT_NEXT_VALUE: - if _, e := p.writer.Write(JSON_COLON); e != nil { + if _, e := p.write(JSON_COLON); e != nil { return NewTProtocolException(e) } break @@ -626,7 +629,7 @@ func (p *TSimpleJSONProtocol) OutputNull() error { if e := p.OutputPreValue(); e != nil { return e } - if _, e := p.writer.Write(JSON_NULL); e != nil { + if _, e := p.write(JSON_NULL); e != nil { return NewTProtocolException(e) } return p.OutputPostValue() @@ -684,7 +687,7 @@ func (p *TSimpleJSONProtocol) OutputString(s string) error { } func (p *TSimpleJSONProtocol) OutputStringData(s string) error { - _, e := p.writer.Write([]byte(s)) + _, e := p.write([]byte(s)) return NewTProtocolException(e) } @@ -692,7 +695,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error { if e := p.OutputPreValue(); e != nil { return e } - if _, e := p.writer.Write(JSON_LBRACE); e != nil { + if _, e := p.write(JSON_LBRACE); e != nil { return NewTProtocolException(e) } p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST)) @@ -700,7 +703,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error { } func (p *TSimpleJSONProtocol) OutputObjectEnd() error { - if _, e := p.writer.Write(JSON_RBRACE); e != nil { + if _, e := p.write(JSON_RBRACE); e != nil { return NewTProtocolException(e) } p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] @@ -714,7 +717,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error { if e := p.OutputPreValue(); e != nil { return e } - if _, e := p.writer.Write(JSON_LBRACKET); e != nil { + if _, e := p.write(JSON_LBRACKET); e != nil { return NewTProtocolException(e) } p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST)) @@ -722,7 +725,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error { } func (p *TSimpleJSONProtocol) OutputListEnd() error { - if _, e := p.writer.Write(JSON_RBRACKET); e != nil { + if _, e := p.write(JSON_RBRACKET); e != nil { return NewTProtocolException(e) } p.dumpContext = p.dumpContext[:len(p.dumpContext)-1] @@ -1318,3 +1321,17 @@ func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool { } return true } + +// Reset the context stack to its initial state. +func (p *TSimpleJSONProtocol) resetContextStack() { + p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)} + p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)} +} + +func (p *TSimpleJSONProtocol) write(b []byte) (int, error) { + n, err := p.writer.Write(b) + if err != nil { + p.writer.Reset(p.trans) // THRIFT-3735 + } + return n, err +}