Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions servers/gateway/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package gateway

import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/7cav/api/datastores"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// fakeAuthDatastore embeds the Datastore interface so it satisfies the type
// without implementing every method; only ValidateApiKey is exercised by
// authMiddleware. Any other call panics (nil method) — a loud failure if a
// test accidentally reaches further into the datastore.
type fakeAuthDatastore struct {
datastores.Datastore
validateApiKey func(string) (*datastores.ApiKeyResult, error)
}

func (f *fakeAuthDatastore) ValidateApiKey(rawKey string) (*datastores.ApiKeyResult, error) {
return f.validateApiKey(rawKey)
}

// callMiddleware runs authMiddleware in front of a handler that records whether
// it was reached, and returns the recorded response plus the next-called flag.
func callMiddleware(t *testing.T, ds datastores.Datastore, authHeader string) (*httptest.ResponseRecorder, bool) {
t.Helper()
nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})
h := authMiddleware(ds, next)

req := httptest.NewRequest(http.MethodGet, "/api/v1/whatever", nil)
if authHeader != "" {
req.Header.Set("Authorization", authHeader)
}
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
return rr, nextCalled
}

func TestAuthMiddleware_NoAuthHeader_NamesBearerScheme(t *testing.T) {
ds := &fakeAuthDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
t.Fatal("ValidateApiKey must not be called when no bearer token is present")
return nil, nil
}}
rr, nextCalled := callMiddleware(t, ds, "")

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.False(t, nextCalled)
body := rr.Body.String()
assert.Contains(t, body, "Bearer")
assert.Contains(t, body, "Authorization")
}

func TestAuthMiddleware_RawKeyNoBearerPrefix_NamesBearerScheme(t *testing.T) {
ds := &fakeAuthDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
t.Fatal("ValidateApiKey must not be called when the Bearer scheme is absent")
return nil, nil
}}
rr, nextCalled := callMiddleware(t, ds, "cav7_rawkeywithoutprefix")

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.False(t, nextCalled)
body := rr.Body.String()
assert.Contains(t, body, "Bearer")
assert.Contains(t, body, "Authorization")
}

func TestAuthMiddleware_BadKey_GenericUnauthorizedNoLeak(t *testing.T) {
ds := &fakeAuthDatastore{validateApiKey: func(token string) (*datastores.ApiKeyResult, error) {
assert.Equal(t, "cav7_badkey", token)
return nil, nil // zero rows → nil result, no error
}}
rr, nextCalled := callMiddleware(t, ds, "Bearer cav7_badkey")

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.False(t, nextCalled)
body := strings.TrimSpace(rr.Body.String())
// Generic — must NOT name the Bearer scheme (that branch is for scheme errors)
// and must not leak anything about the key.
assert.Equal(t, "Unauthorized", body)
assert.NotContains(t, body, "Bearer")
}

func TestAuthMiddleware_ValidKey_CallsNext(t *testing.T) {
ds := &fakeAuthDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
return &datastores.ApiKeyResult{KeyId: 1, UserId: 2}, nil
}}
rr, nextCalled := callMiddleware(t, ds, "Bearer cav7_goodkey")

require.True(t, nextCalled)
assert.Equal(t, http.StatusOK, rr.Code)
}

func TestAuthMiddleware_ValidateError_GenericUnauthorized(t *testing.T) {
ds := &fakeAuthDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
return nil, io.ErrUnexpectedEOF
}}
rr, nextCalled := callMiddleware(t, ds, "Bearer cav7_anykey")

assert.Equal(t, http.StatusUnauthorized, rr.Code)
assert.False(t, nextCalled)
assert.Equal(t, "Unauthorized", strings.TrimSpace(rr.Body.String()))
}
12 changes: 10 additions & 2 deletions servers/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,20 @@ func getOpenAPIHandler() http.Handler {
// cav7_ prefix (5) + 64 hex chars = 69; 128 gives generous headroom.
const maxTokenLen = 128

// errBearerScheme is the 401 body returned when the Authorization header is
// missing or doesn't carry a usable Bearer token (no/empty/oversized token).
// It names the expected format so callers who paste a raw key without the
// "Bearer " prefix get a self-explanatory error. The key-validation-failure
// branch stays the generic "Unauthorized" so it leaks nothing about whether a
// key exists, is expired, or lacks scopes.
const errBearerScheme = "Unauthorized: expected 'Authorization: Bearer <key>' header"

func authMiddleware(ds datastores.Datastore, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := datastores.ParseBearerToken(r.Header.Get("Authorization"), maxTokenLen)
if token == "" {
Warn.Printf("Unauthorized HTTP access attempt from %s", r.RemoteAddr)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
Warn.Printf("Unauthorized HTTP access attempt (bad bearer scheme) from %s", r.RemoteAddr)
http.Error(w, errBearerScheme, http.StatusUnauthorized)
return
}

Expand Down
77 changes: 77 additions & 0 deletions servers/grpc/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,83 @@ func TestAuthInterceptor_LogsRequestOnSuccess(t *testing.T) {
assert.Contains(t, logged, "key_id=17")
}

// rawAuthCtx sets the authorization metadata verbatim (no implicit "Bearer "
// prefix), so scheme-problem cases can be exercised.
func rawAuthCtx(authHeader, peerIP string) context.Context {
var ctx context.Context
if authHeader == "" {
ctx = context.Background()
} else {
ctx = metadata.NewIncomingContext(
context.Background(),
metadata.Pairs("authorization", authHeader),
)
}
return peer.NewContext(ctx, &peer.Peer{
Addr: &net.TCPAddr{IP: net.ParseIP(peerIP), Port: 4242},
})
}

func runInterceptor(t *testing.T, ds *fakeDatastore, ctx context.Context) (any, error, bool) {
t.Helper()
interceptor := NewAuthInterceptor(ds)
info := &grpc.UnaryServerInfo{FullMethod: "/proto.MilpacService/GetProfile"}
handlerCalled := false
resp, err := interceptor(ctx, "req", info, func(ctx context.Context, req any) (any, error) {
handlerCalled = true
return "ok", nil
})
return resp, err, handlerCalled
}

func TestAuthInterceptor_NoAuthHeader_NamesBearerScheme(t *testing.T) {
ds := &fakeDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
t.Fatal("ValidateApiKey must not be called when no authorization metadata is present")
return nil, nil
}}
_, err, called := runInterceptor(t, ds, rawAuthCtx("", "10.0.0.5"))

require.Error(t, err)
assert.False(t, called)
st, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
assert.Contains(t, st.Message(), "Bearer")
assert.Contains(t, st.Message(), "Authorization")
}

func TestAuthInterceptor_RawKeyNoBearerPrefix_NamesBearerScheme(t *testing.T) {
ds := &fakeDatastore{validateApiKey: func(string) (*datastores.ApiKeyResult, error) {
t.Fatal("ValidateApiKey must not be called when the Bearer scheme is absent")
return nil, nil
}}
_, err, called := runInterceptor(t, ds, rawAuthCtx("cav7_rawkeynoprefix", "10.0.0.5"))

require.Error(t, err)
assert.False(t, called)
st, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
assert.Contains(t, st.Message(), "Bearer")
assert.Contains(t, st.Message(), "Authorization")
}

func TestAuthInterceptor_BadKey_GenericNoLeak(t *testing.T) {
ds := &fakeDatastore{validateApiKey: func(token string) (*datastores.ApiKeyResult, error) {
assert.Equal(t, "cav7_badkey", token)
return nil, nil
}}
_, err, called := runInterceptor(t, ds, buildAuthCtx("cav7_badkey", "10.0.0.5"))

require.Error(t, err)
assert.False(t, called)
st, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Unauthenticated, st.Code())
// Generic — must not name the Bearer scheme (that's reserved for scheme errors).
assert.NotContains(t, st.Message(), "Bearer")
}

func TestAuthInterceptor_LogsRequestOnAuthFailure(t *testing.T) {
ds := &fakeDatastore{
validateApiKey: func(token string) (*datastores.ApiKeyResult, error) {
Expand Down
15 changes: 11 additions & 4 deletions servers/grpc/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import (

const maxTokenLen = 128

// errBearerScheme mirrors the HTTP gateway's scheme-problem message: it names
// the expected Authorization format so callers missing the "Bearer " prefix
// get a self-explanatory error. Returned for any Bearer-scheme problem (no
// metadata, no authorization header, or an empty/oversized token). The
// key-validation-failure branch stays generic to leak nothing about the key.
const errBearerScheme = "Unauthenticated: expected 'Authorization: Bearer <key>' metadata"

func NewAuthInterceptor(ds datastores.Datastore) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
peerAddr := "unknown"
Expand All @@ -45,22 +52,22 @@ func NewAuthInterceptor(ds datastores.Datastore) grpc.UnaryServerInterceptor {

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "missing metadata")
return nil, status.Error(codes.Unauthenticated, errBearerScheme)
}

authHeaders := md.Get("authorization")
if len(authHeaders) < 1 {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization token")
return nil, status.Error(codes.Unauthenticated, errBearerScheme)
}

token := datastores.ParseBearerToken(authHeaders[0], maxTokenLen)
if token == "" {
return nil, status.Errorf(codes.Unauthenticated, "missing authorization token")
return nil, status.Error(codes.Unauthenticated, errBearerScheme)
}

key, err := ds.ValidateApiKey(token)
if err != nil || key == nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid api key")
return nil, status.Error(codes.Unauthenticated, "invalid api key")
}

keyID = strconv.FormatUint(uint64(key.KeyId), 10)
Expand Down