From 5e499e65290f86c0a5608f891782eb7a3de55194 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 6 Apr 2026 11:37:15 +0530 Subject: [PATCH 1/3] fix(deps): address CVE-2026-34986 and drop go-jose/v3 - Bump github.com/go-jose/go-jose/v4 to v4.1.4 (patched for CVE-2026-34986). - Upgrade github.com/coreos/go-oidc/v3 to v3.17.0 so the OIDC stack uses go-jose/v4 only; removes the indirect go-jose/v3 dependency. Made-with: Cursor --- go.mod | 5 ++--- go.sum | 15 ++++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 9d5314c6..d46afaab 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,12 @@ require ( github.com/99designs/gqlgen v0.17.73 github.com/arangodb/go-driver v1.6.0 github.com/aws/aws-sdk-go v1.47.4 - github.com/coreos/go-oidc/v3 v3.6.0 + github.com/coreos/go-oidc/v3 v3.17.0 github.com/couchbase/gocb/v2 v2.6.4 github.com/ekristen/gorm-libsql v0.0.0-20231101204708-6e113112bcc2 github.com/gin-gonic/gin v1.9.1 github.com/glebarez/sqlite v1.10.0 - github.com/go-jose/go-jose/v4 v4.1.3 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/gocql/gocql v1.6.0 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/uuid v1.6.0 @@ -55,7 +55,6 @@ require ( github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect - github.com/go-jose/go-jose/v3 v3.0.4 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect diff --git a/go.sum b/go.sum index 347e8ed5..6e5bc493 100644 --- a/go.sum +++ b/go.sum @@ -60,8 +60,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o= -github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc= +github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= +github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/couchbase/gocb/v2 v2.6.4 h1:o5k5JnxYkgamVL9svx+vbXc7vKF5X72tNt/qORs+L30= github.com/couchbase/gocb/v2 v2.6.4/go.mod h1:W/cHlBGfendPh53WzRaF1KIXTovI9DaI7EPeeqIsnmc= @@ -98,10 +98,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/sqlite v1.10.0 h1:u4gt8y7OND/cCei/NMHmfbLxF6xP2wgKcT/BJf2pYkc= github.com/glebarez/sqlite v1.10.0/go.mod h1:IJ+lfSOmiekhQsFTJRx/lHtGYmCdtAiTaf5wI9u5uHA= -github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= -github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= -github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= -github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= @@ -153,7 +151,6 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -350,7 +347,6 @@ golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58 golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= @@ -403,7 +399,6 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -413,7 +408,6 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -424,7 +418,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 002624ad35ad1f05473b85065a85ab43c7f520f1 Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 6 Apr 2026 17:43:57 +0530 Subject: [PATCH 2/3] chore: update aws --- .github/workflows/release.yaml | 2 +- CLAUDE.md | 2 + Dockerfile | 9 + Makefile | 22 +- cmd/root.go | 4 +- docs/storage-optional-null-fields.md | 46 +++ go.mod | 23 +- go.sum | 55 ++- internal/config/config.go | 2 +- internal/integration_tests/test_helper.go | 24 +- internal/memory_store/db/provider_test.go | 172 +++++----- internal/memory_store/db/test_config_test.go | 114 +++++++ internal/memory_store/provider_test.go | 24 +- internal/rate_limit/provider.go | 4 + internal/rate_limit/redis.go | 4 +- internal/storage/db/dynamodb/audit_log.go | 166 +++++---- internal/storage/db/dynamodb/authenticator.go | 30 +- .../storage/db/dynamodb/email_template.go | 78 ++--- internal/storage/db/dynamodb/env.go | 32 +- internal/storage/db/dynamodb/health_check.go | 14 +- internal/storage/db/dynamodb/marshal.go | 141 ++++++++ internal/storage/db/dynamodb/ops.go | 316 ++++++++++++++++++ internal/storage/db/dynamodb/otp.go | 50 ++- internal/storage/db/dynamodb/provider.go | 95 +++--- internal/storage/db/dynamodb/session.go | 4 +- internal/storage/db/dynamodb/session_token.go | 235 ++++++------- internal/storage/db/dynamodb/shared.go | 40 --- internal/storage/db/dynamodb/tables.go | 184 ++++++++++ internal/storage/db/dynamodb/user.go | 182 ++++++---- .../db/dynamodb/verification_requests.go | 89 ++--- internal/storage/db/dynamodb/webhook.go | 106 +++--- internal/storage/db/dynamodb/webhook_log.go | 65 ++-- internal/storage/provider_test.go | 3 +- internal/storage/schemas/audit_log.go | 4 +- internal/storage/schemas/otp.go | 2 +- internal/storage/schemas/user.go | 3 + 36 files changed, 1647 insertions(+), 699 deletions(-) create mode 100644 docs/storage-optional-null-fields.md create mode 100644 internal/memory_store/db/test_config_test.go create mode 100644 internal/storage/db/dynamodb/marshal.go create mode 100644 internal/storage/db/dynamodb/ops.go delete mode 100644 internal/storage/db/dynamodb/shared.go create mode 100644 internal/storage/db/dynamodb/tables.go diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index d12b2ebe..9b9386eb 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -75,7 +75,7 @@ jobs: output: "trivy-results.sarif" severity: "CRITICAL,HIGH" - name: Upload Trivy results - uses: github/codeql-action/upload-sarif@v3 + uses: github/codeql-action/upload-sarif@v4 if: always() with: sarif_file: "trivy-results.sarif" diff --git a/CLAUDE.md b/CLAUDE.md index 328621dd..0e2a4718 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,6 +20,7 @@ make build-dashboard # Build admin UI (web/dashboard) make generate-graphql # Regenerate after schema.graphqls change # Testing (TEST_DBS env var selects databases, default: postgres) +# Optional: TEST_ENABLE_REDIS=1 runs Redis memory_store unit tests (Redis on localhost:6380). make test # Docker Postgres (default) make test-sqlite # SQLite in-memory (no Docker) make test-mongodb # Docker MongoDB @@ -40,6 +41,7 @@ go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSign - Token management: `internal/token/` - Tests: `internal/integration_tests/` - Frontend: `web/app/` (user UI) | `web/dashboard/` (admin UI) +- Optional NULL semantics across SQL/document DBs and DynamoDB: `docs/storage-optional-null-fields.md` **Pattern**: Every subsystem uses `Dependencies` struct + `New()` → `Provider` interface. diff --git a/Dockerfile b/Dockerfile index 9e061416..038f2091 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,11 @@ # syntax=docker/dockerfile:1.4 # Use BuildKit for cache mounts (faster CI: DOCKER_BUILDKIT=1) +# +# Alpine v3.23 main still ships busybox 1.37.0-r30 (e.g. CVE-2025-60876); edge/main has r31+. +# Pin busybox from edge until the stable branch backports it. See alpine/aports work item #17940. FROM golang:1.25-alpine3.23 AS go-builder +ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main +RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31" WORKDIR /authorizer ARG TARGETPLATFORM @@ -33,6 +38,8 @@ RUN --mount=type=cache,target=/go/pkg/mod \ chmod 755 build/${GOOS}/${GOARCH}/authorizer FROM alpine:3.23.3 AS node-builder +ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main +RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31" WORKDIR /authorizer COPY web/app/package*.json web/app/ COPY web/dashboard/package*.json web/dashboard/ @@ -47,6 +54,8 @@ COPY web/dashboard web/dashboard RUN cd web/app && npm run build && cd ../dashboard && npm run build FROM alpine:3.23.3 +ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main +RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31" ARG TARGETARCH=amd64 diff --git a/Makefile b/Makefile index 3c42d81b..274fbb05 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,10 @@ DEFAULT_VERSION=0.1.0-local VERSION := $(or $(VERSION),$(DEFAULT_VERSION)) DOCKER_IMAGE ?= authorizerdev/authorizer:$(VERSION) +# Full module test run. TEST_DBS selects storage backends (storage, integration_tests, +# memory_store/db). Redis memory_store tests run only when TEST_ENABLE_REDIS=1. +GO_TEST_ALL := go test -p 1 -v ./... + .PHONY: all bootstrap build build-app build-dashboard build-local-image build-push-image trivy-scan all: build build-app build-dashboard @@ -40,50 +44,50 @@ dev: go run main.go --database-type=sqlite --database-url=test.db --jwt-type=HS256 --jwt-secret=test --admin-secret=admin --client-id=123456 --client-secret=secret test: test-cleanup test-docker-up - go clean --testcache && TEST_DBS="postgres" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="postgres" $(GO_TEST_ALL) $(MAKE) test-cleanup test-postgres: test-cleanup-postgres docker run -d --name authorizer_postgres -p 5434:5432 -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres postgres sleep 3 - go clean --testcache && TEST_DBS="postgres" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="postgres" $(GO_TEST_ALL) docker rm -vf authorizer_postgres test-sqlite: - go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="sqlite" $(GO_TEST_ALL) test-mongodb: test-cleanup-mongodb docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15 sleep 3 - go clean --testcache && TEST_DBS="mongodb" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="mongodb" $(GO_TEST_ALL) docker rm -vf authorizer_mongodb_db test-scylladb: test-cleanup-scylladb docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla sleep 15 - go clean --testcache && TEST_DBS="scylladb" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="scylladb" $(GO_TEST_ALL) docker rm -vf authorizer_scylla_db test-arangodb: test-cleanup-arangodb docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.10.3 sleep 5 - go clean --testcache && TEST_DBS="arangodb" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="arangodb" $(GO_TEST_ALL) docker rm -vf authorizer_arangodb test-dynamodb: test-cleanup-dynamodb docker run -d --name authorizer_dynamodb -p 8000:8000 amazon/dynamodb-local:latest sleep 3 - go clean --testcache && TEST_DBS="dynamodb" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="dynamodb" $(GO_TEST_ALL) docker rm -vf authorizer_dynamodb test-couchbase: test-cleanup-couchbase docker run -d --name authorizer_couchbase -p 8091-8097:8091-8097 -p 11210:11210 -p 11207:11207 -p 18091-18095:18091-18095 -p 18096:18096 -p 18097:18097 couchbase:latest sh scripts/couchbase-test.sh - go clean --testcache && TEST_DBS="couchbase" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="couchbase" $(GO_TEST_ALL) docker rm -vf authorizer_couchbase test-all-db: test-cleanup test-docker-up - go clean --testcache && TEST_DBS="postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase" go test -p 1 -v ./... + go clean --testcache && TEST_DBS="postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase" $(GO_TEST_ALL) $(MAKE) test-cleanup # Start all test database containers diff --git a/cmd/root.go b/cmd/root.go index a00dcb00..3ac48f8d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -55,7 +55,7 @@ var ( defaultTwitterScopes = []string{"tweet.read", "users.read"} defaultRobloxScopes = []string{"openid", "profile"} // Default RPS cap per IP; raised from 10 to reduce false positives on busy UIs. - defaultRateLimitRPS = float64(30) + defaultRateLimitRPS = 30 defaultRateLimitBurst = 20 ) @@ -161,7 +161,7 @@ func init() { f.BoolVar(&rootArgs.config.DisableAdminHeaderAuth, "disable-admin-header-auth", false, "Disable admin authentication via X-Authorizer-Admin-Secret header") // Rate limiting flags - f.Float64Var(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting") + f.IntVar(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting") f.IntVar(&rootArgs.config.RateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, "Maximum burst size per IP for rate limiting") f.BoolVar(&rootArgs.config.RateLimitFailClosed, "rate-limit-fail-closed", false, "On rate-limit backend errors, reject with 503 instead of allowing the request") diff --git a/docs/storage-optional-null-fields.md b/docs/storage-optional-null-fields.md new file mode 100644 index 00000000..469dddfd --- /dev/null +++ b/docs/storage-optional-null-fields.md @@ -0,0 +1,46 @@ +# Optional fields and NULL semantics across storage providers + +This document explains how **nullable** fields (especially `*int64` timestamps such as `email_verified_at`, `phone_number_verified_at`, `revoked_timestamp`) are **stored**, **updated**, and why we **do not** add broad `json`/`bson` **`omitempty`** tags to those struct fields. + +## Goal + +Application code uses Go **nil pointers** to mean “unset / not verified / not revoked.” Updates must **clear** a previously set value when the pointer is **nil**, matching SQL **NULL** semantics. + +## Behaviour by provider + +| Provider | Typical update path | Nil pointer on update | +|----------|---------------------|------------------------| +| **SQL (GORM)** | `Save(&user)` | Written as **SQL NULL**. | +| **Cassandra** | JSON → map → `UPDATE` | Nil map values become **`= null`** in CQL. | +| **MongoDB** | `UpdateOne` with `$set` and the `User` struct | Driver marshals nil pointers as **BSON Null** when the field is **not** `omitempty`, so the field is cleared in the document. | +| **Couchbase** | `Upsert` full document | `encoding/json` encodes nil pointers as JSON **`null`** unless the field uses `json:",omitempty"`, in which case the key is **omitted** and old values can persist. | +| **ArangoDB** | `UpdateDocument` with struct | Encoding follows JSON-style rules; nil pointers become **`null`** when not omitted by tags. | +| **DynamoDB** | `UpdateItem` with **SET** from marshalled attributes | Nil pointers are **omitted from SET** (see `internal/storage/db/dynamodb/marshal.go`). Attributes are **not** removed automatically, so **explicit REMOVE** is required to clear a previously stored attribute. Implemented for users in `internal/storage/db/dynamodb/user.go` (`updateByHashKeyWithRemoves`, `userDynamoRemoveAttrsIfNil`). Reads may normalize `0` → unset via `normalizeUserOptionalPtrs`. | + +## Why not use `omitempty` on `json` / `bson` for nullable auth fields? + +For **document** databases, **`omitempty`** means: *if this pointer is nil, do not include this key in the encoded payload.* + +During an **update**, omitting a key usually means **“do not change this field”**, not **“set to null.”** That reproduces the DynamoDB-class bug: the old value remains. + +Therefore, for fields where **nil must clear** stored state, keep **`json` / `bson` tags without `omitempty`** (as in `internal/storage/schemas/user.go`) unless every call site is proven to do a **full document replace** and you have verified the driver behaviour end-to-end. + +MongoDB’s own guidance aligns with this: `omitempty` skips marshaling empty values, which is wrong when you need to persist **null** to clear a field in `$set`. + +## DynamoDB specifics + +- **PutItem**: Omitting nil pointers keeps items small; optional attributes may be absent (same idea as “omitempty” on write, but implemented in custom `marshalStruct` by skipping nil pointers). +- **UpdateItem**: Only **SET** attributes present in the marshalled map. Clearing requires **`REMOVE`** for the corresponding attribute names when the Go field is nil. +- Do **not** rely on adding `dynamo:",omitempty"` alone to “fix” updates: the custom marshaller already skips nil pointers; the gap was **REMOVE** on update, not tag-based omission. + +## Related code + +- Schema: `internal/storage/schemas/user.go` +- DynamoDB user update + REMOVE list: `internal/storage/db/dynamodb/user.go` +- DynamoDB update helper: `internal/storage/db/dynamodb/ops.go` (`updateByHashKeyWithRemoves`) +- DynamoDB marshal/unmarshal: `internal/storage/db/dynamodb/marshal.go` + +## TEST_DBS and memory_store tests + +- **`internal/memory_store/db`**: Runs one subtest per entry in `TEST_DBS` (URLs aligned with `internal/integration_tests/test_helper.go` — keep `test_config_test.go` in sync when adding backends). +- **`internal/memory_store` (Redis / in-memory)**: Not driven by `TEST_DBS`. In-memory tests always run; **Redis** subtests run only when **`TEST_ENABLE_REDIS=1`** (or `true`) and Redis is reachable (e.g. `localhost:6380` in `provider_test.go`). See `redisMemoryStoreTestsEnabled` in `provider_test.go`. diff --git a/go.mod b/go.mod index d46afaab..8609a543 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,13 @@ go 1.25.5 require ( github.com/99designs/gqlgen v0.17.73 github.com/arangodb/go-driver v1.6.0 - github.com/aws/aws-sdk-go v1.47.4 + github.com/aws/aws-sdk-go-v2 v1.41.5 + github.com/aws/aws-sdk-go-v2/config v1.32.14 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1 + github.com/aws/smithy-go v1.24.3 github.com/coreos/go-oidc/v3 v3.17.0 github.com/couchbase/gocb/v2 v2.6.4 github.com/ekristen/gorm-libsql v0.0.0-20231101204708-6e113112bcc2 @@ -15,7 +21,6 @@ require ( github.com/gocql/gocql v1.6.0 github.com/golang-jwt/jwt/v4 v4.5.2 github.com/google/uuid v1.6.0 - github.com/guregu/dynamo v1.20.2 github.com/pquerna/otp v1.4.0 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.6.3 @@ -41,10 +46,21 @@ require ( github.com/agnivade/levenshtein v1.2.1 // indirect github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 // indirect github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect - github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/couchbase/gocbcore/v10 v10.2.8 // indirect @@ -76,7 +92,6 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect diff --git a/go.sum b/go.sum index 6e5bc493..53c54f41 100644 --- a/go.sum +++ b/go.sum @@ -34,9 +34,44 @@ github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e h1:Xg+hGrY2 github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e/go.mod h1:mq7Shfa/CaixoDxiyAAc5jZ6CVBAyPaNQCGS7mkj4Ho= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= -github.com/aws/aws-sdk-go v1.44.306/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= -github.com/aws/aws-sdk-go v1.47.4 h1:IyhNbmPt+5ldi5HNzv7ZnXiqSglDMaJiZlzj4Yq3qnk= -github.com/aws/aws-sdk-go v1.47.4/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY= +github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= +github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37 h1:5jh3kI8vDKuAcNa87z3eytYvBCE4Tyk2S8vjdcLoMek= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37/go.mod h1:Q1MNQdT5LEs31od7h6zHZF2a6jjl+oI6/kBH3QYipoY= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37 h1:8clVLUTp0bBO7SZx5Nw0Q1XhF4ItE8W1hboDK9eHIN0= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37/go.mod h1:U0KNwNOuLUPi/vdIWh6qYVBK4v8zKjN6BPpbsHSmPPA= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1 h1:Vk+a1j2pXZHkkYqHmEdpwe8eX6NDtFSBGfzuauMEWYQ= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1/go.mod h1:wHrWCwhXZrl2PuCP5t36UTacy9fCHDJ+vw1r3qxTL5M= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14 h1:Cnlebj/RmCf/4O3q4suVLLB/SBhbQf4zCQre6Dav+4E= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14/go.mod h1:lB9U9zBLviMTUHcHaaJ/vDBkRpHxV5775VJcdnm1DFk= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21 h1:FTg+rVAPx1W21jsO57pxDS1ESy9a/JLFoaHeFubflJA= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21/go.mod h1:92xP4VIS1yO3eF2NPBaHGF4cmyZow8TmFzSaz1nNgzo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= +github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg= +github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= @@ -53,8 +88,6 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= @@ -165,8 +198,6 @@ github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/z github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/guregu/dynamo v1.20.2 h1:b3DZX68Nv0kCWGbMMRbLukWttkALUoiomtzSrDCDiJo= -github.com/guregu/dynamo v1.20.2/go.mod h1:rNSE8PT6IaNbcEno0/i0y6E5XFHDWUyLRxpGlF8O5CU= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -193,10 +224,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= @@ -363,7 +390,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= @@ -377,7 +403,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -393,7 +418,6 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -403,7 +427,6 @@ golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= @@ -413,7 +436,6 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= @@ -453,7 +475,6 @@ gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/config/config.go b/internal/config/config.go index 77f94795..a43984e2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -254,7 +254,7 @@ type Config struct { // Rate Limiting // RateLimitRPS is the maximum requests per second per IP - RateLimitRPS float64 + RateLimitRPS int // RateLimitBurst is the maximum burst size per IP RateLimitBurst int // RateLimitFailClosed rejects requests when the rate limit backend errors (default: fail-open). diff --git a/internal/integration_tests/test_helper.go b/internal/integration_tests/test_helper.go index 8b2cd28c..32878c91 100644 --- a/internal/integration_tests/test_helper.go +++ b/internal/integration_tests/test_helper.go @@ -124,10 +124,17 @@ func getDBURL(dbType string) string { } } -// getTestConfig returns a test config for the default database (postgres). -// For multi-DB testing, use runForEachDB instead. +// getTestConfig returns config for integration tests that expect a single storage backend. +// When TEST_DBS is unset, it behaves like "postgres" (one entry). When TEST_DBS lists +// exactly one database (e.g. make test-dynamodb sets dynamodb), that backend is used. +// If TEST_DBS lists multiple databases, Postgres is used so legacy tests keep a predictable default; +// use runForEachDB to exercise more than one DB. func getTestConfig() *config.Config { - return getTestConfigForDB(constants.DbTypePostgres, "postgres://postgres:postgres@localhost:5434/postgres") + dbs := getTestDBs() + if len(dbs) == 1 { + return getTestConfigForDB(dbs[0].DbType, dbs[0].DbURL) + } + return getTestConfigForDB(constants.DbTypeSqlite, "test.db") } // getTestConfigForDB returns a test config for a specific database type and URL @@ -168,6 +175,11 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config { cfg.CouchBaseBucket = "authorizer_test" } + // DynamoDB Local (and AWS) expect a region for signing; avoid picking up real AWS keys in tests. + if dbType == constants.DbTypeDynamoDB { + cfg.AWSRegion = "us-east-1" + } + return cfg } @@ -203,6 +215,12 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { // Initialize logger logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + if cfg.DatabaseType == constants.DbTypeDynamoDB { + // Match storage tests: use static creds from config instead of ambient AWS_* env. + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + } + if cfg.DatabaseType == constants.DbTypeSqlite || cfg.DatabaseType == constants.DbTypeLibSQL { cfg.DatabaseURL = filepath.Join(t.TempDir(), "authorizer_integration.db") } diff --git a/internal/memory_store/db/provider_test.go b/internal/memory_store/db/provider_test.go index 347465c7..347f4aaa 100644 --- a/internal/memory_store/db/provider_test.go +++ b/internal/memory_store/db/provider_test.go @@ -1,126 +1,128 @@ package db import ( + "os" "path/filepath" "testing" "time" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/config" + "github.com/authorizerdev/authorizer/internal/constants" "github.com/authorizerdev/authorizer/internal/storage" ) -// TestDBMemoryStoreProvider tests the database-backed memory store provider -// This test requires a database to be configured +// TestDBMemoryStoreProvider tests the database-backed memory store against each database in +// TEST_DBS (same semantics as integration_tests). Redis is irrelevant here. func TestDBMemoryStoreProvider(t *testing.T) { - dbFile := filepath.Join(t.TempDir(), "authorizer_test.db") - cfg := &config.Config{ - DatabaseType: "sqlite", - DatabaseURL: dbFile, - Env: "test", + entries := storageTestDBEntriesFromEnv() + if len(entries) == 0 { + t.Fatal("TEST_DBS produced no database configurations") } - // Create storage provider - storageProvider, err := storage.New(cfg, &storage.Dependencies{}) - if err != nil { - t.Skipf("Skipping test: failed to create storage provider: %v", err) - return - } + for _, e := range entries { + t.Run("db="+e.dbType, func(t *testing.T) { + tempSQLite := filepath.Join(t.TempDir(), "memory_store_test.db") + dbURL := resolveSQLiteTestURL(e.dbType, e.dbURL, tempSQLite) + cfg := buildStorageTestConfigForMemoryStore(e.dbType, dbURL) + if cfg.DatabaseType == constants.DbTypeDynamoDB { + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + } + + log := zerolog.New(zerolog.NewTestWriter(t)) + storageProvider, err := storage.New(cfg, &storage.Dependencies{Log: &log}) + if err != nil { + t.Skipf("skipping: storage provider for %s: %v", e.dbType, err) + return + } - // Create DB memory store provider - p, err := NewDBProvider(cfg, &Dependencies{ - StorageProvider: storageProvider, - }) - require.NoError(t, err) - require.NotNil(t, p) + p, err := NewDBProvider(cfg, &Dependencies{ + Log: &log, + StorageProvider: storageProvider, + }) + require.NoError(t, err) + require.NotNil(t, p) - // Test SetUserSession and GetUserSession - err = p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) - err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) - // Get session - key, err := p.GetUserSession("auth_provider:123", "session_token_key") - assert.NoError(t, err) - assert.Equal(t, "test_hash123", key) + key, err := p.GetUserSession("auth_provider:123", "session_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_hash123", key) - key, err = p.GetUserSession("auth_provider:123", "access_token_key") - assert.NoError(t, err) - assert.Equal(t, "test_jwt123", key) + key, err = p.GetUserSession("auth_provider:123", "access_token_key") + assert.NoError(t, err) + assert.Equal(t, "test_jwt123", key) - // Test expiration - err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(1*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(1*time.Second).Unix()) + assert.NoError(t, err) - time.Sleep(2 * time.Second) + time.Sleep(2 * time.Second) - key, err = p.GetUserSession("auth_provider:124", "session_token_key") - assert.Empty(t, key) - assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:124", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) - // Test DeleteUserSession: for DB-backed store, the key argument is the suffix - // (e.g. \"key\"), while the stored keys are \"session_token_key\", \"access_token_key\", etc. - // This matches the in-memory provider behavior and real usage in the codebase. - err = p.DeleteUserSession("auth_provider:123", "key") - assert.NoError(t, err) + err = p.DeleteUserSession("auth_provider:123", "key") + assert.NoError(t, err) - key, err = p.GetUserSession("auth_provider:123", "session_token_key") - assert.Empty(t, key) - assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) - // Test DeleteAllUserSessions - err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) - err = p.DeleteAllUserSessions("123") - assert.NoError(t, err) + err = p.DeleteAllUserSessions("123") + assert.NoError(t, err) - key, err = p.GetUserSession("auth_provider:123", "session_token_key1") - assert.Empty(t, key) - assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:123", "session_token_key1") + assert.Empty(t, key) + assert.Error(t, err) - // Test DeleteSessionForNamespace - err = p.SetUserSession("auth_provider:125", "session_token_key", "test_hash125", time.Now().Add(60*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetUserSession("auth_provider:125", "session_token_key", "test_hash125", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) - err = p.DeleteSessionForNamespace("auth_provider") - assert.NoError(t, err) + err = p.DeleteSessionForNamespace("auth_provider") + assert.NoError(t, err) - key, err = p.GetUserSession("auth_provider:125", "session_token_key") - assert.Empty(t, key) - assert.Error(t, err) + key, err = p.GetUserSession("auth_provider:125", "session_token_key") + assert.Empty(t, key) + assert.Error(t, err) - // Test MFA sessions - err = p.SetMfaSession("auth_provider:123", "session123", time.Now().Add(60*time.Second).Unix()) - assert.NoError(t, err) + err = p.SetMfaSession("auth_provider:123", "session123", time.Now().Add(60*time.Second).Unix()) + assert.NoError(t, err) - key, err = p.GetMfaSession("auth_provider:123", "session123") - assert.NoError(t, err) - assert.Equal(t, "auth_provider:123", key) + key, err = p.GetMfaSession("auth_provider:123", "session123") + assert.NoError(t, err) + assert.Equal(t, "auth_provider:123", key) - err = p.DeleteMfaSession("auth_provider:123", "session123") - assert.NoError(t, err) + err = p.DeleteMfaSession("auth_provider:123", "session123") + assert.NoError(t, err) - key, err = p.GetMfaSession("auth_provider:123", "session123") - assert.Error(t, err) - assert.Empty(t, key) + key, err = p.GetMfaSession("auth_provider:123", "session123") + assert.Error(t, err) + assert.Empty(t, key) - // Test OAuth state - err = p.SetState("test_state_key", "test_state_value") - assert.NoError(t, err) + err = p.SetState("test_state_key", "test_state_value") + assert.NoError(t, err) - state, err := p.GetState("test_state_key") - assert.NoError(t, err) - assert.Equal(t, "test_state_value", state) + state, err := p.GetState("test_state_key") + assert.NoError(t, err) + assert.Equal(t, "test_state_value", state) - err = p.RemoveState("test_state_key") - assert.NoError(t, err) + err = p.RemoveState("test_state_key") + assert.NoError(t, err) - state, err = p.GetState("test_state_key") - assert.Error(t, err) - assert.Empty(t, state) + state, err = p.GetState("test_state_key") + assert.Error(t, err) + assert.Empty(t, state) + }) + } } diff --git a/internal/memory_store/db/test_config_test.go b/internal/memory_store/db/test_config_test.go new file mode 100644 index 00000000..0dc48b95 --- /dev/null +++ b/internal/memory_store/db/test_config_test.go @@ -0,0 +1,114 @@ +package db + +import ( + "os" + "strings" + + "github.com/authorizerdev/authorizer/internal/config" + "github.com/authorizerdev/authorizer/internal/constants" +) + +// storageDBEntry matches one entry from TEST_DBS (same URLs as internal/integration_tests +// getTestDBs / getDBURL — keep these in sync when adding backends). +type storageDBEntry struct { + dbType string + dbURL string +} + +func storageTestDBEntriesFromEnv() []storageDBEntry { + testDBsEnv := os.Getenv("TEST_DBS") + if testDBsEnv == "" { + testDBsEnv = "postgres" + } + var out []storageDBEntry + for _, dbType := range strings.Split(testDBsEnv, ",") { + dbType = strings.TrimSpace(dbType) + if dbType == "" { + continue + } + u := dbURLForMemoryStoreStorageTest(dbType) + if u == "" { + continue + } + out = append(out, storageDBEntry{dbType: dbType, dbURL: u}) + } + return out +} + +func dbURLForMemoryStoreStorageTest(dbType string) string { + switch dbType { + case constants.DbTypePostgres: + return "postgres://postgres:postgres@localhost:5434/postgres" + case constants.DbTypeSqlite: + return "test.db" + case constants.DbTypeLibSQL: + return "test.db" + case constants.DbTypeMysql: + return "root:password@tcp(localhost:3306)/authorizer" + case constants.DbTypeMariaDB: + return "root:password@tcp(localhost:3307)/authorizer" + case constants.DbTypeSqlserver: + return "sqlserver://sa:Password123@localhost:1433?database=authorizer" + case constants.DbTypeMongoDB: + return "mongodb://localhost:27017" + case constants.DbTypeArangoDB: + return "http://localhost:8529" + case constants.DbTypeScyllaDB, constants.DbTypeCassandraDB: + return "127.0.0.1:9042" + case constants.DbTypeDynamoDB: + return "http://127.0.0.1:8000" + case constants.DbTypeCouchbaseDB: + return "couchbase://localhost" + default: + return "" + } +} + +func resolveSQLiteTestURL(dbType, mappedURL, tempPath string) string { + if dbType == constants.DbTypeSqlite || dbType == constants.DbTypeLibSQL { + return tempPath + } + return mappedURL +} + +func buildStorageTestConfigForMemoryStore(dbType, dbURL string) *config.Config { + cfg := &config.Config{ + Env: constants.TestEnv, + SkipTestEndpointSSRFValidation: true, + DatabaseType: dbType, + DatabaseURL: dbURL, + JWTSecret: "test-secret", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + AllowedOrigins: []string{"http://localhost:3000"}, + JWTType: "HS256", + AdminSecret: "test-admin-secret", + TwilioAPISecret: "test-twilio-api-secret", + TwilioAPIKey: "test-twilio-api-key", + TwilioAccountSID: "test-twilio-account-sid", + TwilioSender: "test-twilio-sender", + DefaultRoles: []string{"user"}, + EnableSignup: true, + EnableBasicAuthentication: true, + EnableMobileBasicAuthentication: true, + EnableLoginPage: true, + EnableStrongPassword: true, + IsSMSServiceEnabled: true, + } + + if dbType == constants.DbTypeMongoDB { + cfg.DatabaseName = "authorizer_test" + } + + if dbType == constants.DbTypeCouchbaseDB { + cfg.DatabaseUsername = "Administrator" + cfg.DatabasePassword = "password" + cfg.CouchBaseBucket = "authorizer_test" + } + + if dbType == constants.DbTypeDynamoDB { + cfg.AWSRegion = "us-east-1" + } + + return cfg +} diff --git a/internal/memory_store/provider_test.go b/internal/memory_store/provider_test.go index 244bb375..bc7c00bf 100644 --- a/internal/memory_store/provider_test.go +++ b/internal/memory_store/provider_test.go @@ -1,6 +1,8 @@ package memory_store import ( + "os" + "strings" "testing" "time" @@ -17,9 +19,13 @@ const ( memoryStoreTypeDB = "db" ) -var memoryStoreTypes = []string{ - memoryStoreTypeRedis, - memoryStoreTypeInMemory, +func memoryStoreTypesForTest() []string { + var types []string + if redisMemoryStoreTestsEnabled() { + types = append(types, memoryStoreTypeRedis) + } + types = append(types, memoryStoreTypeInMemory) + return types } func getTestMemoryStorageConfig(storageType string) *config.Config { @@ -40,9 +46,10 @@ func getTestMemoryStorageConfig(storageType string) *config.Config { return cfg } -// TestMemoryStoreProvider tests the memory store provider +// TestMemoryStoreProvider tests the in-memory provider always; Redis only when TEST_ENABLE_REDIS=1. +// TEST_DBS does not apply (these are not storage-backend tests). func TestMemoryStoreProvider(t *testing.T) { - for _, storeType := range memoryStoreTypes { + for _, storeType := range memoryStoreTypesForTest() { t.Run("should test memory store provider for "+storeType, func(t *testing.T) { cfg := getTestMemoryStorageConfig(storeType) logger := zerolog.Nop() @@ -50,7 +57,7 @@ func TestMemoryStoreProvider(t *testing.T) { Log: &logger, }) if storeType == memoryStoreTypeRedis && err != nil { - t.Skipf("skipping redis memory store test: %v", err) + t.Skipf("skipping redis memory store test (is Redis running on localhost:6380?): %v", err) } require.NoError(t, err) require.NotNil(t, p) @@ -170,3 +177,8 @@ func TestMemoryStoreProvider(t *testing.T) { }) } } + +func redisMemoryStoreTestsEnabled() bool { + v := strings.TrimSpace(os.Getenv("TEST_ENABLE_REDIS")) + return v == "1" || strings.EqualFold(v, "true") +} diff --git a/internal/rate_limit/provider.go b/internal/rate_limit/provider.go index c4fa73a8..a01df503 100644 --- a/internal/rate_limit/provider.go +++ b/internal/rate_limit/provider.go @@ -29,6 +29,10 @@ type Dependencies struct { // New creates a new rate limit provider based on available infrastructure. // Uses Redis when RedisStore is provided, falls back to in-memory. func New(cfg *config.Config, deps *Dependencies) (Provider, error) { + deps.Log.Info(). + Int("rate_limit_rps", cfg.RateLimitRPS). + Int("rate_limit_burst", cfg.RateLimitBurst). + Msg("Creating rate limit provider") if deps.RedisStore != nil { return newRedisProvider(cfg, deps) } diff --git a/internal/rate_limit/redis.go b/internal/rate_limit/redis.go index 8e666f27..f1d27121 100644 --- a/internal/rate_limit/redis.go +++ b/internal/rate_limit/redis.go @@ -42,9 +42,9 @@ func newRedisProvider(cfg *config.Config, deps *Dependencies) (*redisProvider, e // Window = burst / rps, minimum 1 second window := 1 if cfg.RateLimitRPS > 0 { - w := float64(cfg.RateLimitBurst) / cfg.RateLimitRPS + w := int(cfg.RateLimitBurst / cfg.RateLimitRPS) if w > 1 { - window = int(w) + window = w } } return &redisProvider{ diff --git a/internal/storage/db/dynamodb/audit_log.go b/internal/storage/db/dynamodb/audit_log.go index 472d7108..5036aa7c 100644 --- a/internal/storage/db/dynamodb/audit_log.go +++ b/internal/storage/db/dynamodb/audit_log.go @@ -2,10 +2,12 @@ package dynamodb import ( "context" + "sort" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -13,7 +15,6 @@ import ( // AddAuditLog adds an audit log entry func (p *provider) AddAuditLog(ctx context.Context, auditLog *schemas.AuditLog) error { - collection := p.db.Table(schemas.Collections.AuditLog) if auditLog.ID == "" { auditLog.ID = uuid.New().String() } @@ -21,85 +22,138 @@ func (p *provider) AddAuditLog(ctx context.Context, auditLog *schemas.AuditLog) if auditLog.CreatedAt == 0 { auditLog.CreatedAt = time.Now().Unix() } - err := collection.Put(auditLog).RunWithContext(ctx) - if err != nil { - return err - } - return nil + return p.putItem(ctx, schemas.Collections.AuditLog, auditLog) } -// ListAuditLogs queries audit logs with filters and pagination -func (p *provider) ListAuditLogs(ctx context.Context, pagination *model.Pagination, filter map[string]interface{}) ([]*schemas.AuditLog, *model.Pagination, error) { - auditLogs := []*schemas.AuditLog{} - var auditLog *schemas.AuditLog - var lastEval dynamo.PagingKey - var iter dynamo.PagingIter - var iteration int64 = 0 - var err error - - collection := p.db.Table(schemas.Collections.AuditLog) - paginationClone := *pagination - scanner := collection.Scan() +func int64FromFilter(v interface{}) (int64, bool) { + switch x := v.(type) { + case int64: + return x, true + case int: + return int64(x), true + case int32: + return int64(x), true + case float64: + return int64(x), true + default: + return 0, false + } +} - // Apply filters - if action, ok := filter["action"]; ok && action != "" { - scanner = scanner.Filter("'action' = ?", action) +// auditLogExtraFilter builds a filter for attributes not covered by the primary key condition +// (omitKey is "action" or "actor_id" when that attribute is the partition key for the Query). +func auditLogExtraFilter(filter map[string]interface{}, omitKey string) *expression.ConditionBuilder { + var conds []expression.ConditionBuilder + if omitKey != "actor_id" { + if actorID, ok := filter["actor_id"]; ok && actorID != "" { + conds = append(conds, expression.Name("actor_id").Equal(expression.Value(actorID))) + } } - if actorID, ok := filter["actor_id"]; ok && actorID != "" { - scanner = scanner.Filter("'actor_id' = ?", actorID) + if omitKey != "action" { + if action, ok := filter["action"]; ok && action != "" { + conds = append(conds, expression.Name("action").Equal(expression.Value(action))) + } } if resourceType, ok := filter["resource_type"]; ok && resourceType != "" { - scanner = scanner.Filter("'resource_type' = ?", resourceType) + conds = append(conds, expression.Name("resource_type").Equal(expression.Value(resourceType))) } - - for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &auditLog) { - if paginationClone.Offset == iteration { - auditLogs = append(auditLogs, auditLog) - } + if resourceID, ok := filter["resource_id"]; ok && resourceID != "" { + conds = append(conds, expression.Name("resource_id").Equal(expression.Value(resourceID))) + } + if v, ok := filter["from_timestamp"]; ok { + if ts, ok := int64FromFilter(v); ok { + conds = append(conds, expression.Name("created_at").GreaterThanEqual(expression.Value(ts))) } - err = iter.Err() - if err != nil { - return nil, nil, err + } + if v, ok := filter["to_timestamp"]; ok { + if ts, ok := int64FromFilter(v); ok { + conds = append(conds, expression.Name("created_at").LessThanEqual(expression.Value(ts))) } - lastEval = iter.LastEvaluatedKey() - iteration += paginationClone.Limit } + if len(conds) == 0 { + return nil + } + merged := conds[0] + for i := 1; i < len(conds); i++ { + merged = merged.And(conds[i]) + } + return &merged +} + +// ListAuditLogs queries audit logs with filters and pagination +func (p *provider) ListAuditLogs(ctx context.Context, pagination *model.Pagination, filter map[string]interface{}) ([]*schemas.AuditLog, *model.Pagination, error) { + paginationClone := *pagination - // Count total matching documents - var total int64 - countScanner := collection.Scan() - if action, ok := filter["action"]; ok && action != "" { - countScanner = countScanner.Filter("'action' = ?", action) + var actionVal, actorVal string + if a, ok := filter["action"]; ok && a != "" { + if s, ok := a.(string); ok { + actionVal = s + } } - if actorID, ok := filter["actor_id"]; ok && actorID != "" { - countScanner = countScanner.Filter("'actor_id' = ?", actorID) + if a, ok := filter["actor_id"]; ok && a != "" { + if s, ok := a.(string); ok { + actorVal = s + } } - if resourceType, ok := filter["resource_type"]; ok && resourceType != "" { - countScanner = countScanner.Filter("'resource_type' = ?", resourceType) + + var items []map[string]types.AttributeValue + var err error + table := schemas.Collections.AuditLog + + switch { + case actionVal != "": + extra := auditLogExtraFilter(filter, "action") + items, err = p.queryEq(ctx, table, "action", "action", actionVal, extra) + case actorVal != "": + extra := auditLogExtraFilter(filter, "actor_id") + items, err = p.queryEq(ctx, table, "actor_id", "actor_id", actorVal, extra) + default: + extra := auditLogExtraFilter(filter, "") + items, err = p.scanAllRaw(ctx, table, nil, extra) } - var countItems []*schemas.AuditLog - if err = countScanner.AllWithContext(ctx, &countItems); err != nil { + if err != nil { return nil, nil, err } - total = int64(len(countItems)) + + var logs []*schemas.AuditLog + for _, it := range items { + var a schemas.AuditLog + if err := unmarshalItem(it, &a); err != nil { + return nil, nil, err + } + logs = append(logs, &a) + } + + sort.Slice(logs, func(i, j int) bool { return logs[i].CreatedAt > logs[j].CreatedAt }) + + total := int64(len(logs)) paginationClone.Total = total - return auditLogs, &paginationClone, nil + start := int(pagination.Offset) + if start >= len(logs) { + return []*schemas.AuditLog{}, &paginationClone, nil + } + end := start + int(pagination.Limit) + if end > len(logs) { + end = len(logs) + } + + return logs[start:end], &paginationClone, nil } // DeleteAuditLogsBefore removes logs older than a timestamp func (p *provider) DeleteAuditLogsBefore(ctx context.Context, before int64) error { - collection := p.db.Table(schemas.Collections.AuditLog) - var auditLogs []*schemas.AuditLog - err := collection.Scan().Filter("'created_at' < ?", before).AllWithContext(ctx, &auditLogs) + f := expression.Name("created_at").LessThan(expression.Value(before)) + items, err := p.scanFilteredAll(ctx, schemas.Collections.AuditLog, nil, &f) if err != nil { return err } - for _, auditLog := range auditLogs { - err := collection.Delete("id", auditLog.ID).RunWithContext(ctx) - if err != nil { + for _, it := range items { + var a schemas.AuditLog + if err := unmarshalItem(it, &a); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.AuditLog, "id", a.ID); err != nil { return err } } diff --git a/internal/storage/db/dynamodb/authenticator.go b/internal/storage/db/dynamodb/authenticator.go index 88e349ad..673a6510 100644 --- a/internal/storage/db/dynamodb/authenticator.go +++ b/internal/storage/db/dynamodb/authenticator.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" "github.com/google/uuid" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -14,44 +15,39 @@ func (p *provider) AddAuthenticator(ctx context.Context, authenticators *schemas if exists != nil { return authenticators, nil } - - collection := p.db.Table(schemas.Collections.Authenticators) if authenticators.ID == "" { authenticators.ID = uuid.New().String() } - authenticators.CreatedAt = time.Now().Unix() authenticators.UpdatedAt = time.Now().Unix() - err := collection.Put(authenticators).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.Authenticators, authenticators); err != nil { return nil, err } return authenticators, nil } func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *schemas.Authenticator) (*schemas.Authenticator, error) { - collection := p.db.Table(schemas.Collections.Authenticators) if authenticators.ID != "" { authenticators.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", authenticators.ID, authenticators) - if err != nil { + if err := p.updateByHashKey(ctx, schemas.Collections.Authenticators, "id", authenticators.ID, authenticators); err != nil { return nil, err } } return authenticators, nil - } func (p *provider) GetAuthenticatorDetailsByUserId(ctx context.Context, userId string, authenticatorType string) (*schemas.Authenticator, error) { - var authenticators *schemas.Authenticator - collection := p.db.Table(schemas.Collections.Authenticators) - iter := collection.Scan().Filter("'user_id' = ?", userId).Filter("'method' = ?", authenticatorType).Iter() - for iter.NextWithContext(ctx, &authenticators) { - return authenticators, nil - } - err := iter.Err() + f := expression.Name("user_id").Equal(expression.Value(userId)).And(expression.Name("method").Equal(expression.Value(authenticatorType))) + items, err := p.scanFilteredAll(ctx, schemas.Collections.Authenticators, nil, &f) if err != nil { return nil, err } - return authenticators, nil + if len(items) == 0 { + return nil, nil + } + var a schemas.Authenticator + if err := unmarshalItem(items[0], &a); err != nil { + return nil, err + } + return &a, nil } diff --git a/internal/storage/db/dynamodb/email_template.go b/internal/storage/db/dynamodb/email_template.go index 4dfe25e1..be936378 100644 --- a/internal/storage/db/dynamodb/email_template.go +++ b/internal/storage/db/dynamodb/email_template.go @@ -5,8 +5,8 @@ import ( "errors" "time" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -14,29 +14,22 @@ import ( // AddEmailTemplate to add EmailTemplate func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) (*schemas.EmailTemplate, error) { - collection := p.db.Table(schemas.Collections.EmailTemplate) if emailTemplate.ID == "" { emailTemplate.ID = uuid.New().String() } - emailTemplate.Key = emailTemplate.ID emailTemplate.CreatedAt = time.Now().Unix() emailTemplate.UpdatedAt = time.Now().Unix() - err := collection.Put(emailTemplate).RunWithContext(ctx) - - if err != nil { + if err := p.putItem(ctx, schemas.Collections.EmailTemplate, emailTemplate); err != nil { return nil, err } - return emailTemplate, nil } // UpdateEmailTemplate to update EmailTemplate func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) (*schemas.EmailTemplate, error) { - collection := p.db.Table(schemas.Collections.EmailTemplate) emailTemplate.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", emailTemplate.ID, emailTemplate) - if err != nil { + if err := p.updateByHashKey(ctx, schemas.Collections.EmailTemplate, "id", emailTemplate.ID, emailTemplate); err != nil { return nil, err } return emailTemplate, nil @@ -44,27 +37,35 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schem // ListEmailTemplates to list EmailTemplate func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) ([]*schemas.EmailTemplate, *model.Pagination, error) { - var emailTemplate *schemas.EmailTemplate - var iter dynamo.PagingIter - var lastEval dynamo.PagingKey - var iteration int64 = 0 - collection := p.db.Table(schemas.Collections.EmailTemplate) - emailTemplates := []*schemas.EmailTemplate{} + var lastKey map[string]types.AttributeValue + var iteration int64 paginationClone := pagination - scanner := collection.Scan() - count, err := scanner.Count() + var emailTemplates []*schemas.EmailTemplate + + count, err := p.scanCount(ctx, schemas.Collections.EmailTemplate, nil) if err != nil { return nil, nil, err } + for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &emailTemplate) { + items, next, err := p.scanPageIter(ctx, schemas.Collections.EmailTemplate, nil, int32(paginationClone.Limit), lastKey) + if err != nil { + return nil, nil, err + } + for _, it := range items { + var e schemas.EmailTemplate + if err := unmarshalItem(it, &e); err != nil { + return nil, nil, err + } if paginationClone.Offset == iteration { - emailTemplates = append(emailTemplates, emailTemplate) + emailTemplates = append(emailTemplates, &e) } } - lastEval = iter.LastEvaluatedKey() + lastKey = next iteration += paginationClone.Limit + if lastKey == nil { + break + } } paginationClone.Total = count return emailTemplates, paginationClone, nil @@ -72,39 +73,34 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagi // GetEmailTemplateByID to get EmailTemplate by id func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*schemas.EmailTemplate, error) { - collection := p.db.Table(schemas.Collections.EmailTemplate) - var emailTemplate *schemas.EmailTemplate - err := collection.Get("id", emailTemplateID).OneWithContext(ctx, &emailTemplate) - if err != nil { + var e schemas.EmailTemplate + if err := p.getItemByHash(ctx, schemas.Collections.EmailTemplate, "id", emailTemplateID, &e); err != nil { return nil, err } - return emailTemplate, nil + return &e, nil } // GetEmailTemplateByEventName to get EmailTemplate by event_name func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*schemas.EmailTemplate, error) { - collection := p.db.Table(schemas.Collections.EmailTemplate) - var emailTemplates []*schemas.EmailTemplate - var emailTemplate *schemas.EmailTemplate - err := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Limit(1).AllWithContext(ctx, &emailTemplates) + // Query the event_name GSI — Scan+Limit applies before filters and can return zero matching items. + items, err := p.queryEq(ctx, schemas.Collections.EmailTemplate, "event_name", "event_name", eventName, nil) if err != nil { return nil, err } - if len(emailTemplates) == 0 { + if len(items) == 0 { return nil, errors.New("no record found") - } - emailTemplate = emailTemplates[0] - return emailTemplate, nil + var e schemas.EmailTemplate + if err := unmarshalItem(items[0], &e); err != nil { + return nil, err + } + return &e, nil } // DeleteEmailTemplate to delete EmailTemplate func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) error { - collection := p.db.Table(schemas.Collections.EmailTemplate) - err := collection.Delete("id", emailTemplate.ID).RunWithContext(ctx) - if err != nil { - return err + if emailTemplate == nil { + return nil } - - return nil + return p.deleteItemByHash(ctx, schemas.Collections.EmailTemplate, "id", emailTemplate.ID) } diff --git a/internal/storage/db/dynamodb/env.go b/internal/storage/db/dynamodb/env.go index 7d0318c1..320dcc82 100644 --- a/internal/storage/db/dynamodb/env.go +++ b/internal/storage/db/dynamodb/env.go @@ -13,15 +13,13 @@ import ( // AddEnv to save environment information in database func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, error) { - collection := p.db.Table(schemas.Collections.Env) if env.ID == "" { env.ID = uuid.New().String() } env.Key = env.ID env.CreatedAt = time.Now().Unix() env.UpdatedAt = time.Now().Unix() - err := collection.Put(env).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.Env, env); err != nil { return nil, err } return env, nil @@ -29,10 +27,8 @@ func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, // UpdateEnv to update environment information in database func (p *provider) UpdateEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, error) { - collection := p.db.Table(schemas.Collections.Env) env.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", env.ID, env) - if err != nil { + if err := p.updateByHashKey(ctx, schemas.Collections.Env, "id", env.ID, env); err != nil { return nil, err } return env, nil @@ -40,20 +36,16 @@ func (p *provider) UpdateEnv(ctx context.Context, env *schemas.Env) (*schemas.En // GetEnv to get environment information from database func (p *provider) GetEnv(ctx context.Context) (*schemas.Env, error) { - var env *schemas.Env - collection := p.db.Table(schemas.Collections.Env) - // As there is no Find one supported. - iter := collection.Scan().Limit(1).Iter() - for iter.NextWithContext(ctx, &env) { - if env == nil { - return nil, errors.New("no documets found") - } else { - return env, nil - } - } - err := iter.Err() + items, err := p.scanFilteredLimit(ctx, schemas.Collections.Env, nil, nil, 1) if err != nil { - return env, fmt.Errorf("config not found") + return nil, err } - return env, nil + if len(items) == 0 { + return nil, errors.New("no documets found") + } + var env schemas.Env + if err := unmarshalItem(items[0], &env); err != nil { + return nil, fmt.Errorf("config not found") + } + return &env, nil } diff --git a/internal/storage/db/dynamodb/health_check.go b/internal/storage/db/dynamodb/health_check.go index 17fbb1a6..56173e5b 100644 --- a/internal/storage/db/dynamodb/health_check.go +++ b/internal/storage/db/dynamodb/health_check.go @@ -9,5 +9,17 @@ import ( // HealthCheck verifies that the DynamoDB backend is reachable and responsive func (p *provider) HealthCheck(ctx context.Context) error { var envs []schemas.Env - return p.db.Table(schemas.Collections.Env).Scan().Limit(1).AllWithContext(ctx, &envs) + items, err := p.scanFilteredLimit(ctx, schemas.Collections.Env, nil, nil, 1) + if err != nil { + return err + } + for _, it := range items { + var e schemas.Env + if err := unmarshalItem(it, &e); err != nil { + return err + } + envs = append(envs, e) + } + _ = envs + return nil } diff --git a/internal/storage/db/dynamodb/marshal.go b/internal/storage/db/dynamodb/marshal.go new file mode 100644 index 00000000..944525a5 --- /dev/null +++ b/internal/storage/db/dynamodb/marshal.go @@ -0,0 +1,141 @@ +package dynamodb + +import ( + "fmt" + "reflect" + "strings" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func dynamoFieldNameFromTag(tag string) string { + if tag == "" || tag == "-" { + return "" + } + parts := strings.Split(tag, ",") + name := strings.TrimSpace(parts[0]) + if name == "" { + return "" + } + return name +} + +func dynamoOmitEmpty(tag string) bool { + return strings.Contains(tag, "omitempty") +} + +func marshalStruct(v interface{}) (map[string]types.AttributeValue, error) { + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return nil, fmt.Errorf("nil value") + } + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return nil, fmt.Errorf("expected struct or pointer to struct") + } + rt := rv.Type() + out := make(map[string]types.AttributeValue) + for i := 0; i < rv.NumField(); i++ { + sf := rt.Field(i) + if !sf.IsExported() { + continue + } + tag := sf.Tag.Get("dynamo") + name := dynamoFieldNameFromTag(tag) + if name == "" { + continue + } + fv := rv.Field(i) + if dynamoOmitEmpty(tag) && isEmptyValue(fv) { + continue + } + if fv.Kind() == reflect.Ptr && fv.IsNil() { + continue + } + av, err := attributevalue.Marshal(fv.Interface()) + if err != nil { + return nil, err + } + out[name] = av + } + return out, nil +} + +func isEmptyValue(v reflect.Value) bool { + if !v.IsValid() { + return true + } + switch v.Kind() { + case reflect.String: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Ptr, reflect.Interface: + return v.IsNil() + default: + return false + } +} + +func marshalMapStringInterface(m map[string]interface{}) (map[string]types.AttributeValue, error) { + out := make(map[string]types.AttributeValue, len(m)) + for k, v := range m { + av, err := attributevalue.Marshal(v) + if err != nil { + return nil, err + } + out[k] = av + } + return out, nil +} + +func unmarshalItem(av map[string]types.AttributeValue, out interface{}) error { + rv := reflect.ValueOf(out) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("out must be non-nil pointer") + } + ev := rv.Elem() + if ev.Kind() != reflect.Struct { + return fmt.Errorf("out must be pointer to struct") + } + et := ev.Type() + for i := 0; i < ev.NumField(); i++ { + sf := et.Field(i) + if !sf.IsExported() { + continue + } + tag := sf.Tag.Get("dynamo") + name := dynamoFieldNameFromTag(tag) + if name == "" { + continue + } + dv, ok := av[name] + if !ok { + continue + } + field := ev.Field(i) + if !field.CanSet() { + continue + } + if field.Kind() == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + if err := attributevalue.Unmarshal(dv, field.Interface()); err != nil { + return err + } + continue + } + if err := attributevalue.Unmarshal(dv, field.Addr().Interface()); err != nil { + return err + } + } + return nil +} diff --git a/internal/storage/db/dynamodb/ops.go b/internal/storage/db/dynamodb/ops.go new file mode 100644 index 00000000..53f30c1e --- /dev/null +++ b/internal/storage/db/dynamodb/ops.go @@ -0,0 +1,316 @@ +package dynamodb + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func (p *provider) putItem(ctx context.Context, table string, v interface{}) error { + item, err := marshalStruct(v) + if err != nil { + return err + } + _, err = p.client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: aws.String(table), + Item: item, + }) + return err +} + +func (p *provider) getItemByHash(ctx context.Context, table, hashKey, hashValue string, out interface{}) error { + res, err := p.client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: aws.String(table), + Key: map[string]types.AttributeValue{ + hashKey: &types.AttributeValueMemberS{Value: hashValue}, + }, + }) + if err != nil { + return err + } + if len(res.Item) == 0 { + return fmt.Errorf("no record found") + } + return unmarshalItem(res.Item, out) +} + +func (p *provider) deleteItemByHash(ctx context.Context, table, hashKey, hashValue string) error { + _, err := p.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: aws.String(table), + Key: map[string]types.AttributeValue{ + hashKey: &types.AttributeValueMemberS{Value: hashValue}, + }, + }) + return err +} + +func (p *provider) scanAllRaw(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) { + var built expression.Expression + var hasFilter bool + if filter != nil { + e, err := expression.NewBuilder().WithFilter(*filter).Build() + if err != nil { + return nil, err + } + built = e + hasFilter = true + } + var out []map[string]types.AttributeValue + var start map[string]types.AttributeValue + for { + in := &dynamodb.ScanInput{ + TableName: aws.String(table), + ExclusiveStartKey: start, + } + if index != nil { + in.IndexName = index + } + if hasFilter { + in.FilterExpression = built.Filter() + in.ExpressionAttributeNames = built.Names() + in.ExpressionAttributeValues = built.Values() + } + res, err := p.client.Scan(ctx, in) + if err != nil { + return nil, err + } + out = append(out, res.Items...) + if res.LastEvaluatedKey == nil { + break + } + start = res.LastEvaluatedKey + } + return out, nil +} + +func (p *provider) scanCount(ctx context.Context, table string, filter *expression.ConditionBuilder) (int64, error) { + var built expression.Expression + var hasFilter bool + if filter != nil { + e, err := expression.NewBuilder().WithFilter(*filter).Build() + if err != nil { + return 0, err + } + built = e + hasFilter = true + } + var total int64 + var start map[string]types.AttributeValue + for { + in := &dynamodb.ScanInput{ + TableName: aws.String(table), + Select: types.SelectCount, + ExclusiveStartKey: start, + } + if hasFilter { + in.FilterExpression = built.Filter() + in.ExpressionAttributeNames = built.Names() + in.ExpressionAttributeValues = built.Values() + } + res, err := p.client.Scan(ctx, in) + if err != nil { + return 0, err + } + total += int64(res.Count) + if res.LastEvaluatedKey == nil { + break + } + start = res.LastEvaluatedKey + } + return total, nil +} + +func (p *provider) queryEq(ctx context.Context, table, indexName, pkAttr, pkVal string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) { + kc := expression.Key(pkAttr).Equal(expression.Value(pkVal)) + var eb expression.Builder + if filter != nil { + eb = expression.NewBuilder().WithKeyCondition(kc).WithFilter(*filter) + } else { + eb = expression.NewBuilder().WithKeyCondition(kc) + } + expr, err := eb.Build() + if err != nil { + return nil, err + } + var out []map[string]types.AttributeValue + var start map[string]types.AttributeValue + for { + in := &dynamodb.QueryInput{ + TableName: aws.String(table), + IndexName: aws.String(indexName), + KeyConditionExpression: expr.KeyCondition(), + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + ExclusiveStartKey: start, + } + if filter != nil { + in.FilterExpression = expr.Filter() + } + res, err := p.client.Query(ctx, in) + if err != nil { + return nil, err + } + out = append(out, res.Items...) + if res.LastEvaluatedKey == nil { + break + } + start = res.LastEvaluatedKey + } + return out, nil +} + +func (p *provider) queryEqLimit(ctx context.Context, table, indexName, pkAttr, pkVal string, filter *expression.ConditionBuilder, limit int32) ([]map[string]types.AttributeValue, error) { + kc := expression.Key(pkAttr).Equal(expression.Value(pkVal)) + var eb expression.Builder + if filter != nil { + eb = expression.NewBuilder().WithKeyCondition(kc).WithFilter(*filter) + } else { + eb = expression.NewBuilder().WithKeyCondition(kc) + } + expr, err := eb.Build() + if err != nil { + return nil, err + } + in := &dynamodb.QueryInput{ + TableName: aws.String(table), + IndexName: aws.String(indexName), + KeyConditionExpression: expr.KeyCondition(), + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + Limit: aws.Int32(limit), + } + if filter != nil { + in.FilterExpression = expr.Filter() + } + res, err := p.client.Query(ctx, in) + if err != nil { + return nil, err + } + return res.Items, nil +} + +func (p *provider) scanFilteredLimit(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder, limit int32) ([]map[string]types.AttributeValue, error) { + in := &dynamodb.ScanInput{ + TableName: aws.String(table), + Limit: aws.Int32(limit), + } + if index != nil { + in.IndexName = index + } + if filter != nil { + expr, err := expression.NewBuilder().WithFilter(*filter).Build() + if err != nil { + return nil, err + } + in.FilterExpression = expr.Filter() + in.ExpressionAttributeNames = expr.Names() + in.ExpressionAttributeValues = expr.Values() + } + res, err := p.client.Scan(ctx, in) + if err != nil { + return nil, err + } + return res.Items, nil +} + +func (p *provider) scanFilteredAll(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) { + return p.scanAllRaw(ctx, table, index, filter) +} + +func (p *provider) updateByHashKey(ctx context.Context, tableName, hashKeyName, hashValue string, item interface{}) error { + return p.updateByHashKeyWithRemoves(ctx, tableName, hashKeyName, hashValue, item, nil) +} + +// updateByHashKeyWithRemoves runs UpdateItem with SET from marshalled fields and optional REMOVE +// of attribute names (e.g. when mapping SQL NULL for optional pointer fields — nil is omitted from +// SET but the old DynamoDB attribute must be explicitly removed). +func (p *provider) updateByHashKeyWithRemoves(ctx context.Context, tableName, hashKeyName, hashValue string, item interface{}, removeAttrs []string) error { + var attrs map[string]types.AttributeValue + var err error + switch m := item.(type) { + case map[string]interface{}: + attrs, err = marshalMapStringInterface(m) + default: + attrs, err = marshalStruct(item) + } + if err != nil { + return err + } + delete(attrs, hashKeyName) + + names := map[string]string{} + vals := map[string]types.AttributeValue{} + var sets []string + i := 0 + for k, v := range attrs { + nk := "#n" + fmt.Sprint(i) + vk := ":v" + fmt.Sprint(i) + names[nk] = k + vals[vk] = v + sets = append(sets, nk+" = "+vk) + i++ + } + + var removeParts []string + for j, attr := range removeAttrs { + rk := "#r" + fmt.Sprint(j) + names[rk] = attr + removeParts = append(removeParts, rk) + } + + if len(sets) == 0 && len(removeParts) == 0 { + return nil + } + + var exprParts []string + if len(sets) > 0 { + exprParts = append(exprParts, "SET "+strings.Join(sets, ", ")) + } + if len(removeParts) > 0 { + exprParts = append(exprParts, "REMOVE "+strings.Join(removeParts, ", ")) + } + updateExpr := strings.Join(exprParts, " ") + + in := &dynamodb.UpdateItemInput{ + TableName: aws.String(tableName), + Key: map[string]types.AttributeValue{ + hashKeyName: &types.AttributeValueMemberS{Value: hashValue}, + }, + UpdateExpression: aws.String(updateExpr), + ExpressionAttributeNames: names, + } + if len(vals) > 0 { + in.ExpressionAttributeValues = vals + } + _, err = p.client.UpdateItem(ctx, in) + return err +} + +func (p *provider) scanPageIter(ctx context.Context, table string, filter *expression.ConditionBuilder, pageLimit int32, startKey map[string]types.AttributeValue) ([]map[string]types.AttributeValue, map[string]types.AttributeValue, error) { + in := &dynamodb.ScanInput{ + TableName: aws.String(table), + Limit: aws.Int32(pageLimit), + ExclusiveStartKey: startKey, + } + if filter != nil { + expr, err := expression.NewBuilder().WithFilter(*filter).Build() + if err != nil { + return nil, nil, err + } + in.FilterExpression = expr.Filter() + in.ExpressionAttributeNames = expr.Names() + in.ExpressionAttributeValues = expr.Values() + } + res, err := p.client.Scan(ctx, in) + if err != nil { + return nil, nil, err + } + return res.Items, res.LastEvaluatedKey, nil +} + +func strPtr(s string) *string { return &s } diff --git a/internal/storage/db/dynamodb/otp.go b/internal/storage/db/dynamodb/otp.go index 5e82eaf0..e305fead 100644 --- a/internal/storage/db/dynamodb/otp.go +++ b/internal/storage/db/dynamodb/otp.go @@ -5,6 +5,7 @@ import ( "errors" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" "github.com/google/uuid" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -12,7 +13,6 @@ import ( // UpsertOTP to add or update otp func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schemas.OTP, error) { - // check if email or phone number is present if otpParam.Email == "" && otpParam.PhoneNumber == "" { return nil, errors.New("email or phone_number is required") } @@ -43,13 +43,12 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem otp.Otp = otpParam.Otp otp.ExpiresAt = otpParam.ExpiresAt } - collection := p.db.Table(schemas.Collections.OTP) otp.UpdatedAt = time.Now().Unix() var err error if shouldCreate { - err = collection.Put(otp).RunWithContext(ctx) + err = p.putItem(ctx, schemas.Collections.OTP, otp) } else { - err = UpdateByHashKey(collection, "id", otp.ID, otp) + err = p.updateByHashKey(ctx, schemas.Collections.OTP, "id", otp.ID, otp) } if err != nil { return nil, err @@ -59,44 +58,41 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem // GetOTPByEmail to get otp for a given email address func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*schemas.OTP, error) { - var otps []schemas.OTP - var otp schemas.OTP - collection := p.db.Table(schemas.Collections.OTP) - err := collection.Scan().Index("email").Filter("'email' = ?", emailAddress).Limit(1).AllWithContext(ctx, &otps) + items, err := p.queryEqLimit(ctx, schemas.Collections.OTP, "email", "email", emailAddress, nil, 1) if err != nil { return nil, err } - if len(otps) > 0 { - otp = otps[0] - return &otp, nil + if len(items) == 0 { + return nil, errors.New("no docuemnt found") } - return nil, errors.New("no docuemnt found") + var otp schemas.OTP + if err := unmarshalItem(items[0], &otp); err != nil { + return nil, err + } + return &otp, nil } // GetOTPByPhoneNumber to get otp for a given phone number func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*schemas.OTP, error) { - var otps []schemas.OTP - var otp schemas.OTP - collection := p.db.Table(schemas.Collections.OTP) - err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).Limit(1).AllWithContext(ctx, &otps) + f := expression.Name("phone_number").Equal(expression.Value(phoneNumber)) + items, err := p.scanAllRaw(ctx, schemas.Collections.OTP, nil, &f) if err != nil { return nil, err } - if len(otps) > 0 { - otp = otps[0] - return &otp, nil + if len(items) == 0 { + return nil, errors.New("no docuemnt found") } - return nil, errors.New("no docuemnt found") + var otp schemas.OTP + if err := unmarshalItem(items[0], &otp); err != nil { + return nil, err + } + return &otp, nil } // DeleteOTP to delete otp func (p *provider) DeleteOTP(ctx context.Context, otp *schemas.OTP) error { - collection := p.db.Table(schemas.Collections.OTP) - if otp.ID != "" { - err := collection.Delete("id", otp.ID).RunWithContext(ctx) - if err != nil { - return err - } + if otp == nil || otp.ID == "" { + return nil } - return nil + return p.deleteItemByHash(ctx, schemas.Collections.OTP, "id", otp.ID) } diff --git a/internal/storage/db/dynamodb/provider.go b/internal/storage/db/dynamodb/provider.go index 7a305309..798f0a7b 100644 --- a/internal/storage/db/dynamodb/provider.go +++ b/internal/storage/db/dynamodb/provider.go @@ -5,14 +5,13 @@ import ( "fmt" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/guregu/dynamo" + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/rs/zerolog" "github.com/authorizerdev/authorizer/internal/config" - "github.com/authorizerdev/authorizer/internal/storage/schemas" ) // Dependencies struct the dynamodb data store provider @@ -23,75 +22,75 @@ type Dependencies struct { type provider struct { config *config.Config dependencies *Dependencies - db *dynamo.DB + client *dynamodb.Client } -// NewProvider returns a new Dynamo provider +// NewProvider returns a new Dynamo provider using AWS SDK for Go v2. func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) { dbURL := cfg.DatabaseURL awsRegion := cfg.AWSRegion awsAccessKeyID := cfg.AWSAccessKeyID awsSecretAccessKey := cfg.AWSSecretAccessKey - awsCfg := aws.Config{ - MaxRetries: aws.Int(3), - CredentialsChainVerboseErrors: aws.Bool(true), // for full error logs + region := awsRegion + if region == "" { + region = "us-east-1" } - if awsRegion != "" { - awsCfg.Region = aws.String(awsRegion) + loadOpts := []func(*awsconfig.LoadOptions) error{ + awsconfig.WithRegion(region), } - // custom awsAccessKeyID, awsSecretAccessKey took first priority, if not then fetch config from aws credentials + if awsAccessKeyID != "" && awsSecretAccessKey != "" { - awsCfg.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, "") + loadOpts = append(loadOpts, awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(awsAccessKeyID, awsSecretAccessKey, ""))) } else if dbURL != "" { deps.Log.Info().Msg("Using DB URL for dynamodb") - // static config in case of testing or local-setup - awsCfg.Credentials = credentials.NewStaticCredentials("key", "key", "") - awsCfg.Endpoint = aws.String(dbURL) + loadOpts = append(loadOpts, awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider("key", "key", ""))) } else { deps.Log.Info().Msg("Using default AWS credentials config from system for dynamodb") } - sess, err := session.NewSession(&awsCfg) + + if dbURL != "" { + resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if service == dynamodb.ServiceID { + return aws.Endpoint{ + URL: dbURL, + HostnameImmutable: true, + }, nil + } + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) + loadOpts = append(loadOpts, awsconfig.WithEndpointResolverWithOptions(resolver)) + } + + awsCfg, err := awsconfig.LoadDefaultConfig(context.Background(), loadOpts...) if err != nil { - return nil, fmt.Errorf("dynamodb session: %w", err) + return nil, fmt.Errorf("aws config: %w", err) + } + + client := dynamodb.NewFromConfig(awsCfg, func(o *dynamodb.Options) { + o.RetryMaxAttempts = 3 + }) + + p := &provider{ + client: client, + config: cfg, + dependencies: deps, } - db := dynamo.New(sess) createCtx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() - tables := []struct { - name string - model interface{} - }{ - {schemas.Collections.User, schemas.User{}}, - {schemas.Collections.Session, schemas.Session{}}, - {schemas.Collections.EmailTemplate, schemas.EmailTemplate{}}, - {schemas.Collections.Env, schemas.Env{}}, - {schemas.Collections.OTP, schemas.OTP{}}, - {schemas.Collections.VerificationRequest, schemas.VerificationRequest{}}, - {schemas.Collections.Webhook, schemas.Webhook{}}, - {schemas.Collections.WebhookLog, schemas.WebhookLog{}}, - {schemas.Collections.Authenticators, schemas.Authenticator{}}, - {schemas.Collections.SessionToken, schemas.SessionToken{}}, - {schemas.Collections.MFASession, schemas.MFASession{}}, - {schemas.Collections.OAuthState, schemas.OAuthState{}}, - {schemas.Collections.AuditLog, schemas.AuditLog{}}, - } - for _, tbl := range tables { - if werr := db.CreateTable(tbl.name, tbl.model).WaitWithContext(createCtx); werr != nil { - return nil, fmt.Errorf("dynamodb create/wait table %q: %w", tbl.name, werr) - } + if err := p.ensureTables(createCtx); err != nil { + return nil, err } - return &provider{ - db: db, - config: cfg, - dependencies: deps, - }, nil + + return p, nil } -// Close is a no-op; the AWS SDK session needs no explicit shutdown for typical use. +// Close is a no-op; the AWS SDK v2 client needs no explicit shutdown for typical use. func (p *provider) Close() error { return nil } diff --git a/internal/storage/db/dynamodb/session.go b/internal/storage/db/dynamodb/session.go index 3a06ae40..4d008747 100644 --- a/internal/storage/db/dynamodb/session.go +++ b/internal/storage/db/dynamodb/session.go @@ -11,14 +11,12 @@ import ( // AddSession to save session information in database func (p *provider) AddSession(ctx context.Context, session *schemas.Session) error { - collection := p.db.Table(schemas.Collections.Session) if session.ID == "" { session.ID = uuid.New().String() } session.CreatedAt = time.Now().Unix() session.UpdatedAt = time.Now().Unix() - err := collection.Put(session).RunWithContext(ctx) - return err + return p.putItem(ctx, schemas.Collections.Session, session) } // DeleteSession to delete session information from database diff --git a/internal/storage/db/dynamodb/session_token.go b/internal/storage/db/dynamodb/session_token.go index f6c2be3f..b004a74f 100644 --- a/internal/storage/db/dynamodb/session_token.go +++ b/internal/storage/db/dynamodb/session_token.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" "github.com/google/uuid" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -23,47 +24,44 @@ func (p *provider) AddSessionToken(ctx context.Context, token *schemas.SessionTo if token.UpdatedAt == 0 { token.UpdatedAt = time.Now().Unix() } - collection := p.db.Table(schemas.Collections.SessionToken) - return collection.Put(token).RunWithContext(ctx) + return p.putItem(ctx, schemas.Collections.SessionToken, token) } // GetSessionTokenByUserIDAndKey retrieves a session token by user ID and key func (p *provider) GetSessionTokenByUserIDAndKey(ctx context.Context, userId, key string) (*schemas.SessionToken, error) { - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan(). - Index("user_id"). - Filter("'user_id' = ? AND 'key_name' = ?", userId, key). - Limit(1). - AllWithContext(ctx, &tokens) + f := expression.Name("key_name").Equal(expression.Value(key)) + items, err := p.queryEqLimit(ctx, schemas.Collections.SessionToken, "user_id", "user_id", userId, &f, 1) if err != nil { return nil, err } - if len(tokens) == 0 { + if len(items) == 0 { return nil, errors.New("session token not found") } - return &tokens[0], nil + var t schemas.SessionToken + if err := unmarshalItem(items[0], &t); err != nil { + return nil, err + } + return &t, nil } // DeleteSessionToken deletes a session token by ID func (p *provider) DeleteSessionToken(ctx context.Context, id string) error { - collection := p.db.Table(schemas.Collections.SessionToken) - return collection.Delete("id", id).RunWithContext(ctx) + return p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", id) } // DeleteSessionTokenByUserIDAndKey deletes a session token by user ID and key func (p *provider) DeleteSessionTokenByUserIDAndKey(ctx context.Context, userId, key string) error { - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan(). - Index("user_id"). - Filter("'user_id' = ? AND 'key_name' = ?", userId, key). - AllWithContext(ctx, &tokens) + f := expression.Name("key_name").Equal(expression.Value(key)) + items, err := p.queryEq(ctx, schemas.Collections.SessionToken, "user_id", "user_id", userId, &f) if err != nil { return err } - for _, token := range tokens { - if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var t schemas.SessionToken + if err := unmarshalItem(it, &t); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil { return err } } @@ -72,15 +70,17 @@ func (p *provider) DeleteSessionTokenByUserIDAndKey(ctx context.Context, userId, // DeleteAllSessionTokensByUserID deletes all session tokens for a user ID func (p *provider) DeleteAllSessionTokensByUserID(ctx context.Context, userId string) error { - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan().AllWithContext(ctx, &tokens) + items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil) if err != nil { return err } - for _, token := range tokens { - if strings.Contains(token.UserID, userId) { - if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var t schemas.SessionToken + if err := unmarshalItem(it, &t); err != nil { + return err + } + if strings.Contains(t.UserID, userId) { + if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil { return err } } @@ -91,15 +91,17 @@ func (p *provider) DeleteAllSessionTokensByUserID(ctx context.Context, userId st // DeleteSessionTokensByNamespace deletes all session tokens for a namespace func (p *provider) DeleteSessionTokensByNamespace(ctx context.Context, namespace string) error { prefix := namespace + ":" - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan().AllWithContext(ctx, &tokens) + items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil) if err != nil { return err } - for _, token := range tokens { - if strings.HasPrefix(token.UserID, prefix) { - if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var t schemas.SessionToken + if err := unmarshalItem(it, &t); err != nil { + return err + } + if strings.HasPrefix(t.UserID, prefix) { + if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil { return err } } @@ -110,14 +112,17 @@ func (p *provider) DeleteSessionTokensByNamespace(ctx context.Context, namespace // CleanExpiredSessionTokens removes expired session tokens from the database func (p *provider) CleanExpiredSessionTokens(ctx context.Context) error { currentTime := time.Now().Unix() - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan().Filter("'expires_at' < ?", currentTime).AllWithContext(ctx, &tokens) + f := expression.Name("expires_at").LessThan(expression.Value(currentTime)) + items, err := p.scanFilteredAll(ctx, schemas.Collections.SessionToken, nil, &f) if err != nil { return err } - for _, token := range tokens { - if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var t schemas.SessionToken + if err := unmarshalItem(it, &t); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil { return err } } @@ -126,15 +131,17 @@ func (p *provider) CleanExpiredSessionTokens(ctx context.Context) error { // GetAllSessionTokens retrieves all session tokens (for testing) func (p *provider) GetAllSessionTokens(ctx context.Context) ([]*schemas.SessionToken, error) { - var tokens []schemas.SessionToken - collection := p.db.Table(schemas.Collections.SessionToken) - err := collection.Scan().AllWithContext(ctx, &tokens) + items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil) if err != nil { return nil, err } var result []*schemas.SessionToken - for i := range tokens { - result = append(result, &tokens[i]) + for _, it := range items { + var t schemas.SessionToken + if err := unmarshalItem(it, &t); err != nil { + return nil, err + } + result = append(result, &t) } return result, nil } @@ -151,47 +158,44 @@ func (p *provider) AddMFASession(ctx context.Context, session *schemas.MFASessio if session.UpdatedAt == 0 { session.UpdatedAt = time.Now().Unix() } - collection := p.db.Table(schemas.Collections.MFASession) - return collection.Put(session).RunWithContext(ctx) + return p.putItem(ctx, schemas.Collections.MFASession, session) } // GetMFASessionByUserIDAndKey retrieves an MFA session by user ID and key func (p *provider) GetMFASessionByUserIDAndKey(ctx context.Context, userId, key string) (*schemas.MFASession, error) { - var sessions []schemas.MFASession - collection := p.db.Table(schemas.Collections.MFASession) - err := collection.Scan(). - Index("user_id"). - Filter("'user_id' = ? AND 'key_name' = ?", userId, key). - Limit(1). - AllWithContext(ctx, &sessions) + f := expression.Name("key_name").Equal(expression.Value(key)) + items, err := p.queryEqLimit(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, &f, 1) if err != nil { return nil, err } - if len(sessions) == 0 { + if len(items) == 0 { return nil, errors.New("MFA session not found") } - return &sessions[0], nil + var s schemas.MFASession + if err := unmarshalItem(items[0], &s); err != nil { + return nil, err + } + return &s, nil } // DeleteMFASession deletes an MFA session by ID func (p *provider) DeleteMFASession(ctx context.Context, id string) error { - collection := p.db.Table(schemas.Collections.MFASession) - return collection.Delete("id", id).RunWithContext(ctx) + return p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", id) } // DeleteMFASessionByUserIDAndKey deletes an MFA session by user ID and key func (p *provider) DeleteMFASessionByUserIDAndKey(ctx context.Context, userId, key string) error { - var sessions []schemas.MFASession - collection := p.db.Table(schemas.Collections.MFASession) - err := collection.Scan(). - Index("user_id"). - Filter("'user_id' = ? AND 'key_name' = ?", userId, key). - AllWithContext(ctx, &sessions) + f := expression.Name("key_name").Equal(expression.Value(key)) + items, err := p.queryEq(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, &f) if err != nil { return err } - for _, session := range sessions { - if err := collection.Delete("id", session.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var s schemas.MFASession + if err := unmarshalItem(it, &s); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", s.ID); err != nil { return err } } @@ -200,18 +204,17 @@ func (p *provider) DeleteMFASessionByUserIDAndKey(ctx context.Context, userId, k // GetAllMFASessionsByUserID retrieves all MFA sessions for a user ID func (p *provider) GetAllMFASessionsByUserID(ctx context.Context, userId string) ([]*schemas.MFASession, error) { - var sessions []schemas.MFASession - collection := p.db.Table(schemas.Collections.MFASession) - err := collection.Scan(). - Index("user_id"). - Filter("'user_id' = ?", userId). - AllWithContext(ctx, &sessions) + items, err := p.queryEq(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, nil) if err != nil { return nil, err } var result []*schemas.MFASession - for i := range sessions { - result = append(result, &sessions[i]) + for _, it := range items { + var s schemas.MFASession + if err := unmarshalItem(it, &s); err != nil { + return nil, err + } + result = append(result, &s) } return result, nil } @@ -219,14 +222,17 @@ func (p *provider) GetAllMFASessionsByUserID(ctx context.Context, userId string) // CleanExpiredMFASessions removes expired MFA sessions from the database func (p *provider) CleanExpiredMFASessions(ctx context.Context) error { currentTime := time.Now().Unix() - var sessions []schemas.MFASession - collection := p.db.Table(schemas.Collections.MFASession) - err := collection.Scan().Filter("'expires_at' < ?", currentTime).AllWithContext(ctx, &sessions) + f := expression.Name("expires_at").LessThan(expression.Value(currentTime)) + items, err := p.scanFilteredAll(ctx, schemas.Collections.MFASession, nil, &f) if err != nil { return err } - for _, session := range sessions { - if err := collection.Delete("id", session.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var s schemas.MFASession + if err := unmarshalItem(it, &s); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", s.ID); err != nil { return err } } @@ -235,15 +241,17 @@ func (p *provider) CleanExpiredMFASessions(ctx context.Context) error { // GetAllMFASessions retrieves all MFA sessions (for testing) func (p *provider) GetAllMFASessions(ctx context.Context) ([]*schemas.MFASession, error) { - var sessions []schemas.MFASession - collection := p.db.Table(schemas.Collections.MFASession) - err := collection.Scan().AllWithContext(ctx, &sessions) + items, err := p.scanAllRaw(ctx, schemas.Collections.MFASession, nil, nil) if err != nil { return nil, err } var result []*schemas.MFASession - for i := range sessions { - result = append(result, &sessions[i]) + for _, it := range items { + var s schemas.MFASession + if err := unmarshalItem(it, &s); err != nil { + return nil, err + } + result = append(result, &s) } return result, nil } @@ -260,50 +268,45 @@ func (p *provider) AddOAuthState(ctx context.Context, state *schemas.OAuthState) if state.UpdatedAt == 0 { state.UpdatedAt = time.Now().Unix() } - // Delete existing state with same state_key first (upsert behavior) - var existing []schemas.OAuthState - collection := p.db.Table(schemas.Collections.OAuthState) - collection.Scan(). - Index("state_key"). - Filter("'state_key' = ?", state.StateKey). - AllWithContext(ctx, &existing) - for _, e := range existing { - collection.Delete("id", e.ID).RunWithContext(ctx) - } - return collection.Put(state).RunWithContext(ctx) + existing, _ := p.queryEq(ctx, schemas.Collections.OAuthState, "state_key", "state_key", state.StateKey, nil) + for _, it := range existing { + var e schemas.OAuthState + if err := unmarshalItem(it, &e); err != nil { + continue + } + _ = p.deleteItemByHash(ctx, schemas.Collections.OAuthState, "id", e.ID) + } + return p.putItem(ctx, schemas.Collections.OAuthState, state) } // GetOAuthStateByKey retrieves an OAuth state by key func (p *provider) GetOAuthStateByKey(ctx context.Context, key string) (*schemas.OAuthState, error) { - var states []schemas.OAuthState - collection := p.db.Table(schemas.Collections.OAuthState) - err := collection.Scan(). - Index("state_key"). - Filter("'state_key' = ?", key). - Limit(1). - AllWithContext(ctx, &states) + items, err := p.queryEqLimit(ctx, schemas.Collections.OAuthState, "state_key", "state_key", key, nil, 1) if err != nil { return nil, err } - if len(states) == 0 { + if len(items) == 0 { return nil, errors.New("OAuth state not found") } - return &states[0], nil + var s schemas.OAuthState + if err := unmarshalItem(items[0], &s); err != nil { + return nil, err + } + return &s, nil } // DeleteOAuthStateByKey deletes an OAuth state by key func (p *provider) DeleteOAuthStateByKey(ctx context.Context, key string) error { - var states []schemas.OAuthState - collection := p.db.Table(schemas.Collections.OAuthState) - err := collection.Scan(). - Index("state_key"). - Filter("'state_key' = ?", key). - AllWithContext(ctx, &states) + items, err := p.queryEq(ctx, schemas.Collections.OAuthState, "state_key", "state_key", key, nil) if err != nil { return err } - for _, state := range states { - if err := collection.Delete("id", state.ID).RunWithContext(ctx); err != nil { + for _, it := range items { + var s schemas.OAuthState + if err := unmarshalItem(it, &s); err != nil { + return err + } + if err := p.deleteItemByHash(ctx, schemas.Collections.OAuthState, "id", s.ID); err != nil { return err } } @@ -312,15 +315,17 @@ func (p *provider) DeleteOAuthStateByKey(ctx context.Context, key string) error // GetAllOAuthStates retrieves all OAuth states (for testing) func (p *provider) GetAllOAuthStates(ctx context.Context) ([]*schemas.OAuthState, error) { - var states []schemas.OAuthState - collection := p.db.Table(schemas.Collections.OAuthState) - err := collection.Scan().AllWithContext(ctx, &states) + items, err := p.scanAllRaw(ctx, schemas.Collections.OAuthState, nil, nil) if err != nil { return nil, err } var result []*schemas.OAuthState - for i := range states { - result = append(result, &states[i]) + for _, it := range items { + var s schemas.OAuthState + if err := unmarshalItem(it, &s); err != nil { + return nil, err + } + result = append(result, &s) } return result, nil } diff --git a/internal/storage/db/dynamodb/shared.go b/internal/storage/db/dynamodb/shared.go deleted file mode 100644 index 5597c0ad..00000000 --- a/internal/storage/db/dynamodb/shared.go +++ /dev/null @@ -1,40 +0,0 @@ -package dynamodb - -import ( - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/guregu/dynamo" -) - -// As updpate all item not supported so set manually via Set and SetNullable for empty field -func UpdateByHashKey(table dynamo.Table, hashKey string, hashValue string, item interface{}) error { - existingValue, err := dynamo.MarshalItem(item) - var i interface{} - if err != nil { - return err - } - nullableValue, err := dynamodbattribute.MarshalMap(item) - if err != nil { - return err - } - u := table.Update(hashKey, hashValue) - for k, v := range existingValue { - if k == hashKey { - continue - } - u = u.Set(k, v) - } - for k, v := range nullableValue { - if k == hashKey { - continue - } - dynamodbattribute.Unmarshal(v, &i) - if i == nil { - u = u.SetNullable(k, v) - } - } - err = u.Run() - if err != nil { - return err - } - return nil -} diff --git a/internal/storage/db/dynamodb/tables.go b/internal/storage/db/dynamodb/tables.go new file mode 100644 index 00000000..81b8b705 --- /dev/null +++ b/internal/storage/db/dynamodb/tables.go @@ -0,0 +1,184 @@ +package dynamodb + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/authorizerdev/authorizer/internal/storage/schemas" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/aws/smithy-go" +) + +func gsi(name, hashAttr string) types.GlobalSecondaryIndex { + return types.GlobalSecondaryIndex{ + IndexName: aws.String(name), + KeySchema: []types.KeySchemaElement{ + {AttributeName: aws.String(hashAttr), KeyType: types.KeyTypeHash}, + }, + Projection: &types.Projection{ProjectionType: types.ProjectionTypeAll}, + } +} + +func createTable(ctx context.Context, client *dynamodb.Client, name string, hashAttr string, attrs []types.AttributeDefinition, gsis []types.GlobalSecondaryIndex) error { + in := &dynamodb.CreateTableInput{ + TableName: aws.String(name), + BillingMode: types.BillingModePayPerRequest, + AttributeDefinitions: attrs, + KeySchema: []types.KeySchemaElement{ + {AttributeName: aws.String(hashAttr), KeyType: types.KeyTypeHash}, + }, + } + if len(gsis) > 0 { + in.GlobalSecondaryIndexes = gsis + } + _, err := client.CreateTable(ctx, in) + if err != nil { + var apiErr smithy.APIError + if errors.As(err, &apiErr) && apiErr.ErrorCode() == "ResourceInUseException" { + return nil + } + return err + } + w := dynamodb.NewTableExistsWaiter(client) + return w.Wait(ctx, &dynamodb.DescribeTableInput{TableName: aws.String(name)}, 5*time.Minute) +} + +func (p *provider) ensureTables(ctx context.Context) error { + tables := []struct { + name string + hash string + attr []types.AttributeDefinition + gsi []types.GlobalSecondaryIndex + }{ + { + name: schemas.Collections.User, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("email"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("email", "email")}, + }, + { + name: schemas.Collections.Session, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")}, + }, + { + name: schemas.Collections.Env, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + }, + }, + { + name: schemas.Collections.Webhook, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("event_name"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("event_name", "event_name")}, + }, + { + name: schemas.Collections.WebhookLog, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("webhook_id"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("webhook_id", "webhook_id")}, + }, + { + name: schemas.Collections.EmailTemplate, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("event_name"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("event_name", "event_name")}, + }, + { + name: schemas.Collections.OTP, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("email"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("email", "email")}, + }, + { + name: schemas.Collections.VerificationRequest, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("token"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("token", "token")}, + }, + { + name: schemas.Collections.Authenticators, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")}, + }, + { + name: schemas.Collections.SessionToken, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")}, + }, + { + name: schemas.Collections.MFASession, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")}, + }, + { + name: schemas.Collections.OAuthState, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("state_key"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{gsi("state_key", "state_key")}, + }, + { + name: schemas.Collections.AuditLog, + hash: "id", + attr: []types.AttributeDefinition{ + {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("actor_id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: aws.String("action"), AttributeType: types.ScalarAttributeTypeS}, + }, + gsi: []types.GlobalSecondaryIndex{ + gsi("actor_id", "actor_id"), + gsi("action", "action"), + }, + }, + } + + for _, t := range tables { + if err := createTable(ctx, p.client, t.name, t.hash, t.attr, t.gsi); err != nil { + return fmt.Errorf("create table %s: %w", t.name, err) + } + } + return nil +} diff --git a/internal/storage/db/dynamodb/user.go b/internal/storage/db/dynamodb/user.go index da517e4b..6f2bad70 100644 --- a/internal/storage/db/dynamodb/user.go +++ b/internal/storage/db/dynamodb/user.go @@ -7,17 +7,34 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/refs" "github.com/authorizerdev/authorizer/internal/storage/schemas" ) +// normalizeUserOptionalPtrs clears *int64 fields that round-trip as 0 from DynamoDB where other +// providers use SQL NULL — tests and handlers treat nil as "unset" for these fields. +func normalizeUserOptionalPtrs(u *schemas.User) { + if u == nil { + return + } + if u.EmailVerifiedAt != nil && *u.EmailVerifiedAt == 0 { + u.EmailVerifiedAt = nil + } + if u.PhoneNumberVerifiedAt != nil && *u.PhoneNumberVerifiedAt == 0 { + u.PhoneNumberVerifiedAt = nil + } + if u.RevokedTimestamp != nil && *u.RevokedTimestamp == 0 { + u.RevokedTimestamp = nil + } +} + // AddUser to save user information in database func (p *provider) AddUser(ctx context.Context, user *schemas.User) (*schemas.User, error) { - collection := p.db.Table(schemas.Collections.User) if user.ID == "" { user.ID = uuid.New().String() } @@ -35,20 +52,37 @@ func (p *provider) AddUser(ctx context.Context, user *schemas.User) (*schemas.Us } user.CreatedAt = time.Now().Unix() user.UpdatedAt = time.Now().Unix() - err := collection.Put(user).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.User, user); err != nil { return nil, err } return user, nil } +// userDynamoRemoveAttrsIfNil lists attribute names to REMOVE so that optional nil-pointer fields +// match SQL NULL semantics (omitting them from SET in DynamoDB would otherwise leave old values). +func userDynamoRemoveAttrsIfNil(u *schemas.User) []string { + if u == nil { + return nil + } + var remove []string + if u.EmailVerifiedAt == nil { + remove = append(remove, "email_verified_at") + } + if u.PhoneNumberVerifiedAt == nil { + remove = append(remove, "phone_number_verified_at") + } + if u.RevokedTimestamp == nil { + remove = append(remove, "revoked_timestamp") + } + return remove +} + // UpdateUser to update user information in database func (p *provider) UpdateUser(ctx context.Context, user *schemas.User) (*schemas.User, error) { - collection := p.db.Table(schemas.Collections.User) if user.ID != "" { user.UpdatedAt = time.Now().Unix() - err := UpdateByHashKey(collection, "id", user.ID, user) - if err != nil { + remove := userDynamoRemoveAttrsIfNil(user) + if err := p.updateByHashKeyWithRemoves(ctx, schemas.Collections.User, "id", user.ID, user, remove); err != nil { return nil, err } } @@ -57,15 +91,22 @@ func (p *provider) UpdateUser(ctx context.Context, user *schemas.User) (*schemas // DeleteUser to delete user information from database func (p *provider) DeleteUser(ctx context.Context, user *schemas.User) error { - collection := p.db.Table(schemas.Collections.User) - sessionCollection := p.db.Table(schemas.Collections.Session) - if user.ID != "" { - err := collection.Delete("id", user.ID).Run() - if err != nil { + if user.ID == "" { + return nil + } + if err := p.deleteItemByHash(ctx, schemas.Collections.User, "id", user.ID); err != nil { + return err + } + items, err := p.queryEq(ctx, schemas.Collections.Session, "user_id", "user_id", user.ID, nil) + if err != nil { + return err + } + for _, it := range items { + var s schemas.Session + if err := unmarshalItem(it, &s); err != nil { return err } - _, err = sessionCollection.Batch("id").Write().Delete(dynamo.Keys{"user_id", user.ID}).RunWithContext(ctx) - if err != nil { + if err := p.deleteItemByHash(ctx, schemas.Collections.Session, "id", s.ID); err != nil { return err } } @@ -74,31 +115,36 @@ func (p *provider) DeleteUser(ctx context.Context, user *schemas.User) error { // ListUsers to get list of users from database func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) ([]*schemas.User, *model.Pagination, error) { - var user *schemas.User - var lastEval dynamo.PagingKey - var iter dynamo.PagingIter - var iteration int64 = 0 - collection := p.db.Table(schemas.Collections.User) - var users []*schemas.User + var lastKey map[string]types.AttributeValue + var iteration int64 paginationClone := pagination - scanner := collection.Scan() - count, err := scanner.Count() + var users []*schemas.User + + count, err := p.scanCount(ctx, schemas.Collections.User, nil) if err != nil { return nil, nil, err } + for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &user) { + items, next, err := p.scanPageIter(ctx, schemas.Collections.User, nil, int32(paginationClone.Limit), lastKey) + if err != nil { + return nil, nil, err + } + for _, it := range items { + var u schemas.User + if err := unmarshalItem(it, &u); err != nil { + return nil, nil, err + } + normalizeUserOptionalPtrs(&u) if paginationClone.Offset == iteration { - users = append(users, user) + users = append(users, &u) } } - lastEval = iter.LastEvaluatedKey() + lastKey = next iteration += paginationClone.Limit - } - err = iter.Err() - if err != nil { - return nil, nil, err + if lastKey == nil { + break + } } paginationClone.Total = count return users, paginationClone, nil @@ -106,79 +152,77 @@ func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) // GetUserByEmail to get user information from database using email address func (p *provider) GetUserByEmail(ctx context.Context, email string) (*schemas.User, error) { - var users []*schemas.User - var user *schemas.User - collection := p.db.Table(schemas.Collections.User) - err := collection.Scan().Index("email").Filter("'email' = ?", email).AllWithContext(ctx, &users) + items, err := p.queryEq(ctx, schemas.Collections.User, "email", "email", email, nil) if err != nil { - return user, nil + return nil, err } - if len(users) > 0 { - user = users[0] - return user, nil - } else { + if len(items) == 0 { return nil, errors.New("no record found") } + var u schemas.User + if err := unmarshalItem(items[0], &u); err != nil { + return nil, err + } + normalizeUserOptionalPtrs(&u) + return &u, nil } // GetUserByID to get user information from database using user ID func (p *provider) GetUserByID(ctx context.Context, id string) (*schemas.User, error) { - collection := p.db.Table(schemas.Collections.User) - var user *schemas.User - err := collection.Get("id", id).OneWithContext(ctx, &user) + var user schemas.User + err := p.getItemByHash(ctx, schemas.Collections.User, "id", id, &user) if err != nil { - if refs.StringValue(user.Email) == "" { - return nil, errors.New("no documets found") - } else { - return user, nil - } + return nil, errors.New("no documets found") } - return user, nil + normalizeUserOptionalPtrs(&user) + return &user, nil } // UpdateUsers to update multiple users, with parameters of user IDs slice -// If ids set to nil / empty all the users will be updated func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error { - // set updated_at time for all users - userCollection := p.db.Table(schemas.Collections.User) - var allUsers []schemas.User - var res int64 = 0 + var res int64 var err error if len(ids) > 0 { for _, v := range ids { - err = UpdateByHashKey(userCollection, "id", v, data) + err = p.updateByHashKey(ctx, schemas.Collections.User, "id", v, data) } } else { - // as there is no facility to update all doc - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/SQLtoNoSQL.UpdateData.html - userCollection.Scan().All(&allUsers) - for _, user := range allUsers { - err = UpdateByHashKey(userCollection, "id", user.ID, data) + items, errScan := p.scanAllRaw(ctx, schemas.Collections.User, nil, nil) + if errScan != nil { + return errScan + } + for _, it := range items { + var user schemas.User + if err := unmarshalItem(it, &user); err != nil { + return err + } + err = p.updateByHashKey(ctx, schemas.Collections.User, "id", user.ID, data) if err == nil { - res = res + 1 + res++ } } } if err != nil { return err - } else { - p.dependencies.Log.Info().Int64("modified_count", res).Msg("users updated") } + p.dependencies.Log.Info().Int64("modified_count", res).Msg("users updated") return nil } // GetUserByPhoneNumber to get user information from database using phone number func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*schemas.User, error) { - var users []*schemas.User - var user *schemas.User - collection := p.db.Table(schemas.Collections.User) - err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).AllWithContext(ctx, &users) + f := expression.Name("phone_number").Equal(expression.Value(phoneNumber)) + items, err := p.scanFilteredAll(ctx, schemas.Collections.User, nil, &f) if err != nil { return nil, err } - if len(users) > 0 { - user = users[0] - return user, nil - } else { + if len(items) == 0 { return nil, errors.New("no record found") } + var u schemas.User + if err := unmarshalItem(items[0], &u); err != nil { + return nil, err + } + normalizeUserOptionalPtrs(&u) + return &u, nil } diff --git a/internal/storage/db/dynamodb/verification_requests.go b/internal/storage/db/dynamodb/verification_requests.go index c44dd52e..0d959190 100644 --- a/internal/storage/db/dynamodb/verification_requests.go +++ b/internal/storage/db/dynamodb/verification_requests.go @@ -4,8 +4,9 @@ import ( "context" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -13,13 +14,11 @@ import ( // AddVerification to save verification request in database func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *schemas.VerificationRequest) (*schemas.VerificationRequest, error) { - collection := p.db.Table(schemas.Collections.VerificationRequest) if verificationRequest.ID == "" { verificationRequest.ID = uuid.New().String() verificationRequest.CreatedAt = time.Now().Unix() verificationRequest.UpdatedAt = time.Now().Unix() - err := collection.Put(verificationRequest).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.VerificationRequest, verificationRequest); err != nil { return nil, err } } @@ -28,61 +27,68 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque // GetVerificationRequestByToken to get verification request from database using token func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*schemas.VerificationRequest, error) { - collection := p.db.Table(schemas.Collections.VerificationRequest) - var verificationRequest *schemas.VerificationRequest - iter := collection.Scan().Filter("'token' = ?", token).Iter() - for iter.NextWithContext(ctx, &verificationRequest) { - return verificationRequest, nil - } - err := iter.Err() + items, err := p.queryEq(ctx, schemas.Collections.VerificationRequest, "token", "token", token, nil) if err != nil { return nil, err } - return verificationRequest, nil + if len(items) == 0 { + return nil, nil + } + var v schemas.VerificationRequest + if err := unmarshalItem(items[0], &v); err != nil { + return nil, err + } + return &v, nil } // GetVerificationRequestByEmail to get verification request by email from database func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*schemas.VerificationRequest, error) { - var verificationRequest *schemas.VerificationRequest - collection := p.db.Table(schemas.Collections.VerificationRequest) - iter := collection.Scan().Filter("'email' = ?", email).Filter("'identifier' = ?", identifier).Iter() - for iter.NextWithContext(ctx, &verificationRequest) { - return verificationRequest, nil - } - err := iter.Err() + f := expression.Name("email").Equal(expression.Value(email)).And(expression.Name("identifier").Equal(expression.Value(identifier))) + items, err := p.scanFilteredAll(ctx, schemas.Collections.VerificationRequest, nil, &f) if err != nil { return nil, err } - return verificationRequest, nil + if len(items) == 0 { + return nil, nil + } + var v schemas.VerificationRequest + if err := unmarshalItem(items[0], &v); err != nil { + return nil, err + } + return &v, nil } // ListVerificationRequests to get list of verification requests from database func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) ([]*schemas.VerificationRequest, *model.Pagination, error) { - var verificationRequests []*schemas.VerificationRequest - var verificationRequest *schemas.VerificationRequest - var lastEval dynamo.PagingKey - var iter dynamo.PagingIter - var iteration int64 = 0 - collection := p.db.Table(schemas.Collections.VerificationRequest) + var lastKey map[string]types.AttributeValue + var iteration int64 paginationClone := pagination - scanner := collection.Scan() - count, err := scanner.Count() + var verificationRequests []*schemas.VerificationRequest + + count, err := p.scanCount(ctx, schemas.Collections.VerificationRequest, nil) if err != nil { return nil, nil, err } + for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &verificationRequest) { - if paginationClone.Offset == iteration { - verificationRequests = append(verificationRequests, verificationRequest) - } - } - err = iter.Err() + items, next, err := p.scanPageIter(ctx, schemas.Collections.VerificationRequest, nil, int32(paginationClone.Limit), lastKey) if err != nil { return nil, nil, err } - lastEval = iter.LastEvaluatedKey() + for _, it := range items { + var v schemas.VerificationRequest + if err := unmarshalItem(it, &v); err != nil { + return nil, nil, err + } + if paginationClone.Offset == iteration { + verificationRequests = append(verificationRequests, &v) + } + } + lastKey = next iteration += paginationClone.Limit + if lastKey == nil { + break + } } paginationClone.Total = count return verificationRequests, paginationClone, nil @@ -90,13 +96,8 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination *mod // DeleteVerificationRequest to delete verification request from database func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *schemas.VerificationRequest) error { - collection := p.db.Table(schemas.Collections.VerificationRequest) - if verificationRequest != nil { - err := collection.Delete("id", verificationRequest.ID).RunWithContext(ctx) - - if err != nil { - return err - } + if verificationRequest == nil { + return nil } - return nil + return p.deleteItemByHash(ctx, schemas.Collections.VerificationRequest, "id", verificationRequest.ID) } diff --git a/internal/storage/db/dynamodb/webhook.go b/internal/storage/db/dynamodb/webhook.go index e85ef964..53b95715 100644 --- a/internal/storage/db/dynamodb/webhook.go +++ b/internal/storage/db/dynamodb/webhook.go @@ -7,8 +7,9 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -16,17 +17,14 @@ import ( // AddWebhook to add webhook func (p *provider) AddWebhook(ctx context.Context, webhook *schemas.Webhook) (*schemas.Webhook, error) { - collection := p.db.Table(schemas.Collections.Webhook) if webhook.ID == "" { webhook.ID = uuid.New().String() } webhook.Key = webhook.ID webhook.CreatedAt = time.Now().Unix() webhook.UpdatedAt = time.Now().Unix() - // Add timestamp to make event name unique for legacy version webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) - err := collection.Put(webhook).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.Webhook, webhook); err != nil { return nil, err } return webhook, nil @@ -35,13 +33,10 @@ func (p *provider) AddWebhook(ctx context.Context, webhook *schemas.Webhook) (*s // UpdateWebhook to update webhook func (p *provider) UpdateWebhook(ctx context.Context, webhook *schemas.Webhook) (*schemas.Webhook, error) { webhook.UpdatedAt = time.Now().Unix() - // Event is changed if !strings.Contains(webhook.EventName, "-") { webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix()) } - collection := p.db.Table(schemas.Collections.Webhook) - err := UpdateByHashKey(collection, "id", webhook.ID, webhook) - if err != nil { + if err := p.updateByHashKey(ctx, schemas.Collections.Webhook, "id", webhook.ID, webhook); err != nil { return nil, err } return webhook, nil @@ -49,31 +44,35 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook *schemas.Webhook) // ListWebhooks to list webhook func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) ([]*schemas.Webhook, *model.Pagination, error) { - webhooks := []*schemas.Webhook{} - var webhook *schemas.Webhook - var lastEval dynamo.PagingKey - var iter dynamo.PagingIter - var iteration int64 = 0 - collection := p.db.Table(schemas.Collections.Webhook) + var lastKey map[string]types.AttributeValue + var iteration int64 paginationClone := pagination - scanner := collection.Scan() - count, err := scanner.Count() + var webhooks []*schemas.Webhook + + count, err := p.scanCount(ctx, schemas.Collections.Webhook, nil) if err != nil { return nil, nil, err } + for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &webhook) { - if paginationClone.Offset == iteration { - webhooks = append(webhooks, webhook) - } - } - err = iter.Err() + items, next, err := p.scanPageIter(ctx, schemas.Collections.Webhook, nil, int32(paginationClone.Limit), lastKey) if err != nil { return nil, nil, err } - lastEval = iter.LastEvaluatedKey() + for _, it := range items { + var w schemas.Webhook + if err := unmarshalItem(it, &w); err != nil { + return nil, nil, err + } + if paginationClone.Offset == iteration { + webhooks = append(webhooks, &w) + } + } + lastKey = next iteration += paginationClone.Limit + if lastKey == nil { + break + } } paginationClone.Total = count return webhooks, paginationClone, nil @@ -81,51 +80,52 @@ func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination // GetWebhookByID to get webhook by id func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*schemas.Webhook, error) { - collection := p.db.Table(schemas.Collections.Webhook) - var webhook *schemas.Webhook - err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook) + var webhook schemas.Webhook + err := p.getItemByHash(ctx, schemas.Collections.Webhook, "id", webhookID, &webhook) if err != nil { return nil, err } if webhook.ID == "" { return nil, errors.New("no document found") } - return webhook, nil + return &webhook, nil } // GetWebhookByEventName to get webhook by event_name func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*schemas.Webhook, error) { - webhooks := []*schemas.Webhook{} - collection := p.db.Table(schemas.Collections.Webhook) - err := collection.Scan().Index("event_name").Filter("contains(event_name, ?)", eventName).AllWithContext(ctx, &webhooks) + // Match SQL LIKE 'eventName%' (see sql/webhook.go); do not use Contains (substring match). + f := expression.Name("event_name").BeginsWith(eventName) + items, err := p.scanFilteredAll(ctx, schemas.Collections.Webhook, strPtr("event_name"), &f) if err != nil { return nil, err } - return webhooks, nil + var out []*schemas.Webhook + for _, it := range items { + var w schemas.Webhook + if err := unmarshalItem(it, &w); err != nil { + return nil, err + } + out = append(out, &w) + } + return out, nil } // DeleteWebhook to delete webhook func (p *provider) DeleteWebhook(ctx context.Context, webhook *schemas.Webhook) error { - // Also delete webhook logs for given webhook id - if webhook != nil { - webhookCollection := p.db.Table(schemas.Collections.Webhook) - webhookLogCollection := p.db.Table(schemas.Collections.WebhookLog) - err := webhookCollection.Delete("id", webhook.ID).RunWithContext(ctx) - if err != nil { - return err - } - pagination := &model.Pagination{} - webhookLogs, _, err := p.ListWebhookLogs(ctx, pagination, webhook.ID) - if err != nil { - p.dependencies.Log.Debug().Err(err).Msg("failed to list webhook logs") - } else { - for _, webhookLog := range webhookLogs { - err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx) - if err != nil { - p.dependencies.Log.Debug().Err(err).Msg("failed to delete webhook log") - // continue - } - } + if webhook == nil { + return nil + } + if err := p.deleteItemByHash(ctx, schemas.Collections.Webhook, "id", webhook.ID); err != nil { + return err + } + logs, _, err := p.ListWebhookLogs(ctx, &model.Pagination{}, webhook.ID) + if err != nil { + p.dependencies.Log.Debug().Err(err).Msg("failed to list webhook logs") + return nil + } + for _, wl := range logs { + if err := p.deleteItemByHash(ctx, schemas.Collections.WebhookLog, "id", wl.ID); err != nil { + p.dependencies.Log.Debug().Err(err).Msg("failed to delete webhook log") } } return nil diff --git a/internal/storage/db/dynamodb/webhook_log.go b/internal/storage/db/dynamodb/webhook_log.go index 2092b1d1..9b794829 100644 --- a/internal/storage/db/dynamodb/webhook_log.go +++ b/internal/storage/db/dynamodb/webhook_log.go @@ -4,8 +4,8 @@ import ( "context" "time" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/google/uuid" - "github.com/guregu/dynamo" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/storage/schemas" @@ -13,15 +13,13 @@ import ( // AddWebhookLog to add webhook log func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *schemas.WebhookLog) (*schemas.WebhookLog, error) { - collection := p.db.Table(schemas.Collections.WebhookLog) if webhookLog.ID == "" { webhookLog.ID = uuid.New().String() } webhookLog.Key = webhookLog.ID webhookLog.CreatedAt = time.Now().Unix() webhookLog.UpdatedAt = time.Now().Unix() - err := collection.Put(webhookLog).RunWithContext(ctx) - if err != nil { + if err := p.putItem(ctx, schemas.Collections.WebhookLog, webhookLog); err != nil { return nil, err } return webhookLog, nil @@ -29,43 +27,48 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *schemas.Webhoo // ListWebhookLogs to list webhook logs func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) ([]*schemas.WebhookLog, *model.Pagination, error) { + paginationClone := pagination + // Non-nil empty slice: callers/tests expect a slice value even when there are no rows. webhookLogs := []*schemas.WebhookLog{} - var webhookLog *schemas.WebhookLog - var lastEval dynamo.PagingKey - var iter dynamo.PagingIter - var iteration int64 = 0 - var err error - var count int64 - collection := p.db.Table(schemas.Collections.WebhookLog) - paginationClone := pagination - scanner := collection.Scan() if webhookID != "" { - iter = scanner.Index("webhook_id").Filter("'webhook_id' = ?", webhookID).Iter() - for iter.NextWithContext(ctx, &webhookLog) { - webhookLogs = append(webhookLogs, webhookLog) - } - err = iter.Err() + items, err := p.queryEq(ctx, schemas.Collections.WebhookLog, "webhook_id", "webhook_id", webhookID, nil) if err != nil { return nil, nil, err } - } else { - for (paginationClone.Offset + paginationClone.Limit) > iteration { - iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter() - for iter.NextWithContext(ctx, &webhookLog) { - if paginationClone.Offset == iteration { - webhookLogs = append(webhookLogs, webhookLog) - } + for _, it := range items { + var wl schemas.WebhookLog + if err := unmarshalItem(it, &wl); err != nil { + return nil, nil, err } - err = iter.Err() - if err != nil { + webhookLogs = append(webhookLogs, &wl) + } + paginationClone.Total = 0 + return webhookLogs, paginationClone, nil + } + + var lastKey map[string]types.AttributeValue + var iteration int64 + for (paginationClone.Offset + paginationClone.Limit) > iteration { + items, next, err := p.scanPageIter(ctx, schemas.Collections.WebhookLog, nil, int32(paginationClone.Limit), lastKey) + if err != nil { + return nil, nil, err + } + for _, it := range items { + var wl schemas.WebhookLog + if err := unmarshalItem(it, &wl); err != nil { return nil, nil, err } - lastEval = iter.LastEvaluatedKey() - iteration += paginationClone.Limit + if paginationClone.Offset == iteration { + webhookLogs = append(webhookLogs, &wl) + } + } + lastKey = next + iteration += paginationClone.Limit + if lastKey == nil { + break } } - paginationClone.Total = count - // paginationClone.Cursor = iter.LastEvaluatedKey() + paginationClone.Total = 0 return webhookLogs, paginationClone, nil } diff --git a/internal/storage/provider_test.go b/internal/storage/provider_test.go index cab6f6e6..721618e2 100644 --- a/internal/storage/provider_test.go +++ b/internal/storage/provider_test.go @@ -74,7 +74,8 @@ func getTestDBConfig(dbType string) *config.Config { // Allow extra time for Couchbase container to become ready in tests cfg.CouchBaseWaitTimeout = 120 case constants.DbTypeDynamoDB: - cfg.DatabaseURL = "http://0.0.0.0:8000" + // Must be a client-routable host (not bind address 0.0.0.0); matches integration_tests getDBURL. + cfg.DatabaseURL = "http://127.0.0.1:8000" } return cfg diff --git a/internal/storage/schemas/audit_log.go b/internal/storage/schemas/audit_log.go index 116f49f0..bd004c10 100644 --- a/internal/storage/schemas/audit_log.go +++ b/internal/storage/schemas/audit_log.go @@ -13,10 +13,10 @@ import ( type AuditLog struct { Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` - ActorID string `gorm:"type:char(36)" json:"actor_id" bson:"actor_id" cql:"actor_id" dynamo:"actor_id" index:"actor_id,hash"` + ActorID string `gorm:"type:char(36)" json:"actor_id" bson:"actor_id" cql:"actor_id" dynamo:"actor_id,omitempty" index:"actor_id,hash"` ActorType string `gorm:"type:varchar(30)" json:"actor_type" bson:"actor_type" cql:"actor_type" dynamo:"actor_type"` ActorEmail string `gorm:"type:varchar(256)" json:"actor_email" bson:"actor_email" cql:"actor_email" dynamo:"actor_email"` - Action string `gorm:"type:varchar(100)" json:"action" bson:"action" cql:"action" dynamo:"action" index:"action,hash"` + Action string `gorm:"type:varchar(100)" json:"action" bson:"action" cql:"action" dynamo:"action,omitempty" index:"action,hash"` ResourceType string `gorm:"type:varchar(50)" json:"resource_type" bson:"resource_type" cql:"resource_type" dynamo:"resource_type"` ResourceID string `gorm:"type:char(36)" json:"resource_id" bson:"resource_id" cql:"resource_id" dynamo:"resource_id"` IPAddress string `gorm:"type:varchar(45)" json:"ip_address" bson:"ip_address" cql:"ip_address" dynamo:"ip_address"` diff --git a/internal/storage/schemas/otp.go b/internal/storage/schemas/otp.go index 4fb94ead..1985f285 100644 --- a/internal/storage/schemas/otp.go +++ b/internal/storage/schemas/otp.go @@ -11,7 +11,7 @@ const ( type OTP struct { Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"` - Email string `gorm:"index" json:"email" bson:"email" cql:"email" dynamo:"email" index:"email,hash"` + Email string `gorm:"index" json:"email" bson:"email" cql:"email" dynamo:"email,omitempty" index:"email,hash"` PhoneNumber string `gorm:"index" json:"phone_number" bson:"phone_number" cql:"phone_number" dynamo:"phone_number"` Otp string `json:"otp" bson:"otp" cql:"otp" dynamo:"otp"` ExpiresAt int64 `json:"expires_at" bson:"expires_at" cql:"expires_at" dynamo:"expires_at"` diff --git a/internal/storage/schemas/user.go b/internal/storage/schemas/user.go index 8c024458..674c500d 100644 --- a/internal/storage/schemas/user.go +++ b/internal/storage/schemas/user.go @@ -9,6 +9,9 @@ import ( ) // Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation +// +// Nullable pointers (*int64, *string, etc.): do not add json/bson omitempty to fields that must +// clear stored values when nil — see docs/storage-optional-null-fields.md. // User model for db type User struct { From 347d762e3f851bd54fda83cffcd76d7e138a034f Mon Sep 17 00:00:00 2001 From: Lakhan Samani Date: Mon, 6 Apr 2026 19:39:14 +0530 Subject: [PATCH 3/3] chore: fix tests --- .claude/settings.local.json | 5 +- CLAUDE.md | 15 +- Makefile | 12 +- internal/integration_tests/audit_logs_test.go | 243 -------- .../custom_access_token_script_test.go | 474 ++++++++------- internal/integration_tests/health_test.go | 156 +++-- internal/integration_tests/metrics_test.go | 559 +++++++++--------- internal/integration_tests/rate_limit_test.go | 203 ++++--- .../redirect_uri_validation_test.go | 181 +++--- internal/integration_tests/test_helper.go | 110 +--- internal/memory_store/db/provider_test.go | 11 +- internal/memory_store/db/test_config_test.go | 69 +-- internal/storage/db/couchbase/health_check.go | 4 +- 13 files changed, 810 insertions(+), 1232 deletions(-) delete mode 100644 internal/integration_tests/audit_logs_test.go diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 0958a2a1..35566b3a 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -37,7 +37,10 @@ "Bash(do echo:*)", "Read(//Users/lakhansamani/personal/authorizer/authorizer/internal/storage/db/**)", "Bash(done)", - "Bash(grep:*)" + "Bash(grep:*)", + "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/integration_tests/)", + "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/memory_store/...)", + "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/storage/)" ] } } diff --git a/CLAUDE.md b/CLAUDE.md index 0e2a4718..a6c7a3e0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,15 +19,16 @@ make build-app # Build login UI (web/app) make build-dashboard # Build admin UI (web/dashboard) make generate-graphql # Regenerate after schema.graphqls change -# Testing (TEST_DBS env var selects databases, default: postgres) +# Testing +# Integration tests always use SQLite (no Docker needed). +# Storage provider tests honour TEST_DBS (default: all 7 DBs, needs Docker). # Optional: TEST_ENABLE_REDIS=1 runs Redis memory_store unit tests (Redis on localhost:6380). -make test # Docker Postgres (default) -make test-sqlite # SQLite in-memory (no Docker) -make test-mongodb # Docker MongoDB +make test # SQLite integration + storage (via TEST_DBS) +make test-sqlite # SQLite everywhere (no Docker) make test-all-db # ALL 7 databases (postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase) -# Single test against specific DBs -go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSignup ./internal/integration_tests/ +# Single test +go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v -run TestSignup ./internal/integration_tests/ ``` ## Architecture (Quick Reference) @@ -51,7 +52,7 @@ go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSign 2. **Schema changes must update ALL 13+ database providers** 3. **Run `make generate-graphql`** after editing `schema.graphqls` 4. **Security**: parameterized queries only, `crypto/rand` for tokens, `crypto/subtle` for comparisons, never log secrets -5. **Tests**: integration tests with real DBs, table-driven subtests, testify assertions +5. **Tests**: integration tests use SQLite via `getTestConfig()` (no `runForEachDB`); storage tests cover all DBs via `TEST_DBS`; testify assertions 6. **NEVER commit to main** — always work on a feature branch (`feat/`, `fix/`, `security/`, `chore/`), push to the branch, and create a merge request. Main must stay deployable. ## AI Agent Roles diff --git a/Makefile b/Makefile index 274fbb05..5d587130 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,9 @@ DEFAULT_VERSION=0.1.0-local VERSION := $(or $(VERSION),$(DEFAULT_VERSION)) DOCKER_IMAGE ?= authorizerdev/authorizer:$(VERSION) -# Full module test run. TEST_DBS selects storage backends (storage, integration_tests, -# memory_store/db). Redis memory_store tests run only when TEST_ENABLE_REDIS=1. +# Full module test run. Storage provider tests honour TEST_DBS (defaults to all). +# Integration tests and memory_store/db tests always use SQLite. +# Redis memory_store tests run only when TEST_ENABLE_REDIS=1. GO_TEST_ALL := go test -p 1 -v ./... .PHONY: all bootstrap build build-app build-dashboard build-local-image build-push-image trivy-scan @@ -43,9 +44,8 @@ clean: dev: go run main.go --database-type=sqlite --database-url=test.db --jwt-type=HS256 --jwt-secret=test --admin-secret=admin --client-id=123456 --client-secret=secret -test: test-cleanup test-docker-up - go clean --testcache && TEST_DBS="postgres" $(GO_TEST_ALL) - $(MAKE) test-cleanup +test: + go clean --testcache && TEST_DBS="sqlite" $(GO_TEST_ALL) test-postgres: test-cleanup-postgres docker run -d --name authorizer_postgres -p 5434:5432 -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres postgres @@ -86,7 +86,7 @@ test-couchbase: test-cleanup-couchbase go clean --testcache && TEST_DBS="couchbase" $(GO_TEST_ALL) docker rm -vf authorizer_couchbase -test-all-db: test-cleanup test-docker-up +test-all-db: test-cleanup test-docker-up test-cleanup go clean --testcache && TEST_DBS="postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase" $(GO_TEST_ALL) $(MAKE) test-cleanup diff --git a/internal/integration_tests/audit_logs_test.go b/internal/integration_tests/audit_logs_test.go deleted file mode 100644 index 8155d051..00000000 --- a/internal/integration_tests/audit_logs_test.go +++ /dev/null @@ -1,243 +0,0 @@ -package integration_tests - -import ( - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/authorizerdev/authorizer/internal/config" - "github.com/authorizerdev/authorizer/internal/constants" - "github.com/authorizerdev/authorizer/internal/graph/model" - "github.com/authorizerdev/authorizer/internal/storage/schemas" -) - -func TestAuditLogs(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - t.Run("should add and list audit logs", func(t *testing.T) { - auditLog := &schemas.AuditLog{ - ActorID: uuid.New().String(), - ActorType: constants.AuditActorTypeUser, - ActorEmail: "test@example.com", - Action: constants.AuditLoginSuccessEvent, - ResourceType: constants.AuditResourceTypeSession, - ResourceID: uuid.New().String(), - IPAddress: "127.0.0.1", - UserAgent: "test-agent", - } - - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - assert.NotEmpty(t, auditLog.ID) - assert.NotZero(t, auditLog.CreatedAt) - - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - logs, pag, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{}) - require.NoError(t, err) - assert.NotNil(t, pag) - assert.GreaterOrEqual(t, len(logs), 1) - }) - - t.Run("should filter audit logs by action", func(t *testing.T) { - uniqueAction := "test_action_" + uuid.New().String()[:8] - - auditLog := &schemas.AuditLog{ - ActorID: uuid.New().String(), - ActorType: constants.AuditActorTypeUser, - ActorEmail: "filter@example.com", - Action: uniqueAction, - ResourceType: constants.AuditResourceTypeUser, - } - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "action": uniqueAction, - }) - require.NoError(t, err) - assert.Equal(t, 1, len(logs)) - assert.Equal(t, uniqueAction, logs[0].Action) - }) - - t.Run("should filter audit logs by actor_id", func(t *testing.T) { - actorID := uuid.New().String() - - auditLog := &schemas.AuditLog{ - ActorID: actorID, - ActorType: constants.AuditActorTypeAdmin, - ActorEmail: "admin@example.com", - Action: constants.AuditAdminUserUpdatedEvent, - ResourceType: constants.AuditResourceTypeUser, - } - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "actor_id": actorID, - }) - require.NoError(t, err) - assert.Equal(t, 1, len(logs)) - assert.Equal(t, actorID, logs[0].ActorID) - }) - - t.Run("should filter audit logs by resource_type", func(t *testing.T) { - uniqueAction := "res_filter_" + uuid.New().String()[:8] - - auditLog := &schemas.AuditLog{ - ActorID: uuid.New().String(), - ActorType: constants.AuditActorTypeAdmin, - Action: uniqueAction, - ResourceType: constants.AuditResourceTypeWebhook, - ResourceID: uuid.New().String(), - } - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "resource_type": constants.AuditResourceTypeWebhook, - "action": uniqueAction, - }) - require.NoError(t, err) - assert.Equal(t, 1, len(logs)) - assert.Equal(t, constants.AuditResourceTypeWebhook, logs[0].ResourceType) - }) - - t.Run("should filter audit logs by timestamp range", func(t *testing.T) { - uniqueAction := "ts_filter_" + uuid.New().String()[:8] - now := time.Now().Unix() - - auditLog := &schemas.AuditLog{ - ActorID: uuid.New().String(), - ActorType: constants.AuditActorTypeUser, - Action: uniqueAction, - ResourceType: constants.AuditResourceTypeSession, - } - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - // Filter from 1 second ago to now+1 - fromTs := now - 1 - toTs := now + 1 - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "action": uniqueAction, - "from_timestamp": fromTs, - "to_timestamp": toTs, - }) - require.NoError(t, err) - assert.Equal(t, 1, len(logs)) - - // Filter with future range should return no results - futureTs := now + 3600 - logs, _, err = ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "action": uniqueAction, - "from_timestamp": futureTs, - }) - require.NoError(t, err) - assert.Equal(t, 0, len(logs)) - }) - - t.Run("should not mutate caller pagination", func(t *testing.T) { - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - _, returnedPag, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{}) - require.NoError(t, err) - assert.NotSame(t, pagination, returnedPag, "should return a new pagination object") - }) - - t.Run("should delete audit logs before created_at", func(t *testing.T) { - uniqueAction := "cleanup_test_" + uuid.New().String()[:8] - - oldLog := &schemas.AuditLog{ - ActorID: uuid.New().String(), - ActorType: constants.AuditActorTypeUser, - ActorEmail: "system@example.com", - Action: uniqueAction, - CreatedAt: time.Now().Add(-24 * time.Hour).Unix(), - ResourceType: constants.AuditResourceTypeUser, - } - err := ts.StorageProvider.AddAuditLog(ctx, oldLog) - require.NoError(t, err) - - // Delete logs older than 1 hour ago - before := time.Now().Add(-1 * time.Hour).Unix() - err = ts.StorageProvider.DeleteAuditLogsBefore(ctx, before) - require.NoError(t, err) - - // Verify the old log is deleted by filtering for its unique action - pagination := &model.Pagination{ - Limit: 10, - Offset: 0, - } - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "action": uniqueAction, - }) - require.NoError(t, err) - assert.Equal(t, 0, len(logs)) - }) - - t.Run("should preserve all fields in audit log round-trip", func(t *testing.T) { - actorID := uuid.New().String() - resourceID := uuid.New().String() - uniqueAction := "roundtrip_" + uuid.New().String()[:8] - - auditLog := &schemas.AuditLog{ - ActorID: actorID, - ActorType: constants.AuditActorTypeAdmin, - ActorEmail: "admin@test.com", - Action: uniqueAction, - ResourceType: constants.AuditResourceTypeEmailTemplate, - ResourceID: resourceID, - IPAddress: "192.168.1.1", - UserAgent: "Mozilla/5.0 Test", - Metadata: `{"key":"value"}`, - } - err := ts.StorageProvider.AddAuditLog(ctx, auditLog) - require.NoError(t, err) - - pagination := &model.Pagination{Limit: 10, Offset: 0} - logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{ - "action": uniqueAction, - }) - require.NoError(t, err) - require.Len(t, logs, 1) - - got := logs[0] - assert.Equal(t, actorID, got.ActorID) - assert.Equal(t, constants.AuditActorTypeAdmin, got.ActorType) - assert.Equal(t, "admin@test.com", got.ActorEmail) - assert.Equal(t, uniqueAction, got.Action) - assert.Equal(t, constants.AuditResourceTypeEmailTemplate, got.ResourceType) - assert.Equal(t, resourceID, got.ResourceID) - assert.Equal(t, "192.168.1.1", got.IPAddress) - assert.Equal(t, "Mozilla/5.0 Test", got.UserAgent) - assert.Equal(t, `{"key":"value"}`, got.Metadata) - assert.NotZero(t, got.CreatedAt) - }) - }) -} diff --git a/internal/integration_tests/custom_access_token_script_test.go b/internal/integration_tests/custom_access_token_script_test.go index 80a58fbf..9b87dc73 100644 --- a/internal/integration_tests/custom_access_token_script_test.go +++ b/internal/integration_tests/custom_access_token_script_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/graph/model" ) @@ -32,269 +31,268 @@ func parseTestJWTClaims(t *testing.T, tokenString string) jwt.MapClaims { // TestCustomAccessTokenScript tests the custom access token script functionality // including the 5-second execution timeout added for DoS protection. func TestCustomAccessTokenScript(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - t.Run("should_add_custom_claims_from_script", func(t *testing.T) { - cfg.CustomAccessTokenScript = `function(user, tokenPayload) { - return { custom_claim: "hello", user_email: user.email }; - }` - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "custom_script_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - require.NoError(t, err) - require.NotNil(t, loginRes) - require.NotNil(t, loginRes.AccessToken) - - // Parse the access token and verify custom claims are present - claims := parseTestJWTClaims(t, *loginRes.AccessToken) - assert.Equal(t, "hello", claims["custom_claim"]) - assert.Equal(t, email, claims["user_email"]) + cfg := getTestConfig() + + t.Run("should_add_custom_claims_from_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { custom_claim: "hello", user_email: user.email }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "custom_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + // Parse the access token and verify custom claims are present + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.Equal(t, "hello", claims["custom_claim"]) + assert.Equal(t, email, claims["user_email"]) + }) + + t.Run("should_not_override_reserved_claims", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { sub: "hacked", iss: "hacked", roles: ["admin"], custom_field: "allowed" }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "reserved_claims_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, }) + require.NoError(t, err) - t.Run("should_not_override_reserved_claims", func(t *testing.T) { - cfg.CustomAccessTokenScript = `function(user, tokenPayload) { - return { sub: "hacked", iss: "hacked", roles: ["admin"], custom_field: "allowed" }; - }` - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "reserved_claims_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - require.NoError(t, err) - require.NotNil(t, loginRes) - - claims := parseTestJWTClaims(t, *loginRes.AccessToken) - // Reserved claims must NOT be overridden - assert.NotEqual(t, "hacked", claims["sub"]) - assert.NotEqual(t, "hacked", claims["iss"]) - // Roles should NOT be overridden to admin - roles, ok := claims["roles"].([]interface{}) - if ok { - for _, r := range roles { - assert.NotEqual(t, "admin", r, "reserved 'roles' claim must not be overridden by script") - } + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + // Reserved claims must NOT be overridden + assert.NotEqual(t, "hacked", claims["sub"]) + assert.NotEqual(t, "hacked", claims["iss"]) + // Roles should NOT be overridden to admin + roles, ok := claims["roles"].([]interface{}) + if ok { + for _, r := range roles { + assert.NotEqual(t, "admin", r, "reserved 'roles' claim must not be overridden by script") } - // Custom (non-reserved) claims should be added - assert.Equal(t, "allowed", claims["custom_field"]) + } + // Custom (non-reserved) claims should be added + assert.Equal(t, "allowed", claims["custom_field"]) + }) + + t.Run("should_timeout_infinite_loop_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + while(true) {} // infinite loop — should be killed after 5 seconds + return { never: "reached" }; + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "timeout_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, }) + require.NoError(t, err) + + // Measure execution time to verify the timeout works + start := time.Now() - t.Run("should_timeout_infinite_loop_script", func(t *testing.T) { - cfg.CustomAccessTokenScript = `function(user, tokenPayload) { - while(true) {} // infinite loop — should be killed after 5 seconds - return { never: "reached" }; - }` - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "timeout_script_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - // Measure execution time to verify the timeout works - start := time.Now() - - // Login should still succeed — the timeout is handled gracefully, - // custom claims are skipped but the token is still created. - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - elapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, loginRes) - require.NotNil(t, loginRes.AccessToken) - - // The token should be valid but without the custom claim from the timed-out script - claims := parseTestJWTClaims(t, *loginRes.AccessToken) - assert.Nil(t, claims["never"], "timed-out script claims must not appear in token") - // Standard claims should still be present - assert.NotEmpty(t, claims["sub"]) - assert.NotEmpty(t, claims["iss"]) - - // Verify the timeout kicked in within a reasonable range (5-8 seconds for the access + id token) - assert.Less(t, elapsed, 20*time.Second, "login with infinite loop script should complete within 20 seconds (two 5s timeouts + overhead)") + // Login should still succeed — the timeout is handled gracefully, + // custom claims are skipped but the token is still created. + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, }) + elapsed := time.Since(start) + + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + // The token should be valid but without the custom claim from the timed-out script + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.Nil(t, claims["never"], "timed-out script claims must not appear in token") + // Standard claims should still be present + assert.NotEmpty(t, claims["sub"]) + assert.NotEmpty(t, claims["iss"]) + + // Verify the timeout kicked in within a reasonable range (5-8 seconds for the access + id token) + assert.Less(t, elapsed, 20*time.Second, "login with infinite loop script should complete within 20 seconds (two 5s timeouts + overhead)") + }) + + t.Run("should_handle_script_error_gracefully", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + throw new Error("intentional error"); + }` + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) - t.Run("should_handle_script_error_gracefully", func(t *testing.T) { - cfg.CustomAccessTokenScript = `function(user, tokenPayload) { - throw new Error("intentional error"); - }` - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "error_script_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - // Login should still succeed even with a broken script - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - require.NoError(t, err) - require.NotNil(t, loginRes) - require.NotNil(t, loginRes.AccessToken) + email := "error_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, }) + require.NoError(t, err) - t.Run("should_work_without_custom_script", func(t *testing.T) { - cfg.CustomAccessTokenScript = "" - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "no_script_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - require.NoError(t, err) - require.NotNil(t, loginRes) - require.NotNil(t, loginRes.AccessToken) - - claims := parseTestJWTClaims(t, *loginRes.AccessToken) - assert.NotEmpty(t, claims["sub"]) - // Ensure no unexpected claims were added - assert.Nil(t, claims["custom_claim"]) + // Login should still succeed even with a broken script + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + }) - t.Run("should_have_custom_claims_in_id_token_too", func(t *testing.T) { - cfg.CustomAccessTokenScript = `function(user, tokenPayload) { - return { team: "engineering" }; - }` - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "id_token_script_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - - loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ - Email: &email, - Password: password, - }) - require.NoError(t, err) - require.NotNil(t, loginRes) - require.NotNil(t, loginRes.IDToken) - - // The custom script runs for both access token and ID token - claims := parseTestJWTClaims(t, *loginRes.IDToken) - assert.Equal(t, "engineering", claims["team"]) + t.Run("should_work_without_custom_script", func(t *testing.T) { + cfg.CustomAccessTokenScript = "" + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "no_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.AccessToken) + + claims := parseTestJWTClaims(t, *loginRes.AccessToken) + assert.NotEmpty(t, claims["sub"]) + // Ensure no unexpected claims were added + assert.Nil(t, claims["custom_claim"]) }) -} -// TestClientIDMismatchMetric verifies that client ID mismatch records a security metric. -func TestClientIDMismatchMetric(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { + t.Run("should_have_custom_claims_in_id_token_too", func(t *testing.T) { + cfg.CustomAccessTokenScript = `function(user, tokenPayload) { + return { team: "engineering" }; + }` ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) - router := setupTestRouter(ts) + email := "id_token_script_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" - t.Run("records_metric_on_client_id_mismatch", func(t *testing.T) { - // Send request with wrong client ID to /graphql (not dashboard/app) - body := `{"query":"{ meta { version } }"}` - w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ - "Content-Type": "application/json", - "X-Authorizer-Client-ID": "wrong-client-id", - "X-Authorizer-URL": "http://localhost:8080", - "Origin": "http://localhost:3000", - }) + _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + + loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{ + Email: &email, + Password: password, + }) + require.NoError(t, err) + require.NotNil(t, loginRes) + require.NotNil(t, loginRes.IDToken) - assert.Equal(t, 400, w.Code) - assert.Contains(t, w.Body.String(), "invalid_client_id") + // The custom script runs for both access token and ID token + claims := parseTestJWTClaims(t, *loginRes.IDToken) + assert.Equal(t, "engineering", claims["team"]) + }) +} - // Check that the security metric was recorded - metricsBody := getMetricsBody(t, router) - assert.Contains(t, metricsBody, `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`) +// TestClientIDMismatchMetric verifies that client ID mismatch records a security metric. +func TestClientIDMismatchMetric(t *testing.T) { + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + + router := setupTestRouter(ts) + + t.Run("records_metric_on_client_id_mismatch", func(t *testing.T) { + // Send request with wrong client ID to /graphql (not dashboard/app) + body := `{"query":"{ meta { version } }"}` + w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-Client-ID": "wrong-client-id", + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", }) - t.Run("no_metric_for_valid_client_id", func(t *testing.T) { - body := `{"query":"{ meta { version } }"}` - w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ - "Content-Type": "application/json", - "X-Authorizer-Client-ID": cfg.ClientID, - "X-Authorizer-URL": "http://localhost:8080", - "Origin": "http://localhost:3000", - }) - - // Should not be 400 - assert.NotEqual(t, 400, w.Code) + assert.Equal(t, 400, w.Code) + assert.Contains(t, w.Body.String(), "invalid_client_id") + + // Check that the security metric was recorded + metricsBody := getMetricsBody(t, router) + assert.Contains(t, metricsBody, `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`) + }) + + t.Run("no_metric_for_valid_client_id", func(t *testing.T) { + body := `{"query":"{ meta { version } }"}` + w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-Client-ID": cfg.ClientID, + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", }) - t.Run("no_metric_for_dashboard_path_mismatch", func(t *testing.T) { - mark := `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}` - before := prometheusCounterSample(t, getMetricsBody(t, router), mark) - w := sendTestRequest(t, router, "GET", "/dashboard/", "", map[string]string{ - "X-Authorizer-Client-ID": "wrong-client-id", - }) - assert.Equal(t, 400, w.Code) - after := prometheusCounterSample(t, getMetricsBody(t, router), mark) - assert.Equal(t, before, after, "dashboard path mismatch must not increment client_id_mismatch metric") + // Should not be 400 + assert.NotEqual(t, 400, w.Code) + }) + + t.Run("no_metric_for_dashboard_path_mismatch", func(t *testing.T) { + mark := `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}` + before := prometheusCounterSample(t, getMetricsBody(t, router), mark) + w := sendTestRequest(t, router, "GET", "/dashboard/", "", map[string]string{ + "X-Authorizer-Client-ID": "wrong-client-id", }) + assert.Equal(t, 400, w.Code) + after := prometheusCounterSample(t, getMetricsBody(t, router), mark) + assert.Equal(t, before, after, "dashboard path mismatch must not increment client_id_mismatch metric") + }) - t.Run("records_client_id_header_missing_metric", func(t *testing.T) { - mark := "authorizer_client_id_header_missing_total" - before := prometheusCounterSample(t, getMetricsBody(t, router), mark) - sendTestRequest(t, router, "POST", "/graphql", `{"query":"{ meta { version } }"}`, map[string]string{ - "Content-Type": "application/json", - "X-Authorizer-URL": "http://localhost:8080", - "Origin": "http://localhost:3000", - }) - after := prometheusCounterSample(t, getMetricsBody(t, router), mark) - assert.Greater(t, after, before) + t.Run("records_client_id_header_missing_metric", func(t *testing.T) { + mark := "authorizer_client_id_header_missing_total" + before := prometheusCounterSample(t, getMetricsBody(t, router), mark) + sendTestRequest(t, router, "POST", "/graphql", `{"query":"{ meta { version } }"}`, map[string]string{ + "Content-Type": "application/json", + "X-Authorizer-URL": "http://localhost:8080", + "Origin": "http://localhost:3000", }) + after := prometheusCounterSample(t, getMetricsBody(t, router), mark) + assert.Greater(t, after, before) }) } diff --git a/internal/integration_tests/health_test.go b/internal/integration_tests/health_test.go index cdb17163..213d0df1 100644 --- a/internal/integration_tests/health_test.go +++ b/internal/integration_tests/health_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/http_handlers" "github.com/authorizerdev/authorizer/internal/storage" ) @@ -29,108 +28,105 @@ func (*failingHealthStorage) HealthCheck(ctx context.Context) error { // TestHealthHandler verifies the /healthz liveness probe endpoint behaviour. func TestHealthHandler(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) - router := gin.New() - router.GET("/healthz", ts.HttpProvider.HealthHandler()) + router := gin.New() + router.GET("/healthz", ts.HttpProvider.HealthHandler()) - t.Run("returns_200_when_storage_is_healthy", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/healthz", nil) - require.NoError(t, err) + t.Run("returns_200_when_storage_is_healthy", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) - router.ServeHTTP(w, req) + router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusOK, w.Code) - var body map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &body) - require.NoError(t, err) - assert.Equal(t, "ok", body["status"], "healthy response must contain status=ok") - }) + var body map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &body) + require.NoError(t, err) + assert.Equal(t, "ok", body["status"], "healthy response must contain status=ok") }) } // TestReadyHandler verifies the /readyz readiness probe endpoint behaviour. func TestReadyHandler(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) - router := gin.New() - router.GET("/readyz", ts.HttpProvider.ReadyHandler()) + router := gin.New() + router.GET("/readyz", ts.HttpProvider.ReadyHandler()) - t.Run("returns_200_when_storage_is_ready", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/readyz", nil) - require.NoError(t, err) + t.Run("returns_200_when_storage_is_ready", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/readyz", nil) + require.NoError(t, err) - router.ServeHTTP(w, req) + router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusOK, w.Code) - var body map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &body) - require.NoError(t, err) - assert.Equal(t, "ready", body["status"], "readiness response must contain status=ready") - }) + var body map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &body) + require.NoError(t, err) + assert.Equal(t, "ready", body["status"], "readiness response must contain status=ready") }) } // TestHealthHandlersUnhealthyStorage verifies liveness/readiness and DB metrics when HealthCheck fails. func TestHealthHandlersUnhealthyStorage(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() - realStorage, err := storage.New(cfg, &storage.Dependencies{Log: &logger}) + cfg := getTestConfig() + logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + realStorage, err := storage.New(cfg, &storage.Dependencies{Log: &logger}) + require.NoError(t, err) + t.Cleanup(func() { _ = realStorage.Close() }) + + wrapped := &failingHealthStorage{Provider: realStorage} + httpProv, err := http_handlers.New(cfg, &http_handlers.Dependencies{ + Log: &logger, + StorageProvider: wrapped, + }) + require.NoError(t, err) + + router := gin.New() + router.GET("/healthz", httpProv.HealthHandler()) + router.GET("/readyz", httpProv.ReadyHandler()) + router.GET("/metrics", httpProv.MetricsHandler()) + + t.Run("healthz_returns_503", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + var body map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, "unhealthy", body["status"]) + }) + + t.Run("readyz_returns_503", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/readyz", nil) require.NoError(t, err) - t.Cleanup(func() { _ = realStorage.Close() }) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + var body map[string]interface{} + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + assert.Equal(t, "not ready", body["status"]) + }) - wrapped := &failingHealthStorage{Provider: realStorage} - httpProv, err := http_handlers.New(cfg, &http_handlers.Dependencies{ - Log: &logger, - StorageProvider: wrapped, - }) + t.Run("records_unhealthy_db_check_metric", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) require.NoError(t, err) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusServiceUnavailable, w.Code) - router := gin.New() - router.GET("/healthz", httpProv.HealthHandler()) - router.GET("/readyz", httpProv.ReadyHandler()) - router.GET("/metrics", httpProv.MetricsHandler()) - - t.Run("healthz_returns_503", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/healthz", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusServiceUnavailable, w.Code) - var body map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) - assert.Equal(t, "unhealthy", body["status"]) - }) - - t.Run("readyz_returns_503", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/readyz", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusServiceUnavailable, w.Code) - var body map[string]interface{} - require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) - assert.Equal(t, "not ready", body["status"]) - }) - - t.Run("records_unhealthy_db_check_metric", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/healthz", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - require.Equal(t, http.StatusServiceUnavailable, w.Code) - - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - assert.Contains(t, w2.Body.String(), `authorizer_db_health_check_total{status="unhealthy"}`) - }) + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + assert.Contains(t, w2.Body.String(), `authorizer_db_health_check_total{status="unhealthy"}`) }) } diff --git a/internal/integration_tests/metrics_test.go b/internal/integration_tests/metrics_test.go index c3672ff5..e3ed4b1f 100644 --- a/internal/integration_tests/metrics_test.go +++ b/internal/integration_tests/metrics_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/metrics" "github.com/authorizerdev/authorizer/internal/refs" @@ -21,263 +20,231 @@ import ( // TestMetricsEndpoint verifies the /metrics endpoint serves Prometheus format. func TestMetricsEndpoint(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - - router := gin.New() - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("returns_200_with_prometheus_format", func(t *testing.T) { - // Trigger some metrics so they appear in output - metrics.RecordAuthEvent("test", "test") - metrics.RecordSecurityEvent("test", "test") - metrics.RecordGraphQLError("test") - metrics.RecordClientIDHeaderMissing() - metrics.DBHealthCheckTotal.WithLabelValues("test").Inc() - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - body := w.Body.String() - // Gauge metrics always appear - assert.Contains(t, body, "authorizer_active_sessions") - // Counter/histogram metrics appear after first increment - assert.Contains(t, body, "authorizer_auth_events_total") - assert.Contains(t, body, "authorizer_security_events_total") - assert.Contains(t, body, "authorizer_graphql_errors_total") - assert.Contains(t, body, "authorizer_db_health_check_total") - assert.Contains(t, body, "authorizer_client_id_header_missing_total") - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) - t.Run("post_metrics_is_not_get_ok", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodPost, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.NotEqual(t, http.StatusOK, w.Code) - }) + router := gin.New() + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + t.Run("returns_200_with_prometheus_format", func(t *testing.T) { + // Trigger some metrics so they appear in output + metrics.RecordAuthEvent("test", "test") + metrics.RecordSecurityEvent("test", "test") + metrics.RecordGraphQLError("test") + metrics.RecordClientIDHeaderMissing() + metrics.DBHealthCheckTotal.WithLabelValues("test").Inc() + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + body := w.Body.String() + // Gauge metrics always appear + assert.Contains(t, body, "authorizer_active_sessions") + // Counter/histogram metrics appear after first increment + assert.Contains(t, body, "authorizer_auth_events_total") + assert.Contains(t, body, "authorizer_security_events_total") + assert.Contains(t, body, "authorizer_graphql_errors_total") + assert.Contains(t, body, "authorizer_db_health_check_total") + assert.Contains(t, body, "authorizer_client_id_header_missing_total") + }) + + t.Run("post_metrics_is_not_get_ok", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.NotEqual(t, http.StatusOK, w.Code) }) } // TestMetricsMiddleware verifies the HTTP metrics middleware records request count. func TestMetricsMiddleware(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - - router := gin.New() - router.Use(ts.HttpProvider.MetricsMiddleware()) - router.GET("/healthz", ts.HttpProvider.HealthHandler()) - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("records_http_request_metrics", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/healthz", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - // Check metrics endpoint has recorded it - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - - body := w2.Body.String() - assert.Contains(t, body, `authorizer_http_requests_total{method="GET",path="/healthz",status="200"}`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) - t.Run("skips_http_metrics_for_excluded_paths", func(t *testing.T) { - router.GET("/app/foo", func(c *gin.Context) { - c.Status(http.StatusOK) - }) + router := gin.New() + router.Use(ts.HttpProvider.MetricsMiddleware()) + router.GET("/healthz", ts.HttpProvider.HealthHandler()) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/app/foo", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + t.Run("records_http_request_metrics", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) + // Check metrics endpoint has recorded it + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + + body := w2.Body.String() + assert.Contains(t, body, `authorizer_http_requests_total{method="GET",path="/healthz",status="200"}`) + }) - body := w2.Body.String() - assert.NotContains(t, body, `authorizer_http_requests_total{method="GET",path="/app/foo"`) + t.Run("skips_http_metrics_for_excluded_paths", func(t *testing.T) { + router.GET("/app/foo", func(c *gin.Context) { + c.Status(http.StatusOK) }) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/app/foo", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + + body := w2.Body.String() + assert.NotContains(t, body, `authorizer_http_requests_total{method="GET",path="/app/foo"`) }) } // TestDBHealthCheckMetrics verifies health check outcomes are tracked. func TestDBHealthCheckMetrics(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - - router := gin.New() - router.GET("/healthz", ts.HttpProvider.HealthHandler()) - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("records_healthy_db_check", func(t *testing.T) { - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/healthz", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - - var body map[string]interface{} - err = json.Unmarshal(w.Body.Bytes(), &body) - require.NoError(t, err) - assert.Equal(t, "ok", body["status"]) - - // Verify metric was recorded - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - - metricsBody := w2.Body.String() - assert.Contains(t, metricsBody, `authorizer_db_health_check_total{status="healthy"}`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + + router := gin.New() + router.GET("/healthz", ts.HttpProvider.HealthHandler()) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + t.Run("records_healthy_db_check", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/healthz", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + var body map[string]interface{} + err = json.Unmarshal(w.Body.Bytes(), &body) + require.NoError(t, err) + assert.Equal(t, "ok", body["status"]) + + // Verify metric was recorded + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + + metricsBody := w2.Body.String() + assert.Contains(t, metricsBody, `authorizer_db_health_check_total{status="healthy"}`) }) } // TestAuthEventMetrics verifies that auth events are recorded in metrics. func TestAuthEventMetrics(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - router := gin.New() - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - email := "metrics_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - t.Run("records_signup_and_login_success", func(t *testing.T) { - signupReq := &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - } - res, err := ts.GraphQLProvider.SignUp(ctx, signupReq) - require.NoError(t, err) - assert.NotNil(t, res) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="signup",status="success"}`) - - loginReq := &model.LoginRequest{ - Email: &email, - Password: password, - } - loginRes, err := ts.GraphQLProvider.Login(ctx, loginReq) - require.NoError(t, err) - assert.NotNil(t, loginRes) - - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - assert.Contains(t, w2.Body.String(), `authorizer_auth_events_total{event="login",status="success"}`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + router := gin.New() + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + email := "metrics_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + t.Run("records_signup_and_login_success", func(t *testing.T) { + signupReq := &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + } + res, err := ts.GraphQLProvider.SignUp(ctx, signupReq) + require.NoError(t, err) + assert.NotNil(t, res) - t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) { - loginReq := &model.LoginRequest{ - Email: &email, - Password: "wrong_password", - } - _, err := ts.GraphQLProvider.Login(ctx, loginReq) - assert.Error(t, err) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - - body := w.Body.String() - assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`) - assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`) - }) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="signup",status="success"}`) + + loginReq := &model.LoginRequest{ + Email: &email, + Password: password, + } + loginRes, err := ts.GraphQLProvider.Login(ctx, loginReq) + require.NoError(t, err) + assert.NotNil(t, loginRes) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + assert.Contains(t, w2.Body.String(), `authorizer_auth_events_total{event="login",status="success"}`) + }) + + t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) { + loginReq := &model.LoginRequest{ + Email: &email, + Password: "wrong_password", + } + _, err := ts.GraphQLProvider.Login(ctx, loginReq) + assert.Error(t, err) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + body := w.Body.String() + assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`) + assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`) }) } // TestGraphQLErrorMetrics verifies that GraphQL errors in 200 responses are captured. func TestGraphQLErrorMetrics(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - - router := gin.New() - router.Use(ts.HttpProvider.ContextMiddleware()) - router.Use(ts.HttpProvider.CORSMiddleware()) - router.POST("/graphql", ts.HttpProvider.GraphqlHandler()) - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("captures_graphql_errors_in_200_responses", func(t *testing.T) { - body := `{"query":"mutation { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}` - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-authorizer-url", "http://localhost:8080") - req.Header.Set("Origin", "http://localhost:3000") - router.ServeHTTP(w, req) - - // GraphQL always returns 200 even with errors - assert.Equal(t, http.StatusOK, w.Code) - - // Check metrics endpoint - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - - metricsBody := w2.Body.String() - assert.Contains(t, metricsBody, "authorizer_graphql_request_duration_seconds") - assert.Contains(t, metricsBody, `authorizer_graphql_errors_total{operation="anonymous"}`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) - t.Run("captures_graphql_errors_with_named_operation", func(t *testing.T) { - body := `{"operationName":"LoginOp","query":"mutation LoginOp { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}` - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-authorizer-url", "http://localhost:8080") - req.Header.Set("Origin", "http://localhost:3000") - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - w2 := httptest.NewRecorder() - req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w2, req2) - loginOpLabel := metrics.GraphQLOperationPrometheusLabel("LoginOp") - assert.Contains(t, w2.Body.String(), `authorizer_graphql_errors_total{operation="`+loginOpLabel+`"}`) - }) - }) -} + router := gin.New() + router.Use(ts.HttpProvider.ContextMiddleware()) + router.Use(ts.HttpProvider.CORSMiddleware()) + router.POST("/graphql", ts.HttpProvider.GraphqlHandler()) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) -// TestClientIDHeaderMissingMiddlewareMetric verifies empty X-Authorizer-Client-ID increments the counter. -func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) + t.Run("captures_graphql_errors_in_200_responses", func(t *testing.T) { + body := `{"query":"mutation { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}` + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-authorizer-url", "http://localhost:8080") + req.Header.Set("Origin", "http://localhost:3000") + router.ServeHTTP(w, req) - router := gin.New() - router.Use(ts.HttpProvider.ClientCheckMiddleware()) - router.GET("/probe", func(c *gin.Context) { - c.Status(http.StatusOK) - }) - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + // GraphQL always returns 200 even with errors + assert.Equal(t, http.StatusOK, w.Code) + // Check metrics endpoint + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + + metricsBody := w2.Body.String() + assert.Contains(t, metricsBody, "authorizer_graphql_request_duration_seconds") + assert.Contains(t, metricsBody, `authorizer_graphql_errors_total{operation="anonymous"}`) + }) + + t.Run("captures_graphql_errors_with_named_operation", func(t *testing.T) { + body := `{"operationName":"LoginOp","query":"mutation LoginOp { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}` w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/probe", nil) + req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body)) require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-authorizer-url", "http://localhost:8080") + req.Header.Set("Origin", "http://localhost:3000") router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -285,8 +252,34 @@ func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) { req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) require.NoError(t, err) router.ServeHTTP(w2, req2) - assert.Contains(t, w2.Body.String(), "authorizer_client_id_header_missing_total") + loginOpLabel := metrics.GraphQLOperationPrometheusLabel("LoginOp") + assert.Contains(t, w2.Body.String(), `authorizer_graphql_errors_total{operation="`+loginOpLabel+`"}`) + }) +} + +// TestClientIDHeaderMissingMiddlewareMetric verifies empty X-Authorizer-Client-ID increments the counter. +func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) { + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + + router := gin.New() + router.Use(ts.HttpProvider.ClientCheckMiddleware()) + router.GET("/probe", func(c *gin.Context) { + c.Status(http.StatusOK) }) + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/probe", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + w2 := httptest.NewRecorder() + req2, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w2, req2) + assert.Contains(t, w2.Body.String(), "authorizer_client_id_header_missing_total") } // TestRecordAuthEventHelpers verifies the helper functions work correctly. @@ -309,71 +302,69 @@ func TestRecordAuthEventHelpers(t *testing.T) { // TestAdminLoginMetrics verifies admin login records metrics. func TestAdminLoginMetrics(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - router := gin.New() - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("records_admin_login_failure", func(t *testing.T) { - loginReq := &model.AdminLoginRequest{ - AdminSecret: "wrong-secret", - } - _, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq) - assert.Error(t, err) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - - body := w.Body.String() - assert.Contains(t, body, `authorizer_auth_events_total{event="admin_login",status="failure"}`) - assert.Contains(t, body, `authorizer_security_events_total{event="invalid_admin_secret"`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) - t.Run("records_admin_login_success", func(t *testing.T) { - loginReq := &model.AdminLoginRequest{ - AdminSecret: cfg.AdminSecret, - } - res, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq) - require.NoError(t, err) - assert.NotNil(t, res) + router := gin.New() + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) + t.Run("records_admin_login_failure", func(t *testing.T) { + loginReq := &model.AdminLoginRequest{ + AdminSecret: "wrong-secret", + } + _, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq) + assert.Error(t, err) - assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="admin_login",status="success"}`) - }) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + body := w.Body.String() + assert.Contains(t, body, `authorizer_auth_events_total{event="admin_login",status="failure"}`) + assert.Contains(t, body, `authorizer_security_events_total{event="invalid_admin_secret"`) + }) + + t.Run("records_admin_login_success", func(t *testing.T) { + loginReq := &model.AdminLoginRequest{ + AdminSecret: cfg.AdminSecret, + } + res, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq) + require.NoError(t, err) + assert.NotNil(t, res) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="admin_login",status="success"}`) }) } // TestForgotPasswordMetrics verifies forgot password records metrics. func TestForgotPasswordMetrics(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - router := gin.New() - router.GET("/metrics", ts.HttpProvider.MetricsHandler()) - - t.Run("records_forgot_password_failure_for_nonexistent_user", func(t *testing.T) { - nonExistentEmail := "nonexistent_metrics@authorizer.dev" - forgotReq := &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(nonExistentEmail), - } - _, err := ts.GraphQLProvider.ForgotPassword(ctx, forgotReq) - assert.Error(t, err) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/metrics", nil) - require.NoError(t, err) - router.ServeHTTP(w, req) - - assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="forgot_password",status="failure"}`) - }) + cfg := getTestConfig() + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + router := gin.New() + router.GET("/metrics", ts.HttpProvider.MetricsHandler()) + + t.Run("records_forgot_password_failure_for_nonexistent_user", func(t *testing.T) { + nonExistentEmail := "nonexistent_metrics@authorizer.dev" + forgotReq := &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(nonExistentEmail), + } + _, err := ts.GraphQLProvider.ForgotPassword(ctx, forgotReq) + assert.Error(t, err) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/metrics", nil) + require.NoError(t, err) + router.ServeHTTP(w, req) + + assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="forgot_password",status="failure"}`) }) } diff --git a/internal/integration_tests/rate_limit_test.go b/internal/integration_tests/rate_limit_test.go index 968e1ff1..ad261214 100644 --- a/internal/integration_tests/rate_limit_test.go +++ b/internal/integration_tests/rate_limit_test.go @@ -16,132 +16,131 @@ import ( ) func TestRateLimitMiddleware(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - // Set low rate limit for testing - cfg.RateLimitRPS = 5 - cfg.RateLimitBurst = 5 - - ts := initTestSetup(t, cfg) - - t.Run("should_allow_requests_within_limit", func(t *testing.T) { - w := httptest.NewRecorder() - _, router := gin.CreateTestContext(w) - router.Use(ts.HttpProvider.RateLimitMiddleware()) - router.POST("/graphql", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"data": "ok"}) - }) - - // First request should succeed + cfg := getTestConfig() + // Set low rate limit for testing + cfg.RateLimitRPS = 5 + cfg.RateLimitBurst = 5 + + ts := initTestSetup(t, cfg) + + t.Run("should_allow_requests_within_limit", func(t *testing.T) { + w := httptest.NewRecorder() + _, router := gin.CreateTestContext(w) + router.Use(ts.HttpProvider.RateLimitMiddleware()) + router.POST("/graphql", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"data": "ok"}) + }) + + // First request should succeed + req, err := http.NewRequest(http.MethodPost, "/graphql", nil) + require.NoError(t, err) + req.RemoteAddr = "192.168.1.1:1234" + + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("should_reject_requests_over_limit", func(t *testing.T) { + w := httptest.NewRecorder() + _, router := gin.CreateTestContext(w) + router.Use(ts.HttpProvider.RateLimitMiddleware()) + router.POST("/graphql", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"data": "ok"}) + }) + + // Exhaust the burst + for i := 0; i < 5; i++ { req, err := http.NewRequest(http.MethodPost, "/graphql", nil) require.NoError(t, err) - req.RemoteAddr = "192.168.1.1:1234" - + req.RemoteAddr = "10.0.0.1:1234" w = httptest.NewRecorder() router.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) + } + + // Next request should be rejected + req, err := http.NewRequest(http.MethodPost, "/graphql", nil) + require.NoError(t, err) + req.RemoteAddr = "10.0.0.1:1234" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusTooManyRequests, w.Code) + assert.Contains(t, w.Header().Get("Retry-After"), "1") + }) + + t.Run("should_not_rate_limit_exempt_paths", func(t *testing.T) { + w := httptest.NewRecorder() + _, router := gin.CreateTestContext(w) + router.Use(ts.HttpProvider.RateLimitMiddleware()) + router.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + router.GET("/.well-known/openid-configuration", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"issuer": "test"}) }) - t.Run("should_reject_requests_over_limit", func(t *testing.T) { - w := httptest.NewRecorder() - _, router := gin.CreateTestContext(w) - router.Use(ts.HttpProvider.RateLimitMiddleware()) - router.POST("/graphql", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"data": "ok"}) - }) - - // Exhaust the burst - for i := 0; i < 5; i++ { - req, err := http.NewRequest(http.MethodPost, "/graphql", nil) + exemptPaths := []string{"/health", "/.well-known/openid-configuration"} + for _, path := range exemptPaths { + // Make many requests - none should be limited + for i := 0; i < 10; i++ { + req, err := http.NewRequest(http.MethodGet, path, nil) require.NoError(t, err) - req.RemoteAddr = "10.0.0.1:1234" + req.RemoteAddr = "10.0.0.2:1234" w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusOK, w.Code, "path %s request %d should not be rate limited", path, i) } + } + }) - // Next request should be rejected + t.Run("should_isolate_rate_limits_per_ip", func(t *testing.T) { + w := httptest.NewRecorder() + _, router := gin.CreateTestContext(w) + router.Use(ts.HttpProvider.RateLimitMiddleware()) + router.POST("/graphql", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"data": "ok"}) + }) + + // Exhaust burst for IP A + for i := 0; i < 5; i++ { req, err := http.NewRequest(http.MethodPost, "/graphql", nil) require.NoError(t, err) - req.RemoteAddr = "10.0.0.1:1234" + req.RemoteAddr = "10.0.0.3:1234" w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusTooManyRequests, w.Code) - assert.Contains(t, w.Header().Get("Retry-After"), "1") - }) + } - t.Run("should_not_rate_limit_exempt_paths", func(t *testing.T) { - w := httptest.NewRecorder() - _, router := gin.CreateTestContext(w) - router.Use(ts.HttpProvider.RateLimitMiddleware()) - router.GET("/health", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - }) - router.GET("/.well-known/openid-configuration", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"issuer": "test"}) - }) - - exemptPaths := []string{"/health", "/.well-known/openid-configuration"} - for _, path := range exemptPaths { - // Make many requests - none should be limited - for i := 0; i < 10; i++ { - req, err := http.NewRequest(http.MethodGet, path, nil) - require.NoError(t, err) - req.RemoteAddr = "10.0.0.2:1234" - w = httptest.NewRecorder() - router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code, "path %s request %d should not be rate limited", path, i) - } - } - }) + // IP B should still be allowed + req, err := http.NewRequest(http.MethodPost, "/graphql", nil) + require.NoError(t, err) + req.RemoteAddr = "10.0.0.4:1234" + w = httptest.NewRecorder() + router.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + }) - t.Run("should_isolate_rate_limits_per_ip", func(t *testing.T) { - w := httptest.NewRecorder() - _, router := gin.CreateTestContext(w) - router.Use(ts.HttpProvider.RateLimitMiddleware()) - router.POST("/graphql", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"data": "ok"}) - }) - - // Exhaust burst for IP A - for i := 0; i < 5; i++ { - req, err := http.NewRequest(http.MethodPost, "/graphql", nil) - require.NoError(t, err) - req.RemoteAddr = "10.0.0.3:1234" - w = httptest.NewRecorder() - router.ServeHTTP(w, req) - } + t.Run("should_return_correct_error_format", func(t *testing.T) { + w := httptest.NewRecorder() + _, router := gin.CreateTestContext(w) + router.Use(ts.HttpProvider.RateLimitMiddleware()) + router.POST("/graphql", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"data": "ok"}) + }) - // IP B should still be allowed + // Exhaust the burst + for i := 0; i < 6; i++ { req, err := http.NewRequest(http.MethodPost, "/graphql", nil) require.NoError(t, err) - req.RemoteAddr = "10.0.0.4:1234" + req.RemoteAddr = "10.0.0.5:1234" w = httptest.NewRecorder() router.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - }) - - t.Run("should_return_correct_error_format", func(t *testing.T) { - w := httptest.NewRecorder() - _, router := gin.CreateTestContext(w) - router.Use(ts.HttpProvider.RateLimitMiddleware()) - router.POST("/graphql", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"data": "ok"}) - }) - - // Exhaust the burst - for i := 0; i < 6; i++ { - req, err := http.NewRequest(http.MethodPost, "/graphql", nil) - require.NoError(t, err) - req.RemoteAddr = "10.0.0.5:1234" - w = httptest.NewRecorder() - router.ServeHTTP(w, req) - } + } - // Check the 429 response body has OAuth2 error format - assert.Equal(t, http.StatusTooManyRequests, w.Code) - assert.Contains(t, w.Body.String(), "rate_limit_exceeded") - assert.Contains(t, w.Body.String(), "error_description") - }) + // Check the 429 response body has OAuth2 error format + assert.Equal(t, http.StatusTooManyRequests, w.Code) + assert.Contains(t, w.Body.String(), "rate_limit_exceeded") + assert.Contains(t, w.Body.String(), "error_description") }) } diff --git a/internal/integration_tests/redirect_uri_validation_test.go b/internal/integration_tests/redirect_uri_validation_test.go index 3dfc830c..56573138 100644 --- a/internal/integration_tests/redirect_uri_validation_test.go +++ b/internal/integration_tests/redirect_uri_validation_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/graph/model" "github.com/authorizerdev/authorizer/internal/refs" ) @@ -15,43 +14,42 @@ import ( // TestRedirectURIRejectsAttacker verifies that forgot_password rejects // attacker-controlled redirect_uri values with explicit AllowedOrigins. func TestRedirectURIRejectsAttacker(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - cfg.AllowedOrigins = []string{"http://localhost:3000"} - cfg.EnableBasicAuthentication = true - cfg.EnableEmailVerification = false - cfg.EnableSignup = true - - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "redirect_test_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - require.NotNil(t, signupRes) - - t.Run("rejects attacker redirect_uri", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef("https://attacker.com/steal"), - }) - assert.Error(t, err) - assert.Nil(t, res) - assert.Contains(t, err.Error(), "invalid redirect URI") + cfg := getTestConfig() + cfg.AllowedOrigins = []string{"http://localhost:3000"} + cfg.EnableBasicAuthentication = true + cfg.EnableEmailVerification = false + cfg.EnableSignup = true + + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "redirect_test_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + require.NotNil(t, signupRes) + + t.Run("rejects attacker redirect_uri", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef("https://attacker.com/steal"), }) + assert.Error(t, err) + assert.Nil(t, res) + assert.Contains(t, err.Error(), "invalid redirect URI") + }) - t.Run("accepts valid redirect_uri", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef("http://localhost:3000/reset"), - }) - assert.NoError(t, err) - assert.NotNil(t, res) + t.Run("accepts valid redirect_uri", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef("http://localhost:3000/reset"), }) + assert.NoError(t, err) + assert.NotNil(t, res) }) } @@ -59,71 +57,70 @@ func TestRedirectURIRejectsAttacker(t *testing.T) { // vulnerability (issue #540). When allowed_origins=["*"] (the default config), // attacker-controlled redirect_uri values must still be rejected. func TestRedirectURIWildcardOrigins(t *testing.T) { - runForEachDB(t, func(t *testing.T, cfg *config.Config) { - cfg.AllowedOrigins = []string{"*"} - cfg.EnableBasicAuthentication = true - cfg.EnableEmailVerification = false - cfg.EnableSignup = true - ts := initTestSetup(t, cfg) - _, ctx := createContext(ts) - - email := "wildcard_redirect_" + uuid.New().String() + "@authorizer.dev" - password := "Password@123" - - signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ - Email: &email, - Password: password, - ConfirmPassword: password, - }) - require.NoError(t, err) - require.NotNil(t, signupRes) - - t.Run("rejects attacker redirect_uri with wildcard origins", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef("https://attacker.com/capture"), - }) - assert.Error(t, err) - assert.Nil(t, res) - assert.Contains(t, err.Error(), "invalid redirect URI") + cfg := getTestConfig() + cfg.AllowedOrigins = []string{"*"} + cfg.EnableBasicAuthentication = true + cfg.EnableEmailVerification = false + cfg.EnableSignup = true + ts := initTestSetup(t, cfg) + _, ctx := createContext(ts) + + email := "wildcard_redirect_" + uuid.New().String() + "@authorizer.dev" + password := "Password@123" + + signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{ + Email: &email, + Password: password, + ConfirmPassword: password, + }) + require.NoError(t, err) + require.NotNil(t, signupRes) + + t.Run("rejects attacker redirect_uri with wildcard origins", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef("https://attacker.com/capture"), }) + assert.Error(t, err) + assert.Nil(t, res) + assert.Contains(t, err.Error(), "invalid redirect URI") + }) - t.Run("allows self-origin redirect_uri with wildcard origins", func(t *testing.T) { - selfURI := "http://" + ts.HttpServer.Listener.Addr().String() + "/app/reset-password" - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef(selfURI), - }) - assert.NoError(t, err) - assert.NotNil(t, res) - assert.NotEmpty(t, res.Message) + t.Run("allows self-origin redirect_uri with wildcard origins", func(t *testing.T) { + selfURI := "http://" + ts.HttpServer.Listener.Addr().String() + "/app/reset-password" + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef(selfURI), }) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.NotEmpty(t, res.Message) + }) - t.Run("works without redirect_uri (uses default)", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - }) - assert.NoError(t, err) - assert.NotNil(t, res) - assert.NotEmpty(t, res.Message) + t.Run("works without redirect_uri (uses default)", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), }) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.NotEmpty(t, res.Message) + }) - t.Run("rejects javascript scheme", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef("javascript:alert(1)"), - }) - assert.Error(t, err) - assert.Nil(t, res) + t.Run("rejects javascript scheme", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef("javascript:alert(1)"), }) + assert.Error(t, err) + assert.Nil(t, res) + }) - t.Run("rejects data scheme", func(t *testing.T) { - res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ - Email: refs.NewStringRef(email), - RedirectURI: refs.NewStringRef("data:text/html,

evil

"), - }) - assert.Error(t, err) - assert.Nil(t, res) + t.Run("rejects data scheme", func(t *testing.T) { + res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{ + Email: refs.NewStringRef(email), + RedirectURI: refs.NewStringRef("data:text/html,

evil

"), }) + assert.Error(t, err) + assert.Nil(t, res) }) } diff --git a/internal/integration_tests/test_helper.go b/internal/integration_tests/test_helper.go index 32878c91..fa6123c2 100644 --- a/internal/integration_tests/test_helper.go +++ b/internal/integration_tests/test_helper.go @@ -6,7 +6,6 @@ import ( "net/http/httptest" "os" "path/filepath" - "strings" "testing" "github.com/gin-gonic/gin" @@ -58,82 +57,9 @@ func createContext(s *testSetup) (*http.Request, context.Context) { return req, ctx } -// dbTestConfig holds database-specific test configuration -type dbTestConfig struct { - DbType string - DbURL string -} - -// getTestDBs returns the list of database configurations to test against. -// It reads the TEST_DBS environment variable (comma-separated list of db types). -// Defaults to "postgres" if not set. -func getTestDBs() []dbTestConfig { - testDBsEnv := os.Getenv("TEST_DBS") - if testDBsEnv == "" { - testDBsEnv = "postgres" - } - - dbTypes := strings.Split(testDBsEnv, ",") - var configs []dbTestConfig - - for _, dbType := range dbTypes { - dbType = strings.TrimSpace(dbType) - if dbType == "" { - continue - } - - dbURL := getDBURL(dbType) - if dbURL != "" { - configs = append(configs, dbTestConfig{ - DbType: dbType, - DbURL: dbURL, - }) - } - } - - return configs -} - -// getDBURL returns the connection URL for a given database type -func getDBURL(dbType string) string { - switch dbType { - case constants.DbTypePostgres: - return "postgres://postgres:postgres@localhost:5434/postgres" - case constants.DbTypeSqlite: - return "test.db" - case constants.DbTypeLibSQL: - return "test.db" - case constants.DbTypeMysql: - return "root:password@tcp(localhost:3306)/authorizer" - case constants.DbTypeMariaDB: - return "root:password@tcp(localhost:3307)/authorizer" - case constants.DbTypeSqlserver: - return "sqlserver://sa:Password123@localhost:1433?database=authorizer" - case constants.DbTypeMongoDB: - return "mongodb://localhost:27017" - case constants.DbTypeArangoDB: - return "http://localhost:8529" - case constants.DbTypeScyllaDB, constants.DbTypeCassandraDB: - return "127.0.0.1:9042" - case constants.DbTypeDynamoDB: - return "http://localhost:8000" - case constants.DbTypeCouchbaseDB: - return "couchbase://localhost" - default: - return "" - } -} - -// getTestConfig returns config for integration tests that expect a single storage backend. -// When TEST_DBS is unset, it behaves like "postgres" (one entry). When TEST_DBS lists -// exactly one database (e.g. make test-dynamodb sets dynamodb), that backend is used. -// If TEST_DBS lists multiple databases, Postgres is used so legacy tests keep a predictable default; -// use runForEachDB to exercise more than one DB. +// getTestConfig returns config for integration tests using SQLite. +// Integration tests validate business logic, not storage compatibility. func getTestConfig() *config.Config { - dbs := getTestDBs() - if len(dbs) == 1 { - return getTestConfigForDB(dbs[0].DbType, dbs[0].DbURL) - } return getTestConfigForDB(constants.DbTypeSqlite, "test.db") } @@ -163,8 +89,9 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config { IsSMSServiceEnabled: true, } - // Set MongoDB-specific config - if dbType == constants.DbTypeMongoDB { + // MongoDB, ArangoDB, Cassandra/Scylla require DatabaseName (keyspace / DB name); see storage New(). + if dbType == constants.DbTypeMongoDB || dbType == constants.DbTypeArangoDB || + dbType == constants.DbTypeScyllaDB || dbType == constants.DbTypeCassandraDB { cfg.DatabaseName = "authorizer_test" } @@ -183,33 +110,6 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config { return cfg } -// runForEachDB runs the given test function against each database specified in TEST_DBS. -// This is the primary way to run tests across multiple database providers. -// -// Usage: -// -// func TestFeature(t *testing.T) { -// runForEachDB(t, func(t *testing.T, cfg *config.Config) { -// ts := initTestSetup(t, cfg) -// _, ctx := createContext(ts) -// // ... test logic -// }) -// } -func runForEachDB(t *testing.T, testFn func(t *testing.T, cfg *config.Config)) { - t.Helper() - dbConfigs := getTestDBs() - if len(dbConfigs) == 0 { - t.Fatal("TEST_DBS produced no runnable database configurations; check TEST_DBS and that each database type resolves to a non-empty URL") - } - - for _, dbCfg := range dbConfigs { - t.Run("db="+dbCfg.DbType, func(t *testing.T) { - cfg := getTestConfigForDB(dbCfg.DbType, dbCfg.DbURL) - testFn(t, cfg) - }) - } -} - // initTestSetup initializes the test setup func initTestSetup(t *testing.T, cfg *config.Config) *testSetup { // Initialize logger diff --git a/internal/memory_store/db/provider_test.go b/internal/memory_store/db/provider_test.go index 347f4aaa..a5a43140 100644 --- a/internal/memory_store/db/provider_test.go +++ b/internal/memory_store/db/provider_test.go @@ -1,7 +1,6 @@ package db import ( - "os" "path/filepath" "testing" "time" @@ -10,16 +9,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/authorizerdev/authorizer/internal/constants" "github.com/authorizerdev/authorizer/internal/storage" ) -// TestDBMemoryStoreProvider tests the database-backed memory store against each database in -// TEST_DBS (same semantics as integration_tests). Redis is irrelevant here. +// TestDBMemoryStoreProvider tests the database-backed memory store against SQLite. func TestDBMemoryStoreProvider(t *testing.T) { entries := storageTestDBEntriesFromEnv() if len(entries) == 0 { - t.Fatal("TEST_DBS produced no database configurations") + t.Fatal("no database configurations for memory store DB tests") } for _, e := range entries { @@ -27,10 +24,6 @@ func TestDBMemoryStoreProvider(t *testing.T) { tempSQLite := filepath.Join(t.TempDir(), "memory_store_test.db") dbURL := resolveSQLiteTestURL(e.dbType, e.dbURL, tempSQLite) cfg := buildStorageTestConfigForMemoryStore(e.dbType, dbURL) - if cfg.DatabaseType == constants.DbTypeDynamoDB { - os.Unsetenv("AWS_ACCESS_KEY_ID") - os.Unsetenv("AWS_SECRET_ACCESS_KEY") - } log := zerolog.New(zerolog.NewTestWriter(t)) storageProvider, err := storage.New(cfg, &storage.Dependencies{Log: &log}) diff --git a/internal/memory_store/db/test_config_test.go b/internal/memory_store/db/test_config_test.go index 0dc48b95..fdd6982a 100644 --- a/internal/memory_store/db/test_config_test.go +++ b/internal/memory_store/db/test_config_test.go @@ -1,66 +1,21 @@ package db import ( - "os" - "strings" - "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/constants" ) -// storageDBEntry matches one entry from TEST_DBS (same URLs as internal/integration_tests -// getTestDBs / getDBURL — keep these in sync when adding backends). +// storageDBEntry matches one entry for memory store DB tests. +// Memory store DB tests only run against SQLite — storage-layer compatibility +// is covered by internal/storage tests. type storageDBEntry struct { dbType string dbURL string } func storageTestDBEntriesFromEnv() []storageDBEntry { - testDBsEnv := os.Getenv("TEST_DBS") - if testDBsEnv == "" { - testDBsEnv = "postgres" - } - var out []storageDBEntry - for _, dbType := range strings.Split(testDBsEnv, ",") { - dbType = strings.TrimSpace(dbType) - if dbType == "" { - continue - } - u := dbURLForMemoryStoreStorageTest(dbType) - if u == "" { - continue - } - out = append(out, storageDBEntry{dbType: dbType, dbURL: u}) - } - return out -} - -func dbURLForMemoryStoreStorageTest(dbType string) string { - switch dbType { - case constants.DbTypePostgres: - return "postgres://postgres:postgres@localhost:5434/postgres" - case constants.DbTypeSqlite: - return "test.db" - case constants.DbTypeLibSQL: - return "test.db" - case constants.DbTypeMysql: - return "root:password@tcp(localhost:3306)/authorizer" - case constants.DbTypeMariaDB: - return "root:password@tcp(localhost:3307)/authorizer" - case constants.DbTypeSqlserver: - return "sqlserver://sa:Password123@localhost:1433?database=authorizer" - case constants.DbTypeMongoDB: - return "mongodb://localhost:27017" - case constants.DbTypeArangoDB: - return "http://localhost:8529" - case constants.DbTypeScyllaDB, constants.DbTypeCassandraDB: - return "127.0.0.1:9042" - case constants.DbTypeDynamoDB: - return "http://127.0.0.1:8000" - case constants.DbTypeCouchbaseDB: - return "couchbase://localhost" - default: - return "" + return []storageDBEntry{ + {dbType: constants.DbTypeSqlite, dbURL: "test.db"}, } } @@ -96,19 +51,5 @@ func buildStorageTestConfigForMemoryStore(dbType, dbURL string) *config.Config { IsSMSServiceEnabled: true, } - if dbType == constants.DbTypeMongoDB { - cfg.DatabaseName = "authorizer_test" - } - - if dbType == constants.DbTypeCouchbaseDB { - cfg.DatabaseUsername = "Administrator" - cfg.DatabasePassword = "password" - cfg.CouchBaseBucket = "authorizer_test" - } - - if dbType == constants.DbTypeDynamoDB { - cfg.AWSRegion = "us-east-1" - } - return cfg } diff --git a/internal/storage/db/couchbase/health_check.go b/internal/storage/db/couchbase/health_check.go index 92ae64f8..7a5666ac 100644 --- a/internal/storage/db/couchbase/health_check.go +++ b/internal/storage/db/couchbase/health_check.go @@ -5,11 +5,13 @@ import ( "fmt" "github.com/couchbase/gocb/v2" + + "github.com/authorizerdev/authorizer/internal/storage/schemas" ) // HealthCheck verifies that the Couchbase backend is reachable and responsive func (p *provider) HealthCheck(ctx context.Context) error { - query := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1", p.scopeName) + query := fmt.Sprintf("SELECT 1 FROM %s.%s LIMIT 1", p.scopeName, schemas.Collections.User) _, err := p.db.Query(query, &gocb.QueryOptions{ Context: ctx, })