Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions node/rpcstack.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"

Expand Down Expand Up @@ -105,17 +106,94 @@ var gzPool = sync.Pool{
}

type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
resp http.ResponseWriter

gz *gzip.Writer
contentLength uint64 // total length of the uncompressed response
written uint64 // amount of written bytes from the uncompressed response
hasLength bool // true if uncompressed response had Content-Length
inited bool // true after init was called for the first time
}

// init runs just before response headers are written. Among other things, this function
// also decides whether compression will be applied at all.
func (w *gzipResponseWriter) init() {
if w.inited {
return
}
w.inited = true

hdr := w.resp.Header()
length := hdr.Get("content-length")
if len(length) > 0 {
if n, err := strconv.ParseUint(length, 10, 64); err != nil {
w.hasLength = true
w.contentLength = n
}
}

// Setting Transfer-Encoding to "identity" explicitly disables compression. net/http
// also recognizes this header value and uses it to disable "chunked" transfer
// encoding, trimming the header from the response. This means downstream handlers can
// set this without harm, even if they aren't wrapped by newGzipHandler.
//
// In go-ethereum, we use this signal to disable compression for certain error
// responses which are flushed out close to the write deadline of the response. For
// these cases, we want to avoid chunked transfer encoding and compression because
// they require additional output that may not get written in time.
passthrough := hdr.Get("transfer-encoding") == "identity"
if !passthrough {
w.gz = gzPool.Get().(*gzip.Writer)
w.gz.Reset(w.resp)
hdr.Del("content-length")
hdr.Set("content-encoding", "gzip")
}
}

func (w *gzipResponseWriter) Header() http.Header {
return w.resp.Header()
}

func (w *gzipResponseWriter) WriteHeader(status int) {
w.Header().Del("Content-Length")
w.ResponseWriter.WriteHeader(status)
w.init()
w.resp.WriteHeader(status)
}

func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
w.init()

if w.gz == nil {
// Compression is disabled.
return w.resp.Write(b)
}

n, err := w.gz.Write(b)
w.written += uint64(n)
if w.hasLength && w.written >= w.contentLength {
// The HTTP handler has finished writing the entire uncompressed response. Close
// the gzip stream to ensure the footer will be seen by the client in case the
// response is flushed after this call to write.
err = w.gz.Close()
}
return n, err
}

func (w *gzipResponseWriter) Flush() {
if w.gz != nil {
w.gz.Flush()
}
if f, ok := w.resp.(http.Flusher); ok {
f.Flush()
}
}

func (w *gzipResponseWriter) close() {
if w.gz == nil {
return
}
w.gz.Close()
gzPool.Put(w.gz)
w.gz = nil
}

func newGzipHandler(next http.Handler) http.Handler {
Expand All @@ -125,15 +203,10 @@ func newGzipHandler(next http.Handler) http.Handler {
return
}

w.Header().Set("Content-Encoding", "gzip")

gz := gzPool.Get().(*gzip.Writer)
defer gzPool.Put(gz)

gz.Reset(w)
defer gz.Close()
wrapper := &gzipResponseWriter{resp: w}
defer wrapper.close()

next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
next.ServeHTTP(wrapper, r)
})
}

Expand Down
2 changes: 1 addition & 1 deletion p2p/simulations/adapters/inproc.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error {
if err != nil {
return err
}
codec := rpc.NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
codec := rpc.NewFuncCodec(conn, func(v any, _ bool) error { return conn.WriteJSON(v) }, conn.ReadJSON)
handler.ServeCodec(codec, 0)
return nil
}
Expand Down
5 changes: 3 additions & 2 deletions rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
return err
}
}
err := c.writeConn.writeJSON(ctx, msg)
err := c.writeConn.writeJSON(ctx, msg, false)
if err != nil {
c.writeConn = nil
if !retry {
Expand Down Expand Up @@ -647,7 +647,8 @@ func (c *Client) read(codec ServerCodec) {
for {
msgs, batch, err := codec.readBatch()
if _, ok := err.(*json.SyntaxError); ok {
codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
msg := errorMessage(&parseError{err.Error()})
codec.writeJSON(context.Background(), msg, true)
}
if err != nil {
c.readErr <- err
Expand Down
38 changes: 30 additions & 8 deletions rpc/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,17 @@ var (
_ Error = new(invalidParamsError)
)

const defaultErrorCode = -32000
const (
defaultErrorCode = -32000
errcodeNotificationsUnsupported = -32001
errcodeTimeout = -32002
errcodePanic = -32603
errcodeMarshalError = -32603
)

const (
errMsgTimeout = "request timed out"
)

type methodNotFoundError struct{ method string }

Expand All @@ -81,13 +91,6 @@ func (e *parseError) ErrorCode() int { return -32700 }

func (e *parseError) Error() string { return e.message }

// received message isn't a valid request
type invalidRequestError struct{ message string }

func (e *invalidRequestError) ErrorCode() int { return -32600 }

func (e *invalidRequestError) Error() string { return e.message }

// received message is invalid
type invalidMessageError struct{ message string }

Expand All @@ -98,6 +101,25 @@ func (e *invalidMessageError) Error() string { return e.message }
// unable to decode supplied params, or an invalid number of parameters
type invalidParamsError struct{ message string }

// received message isn't a valid request
type invalidRequestError struct{ message string }

func (e *invalidRequestError) ErrorCode() int { return -32600 }

func (e *invalidRequestError) Error() string { return e.message }

// unable to decode supplied params, or an invalid number of parameters

func (e *invalidParamsError) ErrorCode() int { return -32602 }

func (e *invalidParamsError) Error() string { return e.message }

// internalServerError is used for server errors during request processing.
type internalServerError struct {
code int
message string
}

func (e *internalServerError) ErrorCode() int { return e.code }

func (e *internalServerError) Error() string { return e.message }
141 changes: 132 additions & 9 deletions rpc/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,83 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
return h
}

// batchCallBuffer manages in progress call messages and their responses during a batch
// call. Calls need to be synchronized between the processing and timeout-triggering
// goroutines.
type batchCallBuffer struct {
mutex sync.Mutex
calls []*jsonrpcMessage
resp []*jsonrpcMessage
wrote bool
}

// nextCall returns the next unprocessed message.
func (b *batchCallBuffer) nextCall() *jsonrpcMessage {
b.mutex.Lock()
defer b.mutex.Unlock()

if len(b.calls) == 0 {
return nil
}
// The popping happens in `pushAnswer`. The in progress call is kept
// so we can return an error for it in case of timeout.
msg := b.calls[0]
return msg
}

// pushResponse adds the response to last call returned by nextCall.
func (b *batchCallBuffer) pushResponse(answer *jsonrpcMessage) {
b.mutex.Lock()
defer b.mutex.Unlock()

if answer != nil {
b.resp = append(b.resp, answer)
}
b.calls = b.calls[1:]
}

// write sends the responses.
func (b *batchCallBuffer) write(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()

b.doWrite(ctx, conn, false)
}

// timeout sends the responses added so far. For the remaining unanswered call
// messages, it sends a timeout error response.
func (b *batchCallBuffer) timeout(ctx context.Context, conn jsonWriter) {
b.mutex.Lock()
defer b.mutex.Unlock()

for _, msg := range b.calls {
if !msg.isNotification() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
b.resp = append(b.resp, resp)
}
}
b.doWrite(ctx, conn, true)
}

// doWrite actually writes the response.
// This assumes b.mutex is held.
func (b *batchCallBuffer) doWrite(ctx context.Context, conn jsonWriter, isErrorResponse bool) {
if b.wrote {
return
}
b.wrote = true // can only write once
if len(b.resp) > 0 {
conn.writeJSON(ctx, b.resp, isErrorResponse)
}
}

// handleBatch executes all messages in a batch and returns the responses.
func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
// Emit error response for empty batches:
if len(msgs) == 0 {
h.startCallProc(func(cp *callProc) {
h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
resp := errorMessage(&invalidRequestError{"empty batch"})
h.conn.writeJSON(cp.ctx, resp, true)
})
return
}
Expand All @@ -113,16 +184,42 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
}
// Process calls on a goroutine because they may block indefinitely:
h.startCallProc(func(cp *callProc) {
answers := make([]*jsonrpcMessage, 0, len(msgs))
for _, msg := range calls {
if answer := h.handleCallMsg(cp, msg); answer != nil {
answers = append(answers, answer)
var (
timer *time.Timer
cancel context.CancelFunc
callBuffer = &batchCallBuffer{calls: calls, resp: make([]*jsonrpcMessage, 0, len(calls))}
)

cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()

// Cancel the request context after timeout and send an error response. Since the
// currently-running method might not return immediately on timeout, we must wait
// for the timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
callBuffer.timeout(cp.ctx, h.conn)
})
}

for {
// No need to handle rest of calls if timed out.
if cp.ctx.Err() != nil {
break
}
msg := callBuffer.nextCall()
if msg == nil {
break
}
resp := h.handleCallMsg(cp, msg)
callBuffer.pushResponse(resp)
}
h.addSubscriptions(cp.notifiers)
if len(answers) > 0 {
h.conn.writeJSON(cp.ctx, answers)
if timer != nil {
timer.Stop()
}
callBuffer.write(cp.ctx, h.conn)
h.addSubscriptions(cp.notifiers)
for _, n := range cp.notifiers {
n.activate()
}
Expand All @@ -135,10 +232,36 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
return
}
h.startCallProc(func(cp *callProc) {
var (
responded sync.Once
timer *time.Timer
cancel context.CancelFunc
)
cp.ctx, cancel = context.WithCancel(cp.ctx)
defer cancel()

// Cancel the request context after timeout and send an error response. Since the
// running method might not return immediately on timeout, we must wait for the
// timeout concurrently with processing the request.
if timeout, ok := ContextRequestTimeout(cp.ctx); ok {
timer = time.AfterFunc(timeout, func() {
cancel()
responded.Do(func() {
resp := msg.errorResponse(&internalServerError{errcodeTimeout, errMsgTimeout})
h.conn.writeJSON(cp.ctx, resp, true)
})
})
}

answer := h.handleCallMsg(cp, msg)
if timer != nil {
timer.Stop()
}
h.addSubscriptions(cp.notifiers)
if answer != nil {
h.conn.writeJSON(cp.ctx, answer)
responded.Do(func() {
h.conn.writeJSON(cp.ctx, answer, false)
})
}
for _, n := range cp.notifiers {
n.activate()
Expand Down
Loading