Skip to content

Commit

Permalink
add logger unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aradwann committed Jan 6, 2024
1 parent e567be5 commit 87a81c8
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 47 deletions.
40 changes: 33 additions & 7 deletions gapi/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,36 @@ package gapi

import (
"context"
"log/slog"
"net/http"
"os"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func GrpcLogger(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
"log/slog"
)

// GrpcLogger logs gRPC requests and responses.
func GrpcLogger(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
startTime := time.Now()
res, err := handler(ctx, req)
duration := time.Since(startTime)
statusCode := codes.Unknown

if st, ok := status.FromError(err); ok {
statusCode = st.Code()
}

logLevel := slog.LevelInfo
var errForLog slog.Attr

if err != nil {
logLevel = slog.LevelError
errForLog = slog.String("error", err.Error())
}

slog.LogAttrs(context.Background(),
logLevel,
"received grpc req",
Expand All @@ -40,19 +46,21 @@ func GrpcLogger(ctx context.Context, req any, info *grpc.UnaryServerInfo, handle
return res, err
}

// HttpLogger logs HTTP requests and responses.
func HttpLogger(handler http.Handler) http.Handler {

return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
startTime := time.Now()
rec := &ResponseRecorder{
ResponseWriter: res,
StatusCode: http.StatusOK,
}

handler.ServeHTTP(rec, req)
duration := time.Since(startTime)

logLevel := slog.LevelInfo
var errForLog slog.Attr

if rec.StatusCode != http.StatusOK {
logLevel = slog.LevelError
errForLog = slog.String("body", string(rec.Body))
Expand All @@ -63,29 +71,47 @@ func HttpLogger(handler http.Handler) http.Handler {
"received http req",
slog.String("protocol", "http"),
slog.String("method", req.Method),
slog.String("method", req.RequestURI),
slog.String("uri", req.RequestURI),
slog.Int("status_code", rec.StatusCode),
slog.String("status_text", http.StatusText(rec.StatusCode)),
errForLog,
slog.Duration("duration", duration),
)

})
}

// ResponseRecorder to get the status code from the original response writer
// ResponseRecorder is used to get the status code from the original response writer.
type ResponseRecorder struct {
http.ResponseWriter
StatusCode int
Body []byte
}

// WriteHeader captures the status code.
func (rec *ResponseRecorder) WriteHeader(statusCode int) {
rec.StatusCode = statusCode
rec.ResponseWriter.WriteHeader(statusCode)
}

// Write captures the response body.
func (rec *ResponseRecorder) Write(body []byte) (int, error) {
rec.Body = body
return rec.ResponseWriter.Write(body)
}

func NewDevelopmentLoggerHandler() slog.Handler {
return slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
AddSource: false,
Level: nil,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey && len(groups) == 0 {
return slog.Attr{}
}
return a
},
})
}

func NewProductionLoggerHandler() slog.Handler {
return slog.NewJSONHandler(os.Stdout, nil)
}
154 changes: 154 additions & 0 deletions gapi/logger_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package gapi

import (
"bytes"
"context"
"io"
"log"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)

func TestGrpcLogger(t *testing.T) {
// Mock gRPC handler function
mockHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}

// Mock gRPC UnaryServerInfo
mockInfo := &grpc.UnaryServerInfo{
FullMethod: "/example.Service/Method",
}

// Execute GrpcLogger
_, err := GrpcLogger(context.Background(), nil, mockInfo, mockHandler)
require.NoError(t, err)

// Capture logs during the test
logs := make([]string, 0)
captureLogs(func() {
// Execute GrpcLogger
_, err := GrpcLogger(context.Background(), nil, mockInfo, mockHandler)
require.NoError(t, err)

}, func(log string) {
logs = append(logs, log)
})

// Verify the log output
assert.Contains(t, logs[0], "received grpc req")
assert.Contains(t, logs[0], "protocol=grpc")
assert.Contains(t, logs[0], "method=/example.Service/Method")
assert.Contains(t, logs[0], "status_code=0") // The default value for codes.OK
assert.Contains(t, logs[0], "status_text=OK")
assert.Contains(t, logs[0], "duration=")

}

func TestHttpLogger(t *testing.T) {
// Mock HTTP handler function
mockHandler := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(http.StatusOK)
res.Write([]byte("OK"))
})

// Create a test HTTP request
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()

// Capture logs during the test
logs := make([]string, 0)
captureLogs(func() {
// Execute HttpLogger
HttpLogger(mockHandler).ServeHTTP(rec, req)
}, func(log string) {
logs = append(logs, log)
})

// Verify the log output
assert.Contains(t, logs[0], "received http req")
assert.Contains(t, logs[0], "status_code=200")
assert.Contains(t, logs[0], "status_text=OK")
}

func TestHttpLoggerErr(t *testing.T) {
// Mock HTTP handler function
mockHandler := http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.WriteHeader(http.StatusInternalServerError)
res.Write([]byte("Error"))
})

// Create a test HTTP request
req := httptest.NewRequest("GET", "/test", nil)
rec := httptest.NewRecorder()

// Capture logs during the test
logs := make([]string, 0)
captureLogs(func() {
// Execute HttpLogger
HttpLogger(mockHandler).ServeHTTP(rec, req)
}, func(log string) {
logs = append(logs, log)
})

// Verify the log output using substring matching
assert.Contains(t, logs[0], "received http req")
assert.Contains(t, logs[0], "status_code=500")
assert.Contains(t, logs[0], "status_text=\"Internal Server Error\"")
assert.Contains(t, logs[0], "body=Error")
}

func TestResponseRecorder(t *testing.T) {
// Mock ResponseWriter
mockResponseWriter := httptest.NewRecorder()

// Create a ResponseRecorder
rec := &ResponseRecorder{
ResponseWriter: mockResponseWriter,
StatusCode: http.StatusOK,
}

// Write some data to ResponseRecorder
rec.Write([]byte("Test Body"))

// Verify the StatusCode
require.Equal(t, rec.StatusCode, http.StatusOK)

// Verify the captured body
require.Equal(t, string(rec.Body), "Test Body")

}

// captureLogs captures logs produced during the execution of a function.
func captureLogs(fn func(), logCallback func(string)) {
originalOutput := log.Writer()
defer func() { log.SetOutput(originalOutput) }()

r, w, err := os.Pipe()
if err != nil {
panic(err)
}
log.SetOutput(w)

var wg sync.WaitGroup
wg.Add(1)

go func() {
defer wg.Done()
var buf bytes.Buffer
io.Copy(&buf, r)
logCallback(buf.String())
}()

fn()

w.Close()
wg.Wait()
}
72 changes: 72 additions & 0 deletions gapi/metadata_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package gapi

import (
"context"
"net"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
)

func TestExtractMetadata(t *testing.T) {
testCases := []struct {
name string
headers metadata.MD
peerAddr net.IP
expected Metadata
}{
{
name: "ExtractMetadataFromGateway",
headers: metadata.Pairs(
grpcGatewayUserAgentHeader, "grpc-gateway-user-agent-value",
xForwardedForHeader, "127.0.0.1",
),
expected: Metadata{
UserAgent: "grpc-gateway-user-agent-value",
ClientIP: "127.0.0.1",
},
},
{
name: "ExtractMetadataFromGrpc",
headers: metadata.Pairs(
userAgentHeader, "user-agent-value",
),
peerAddr: net.ParseIP("127.0.0.1"),
expected: Metadata{
UserAgent: "user-agent-value",
ClientIP: "127.0.0.1",
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
server := &Server{}
ctx := context.Background()

// Set metadata in context
if len(tc.headers) > 0 {
ctx = metadata.NewIncomingContext(ctx, tc.headers)
}

// Set peer information in context
if tc.peerAddr != nil {
p := &peer.Peer{
Addr: &net.IPAddr{
IP: tc.peerAddr,
},
}
ctx = peer.NewContext(ctx, p)
}

// Execute the extractMetadata method
result := server.extractMetadata(ctx)

// Verify the extracted metadata
assert.Equal(t, tc.expected.UserAgent, result.UserAgent)
assert.Equal(t, tc.expected.ClientIP, result.ClientIP)
})
}
}
Loading

0 comments on commit 87a81c8

Please sign in to comment.