Skip to content

Commit

Permalink
Pass Logger into handler
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamSLevy committed Oct 14, 2019
1 parent 34ada31 commit 0f65c00
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 27 deletions.
1 change: 1 addition & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
// Client.DebugRequest is true.
type Logger interface {
Println(...interface{})
Printf(string, ...interface{})
}

// Client embeds http.Client and provides a convenient way to make JSON-RPC 2.0
Expand Down
3 changes: 2 additions & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
//
// func StartServer() {
// methods := jsonrpc2.MethodMap{"version": versionMethod}
// http.ListenAndServe(":8080", jsonrpc2.HTTPRequestHandler(methods))
// http.ListenAndServe(":8080", jsonrpc2.HTTPRequestHandler(methods,
// log.New(os.Stderr, "", 0)))
// }
package jsonrpc2
4 changes: 3 additions & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"

// Specify the package name to avoid goimports from reverting this
// import to an older version.
Expand Down Expand Up @@ -124,7 +126,7 @@ func Example() {
"get_data": getData,
}
jsonrpc2.DebugMethodFunc = true
handler := jsonrpc2.HTTPRequestHandler(methods)
handler := jsonrpc2.HTTPRequestHandler(methods, log.New(os.Stdout, "", 0))
http.ListenAndServe(":18888", handler)
}()

Expand Down
27 changes: 18 additions & 9 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
)

Expand All @@ -44,33 +46,40 @@ import (
//
// This will panic if a method name beginning with "rpc." is used. See
// MethodMap for more details.
func HTTPRequestHandler(methods MethodMap) http.HandlerFunc {
//
// The handler will use lgr to log any errors and debug information, if
// DebugMethodFunc is true. If lgr is nil, the default Logger from the log
// package is used.
func HTTPRequestHandler(methods MethodMap, lgr Logger) http.HandlerFunc {
for name := range methods {
if strings.HasPrefix(name, "rpc.") {
panic(fmt.Errorf("invalid method name: %v", name))
}
}
if lgr == nil {
lgr = log.New(os.Stderr, "", log.LstdFlags)
}

return func(w http.ResponseWriter, req *http.Request) {
res := handle(methods, req)
res := handle(methods, req, lgr)
if res == nil {
return
}
enc := json.NewEncoder(w)
// We should never have a JSON encoding related error because
// MethodFunc.call() already Marshaled any user provided Data
// or Result, and everything else is marshalable.
//
// However an error can be returned related to w.Write, which
// there is nothing we can do about, so we just log it here.
enc := json.NewEncoder(w)
if err := enc.Encode(res); err != nil {
logger.Printf("error writing response: %v", err)
lgr.Printf("req.Body.Write(): %v", err)
}
}
}

// handle an http.Request for the given methods.
func handle(methods MethodMap, req *http.Request) interface{} {
func handle(methods MethodMap, req *http.Request, lgr Logger) interface{} {
// Read all bytes of HTTP request body.
reqBytes, err := ioutil.ReadAll(req.Body)
if err != nil {
Expand Down Expand Up @@ -102,7 +111,7 @@ func handle(methods MethodMap, req *http.Request) interface{} {
// Process each Request, omitting any returned Response that is empty.
responses := make(BatchResponse, 0, len(rawReqs))
for _, rawReq := range rawReqs {
res := processRequest(req.Context(), methods, rawReq)
res := processRequest(req.Context(), methods, rawReq, lgr)
if res == (Response{}) {
// Don't respond to Notifications.
continue
Expand All @@ -128,7 +137,7 @@ func handle(methods MethodMap, req *http.Request) interface{} {
// using the methods defined in methods. If res is zero valued, then the
// Request was a Notification and should not be responded to.
func processRequest(ctx context.Context,
methods MethodMap, rawReq json.RawMessage) (res Response) {
methods MethodMap, rawReq json.RawMessage, lgr Logger) (res Response) {

// Unmarshal into req with an error on any unknown fields.
var req Request
Expand Down Expand Up @@ -162,12 +171,12 @@ func processRequest(ctx context.Context,
if !ok {
return Response{Error: errorMethodNotFound(req.Method)}
}
res = method.call(ctx, req.Method, params)
res = method.call(ctx, req.Method, params, lgr)

// Log the method name if debugging is enabled and the method had an
// internal error.
if DebugMethodFunc && res.HasError() && res.Error.Code == ErrorCodeInternal {
logger.Printf("Method: %#v\n\n", req.Method)
lgr.Printf("Method: %#v\n\n", req.Method)
}

return res
Expand Down
12 changes: 4 additions & 8 deletions methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"runtime"
)

Expand All @@ -36,8 +34,6 @@ import (
// MethodFunc.
var DebugMethodFunc = false

var logger = log.New(os.Stdout, "", 0)

// MethodMap associates method names with MethodFuncs and is passed to
// HTTPRequestHandler to generate a corresponding http.HandlerFunc.
//
Expand Down Expand Up @@ -88,7 +84,7 @@ type MethodFunc func(ctx context.Context, params json.RawMessage) interface{}
// and validate and sanitize the returned Response. If the method panics or
// returns an invalid Response, an Internal Error is returned.
func (method MethodFunc) call(ctx context.Context,
name string, params json.RawMessage) (res Response) {
name string, params json.RawMessage, lgr Logger) (res Response) {

var result interface{}
defer func() {
Expand All @@ -99,10 +95,10 @@ func (method MethodFunc) call(ctx context.Context,
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
logger.Printf("jsonrpc2: panic running method %#v: %v\n%s",
lgr.Printf("jsonrpc2: panic running method %#v: %v\n%s",
method, err, buf)
logger.Printf("jsonrpc2: Params: %v", string(params))
logger.Printf("jsonrpc2: Return: %#v", result)
lgr.Printf("jsonrpc2: Params: %v", string(params))
lgr.Printf("jsonrpc2: Return: %#v", result)
}
}
}()
Expand Down
12 changes: 4 additions & 8 deletions methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"encoding/json"
"fmt"
"log"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -80,14 +79,11 @@ func TestMethodFuncCall(t *testing.T) {
assert := assert.New(t)

var buf bytes.Buffer
logger.SetOutput(&buf) // hide output
log := log.New(&buf, "", 0) // record output
DebugMethodFunc = true
defer func() {
logger = log.New(os.Stdout, "", 0)
}()

for _, test := range testMethods {
res := test.Func.call(context.Background(), "", nil)
res := test.Func.call(context.Background(), "", nil, log)
if test.Error == nil {
assert.Equal(errorInternal(nil), res.Error, test.Name)
} else {
Expand All @@ -101,7 +97,7 @@ func TestMethodFuncCall(t *testing.T) {
var f MethodFunc = func(_ context.Context, _ json.RawMessage) interface{} {
return Error{100, "custom", "data"}
}
res := f.call(context.Background(), "", nil)
res := f.call(context.Background(), "", nil, log)
if assert.NotNil(res.Error) {
assert.Equal(Error{
Code: 100,
Expand All @@ -114,7 +110,7 @@ func TestMethodFuncCall(t *testing.T) {
f = func(_ context.Context, _ json.RawMessage) interface{} {
return ErrorInvalidParams("data")
}
res = f.call(context.Background(), "", nil)
res = f.call(context.Background(), "", nil, log)
if assert.NotNil(res.Error) {
e := ErrorInvalidParams(json.RawMessage(`"data"`))
assert.Equal(e, res.Error)
Expand Down

0 comments on commit 0f65c00

Please sign in to comment.