Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/tests/mock_services.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,45 @@ func (m *MockServiceRegistry) Get(name string) any {
args := m.Called(name)
return args.Get(0)
}

type MockSecondaryStorage struct {
mock.Mock
}

func (m *MockSecondaryStorage) Get(ctx context.Context, key string) (any, error) {
args := m.Called(ctx, key)
return args.Get(0), args.Error(1)
}

func (m *MockSecondaryStorage) Set(ctx context.Context, key string, value any, ttl *time.Duration) error {
args := m.Called(ctx, key, value, ttl)
return args.Error(0)
}

func (m *MockSecondaryStorage) Delete(ctx context.Context, key string) error {
args := m.Called(ctx, key)
return args.Error(0)
}

func (m *MockSecondaryStorage) Incr(ctx context.Context, key string, ttl *time.Duration) (int, error) {
args := m.Called(ctx, key, ttl)
return args.Int(0), args.Error(1)
}

func (m *MockSecondaryStorage) TTL(ctx context.Context, key string) (*time.Duration, error) {
args := m.Called(ctx, key)
if v := args.Get(0); v != nil {
return v.(*time.Duration), args.Error(1)
}
return nil, args.Error(1)
}

func (m *MockSecondaryStorage) Scan(ctx context.Context, prefix string) ([]string, error) {
args := m.Called(ctx, prefix)
return args.Get(0).([]string), args.Error(1)
}

func (m *MockSecondaryStorage) Close() error {
args := m.Called()
return args.Error(0)
}
2 changes: 2 additions & 0 deletions models/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type SecondaryStorage interface {
Incr(ctx context.Context, key string, ttl *time.Duration) (int, error)
// TTL retrieves the time-to-live (TTL) for the given key.
TTL(ctx context.Context, key string) (*time.Duration, error)
// Scan returns all keys matching the given prefix that have not expired.
Scan(ctx context.Context, prefix string) ([]string, error)
// Close closes the storage and releases any resources.
Close() error
}
14 changes: 4 additions & 10 deletions plugins/bearer/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (p *BearerPlugin) validateBearerToken(reqCtx *models.RequestContext) error
return nil
}

userID, err := p.jwtService.ValidateToken(token)
actor, err := p.jwtService.ValidateToken(reqCtx.Request.Context(), token)
if err != nil {
reqCtx.SetJSONResponse(http.StatusUnauthorized, map[string]any{
"message": "Bearer token invalid or expired",
Expand All @@ -58,10 +58,7 @@ func (p *BearerPlugin) validateBearerToken(reqCtx *models.RequestContext) error
return nil
}

reqCtx.SetActorInContext(&models.Actor{
ID: userID,
Type: models.ActorUser,
})
reqCtx.SetActorInContext(actor)

return nil
}
Expand All @@ -78,15 +75,12 @@ func (p *BearerPlugin) validateBearerTokenOptional(reqCtx *models.RequestContext
return nil
}

userID, err := p.jwtService.ValidateToken(token)
actor, err := p.jwtService.ValidateToken(reqCtx.Request.Context(), token)
if err != nil {
return nil
}

reqCtx.SetActorInContext(&models.Actor{
ID: userID,
Type: models.ActorUser,
})
reqCtx.SetActorInContext(actor)

return nil
}
179 changes: 179 additions & 0 deletions plugins/bearer/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package bearer

import (
"errors"
"net/http"
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

internaltests "github.com/Authula/authula/internal/tests"
"github.com/Authula/authula/models"
bearertests "github.com/Authula/authula/plugins/bearer/tests"
)

func newTestBearerPlugin(jwtSvc *bearertests.MockJWTService) *BearerPlugin {
return &BearerPlugin{
config: BearerPluginConfig{HeaderName: "Authorization"},
jwtService: jwtSvc,
}
}

func newBearerRequestCtx(t *testing.T, header string) *models.RequestContext {
t.Helper()
req, _, reqCtx := internaltests.NewHandlerRequestWithActor(t, http.MethodGet, "/test", nil, nil)
if header != "" {
req.Header.Set("Authorization", header)
reqCtx.Headers = req.Header
}
return reqCtx
}

func TestValidateBearerToken(t *testing.T) {
t.Parallel()

tests := []struct {
name string
header string
setupMock func(*bearertests.MockJWTService)
preSetActor *models.Actor
wantHandled bool
wantStatus int
wantActor *models.Actor
}{
{
name: "actor_already_set",
preSetActor: &models.Actor{ID: "existing-user", Type: models.ActorUser},
setupMock: func(m *bearertests.MockJWTService) {
},
wantHandled: false,
wantActor: &models.Actor{ID: "existing-user", Type: models.ActorUser},
},
{
name: "no_token",
header: "",
setupMock: func(m *bearertests.MockJWTService) {
},
wantHandled: true,
wantStatus: http.StatusUnauthorized,
},
{
name: "invalid_token",
header: "Bearer invalid-token",
setupMock: func(m *bearertests.MockJWTService) {
m.On("ValidateToken", mock.Anything, "invalid-token").Return(nil, errors.New("invalid token")).Once()
},
wantHandled: true,
wantStatus: http.StatusUnauthorized,
},
{
name: "valid_user_token",
header: "Bearer valid-user-token",
setupMock: func(m *bearertests.MockJWTService) {
m.On("ValidateToken", mock.Anything, "valid-user-token").Return(&models.Actor{ID: "user-1", Type: models.ActorUser}, nil).Once()
},
wantActor: &models.Actor{ID: "user-1", Type: models.ActorUser, Scopes: []string{}, Metadata: map[string]any{}},
},
{
name: "valid_machine_token",
header: "Bearer valid-machine-token",
setupMock: func(m *bearertests.MockJWTService) {
m.On("ValidateToken", mock.Anything, "valid-machine-token").Return(&models.Actor{ID: "client-1", Type: models.ActorMachine, OrganizationID: internaltests.PtrString("org-1"), Scopes: []string{"read"}}, nil).Once()
},
wantActor: &models.Actor{ID: "client-1", Type: models.ActorMachine, OrganizationID: internaltests.PtrString("org-1"), Scopes: []string{"read"}, Metadata: map[string]any{}},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &bearertests.MockJWTService{}
tt.setupMock(mockSvc)

p := newTestBearerPlugin(mockSvc)
reqCtx := newBearerRequestCtx(t, tt.header)
if tt.preSetActor != nil {
reqCtx.Actor = tt.preSetActor
}

err := p.validateBearerToken(reqCtx)
require.NoError(t, err)

require.Equal(t, tt.wantHandled, reqCtx.Handled)
if tt.wantStatus != 0 {
require.Equal(t, tt.wantStatus, reqCtx.ResponseStatus)
}
if tt.wantActor != nil {
require.Equal(t, tt.wantActor, reqCtx.Actor)
}
mockSvc.AssertExpectations(t)
})
}
}

func TestValidateBearerTokenOptional(t *testing.T) {
t.Parallel()

tests := []struct {
name string
header string
setupMock func(*bearertests.MockJWTService)
preSetActor *models.Actor
wantHandled bool
wantActor *models.Actor
}{
{
name: "actor_already_set_optional",
preSetActor: &models.Actor{ID: "existing-user", Type: models.ActorUser},
setupMock: func(m *bearertests.MockJWTService) {
},
wantHandled: false,
wantActor: &models.Actor{ID: "existing-user", Type: models.ActorUser},
},
{
name: "no_token_optional",
header: "",
setupMock: func(m *bearertests.MockJWTService) {
},
wantHandled: false,
},
{
name: "invalid_token_optional",
header: "Bearer invalid-token",
setupMock: func(m *bearertests.MockJWTService) {
m.On("ValidateToken", mock.Anything, "invalid-token").Return(nil, errors.New("invalid token")).Once()
},
wantHandled: false,
},
{
name: "valid_token_optional",
header: "Bearer valid-user-token",
setupMock: func(m *bearertests.MockJWTService) {
m.On("ValidateToken", mock.Anything, "valid-user-token").Return(&models.Actor{ID: "user-1", Type: models.ActorUser}, nil).Once()
},
wantActor: &models.Actor{ID: "user-1", Type: models.ActorUser, Scopes: []string{}, Metadata: map[string]any{}},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSvc := &bearertests.MockJWTService{}
tt.setupMock(mockSvc)

p := newTestBearerPlugin(mockSvc)
reqCtx := newBearerRequestCtx(t, tt.header)
if tt.preSetActor != nil {
reqCtx.Actor = tt.preSetActor
}

err := p.validateBearerTokenOptional(reqCtx)
require.NoError(t, err)

require.Equal(t, tt.wantHandled, reqCtx.Handled)
if tt.wantActor != nil {
require.Equal(t, tt.wantActor, reqCtx.Actor)
}
mockSvc.AssertExpectations(t)
})
}
}
14 changes: 10 additions & 4 deletions plugins/bearer/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,16 @@ func (p *BearerPlugin) AuthMiddleware() func(http.Handler) http.Handler {
return
}

userID, err := p.jwtService.ValidateToken(token)
actor, err := p.jwtService.ValidateToken(r.Context(), token)
if err != nil {
p.writeUnauthorized(w, err)
return
}

ctx := context.WithValue(r.Context(), models.ContextUserID, userID)
ctx := context.WithValue(r.Context(), models.ContextAuthActor, actor)
if actor.ID != "" {
ctx = context.WithValue(ctx, models.ContextUserID, actor.ID)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
Expand All @@ -90,8 +93,11 @@ func (p *BearerPlugin) OptionalAuthMiddleware() func(http.Handler) http.Handler
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, err := p.extractToken(r)
if err == nil && token != "" {
if userID, validateErr := p.jwtService.ValidateToken(token); validateErr == nil {
ctx := context.WithValue(r.Context(), models.ContextUserID, userID)
if actor, validateErr := p.jwtService.ValidateToken(r.Context(), token); validateErr == nil {
ctx := context.WithValue(r.Context(), models.ContextAuthActor, actor)
if actor.ID != "" {
ctx = context.WithValue(ctx, models.ContextUserID, actor.ID)
}
r = r.WithContext(ctx)
}
}
Expand Down
68 changes: 68 additions & 0 deletions plugins/bearer/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package bearer

import (
"testing"

"github.com/stretchr/testify/require"

internaltests "github.com/Authula/authula/internal/tests"
"github.com/Authula/authula/models"
bearertests "github.com/Authula/authula/plugins/bearer/tests"
)

func TestBearerPlugin_Metadata(t *testing.T) {
t.Parallel()

plugin := New(BearerPluginConfig{})
metadata := plugin.Metadata()

require.NotEmpty(t, metadata.ID)
require.NotEmpty(t, metadata.Version)
require.NotEmpty(t, metadata.Description)
}

func TestBearerPlugin_Config(t *testing.T) {
t.Parallel()

cfg := BearerPluginConfig{HeaderName: "Custom-Auth", Enabled: true}
plugin := New(cfg)

returnedCfg := plugin.Config()
require.Equal(t, cfg, returnedCfg)
}

func TestBearerPlugin_Init(t *testing.T) {
t.Parallel()

t.Run("missing_jwt_service", func(t *testing.T) {
t.Parallel()
reg := &internaltests.MockServiceRegistry{}
reg.On("Get", models.ServiceJWT.String()).Return(nil).Once()

plugin := New(BearerPluginConfig{})
err := plugin.Init(&models.PluginContext{
Logger: &internaltests.MockLogger{},
ServiceRegistry: reg,
GetConfig: func() *models.Config { return &models.Config{} },
})
require.Error(t, err)
reg.AssertExpectations(t)
})

t.Run("success", func(t *testing.T) {
t.Parallel()
mockSvc := &bearertests.MockJWTService{}
reg := &internaltests.MockServiceRegistry{}
reg.On("Get", models.ServiceJWT.String()).Return(mockSvc).Once()

plugin := New(BearerPluginConfig{})
err := plugin.Init(&models.PluginContext{
Logger: &internaltests.MockLogger{},
ServiceRegistry: reg,
GetConfig: func() *models.Config { return &models.Config{} },
})
require.NoError(t, err)
require.Equal(t, mockSvc, plugin.jwtService)
reg.AssertExpectations(t)
})
}
21 changes: 21 additions & 0 deletions plugins/bearer/tests/mocks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package tests

import (
"context"

"github.com/stretchr/testify/mock"

"github.com/Authula/authula/models"
)

type MockJWTService struct {
mock.Mock
}

func (m *MockJWTService) ValidateToken(ctx context.Context, token string) (*models.Actor, error) {
args := m.Called(ctx, token)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(*models.Actor), args.Error(1)
}
Loading