Skip to content

Commit

Permalink
Fix new subrouter to inherit middlewares (#113)
Browse files Browse the repository at this point in the history
* Fix new subrouter to inherit middlewares

* zrouter: fix Run method
  • Loading branch information
lucaslopezf committed Apr 8, 2024
1 parent 2afb299 commit 6e35bd7
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
26 changes: 21 additions & 5 deletions pkg/zrouter/zrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ func (r *zrouter) NewSubRouter() ZRouter {
config: r.config,
}

for _, middleware := range r.middlewares {
newRouter.router.Use(middleware)
}

for _, middleware := range r.defaultMiddlewares {
newRouter.router.Use(middleware)
}

return newRouter
}

Expand Down Expand Up @@ -178,11 +186,7 @@ func (r *zrouter) Group(prefix string) Routes {
}

func (r *zrouter) Run(addr ...string) error {
address := defaultAddress
if len(addr) > 0 {
address = addr[0]
}

address := formatAddress(addr...)
r.config.Logger.Infof("Start server at %v", address)

if r.config.JWTUsageMetricsConfig.Enable {
Expand Down Expand Up @@ -314,6 +318,18 @@ func (r *zrouter) ServeFiles(routePattern string, httpHandler http.Handler) {
r.router.Handle(routePattern, httpHandler)
}

func formatAddress(addr ...string) string {
if len(addr) > 0 {
address := addr[0]
// Ensure the address starts with a colon if only a port number is provided
if !strings.Contains(address, ":") {
address = ":" + address
}
return address
}
return defaultAddress
}

func LogTopJWTPathMetrics(ctx context.Context, zCache zcache.RemoteCache, updateInterval time.Duration, topN int) {
if updateInterval == 0 {
updateInterval = defaultUpdateInterval
Expand Down
70 changes: 70 additions & 0 deletions pkg/zrouter/zrouter_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
package zrouter

import (
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"github.com/zondax/golem/pkg/logger"
"github.com/zondax/golem/pkg/zrouter/domain"
"github.com/zondax/golem/pkg/zrouter/zmiddlewares"
"net/http"
"net/http/httptest"
"testing"
)

const (
headerDefaultMiddleware = "X-Default-Middleware"
headerCustomMiddleware = "X-Custom-Middleware"
testValue = "applied"
)

type ZRouterSuite struct {
suite.Suite
router ZRouter
Expand Down Expand Up @@ -75,3 +83,65 @@ func TestValidateAppVersionAndRevision(t *testing.T) {

New(nil, nil)
}

func dummyDefaultMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(headerDefaultMiddleware, testValue)
next.ServeHTTP(w, r)
})
}

func dummyCustomMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(headerCustomMiddleware, testValue)
next.ServeHTTP(w, r)
})
}

func TestNewSubRouterWithMiddleware(t *testing.T) {
mainRouter := &zrouter{
router: chi.NewRouter(),
middlewares: []zmiddlewares.Middleware{dummyDefaultMiddleware},
defaultMiddlewares: []zmiddlewares.Middleware{dummyCustomMiddleware},
}

subRouter := mainRouter.NewSubRouter()

subRouter.(*zrouter).router.Get("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

req, _ := http.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()

subRouter.(*zrouter).router.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code, "The request should be processed correctly")

assert.Equal(t, testValue, w.Header().Get(headerDefaultMiddleware))
assert.Equal(t, testValue, w.Header().Get(headerCustomMiddleware))
}

func TestFormatAddress(t *testing.T) {
var tests = []struct {
name string
input []string
expected string
}{
{"No input", []string{}, defaultAddress},
{"Just port", []string{"3030"}, ":3030"},
{"Colon and port", []string{":3030"}, ":3030"},
{"IP and port", []string{"127.0.0.1:3030"}, "127.0.0.1:3030"},
{"Hostname and port", []string{"localhost:3030"}, "localhost:3030"},
{"IPv6 with port", []string{"[::1]:3030"}, "[::1]:3030"},
{"IPv6 without port", []string{"::1"}, "::1"},
{"Full address with colon", []string{"http://localhost:3030"}, "http://localhost:3030"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := formatAddress(tt.input...)
assert.Equal(t, tt.expected, result, "Formatted address should match expected value")
})
}
}

0 comments on commit 6e35bd7

Please sign in to comment.