From 5c630efa37914af70655fa5044170ac73348fa07 Mon Sep 17 00:00:00 2001 From: eloe868 Date: Tue, 23 Nov 2021 10:03:48 +0800 Subject: [PATCH] reponse sent handler --- header.go | 5 ++++- http.go | 35 ++++++++++++++++++++++------------- server.go | 23 ++++++++++++++++++----- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/header.go b/header.go index 4cbe9f3832..4cc3c1111d 100644 --- a/header.go +++ b/header.go @@ -45,6 +45,8 @@ type ResponseHeader struct { bufKV argsKV cookies []argsKV + + written int64 } // RequestHeader represents HTTP request header. @@ -1621,7 +1623,8 @@ func refreshServerDate() { // Write writes response header to w. func (h *ResponseHeader) Write(w *bufio.Writer) error { - _, err := w.Write(h.Header()) + n, err := w.Write(h.Header()) + h.written = int64(n) return err } diff --git a/http.go b/http.go index aca5f8acaa..a06102916c 100644 --- a/http.go +++ b/http.go @@ -99,6 +99,8 @@ type Response struct { laddr net.Addr manualBodyReader io.ReadCloser + + written int64 } // SetHost sets host for the request. @@ -492,11 +494,13 @@ func (req *Request) BodyWriteTo(w io.Writer) error { // BodyWriteTo writes response body to w. func (resp *Response) BodyWriteTo(w io.Writer) error { if resp.bodyStream != nil { - _, err := copyZeroAlloc(w, resp.bodyStream) + n, err := copyZeroAlloc(w, resp.bodyStream) + resp.written += n resp.closeBodyStream() //nolint:errcheck return err } - _, err := w.Write(resp.bodyBytes()) + n, err := w.Write(resp.bodyBytes()) + resp.written += int64(n) return err } @@ -1673,7 +1677,7 @@ func (w *flushWriter) Write(p []byte) (int, error) { // See also WriteTo. func (resp *Response) Write(w *bufio.Writer) error { sendBody := !resp.mustSkipBody() - + resp.written = 0 if resp.bodyStream != nil { return resp.writeBodyStream(w, sendBody) } @@ -1684,12 +1688,16 @@ func (resp *Response) Write(w *bufio.Writer) error { resp.Header.SetContentLength(bodyLen) } if err := resp.Header.Write(w); err != nil { + resp.written += resp.Header.written return err } + resp.written += resp.Header.written if sendBody { - if _, err := w.Write(body); err != nil { + if n, err := w.Write(body); err != nil { + resp.written += int64(n) return err } + resp.written += int64(bodyLen) } return nil } @@ -1712,12 +1720,12 @@ func (req *Request) writeBodyStream(w *bufio.Writer) error { } if contentLength >= 0 { if err = req.Header.Write(w); err == nil { - err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength)) + _, err = writeBodyFixedSize(w, req.bodyStream, int64(contentLength)) } } else { req.Header.SetContentLength(-1) if err = req.Header.Write(w); err == nil { - err = writeBodyChunked(w, req.bodyStream) + _, err = writeBodyChunked(w, req.bodyStream) } } err1 := req.closeBodyStream() @@ -1760,7 +1768,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error err = w.Flush() } if err == nil && sendBody { - err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)) + resp.written, err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)) } } } else { @@ -1770,7 +1778,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error err = w.Flush() } if err == nil && sendBody { - err = writeBodyChunked(w, resp.bodyStream) + resp.written, err = writeBodyChunked(w, resp.bodyStream) } } } @@ -1778,6 +1786,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error if err == nil { err = err1 } + resp.written += resp.Header.written return err } @@ -1845,7 +1854,7 @@ type httpWriter interface { Write(w *bufio.Writer) error } -func writeBodyChunked(w *bufio.Writer, r io.Reader) error { +func writeBodyChunked(w *bufio.Writer, r io.Reader) (int64, error) { vbuf := copyBufPool.Get() buf := vbuf.([]byte) @@ -1871,7 +1880,7 @@ func writeBodyChunked(w *bufio.Writer, r io.Reader) error { } copyBufPool.Put(vbuf) - return err + return int64(n), err } func limitedReaderSize(r io.Reader) int64 { @@ -1882,12 +1891,12 @@ func limitedReaderSize(r io.Reader) int64 { return lr.N } -func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error { +func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) (int64, error) { if size > maxSmallFileSize { // w buffer must be empty for triggering // sendfile path in bufio.Writer.ReadFrom. if err := w.Flush(); err != nil { - return err + return 0, err } } @@ -1896,7 +1905,7 @@ func writeBodyFixedSize(w *bufio.Writer, r io.Reader, size int64) error { if n != size && err == nil { err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size) } - return err + return n, err } func copyZeroAlloc(w io.Writer, r io.Reader) (int64, error) { diff --git a/server.go b/server.go index cb4c823dfe..caa7e6f3fe 100644 --- a/server.go +++ b/server.go @@ -135,6 +135,8 @@ func ListenAndServeTLSEmbed(addr string, certData, keyData []byte, handler Reque // must be limited. type RequestHandler func(ctx *RequestCtx) +type ResponseSentHandler func(ctx *RequestCtx, size int64) + // ServeHandler must process tls.Config.NextProto negotiated requests. type ServeHandler func(c net.Conn) error @@ -156,6 +158,8 @@ type Server struct { // Instead the user should use `recover` to handle these situations. Handler RequestHandler + RespSentHandler ResponseSentHandler + // ErrorHandler for returning a response in case of an error while receiving or parsing the request. // // The following is a non-exhaustive list of errors that can be expected as argument: @@ -2297,8 +2301,8 @@ func (s *Server) serveConn(c net.Conn) (err error) { if !ctx.IsGet() && ctx.IsHead() { ctx.Response.SkipBody = true } - reqReset = true - ctx.Request.Reset() + //reqReset = true + //ctx.Request.Reset() hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil @@ -2341,7 +2345,10 @@ func (s *Server) serveConn(c net.Conn) (err error) { if bw == nil { bw = acquireWriter(ctx) } - if err = writeResponse(ctx, bw); err != nil { + err = writeResponse(ctx, bw, s.RespSentHandler) + reqReset = true + ctx.Request.Reset() + if err != nil { break } @@ -2363,6 +2370,9 @@ func (s *Server) serveConn(c net.Conn) (err error) { releaseWriter(s, bw) bw = nil } + }else{ + reqReset = true + ctx.Request.Reset() } if hijackHandler != nil { @@ -2499,11 +2509,14 @@ func (ctx *RequestCtx) LastTimeoutErrorResponse() *Response { return ctx.timeoutResponse } -func writeResponse(ctx *RequestCtx, w *bufio.Writer) error { +func writeResponse(ctx *RequestCtx, w *bufio.Writer, sent ResponseSentHandler) error { if ctx.timeoutResponse != nil { panic("BUG: cannot write timed out response") } err := ctx.Response.Write(w) + if sent != nil { + sent(ctx, ctx.Response.written) + } ctx.Response.Reset() return err } @@ -2789,7 +2802,7 @@ func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverNam if bw == nil { bw = acquireWriter(ctx) } - writeResponse(ctx, bw) //nolint:errcheck + writeResponse(ctx, bw, s.RespSentHandler) //nolint:errcheck bw.Flush() return bw }