diff --git a/.claude/settings.local.json b/.claude/settings.local.json
index 0958a2a17..35566b3af 100644
--- a/.claude/settings.local.json
+++ b/.claude/settings.local.json
@@ -37,7 +37,10 @@
"Bash(do echo:*)",
"Read(//Users/lakhansamani/personal/authorizer/authorizer/internal/storage/db/**)",
"Bash(done)",
- "Bash(grep:*)"
+ "Bash(grep:*)",
+ "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/integration_tests/)",
+ "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/memory_store/...)",
+ "Bash(TEST_DBS=\"sqlite\" go test -p 1 -v -count=1 ./internal/storage/)"
]
}
}
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index d12b2ebee..9b9386eba 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -75,7 +75,7 @@ jobs:
output: "trivy-results.sarif"
severity: "CRITICAL,HIGH"
- name: Upload Trivy results
- uses: github/codeql-action/upload-sarif@v3
+ uses: github/codeql-action/upload-sarif@v4
if: always()
with:
sarif_file: "trivy-results.sarif"
diff --git a/CLAUDE.md b/CLAUDE.md
index 328621ddf..a6c7a3e0b 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -19,14 +19,16 @@ make build-app # Build login UI (web/app)
make build-dashboard # Build admin UI (web/dashboard)
make generate-graphql # Regenerate after schema.graphqls change
-# Testing (TEST_DBS env var selects databases, default: postgres)
-make test # Docker Postgres (default)
-make test-sqlite # SQLite in-memory (no Docker)
-make test-mongodb # Docker MongoDB
+# Testing
+# Integration tests always use SQLite (no Docker needed).
+# Storage provider tests honour TEST_DBS (default: all 7 DBs, needs Docker).
+# Optional: TEST_ENABLE_REDIS=1 runs Redis memory_store unit tests (Redis on localhost:6380).
+make test # SQLite integration + storage (via TEST_DBS)
+make test-sqlite # SQLite everywhere (no Docker)
make test-all-db # ALL 7 databases (postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase)
-# Single test against specific DBs
-go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSignup ./internal/integration_tests/
+# Single test
+go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v -run TestSignup ./internal/integration_tests/
```
## Architecture (Quick Reference)
@@ -40,6 +42,7 @@ go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSign
- Token management: `internal/token/`
- Tests: `internal/integration_tests/`
- Frontend: `web/app/` (user UI) | `web/dashboard/` (admin UI)
+- Optional NULL semantics across SQL/document DBs and DynamoDB: `docs/storage-optional-null-fields.md`
**Pattern**: Every subsystem uses `Dependencies` struct + `New()` → `Provider` interface.
@@ -49,7 +52,7 @@ go clean --testcache && TEST_DBS="sqlite,postgres" go test -p 1 -v -run TestSign
2. **Schema changes must update ALL 13+ database providers**
3. **Run `make generate-graphql`** after editing `schema.graphqls`
4. **Security**: parameterized queries only, `crypto/rand` for tokens, `crypto/subtle` for comparisons, never log secrets
-5. **Tests**: integration tests with real DBs, table-driven subtests, testify assertions
+5. **Tests**: integration tests use SQLite via `getTestConfig()` (no `runForEachDB`); storage tests cover all DBs via `TEST_DBS`; testify assertions
6. **NEVER commit to main** — always work on a feature branch (`feat/`, `fix/`, `security/`, `chore/`), push to the branch, and create a merge request. Main must stay deployable.
## AI Agent Roles
diff --git a/Dockerfile b/Dockerfile
index 9e0614167..038f20919 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,11 @@
# syntax=docker/dockerfile:1.4
# Use BuildKit for cache mounts (faster CI: DOCKER_BUILDKIT=1)
+#
+# Alpine v3.23 main still ships busybox 1.37.0-r30 (e.g. CVE-2025-60876); edge/main has r31+.
+# Pin busybox from edge until the stable branch backports it. See alpine/aports work item #17940.
FROM golang:1.25-alpine3.23 AS go-builder
+ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main
+RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31"
WORKDIR /authorizer
ARG TARGETPLATFORM
@@ -33,6 +38,8 @@ RUN --mount=type=cache,target=/go/pkg/mod \
chmod 755 build/${GOOS}/${GOARCH}/authorizer
FROM alpine:3.23.3 AS node-builder
+ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main
+RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31"
WORKDIR /authorizer
COPY web/app/package*.json web/app/
COPY web/dashboard/package*.json web/dashboard/
@@ -47,6 +54,8 @@ COPY web/dashboard web/dashboard
RUN cd web/app && npm run build && cd ../dashboard && npm run build
FROM alpine:3.23.3
+ARG ALPINE_EDGE_MAIN=https://dl-cdn.alpinelinux.org/alpine/edge/main
+RUN apk add --no-cache -X "${ALPINE_EDGE_MAIN}" "busybox>=1.37.0-r31"
ARG TARGETARCH=amd64
diff --git a/Makefile b/Makefile
index 3c42d81b6..5d5871301 100644
--- a/Makefile
+++ b/Makefile
@@ -3,6 +3,11 @@ DEFAULT_VERSION=0.1.0-local
VERSION := $(or $(VERSION),$(DEFAULT_VERSION))
DOCKER_IMAGE ?= authorizerdev/authorizer:$(VERSION)
+# Full module test run. Storage provider tests honour TEST_DBS (defaults to all).
+# Integration tests and memory_store/db tests always use SQLite.
+# Redis memory_store tests run only when TEST_ENABLE_REDIS=1.
+GO_TEST_ALL := go test -p 1 -v ./...
+
.PHONY: all bootstrap build build-app build-dashboard build-local-image build-push-image trivy-scan
all: build build-app build-dashboard
@@ -39,51 +44,50 @@ clean:
dev:
go run main.go --database-type=sqlite --database-url=test.db --jwt-type=HS256 --jwt-secret=test --admin-secret=admin --client-id=123456 --client-secret=secret
-test: test-cleanup test-docker-up
- go clean --testcache && TEST_DBS="postgres" go test -p 1 -v ./...
- $(MAKE) test-cleanup
+test:
+ go clean --testcache && TEST_DBS="sqlite" $(GO_TEST_ALL)
test-postgres: test-cleanup-postgres
docker run -d --name authorizer_postgres -p 5434:5432 -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=postgres postgres
sleep 3
- go clean --testcache && TEST_DBS="postgres" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="postgres" $(GO_TEST_ALL)
docker rm -vf authorizer_postgres
test-sqlite:
- go clean --testcache && TEST_DBS="sqlite" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="sqlite" $(GO_TEST_ALL)
test-mongodb: test-cleanup-mongodb
docker run -d --name authorizer_mongodb_db -p 27017:27017 mongo:4.4.15
sleep 3
- go clean --testcache && TEST_DBS="mongodb" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="mongodb" $(GO_TEST_ALL)
docker rm -vf authorizer_mongodb_db
test-scylladb: test-cleanup-scylladb
docker run -d --name authorizer_scylla_db -p 9042:9042 scylladb/scylla
sleep 15
- go clean --testcache && TEST_DBS="scylladb" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="scylladb" $(GO_TEST_ALL)
docker rm -vf authorizer_scylla_db
test-arangodb: test-cleanup-arangodb
docker run -d --name authorizer_arangodb -p 8529:8529 -e ARANGO_NO_AUTH=1 arangodb/arangodb:3.10.3
sleep 5
- go clean --testcache && TEST_DBS="arangodb" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="arangodb" $(GO_TEST_ALL)
docker rm -vf authorizer_arangodb
test-dynamodb: test-cleanup-dynamodb
docker run -d --name authorizer_dynamodb -p 8000:8000 amazon/dynamodb-local:latest
sleep 3
- go clean --testcache && TEST_DBS="dynamodb" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="dynamodb" $(GO_TEST_ALL)
docker rm -vf authorizer_dynamodb
test-couchbase: test-cleanup-couchbase
docker run -d --name authorizer_couchbase -p 8091-8097:8091-8097 -p 11210:11210 -p 11207:11207 -p 18091-18095:18091-18095 -p 18096:18096 -p 18097:18097 couchbase:latest
sh scripts/couchbase-test.sh
- go clean --testcache && TEST_DBS="couchbase" go test -p 1 -v ./...
+ go clean --testcache && TEST_DBS="couchbase" $(GO_TEST_ALL)
docker rm -vf authorizer_couchbase
-test-all-db: test-cleanup test-docker-up
- go clean --testcache && TEST_DBS="postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase" go test -p 1 -v ./...
+test-all-db: test-cleanup test-docker-up test-cleanup
+ go clean --testcache && TEST_DBS="postgres,sqlite,mongodb,arangodb,scylladb,dynamodb,couchbase" $(GO_TEST_ALL)
$(MAKE) test-cleanup
# Start all test database containers
diff --git a/cmd/root.go b/cmd/root.go
index a00dcb006..3ac48f8db 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -55,7 +55,7 @@ var (
defaultTwitterScopes = []string{"tweet.read", "users.read"}
defaultRobloxScopes = []string{"openid", "profile"}
// Default RPS cap per IP; raised from 10 to reduce false positives on busy UIs.
- defaultRateLimitRPS = float64(30)
+ defaultRateLimitRPS = 30
defaultRateLimitBurst = 20
)
@@ -161,7 +161,7 @@ func init() {
f.BoolVar(&rootArgs.config.DisableAdminHeaderAuth, "disable-admin-header-auth", false, "Disable admin authentication via X-Authorizer-Admin-Secret header")
// Rate limiting flags
- f.Float64Var(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting")
+ f.IntVar(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting")
f.IntVar(&rootArgs.config.RateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, "Maximum burst size per IP for rate limiting")
f.BoolVar(&rootArgs.config.RateLimitFailClosed, "rate-limit-fail-closed", false, "On rate-limit backend errors, reject with 503 instead of allowing the request")
diff --git a/docs/storage-optional-null-fields.md b/docs/storage-optional-null-fields.md
new file mode 100644
index 000000000..469dddfd0
--- /dev/null
+++ b/docs/storage-optional-null-fields.md
@@ -0,0 +1,46 @@
+# Optional fields and NULL semantics across storage providers
+
+This document explains how **nullable** fields (especially `*int64` timestamps such as `email_verified_at`, `phone_number_verified_at`, `revoked_timestamp`) are **stored**, **updated**, and why we **do not** add broad `json`/`bson` **`omitempty`** tags to those struct fields.
+
+## Goal
+
+Application code uses Go **nil pointers** to mean “unset / not verified / not revoked.” Updates must **clear** a previously set value when the pointer is **nil**, matching SQL **NULL** semantics.
+
+## Behaviour by provider
+
+| Provider | Typical update path | Nil pointer on update |
+|----------|---------------------|------------------------|
+| **SQL (GORM)** | `Save(&user)` | Written as **SQL NULL**. |
+| **Cassandra** | JSON → map → `UPDATE` | Nil map values become **`= null`** in CQL. |
+| **MongoDB** | `UpdateOne` with `$set` and the `User` struct | Driver marshals nil pointers as **BSON Null** when the field is **not** `omitempty`, so the field is cleared in the document. |
+| **Couchbase** | `Upsert` full document | `encoding/json` encodes nil pointers as JSON **`null`** unless the field uses `json:",omitempty"`, in which case the key is **omitted** and old values can persist. |
+| **ArangoDB** | `UpdateDocument` with struct | Encoding follows JSON-style rules; nil pointers become **`null`** when not omitted by tags. |
+| **DynamoDB** | `UpdateItem` with **SET** from marshalled attributes | Nil pointers are **omitted from SET** (see `internal/storage/db/dynamodb/marshal.go`). Attributes are **not** removed automatically, so **explicit REMOVE** is required to clear a previously stored attribute. Implemented for users in `internal/storage/db/dynamodb/user.go` (`updateByHashKeyWithRemoves`, `userDynamoRemoveAttrsIfNil`). Reads may normalize `0` → unset via `normalizeUserOptionalPtrs`. |
+
+## Why not use `omitempty` on `json` / `bson` for nullable auth fields?
+
+For **document** databases, **`omitempty`** means: *if this pointer is nil, do not include this key in the encoded payload.*
+
+During an **update**, omitting a key usually means **“do not change this field”**, not **“set to null.”** That reproduces the DynamoDB-class bug: the old value remains.
+
+Therefore, for fields where **nil must clear** stored state, keep **`json` / `bson` tags without `omitempty`** (as in `internal/storage/schemas/user.go`) unless every call site is proven to do a **full document replace** and you have verified the driver behaviour end-to-end.
+
+MongoDB’s own guidance aligns with this: `omitempty` skips marshaling empty values, which is wrong when you need to persist **null** to clear a field in `$set`.
+
+## DynamoDB specifics
+
+- **PutItem**: Omitting nil pointers keeps items small; optional attributes may be absent (same idea as “omitempty” on write, but implemented in custom `marshalStruct` by skipping nil pointers).
+- **UpdateItem**: Only **SET** attributes present in the marshalled map. Clearing requires **`REMOVE`** for the corresponding attribute names when the Go field is nil.
+- Do **not** rely on adding `dynamo:",omitempty"` alone to “fix” updates: the custom marshaller already skips nil pointers; the gap was **REMOVE** on update, not tag-based omission.
+
+## Related code
+
+- Schema: `internal/storage/schemas/user.go`
+- DynamoDB user update + REMOVE list: `internal/storage/db/dynamodb/user.go`
+- DynamoDB update helper: `internal/storage/db/dynamodb/ops.go` (`updateByHashKeyWithRemoves`)
+- DynamoDB marshal/unmarshal: `internal/storage/db/dynamodb/marshal.go`
+
+## TEST_DBS and memory_store tests
+
+- **`internal/memory_store/db`**: Runs one subtest per entry in `TEST_DBS` (URLs aligned with `internal/integration_tests/test_helper.go` — keep `test_config_test.go` in sync when adding backends).
+- **`internal/memory_store` (Redis / in-memory)**: Not driven by `TEST_DBS`. In-memory tests always run; **Redis** subtests run only when **`TEST_ENABLE_REDIS=1`** (or `true`) and Redis is reachable (e.g. `localhost:6380` in `provider_test.go`). See `redisMemoryStoreTestsEnabled` in `provider_test.go`.
diff --git a/go.mod b/go.mod
index 9d5314c61..8609a543b 100644
--- a/go.mod
+++ b/go.mod
@@ -5,17 +5,22 @@ go 1.25.5
require (
github.com/99designs/gqlgen v0.17.73
github.com/arangodb/go-driver v1.6.0
- github.com/aws/aws-sdk-go v1.47.4
- github.com/coreos/go-oidc/v3 v3.6.0
+ github.com/aws/aws-sdk-go-v2 v1.41.5
+ github.com/aws/aws-sdk-go-v2/config v1.32.14
+ github.com/aws/aws-sdk-go-v2/credentials v1.19.14
+ github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37
+ github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37
+ github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1
+ github.com/aws/smithy-go v1.24.3
+ github.com/coreos/go-oidc/v3 v3.17.0
github.com/couchbase/gocb/v2 v2.6.4
github.com/ekristen/gorm-libsql v0.0.0-20231101204708-6e113112bcc2
github.com/gin-gonic/gin v1.9.1
github.com/glebarez/sqlite v1.10.0
- github.com/go-jose/go-jose/v4 v4.1.3
+ github.com/go-jose/go-jose/v4 v4.1.4
github.com/gocql/gocql v1.6.0
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/google/uuid v1.6.0
- github.com/guregu/dynamo v1.20.2
github.com/pquerna/otp v1.4.0
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.6.3
@@ -41,10 +46,21 @@ require (
github.com/agnivade/levenshtein v1.2.1 // indirect
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 // indirect
github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
+ github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
+ github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect
- github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/couchbase/gocbcore/v10 v10.2.8 // indirect
@@ -55,7 +71,6 @@ require (
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
- github.com/go-jose/go-jose/v3 v3.0.4 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
@@ -77,7 +92,6 @@ require (
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
- github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
diff --git a/go.sum b/go.sum
index 347e8ed58..53c54f418 100644
--- a/go.sum
+++ b/go.sum
@@ -34,9 +34,44 @@ github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e h1:Xg+hGrY2
github.com/arangodb/go-velocypack v0.0.0-20200318135517-5af53c29c67e/go.mod h1:mq7Shfa/CaixoDxiyAAc5jZ6CVBAyPaNQCGS7mkj4Ho=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
-github.com/aws/aws-sdk-go v1.44.306/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
-github.com/aws/aws-sdk-go v1.47.4 h1:IyhNbmPt+5ldi5HNzv7ZnXiqSglDMaJiZlzj4Yq3qnk=
-github.com/aws/aws-sdk-go v1.47.4/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
+github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
+github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
+github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
+github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
+github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37 h1:5jh3kI8vDKuAcNa87z3eytYvBCE4Tyk2S8vjdcLoMek=
+github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.37/go.mod h1:Q1MNQdT5LEs31od7h6zHZF2a6jjl+oI6/kBH3QYipoY=
+github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37 h1:8clVLUTp0bBO7SZx5Nw0Q1XhF4ItE8W1hboDK9eHIN0=
+github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.37/go.mod h1:U0KNwNOuLUPi/vdIWh6qYVBK4v8zKjN6BPpbsHSmPPA=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
+github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1 h1:Vk+a1j2pXZHkkYqHmEdpwe8eX6NDtFSBGfzuauMEWYQ=
+github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1/go.mod h1:wHrWCwhXZrl2PuCP5t36UTacy9fCHDJ+vw1r3qxTL5M=
+github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14 h1:Cnlebj/RmCf/4O3q4suVLLB/SBhbQf4zCQre6Dav+4E=
+github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.14/go.mod h1:lB9U9zBLviMTUHcHaaJ/vDBkRpHxV5775VJcdnm1DFk=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
+github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21 h1:FTg+rVAPx1W21jsO57pxDS1ESy9a/JLFoaHeFubflJA=
+github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.21/go.mod h1:92xP4VIS1yO3eF2NPBaHGF4cmyZow8TmFzSaz1nNgzo=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
+github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
+github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
@@ -53,15 +88,13 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
-github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
-github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
-github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o=
-github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc=
+github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
+github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/couchbase/gocb/v2 v2.6.4 h1:o5k5JnxYkgamVL9svx+vbXc7vKF5X72tNt/qORs+L30=
github.com/couchbase/gocb/v2 v2.6.4/go.mod h1:W/cHlBGfendPh53WzRaF1KIXTovI9DaI7EPeeqIsnmc=
@@ -98,10 +131,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.10.0 h1:u4gt8y7OND/cCei/NMHmfbLxF6xP2wgKcT/BJf2pYkc=
github.com/glebarez/sqlite v1.10.0/go.mod h1:IJ+lfSOmiekhQsFTJRx/lHtGYmCdtAiTaf5wI9u5uHA=
-github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY=
-github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ=
-github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
-github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
+github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
+github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
@@ -153,7 +184,6 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -168,8 +198,6 @@ github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/z
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/guregu/dynamo v1.20.2 h1:b3DZX68Nv0kCWGbMMRbLukWttkALUoiomtzSrDCDiJo=
-github.com/guregu/dynamo v1.20.2/go.mod h1:rNSE8PT6IaNbcEno0/i0y6E5XFHDWUyLRxpGlF8O5CU=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
@@ -196,10 +224,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
-github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
-github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
-github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
-github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
@@ -350,7 +374,6 @@ golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
-golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=
@@ -367,7 +390,6 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
-golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
@@ -381,7 +403,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -397,34 +418,28 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
-golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
-golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
-golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -460,7 +475,6 @@ gopkg.in/sourcemap.v1 v1.0.5/go.mod h1:2RlvNNSMglmRrcvhfuzp4hQHwOtjxlbjX7UPY/GXb
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
-gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
diff --git a/internal/config/config.go b/internal/config/config.go
index 77f947950..a43984e25 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -254,7 +254,7 @@ type Config struct {
// Rate Limiting
// RateLimitRPS is the maximum requests per second per IP
- RateLimitRPS float64
+ RateLimitRPS int
// RateLimitBurst is the maximum burst size per IP
RateLimitBurst int
// RateLimitFailClosed rejects requests when the rate limit backend errors (default: fail-open).
diff --git a/internal/integration_tests/audit_logs_test.go b/internal/integration_tests/audit_logs_test.go
deleted file mode 100644
index 8155d0513..000000000
--- a/internal/integration_tests/audit_logs_test.go
+++ /dev/null
@@ -1,243 +0,0 @@
-package integration_tests
-
-import (
- "testing"
- "time"
-
- "github.com/google/uuid"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "github.com/authorizerdev/authorizer/internal/config"
- "github.com/authorizerdev/authorizer/internal/constants"
- "github.com/authorizerdev/authorizer/internal/graph/model"
- "github.com/authorizerdev/authorizer/internal/storage/schemas"
-)
-
-func TestAuditLogs(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- t.Run("should add and list audit logs", func(t *testing.T) {
- auditLog := &schemas.AuditLog{
- ActorID: uuid.New().String(),
- ActorType: constants.AuditActorTypeUser,
- ActorEmail: "test@example.com",
- Action: constants.AuditLoginSuccessEvent,
- ResourceType: constants.AuditResourceTypeSession,
- ResourceID: uuid.New().String(),
- IPAddress: "127.0.0.1",
- UserAgent: "test-agent",
- }
-
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
- assert.NotEmpty(t, auditLog.ID)
- assert.NotZero(t, auditLog.CreatedAt)
-
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- logs, pag, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{})
- require.NoError(t, err)
- assert.NotNil(t, pag)
- assert.GreaterOrEqual(t, len(logs), 1)
- })
-
- t.Run("should filter audit logs by action", func(t *testing.T) {
- uniqueAction := "test_action_" + uuid.New().String()[:8]
-
- auditLog := &schemas.AuditLog{
- ActorID: uuid.New().String(),
- ActorType: constants.AuditActorTypeUser,
- ActorEmail: "filter@example.com",
- Action: uniqueAction,
- ResourceType: constants.AuditResourceTypeUser,
- }
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
-
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "action": uniqueAction,
- })
- require.NoError(t, err)
- assert.Equal(t, 1, len(logs))
- assert.Equal(t, uniqueAction, logs[0].Action)
- })
-
- t.Run("should filter audit logs by actor_id", func(t *testing.T) {
- actorID := uuid.New().String()
-
- auditLog := &schemas.AuditLog{
- ActorID: actorID,
- ActorType: constants.AuditActorTypeAdmin,
- ActorEmail: "admin@example.com",
- Action: constants.AuditAdminUserUpdatedEvent,
- ResourceType: constants.AuditResourceTypeUser,
- }
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
-
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "actor_id": actorID,
- })
- require.NoError(t, err)
- assert.Equal(t, 1, len(logs))
- assert.Equal(t, actorID, logs[0].ActorID)
- })
-
- t.Run("should filter audit logs by resource_type", func(t *testing.T) {
- uniqueAction := "res_filter_" + uuid.New().String()[:8]
-
- auditLog := &schemas.AuditLog{
- ActorID: uuid.New().String(),
- ActorType: constants.AuditActorTypeAdmin,
- Action: uniqueAction,
- ResourceType: constants.AuditResourceTypeWebhook,
- ResourceID: uuid.New().String(),
- }
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
-
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "resource_type": constants.AuditResourceTypeWebhook,
- "action": uniqueAction,
- })
- require.NoError(t, err)
- assert.Equal(t, 1, len(logs))
- assert.Equal(t, constants.AuditResourceTypeWebhook, logs[0].ResourceType)
- })
-
- t.Run("should filter audit logs by timestamp range", func(t *testing.T) {
- uniqueAction := "ts_filter_" + uuid.New().String()[:8]
- now := time.Now().Unix()
-
- auditLog := &schemas.AuditLog{
- ActorID: uuid.New().String(),
- ActorType: constants.AuditActorTypeUser,
- Action: uniqueAction,
- ResourceType: constants.AuditResourceTypeSession,
- }
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
-
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- // Filter from 1 second ago to now+1
- fromTs := now - 1
- toTs := now + 1
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "action": uniqueAction,
- "from_timestamp": fromTs,
- "to_timestamp": toTs,
- })
- require.NoError(t, err)
- assert.Equal(t, 1, len(logs))
-
- // Filter with future range should return no results
- futureTs := now + 3600
- logs, _, err = ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "action": uniqueAction,
- "from_timestamp": futureTs,
- })
- require.NoError(t, err)
- assert.Equal(t, 0, len(logs))
- })
-
- t.Run("should not mutate caller pagination", func(t *testing.T) {
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- _, returnedPag, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{})
- require.NoError(t, err)
- assert.NotSame(t, pagination, returnedPag, "should return a new pagination object")
- })
-
- t.Run("should delete audit logs before created_at", func(t *testing.T) {
- uniqueAction := "cleanup_test_" + uuid.New().String()[:8]
-
- oldLog := &schemas.AuditLog{
- ActorID: uuid.New().String(),
- ActorType: constants.AuditActorTypeUser,
- ActorEmail: "system@example.com",
- Action: uniqueAction,
- CreatedAt: time.Now().Add(-24 * time.Hour).Unix(),
- ResourceType: constants.AuditResourceTypeUser,
- }
- err := ts.StorageProvider.AddAuditLog(ctx, oldLog)
- require.NoError(t, err)
-
- // Delete logs older than 1 hour ago
- before := time.Now().Add(-1 * time.Hour).Unix()
- err = ts.StorageProvider.DeleteAuditLogsBefore(ctx, before)
- require.NoError(t, err)
-
- // Verify the old log is deleted by filtering for its unique action
- pagination := &model.Pagination{
- Limit: 10,
- Offset: 0,
- }
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "action": uniqueAction,
- })
- require.NoError(t, err)
- assert.Equal(t, 0, len(logs))
- })
-
- t.Run("should preserve all fields in audit log round-trip", func(t *testing.T) {
- actorID := uuid.New().String()
- resourceID := uuid.New().String()
- uniqueAction := "roundtrip_" + uuid.New().String()[:8]
-
- auditLog := &schemas.AuditLog{
- ActorID: actorID,
- ActorType: constants.AuditActorTypeAdmin,
- ActorEmail: "admin@test.com",
- Action: uniqueAction,
- ResourceType: constants.AuditResourceTypeEmailTemplate,
- ResourceID: resourceID,
- IPAddress: "192.168.1.1",
- UserAgent: "Mozilla/5.0 Test",
- Metadata: `{"key":"value"}`,
- }
- err := ts.StorageProvider.AddAuditLog(ctx, auditLog)
- require.NoError(t, err)
-
- pagination := &model.Pagination{Limit: 10, Offset: 0}
- logs, _, err := ts.StorageProvider.ListAuditLogs(ctx, pagination, map[string]interface{}{
- "action": uniqueAction,
- })
- require.NoError(t, err)
- require.Len(t, logs, 1)
-
- got := logs[0]
- assert.Equal(t, actorID, got.ActorID)
- assert.Equal(t, constants.AuditActorTypeAdmin, got.ActorType)
- assert.Equal(t, "admin@test.com", got.ActorEmail)
- assert.Equal(t, uniqueAction, got.Action)
- assert.Equal(t, constants.AuditResourceTypeEmailTemplate, got.ResourceType)
- assert.Equal(t, resourceID, got.ResourceID)
- assert.Equal(t, "192.168.1.1", got.IPAddress)
- assert.Equal(t, "Mozilla/5.0 Test", got.UserAgent)
- assert.Equal(t, `{"key":"value"}`, got.Metadata)
- assert.NotZero(t, got.CreatedAt)
- })
- })
-}
diff --git a/internal/integration_tests/custom_access_token_script_test.go b/internal/integration_tests/custom_access_token_script_test.go
index 80a58fbf0..9b87dc734 100644
--- a/internal/integration_tests/custom_access_token_script_test.go
+++ b/internal/integration_tests/custom_access_token_script_test.go
@@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authorizerdev/authorizer/internal/config"
"github.com/authorizerdev/authorizer/internal/graph/model"
)
@@ -32,269 +31,268 @@ func parseTestJWTClaims(t *testing.T, tokenString string) jwt.MapClaims {
// TestCustomAccessTokenScript tests the custom access token script functionality
// including the 5-second execution timeout added for DoS protection.
func TestCustomAccessTokenScript(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- t.Run("should_add_custom_claims_from_script", func(t *testing.T) {
- cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
- return { custom_claim: "hello", user_email: user.email };
- }`
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "custom_script_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- require.NoError(t, err)
- require.NotNil(t, loginRes)
- require.NotNil(t, loginRes.AccessToken)
-
- // Parse the access token and verify custom claims are present
- claims := parseTestJWTClaims(t, *loginRes.AccessToken)
- assert.Equal(t, "hello", claims["custom_claim"])
- assert.Equal(t, email, claims["user_email"])
+ cfg := getTestConfig()
+
+ t.Run("should_add_custom_claims_from_script", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
+ return { custom_claim: "hello", user_email: user.email };
+ }`
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "custom_script_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
+ })
+ require.NoError(t, err)
+
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+ require.NotNil(t, loginRes.AccessToken)
+
+ // Parse the access token and verify custom claims are present
+ claims := parseTestJWTClaims(t, *loginRes.AccessToken)
+ assert.Equal(t, "hello", claims["custom_claim"])
+ assert.Equal(t, email, claims["user_email"])
+ })
+
+ t.Run("should_not_override_reserved_claims", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
+ return { sub: "hacked", iss: "hacked", roles: ["admin"], custom_field: "allowed" };
+ }`
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "reserved_claims_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
})
+ require.NoError(t, err)
- t.Run("should_not_override_reserved_claims", func(t *testing.T) {
- cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
- return { sub: "hacked", iss: "hacked", roles: ["admin"], custom_field: "allowed" };
- }`
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "reserved_claims_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- require.NoError(t, err)
- require.NotNil(t, loginRes)
-
- claims := parseTestJWTClaims(t, *loginRes.AccessToken)
- // Reserved claims must NOT be overridden
- assert.NotEqual(t, "hacked", claims["sub"])
- assert.NotEqual(t, "hacked", claims["iss"])
- // Roles should NOT be overridden to admin
- roles, ok := claims["roles"].([]interface{})
- if ok {
- for _, r := range roles {
- assert.NotEqual(t, "admin", r, "reserved 'roles' claim must not be overridden by script")
- }
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+
+ claims := parseTestJWTClaims(t, *loginRes.AccessToken)
+ // Reserved claims must NOT be overridden
+ assert.NotEqual(t, "hacked", claims["sub"])
+ assert.NotEqual(t, "hacked", claims["iss"])
+ // Roles should NOT be overridden to admin
+ roles, ok := claims["roles"].([]interface{})
+ if ok {
+ for _, r := range roles {
+ assert.NotEqual(t, "admin", r, "reserved 'roles' claim must not be overridden by script")
}
- // Custom (non-reserved) claims should be added
- assert.Equal(t, "allowed", claims["custom_field"])
+ }
+ // Custom (non-reserved) claims should be added
+ assert.Equal(t, "allowed", claims["custom_field"])
+ })
+
+ t.Run("should_timeout_infinite_loop_script", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
+ while(true) {} // infinite loop — should be killed after 5 seconds
+ return { never: "reached" };
+ }`
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "timeout_script_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
})
+ require.NoError(t, err)
+
+ // Measure execution time to verify the timeout works
+ start := time.Now()
- t.Run("should_timeout_infinite_loop_script", func(t *testing.T) {
- cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
- while(true) {} // infinite loop — should be killed after 5 seconds
- return { never: "reached" };
- }`
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "timeout_script_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- // Measure execution time to verify the timeout works
- start := time.Now()
-
- // Login should still succeed — the timeout is handled gracefully,
- // custom claims are skipped but the token is still created.
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- elapsed := time.Since(start)
-
- require.NoError(t, err)
- require.NotNil(t, loginRes)
- require.NotNil(t, loginRes.AccessToken)
-
- // The token should be valid but without the custom claim from the timed-out script
- claims := parseTestJWTClaims(t, *loginRes.AccessToken)
- assert.Nil(t, claims["never"], "timed-out script claims must not appear in token")
- // Standard claims should still be present
- assert.NotEmpty(t, claims["sub"])
- assert.NotEmpty(t, claims["iss"])
-
- // Verify the timeout kicked in within a reasonable range (5-8 seconds for the access + id token)
- assert.Less(t, elapsed, 20*time.Second, "login with infinite loop script should complete within 20 seconds (two 5s timeouts + overhead)")
+ // Login should still succeed — the timeout is handled gracefully,
+ // custom claims are skipped but the token is still created.
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
})
+ elapsed := time.Since(start)
+
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+ require.NotNil(t, loginRes.AccessToken)
+
+ // The token should be valid but without the custom claim from the timed-out script
+ claims := parseTestJWTClaims(t, *loginRes.AccessToken)
+ assert.Nil(t, claims["never"], "timed-out script claims must not appear in token")
+ // Standard claims should still be present
+ assert.NotEmpty(t, claims["sub"])
+ assert.NotEmpty(t, claims["iss"])
+
+ // Verify the timeout kicked in within a reasonable range (5-8 seconds for the access + id token)
+ assert.Less(t, elapsed, 20*time.Second, "login with infinite loop script should complete within 20 seconds (two 5s timeouts + overhead)")
+ })
+
+ t.Run("should_handle_script_error_gracefully", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
+ throw new Error("intentional error");
+ }`
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
- t.Run("should_handle_script_error_gracefully", func(t *testing.T) {
- cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
- throw new Error("intentional error");
- }`
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "error_script_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- // Login should still succeed even with a broken script
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- require.NoError(t, err)
- require.NotNil(t, loginRes)
- require.NotNil(t, loginRes.AccessToken)
+ email := "error_script_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
})
+ require.NoError(t, err)
- t.Run("should_work_without_custom_script", func(t *testing.T) {
- cfg.CustomAccessTokenScript = ""
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "no_script_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- require.NoError(t, err)
- require.NotNil(t, loginRes)
- require.NotNil(t, loginRes.AccessToken)
-
- claims := parseTestJWTClaims(t, *loginRes.AccessToken)
- assert.NotEmpty(t, claims["sub"])
- // Ensure no unexpected claims were added
- assert.Nil(t, claims["custom_claim"])
+ // Login should still succeed even with a broken script
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
})
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+ require.NotNil(t, loginRes.AccessToken)
+ })
- t.Run("should_have_custom_claims_in_id_token_too", func(t *testing.T) {
- cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
- return { team: "engineering" };
- }`
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "id_token_script_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
-
- loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
- Email: &email,
- Password: password,
- })
- require.NoError(t, err)
- require.NotNil(t, loginRes)
- require.NotNil(t, loginRes.IDToken)
-
- // The custom script runs for both access token and ID token
- claims := parseTestJWTClaims(t, *loginRes.IDToken)
- assert.Equal(t, "engineering", claims["team"])
+ t.Run("should_work_without_custom_script", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = ""
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "no_script_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
})
+ require.NoError(t, err)
+
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+ require.NotNil(t, loginRes.AccessToken)
+
+ claims := parseTestJWTClaims(t, *loginRes.AccessToken)
+ assert.NotEmpty(t, claims["sub"])
+ // Ensure no unexpected claims were added
+ assert.Nil(t, claims["custom_claim"])
})
-}
-// TestClientIDMismatchMetric verifies that client ID mismatch records a security metric.
-func TestClientIDMismatchMetric(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
+ t.Run("should_have_custom_claims_in_id_token_too", func(t *testing.T) {
+ cfg.CustomAccessTokenScript = `function(user, tokenPayload) {
+ return { team: "engineering" };
+ }`
ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
- router := setupTestRouter(ts)
+ email := "id_token_script_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
- t.Run("records_metric_on_client_id_mismatch", func(t *testing.T) {
- // Send request with wrong client ID to /graphql (not dashboard/app)
- body := `{"query":"{ meta { version } }"}`
- w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{
- "Content-Type": "application/json",
- "X-Authorizer-Client-ID": "wrong-client-id",
- "X-Authorizer-URL": "http://localhost:8080",
- "Origin": "http://localhost:3000",
- })
+ _, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
+ })
+ require.NoError(t, err)
+
+ loginRes, err := ts.GraphQLProvider.Login(ctx, &model.LoginRequest{
+ Email: &email,
+ Password: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, loginRes)
+ require.NotNil(t, loginRes.IDToken)
- assert.Equal(t, 400, w.Code)
- assert.Contains(t, w.Body.String(), "invalid_client_id")
+ // The custom script runs for both access token and ID token
+ claims := parseTestJWTClaims(t, *loginRes.IDToken)
+ assert.Equal(t, "engineering", claims["team"])
+ })
+}
- // Check that the security metric was recorded
- metricsBody := getMetricsBody(t, router)
- assert.Contains(t, metricsBody, `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`)
+// TestClientIDMismatchMetric verifies that client ID mismatch records a security metric.
+func TestClientIDMismatchMetric(t *testing.T) {
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+
+ router := setupTestRouter(ts)
+
+ t.Run("records_metric_on_client_id_mismatch", func(t *testing.T) {
+ // Send request with wrong client ID to /graphql (not dashboard/app)
+ body := `{"query":"{ meta { version } }"}`
+ w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{
+ "Content-Type": "application/json",
+ "X-Authorizer-Client-ID": "wrong-client-id",
+ "X-Authorizer-URL": "http://localhost:8080",
+ "Origin": "http://localhost:3000",
})
- t.Run("no_metric_for_valid_client_id", func(t *testing.T) {
- body := `{"query":"{ meta { version } }"}`
- w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{
- "Content-Type": "application/json",
- "X-Authorizer-Client-ID": cfg.ClientID,
- "X-Authorizer-URL": "http://localhost:8080",
- "Origin": "http://localhost:3000",
- })
-
- // Should not be 400
- assert.NotEqual(t, 400, w.Code)
+ assert.Equal(t, 400, w.Code)
+ assert.Contains(t, w.Body.String(), "invalid_client_id")
+
+ // Check that the security metric was recorded
+ metricsBody := getMetricsBody(t, router)
+ assert.Contains(t, metricsBody, `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`)
+ })
+
+ t.Run("no_metric_for_valid_client_id", func(t *testing.T) {
+ body := `{"query":"{ meta { version } }"}`
+ w := sendTestRequest(t, router, "POST", "/graphql", body, map[string]string{
+ "Content-Type": "application/json",
+ "X-Authorizer-Client-ID": cfg.ClientID,
+ "X-Authorizer-URL": "http://localhost:8080",
+ "Origin": "http://localhost:3000",
})
- t.Run("no_metric_for_dashboard_path_mismatch", func(t *testing.T) {
- mark := `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`
- before := prometheusCounterSample(t, getMetricsBody(t, router), mark)
- w := sendTestRequest(t, router, "GET", "/dashboard/", "", map[string]string{
- "X-Authorizer-Client-ID": "wrong-client-id",
- })
- assert.Equal(t, 400, w.Code)
- after := prometheusCounterSample(t, getMetricsBody(t, router), mark)
- assert.Equal(t, before, after, "dashboard path mismatch must not increment client_id_mismatch metric")
+ // Should not be 400
+ assert.NotEqual(t, 400, w.Code)
+ })
+
+ t.Run("no_metric_for_dashboard_path_mismatch", func(t *testing.T) {
+ mark := `authorizer_security_events_total{event="client_id_mismatch",reason="invalid_client_id"}`
+ before := prometheusCounterSample(t, getMetricsBody(t, router), mark)
+ w := sendTestRequest(t, router, "GET", "/dashboard/", "", map[string]string{
+ "X-Authorizer-Client-ID": "wrong-client-id",
})
+ assert.Equal(t, 400, w.Code)
+ after := prometheusCounterSample(t, getMetricsBody(t, router), mark)
+ assert.Equal(t, before, after, "dashboard path mismatch must not increment client_id_mismatch metric")
+ })
- t.Run("records_client_id_header_missing_metric", func(t *testing.T) {
- mark := "authorizer_client_id_header_missing_total"
- before := prometheusCounterSample(t, getMetricsBody(t, router), mark)
- sendTestRequest(t, router, "POST", "/graphql", `{"query":"{ meta { version } }"}`, map[string]string{
- "Content-Type": "application/json",
- "X-Authorizer-URL": "http://localhost:8080",
- "Origin": "http://localhost:3000",
- })
- after := prometheusCounterSample(t, getMetricsBody(t, router), mark)
- assert.Greater(t, after, before)
+ t.Run("records_client_id_header_missing_metric", func(t *testing.T) {
+ mark := "authorizer_client_id_header_missing_total"
+ before := prometheusCounterSample(t, getMetricsBody(t, router), mark)
+ sendTestRequest(t, router, "POST", "/graphql", `{"query":"{ meta { version } }"}`, map[string]string{
+ "Content-Type": "application/json",
+ "X-Authorizer-URL": "http://localhost:8080",
+ "Origin": "http://localhost:3000",
})
+ after := prometheusCounterSample(t, getMetricsBody(t, router), mark)
+ assert.Greater(t, after, before)
})
}
diff --git a/internal/integration_tests/health_test.go b/internal/integration_tests/health_test.go
index cdb171639..213d0df11 100644
--- a/internal/integration_tests/health_test.go
+++ b/internal/integration_tests/health_test.go
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authorizerdev/authorizer/internal/config"
"github.com/authorizerdev/authorizer/internal/http_handlers"
"github.com/authorizerdev/authorizer/internal/storage"
)
@@ -29,108 +28,105 @@ func (*failingHealthStorage) HealthCheck(ctx context.Context) error {
// TestHealthHandler verifies the /healthz liveness probe endpoint behaviour.
func TestHealthHandler(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
- router := gin.New()
- router.GET("/healthz", ts.HttpProvider.HealthHandler())
+ router := gin.New()
+ router.GET("/healthz", ts.HttpProvider.HealthHandler())
- t.Run("returns_200_when_storage_is_healthy", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
- require.NoError(t, err)
+ t.Run("returns_200_when_storage_is_healthy", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
+ require.NoError(t, err)
- router.ServeHTTP(w, req)
+ router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, http.StatusOK, w.Code)
- var body map[string]interface{}
- err = json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, "ok", body["status"], "healthy response must contain status=ok")
- })
+ var body map[string]interface{}
+ err = json.Unmarshal(w.Body.Bytes(), &body)
+ require.NoError(t, err)
+ assert.Equal(t, "ok", body["status"], "healthy response must contain status=ok")
})
}
// TestReadyHandler verifies the /readyz readiness probe endpoint behaviour.
func TestReadyHandler(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
- router := gin.New()
- router.GET("/readyz", ts.HttpProvider.ReadyHandler())
+ router := gin.New()
+ router.GET("/readyz", ts.HttpProvider.ReadyHandler())
- t.Run("returns_200_when_storage_is_ready", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/readyz", nil)
- require.NoError(t, err)
+ t.Run("returns_200_when_storage_is_ready", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/readyz", nil)
+ require.NoError(t, err)
- router.ServeHTTP(w, req)
+ router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, http.StatusOK, w.Code)
- var body map[string]interface{}
- err = json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, "ready", body["status"], "readiness response must contain status=ready")
- })
+ var body map[string]interface{}
+ err = json.Unmarshal(w.Body.Bytes(), &body)
+ require.NoError(t, err)
+ assert.Equal(t, "ready", body["status"], "readiness response must contain status=ready")
})
}
// TestHealthHandlersUnhealthyStorage verifies liveness/readiness and DB metrics when HealthCheck fails.
func TestHealthHandlersUnhealthyStorage(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger()
- realStorage, err := storage.New(cfg, &storage.Dependencies{Log: &logger})
+ cfg := getTestConfig()
+ logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger()
+ realStorage, err := storage.New(cfg, &storage.Dependencies{Log: &logger})
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = realStorage.Close() })
+
+ wrapped := &failingHealthStorage{Provider: realStorage}
+ httpProv, err := http_handlers.New(cfg, &http_handlers.Dependencies{
+ Log: &logger,
+ StorageProvider: wrapped,
+ })
+ require.NoError(t, err)
+
+ router := gin.New()
+ router.GET("/healthz", httpProv.HealthHandler())
+ router.GET("/readyz", httpProv.ReadyHandler())
+ router.GET("/metrics", httpProv.MetricsHandler())
+
+ t.Run("healthz_returns_503", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusServiceUnavailable, w.Code)
+ var body map[string]interface{}
+ require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
+ assert.Equal(t, "unhealthy", body["status"])
+ })
+
+ t.Run("readyz_returns_503", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/readyz", nil)
require.NoError(t, err)
- t.Cleanup(func() { _ = realStorage.Close() })
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusServiceUnavailable, w.Code)
+ var body map[string]interface{}
+ require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
+ assert.Equal(t, "not ready", body["status"])
+ })
- wrapped := &failingHealthStorage{Provider: realStorage}
- httpProv, err := http_handlers.New(cfg, &http_handlers.Dependencies{
- Log: &logger,
- StorageProvider: wrapped,
- })
+ t.Run("records_unhealthy_db_check_metric", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ require.Equal(t, http.StatusServiceUnavailable, w.Code)
- router := gin.New()
- router.GET("/healthz", httpProv.HealthHandler())
- router.GET("/readyz", httpProv.ReadyHandler())
- router.GET("/metrics", httpProv.MetricsHandler())
-
- t.Run("healthz_returns_503", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusServiceUnavailable, w.Code)
- var body map[string]interface{}
- require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
- assert.Equal(t, "unhealthy", body["status"])
- })
-
- t.Run("readyz_returns_503", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/readyz", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusServiceUnavailable, w.Code)
- var body map[string]interface{}
- require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
- assert.Equal(t, "not ready", body["status"])
- })
-
- t.Run("records_unhealthy_db_check_metric", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- require.Equal(t, http.StatusServiceUnavailable, w.Code)
-
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
- assert.Contains(t, w2.Body.String(), `authorizer_db_health_check_total{status="unhealthy"}`)
- })
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+ assert.Contains(t, w2.Body.String(), `authorizer_db_health_check_total{status="unhealthy"}`)
})
}
diff --git a/internal/integration_tests/metrics_test.go b/internal/integration_tests/metrics_test.go
index c3672ff59..e3ed4b1f7 100644
--- a/internal/integration_tests/metrics_test.go
+++ b/internal/integration_tests/metrics_test.go
@@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authorizerdev/authorizer/internal/config"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/metrics"
"github.com/authorizerdev/authorizer/internal/refs"
@@ -21,263 +20,231 @@ import (
// TestMetricsEndpoint verifies the /metrics endpoint serves Prometheus format.
func TestMetricsEndpoint(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
-
- router := gin.New()
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("returns_200_with_prometheus_format", func(t *testing.T) {
- // Trigger some metrics so they appear in output
- metrics.RecordAuthEvent("test", "test")
- metrics.RecordSecurityEvent("test", "test")
- metrics.RecordGraphQLError("test")
- metrics.RecordClientIDHeaderMissing()
- metrics.DBHealthCheckTotal.WithLabelValues("test").Inc()
-
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
-
- router.ServeHTTP(w, req)
-
- assert.Equal(t, http.StatusOK, w.Code)
- body := w.Body.String()
- // Gauge metrics always appear
- assert.Contains(t, body, "authorizer_active_sessions")
- // Counter/histogram metrics appear after first increment
- assert.Contains(t, body, "authorizer_auth_events_total")
- assert.Contains(t, body, "authorizer_security_events_total")
- assert.Contains(t, body, "authorizer_graphql_errors_total")
- assert.Contains(t, body, "authorizer_db_health_check_total")
- assert.Contains(t, body, "authorizer_client_id_header_missing_total")
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
- t.Run("post_metrics_is_not_get_ok", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodPost, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.NotEqual(t, http.StatusOK, w.Code)
- })
+ router := gin.New()
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+
+ t.Run("returns_200_with_prometheus_format", func(t *testing.T) {
+ // Trigger some metrics so they appear in output
+ metrics.RecordAuthEvent("test", "test")
+ metrics.RecordSecurityEvent("test", "test")
+ metrics.RecordGraphQLError("test")
+ metrics.RecordClientIDHeaderMissing()
+ metrics.DBHealthCheckTotal.WithLabelValues("test").Inc()
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+
+ router.ServeHTTP(w, req)
+
+ assert.Equal(t, http.StatusOK, w.Code)
+ body := w.Body.String()
+ // Gauge metrics always appear
+ assert.Contains(t, body, "authorizer_active_sessions")
+ // Counter/histogram metrics appear after first increment
+ assert.Contains(t, body, "authorizer_auth_events_total")
+ assert.Contains(t, body, "authorizer_security_events_total")
+ assert.Contains(t, body, "authorizer_graphql_errors_total")
+ assert.Contains(t, body, "authorizer_db_health_check_total")
+ assert.Contains(t, body, "authorizer_client_id_header_missing_total")
+ })
+
+ t.Run("post_metrics_is_not_get_ok", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodPost, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.NotEqual(t, http.StatusOK, w.Code)
})
}
// TestMetricsMiddleware verifies the HTTP metrics middleware records request count.
func TestMetricsMiddleware(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
-
- router := gin.New()
- router.Use(ts.HttpProvider.MetricsMiddleware())
- router.GET("/healthz", ts.HttpProvider.HealthHandler())
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("records_http_request_metrics", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
-
- // Check metrics endpoint has recorded it
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
-
- body := w2.Body.String()
- assert.Contains(t, body, `authorizer_http_requests_total{method="GET",path="/healthz",status="200"}`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
- t.Run("skips_http_metrics_for_excluded_paths", func(t *testing.T) {
- router.GET("/app/foo", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
+ router := gin.New()
+ router.Use(ts.HttpProvider.MetricsMiddleware())
+ router.GET("/healthz", ts.HttpProvider.HealthHandler())
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/app/foo", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
+ t.Run("records_http_request_metrics", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusOK, w.Code)
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
+ // Check metrics endpoint has recorded it
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+
+ body := w2.Body.String()
+ assert.Contains(t, body, `authorizer_http_requests_total{method="GET",path="/healthz",status="200"}`)
+ })
- body := w2.Body.String()
- assert.NotContains(t, body, `authorizer_http_requests_total{method="GET",path="/app/foo"`)
+ t.Run("skips_http_metrics_for_excluded_paths", func(t *testing.T) {
+ router.GET("/app/foo", func(c *gin.Context) {
+ c.Status(http.StatusOK)
})
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/app/foo", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusOK, w.Code)
+
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+
+ body := w2.Body.String()
+ assert.NotContains(t, body, `authorizer_http_requests_total{method="GET",path="/app/foo"`)
})
}
// TestDBHealthCheckMetrics verifies health check outcomes are tracked.
func TestDBHealthCheckMetrics(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
-
- router := gin.New()
- router.GET("/healthz", ts.HttpProvider.HealthHandler())
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("records_healthy_db_check", func(t *testing.T) {
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
-
- var body map[string]interface{}
- err = json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, "ok", body["status"])
-
- // Verify metric was recorded
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
-
- metricsBody := w2.Body.String()
- assert.Contains(t, metricsBody, `authorizer_db_health_check_total{status="healthy"}`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+
+ router := gin.New()
+ router.GET("/healthz", ts.HttpProvider.HealthHandler())
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+
+ t.Run("records_healthy_db_check", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/healthz", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+
+ var body map[string]interface{}
+ err = json.Unmarshal(w.Body.Bytes(), &body)
+ require.NoError(t, err)
+ assert.Equal(t, "ok", body["status"])
+
+ // Verify metric was recorded
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+
+ metricsBody := w2.Body.String()
+ assert.Contains(t, metricsBody, `authorizer_db_health_check_total{status="healthy"}`)
})
}
// TestAuthEventMetrics verifies that auth events are recorded in metrics.
func TestAuthEventMetrics(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- router := gin.New()
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- email := "metrics_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- t.Run("records_signup_and_login_success", func(t *testing.T) {
- signupReq := &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- }
- res, err := ts.GraphQLProvider.SignUp(ctx, signupReq)
- require.NoError(t, err)
- assert.NotNil(t, res)
-
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
- assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="signup",status="success"}`)
-
- loginReq := &model.LoginRequest{
- Email: &email,
- Password: password,
- }
- loginRes, err := ts.GraphQLProvider.Login(ctx, loginReq)
- require.NoError(t, err)
- assert.NotNil(t, loginRes)
-
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
- assert.Contains(t, w2.Body.String(), `authorizer_auth_events_total{event="login",status="success"}`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ router := gin.New()
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+
+ email := "metrics_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ t.Run("records_signup_and_login_success", func(t *testing.T) {
+ signupReq := &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
+ }
+ res, err := ts.GraphQLProvider.SignUp(ctx, signupReq)
+ require.NoError(t, err)
+ assert.NotNil(t, res)
- t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) {
- loginReq := &model.LoginRequest{
- Email: &email,
- Password: "wrong_password",
- }
- _, err := ts.GraphQLProvider.Login(ctx, loginReq)
- assert.Error(t, err)
-
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
-
- body := w.Body.String()
- assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`)
- assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`)
- })
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="signup",status="success"}`)
+
+ loginReq := &model.LoginRequest{
+ Email: &email,
+ Password: password,
+ }
+ loginRes, err := ts.GraphQLProvider.Login(ctx, loginReq)
+ require.NoError(t, err)
+ assert.NotNil(t, loginRes)
+
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+ assert.Contains(t, w2.Body.String(), `authorizer_auth_events_total{event="login",status="success"}`)
+ })
+
+ t.Run("records_login_failure_on_bad_credentials", func(t *testing.T) {
+ loginReq := &model.LoginRequest{
+ Email: &email,
+ Password: "wrong_password",
+ }
+ _, err := ts.GraphQLProvider.Login(ctx, loginReq)
+ assert.Error(t, err)
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+
+ body := w.Body.String()
+ assert.Contains(t, body, `authorizer_auth_events_total{event="login",status="failure"}`)
+ assert.Contains(t, body, `authorizer_security_events_total{event="invalid_credentials"`)
})
}
// TestGraphQLErrorMetrics verifies that GraphQL errors in 200 responses are captured.
func TestGraphQLErrorMetrics(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
-
- router := gin.New()
- router.Use(ts.HttpProvider.ContextMiddleware())
- router.Use(ts.HttpProvider.CORSMiddleware())
- router.POST("/graphql", ts.HttpProvider.GraphqlHandler())
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("captures_graphql_errors_in_200_responses", func(t *testing.T) {
- body := `{"query":"mutation { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}`
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body))
- require.NoError(t, err)
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("x-authorizer-url", "http://localhost:8080")
- req.Header.Set("Origin", "http://localhost:3000")
- router.ServeHTTP(w, req)
-
- // GraphQL always returns 200 even with errors
- assert.Equal(t, http.StatusOK, w.Code)
-
- // Check metrics endpoint
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
-
- metricsBody := w2.Body.String()
- assert.Contains(t, metricsBody, "authorizer_graphql_request_duration_seconds")
- assert.Contains(t, metricsBody, `authorizer_graphql_errors_total{operation="anonymous"}`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
- t.Run("captures_graphql_errors_with_named_operation", func(t *testing.T) {
- body := `{"operationName":"LoginOp","query":"mutation LoginOp { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}`
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body))
- require.NoError(t, err)
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("x-authorizer-url", "http://localhost:8080")
- req.Header.Set("Origin", "http://localhost:3000")
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
-
- w2 := httptest.NewRecorder()
- req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w2, req2)
- loginOpLabel := metrics.GraphQLOperationPrometheusLabel("LoginOp")
- assert.Contains(t, w2.Body.String(), `authorizer_graphql_errors_total{operation="`+loginOpLabel+`"}`)
- })
- })
-}
+ router := gin.New()
+ router.Use(ts.HttpProvider.ContextMiddleware())
+ router.Use(ts.HttpProvider.CORSMiddleware())
+ router.POST("/graphql", ts.HttpProvider.GraphqlHandler())
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-// TestClientIDHeaderMissingMiddlewareMetric verifies empty X-Authorizer-Client-ID increments the counter.
-func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
+ t.Run("captures_graphql_errors_in_200_responses", func(t *testing.T) {
+ body := `{"query":"mutation { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}`
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body))
+ require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("x-authorizer-url", "http://localhost:8080")
+ req.Header.Set("Origin", "http://localhost:3000")
+ router.ServeHTTP(w, req)
- router := gin.New()
- router.Use(ts.HttpProvider.ClientCheckMiddleware())
- router.GET("/probe", func(c *gin.Context) {
- c.Status(http.StatusOK)
- })
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+ // GraphQL always returns 200 even with errors
+ assert.Equal(t, http.StatusOK, w.Code)
+ // Check metrics endpoint
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+
+ metricsBody := w2.Body.String()
+ assert.Contains(t, metricsBody, "authorizer_graphql_request_duration_seconds")
+ assert.Contains(t, metricsBody, `authorizer_graphql_errors_total{operation="anonymous"}`)
+ })
+
+ t.Run("captures_graphql_errors_with_named_operation", func(t *testing.T) {
+ body := `{"operationName":"LoginOp","query":"mutation LoginOp { login(params: {email: \"nonexistent@test.com\", password: \"wrong\"}) { message } }"}`
w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/probe", nil)
+ req, err := http.NewRequest(http.MethodPost, "/graphql", strings.NewReader(body))
require.NoError(t, err)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("x-authorizer-url", "http://localhost:8080")
+ req.Header.Set("Origin", "http://localhost:3000")
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
@@ -285,8 +252,34 @@ func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) {
req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
require.NoError(t, err)
router.ServeHTTP(w2, req2)
- assert.Contains(t, w2.Body.String(), "authorizer_client_id_header_missing_total")
+ loginOpLabel := metrics.GraphQLOperationPrometheusLabel("LoginOp")
+ assert.Contains(t, w2.Body.String(), `authorizer_graphql_errors_total{operation="`+loginOpLabel+`"}`)
+ })
+}
+
+// TestClientIDHeaderMissingMiddlewareMetric verifies empty X-Authorizer-Client-ID increments the counter.
+func TestClientIDHeaderMissingMiddlewareMetric(t *testing.T) {
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+
+ router := gin.New()
+ router.Use(ts.HttpProvider.ClientCheckMiddleware())
+ router.GET("/probe", func(c *gin.Context) {
+ c.Status(http.StatusOK)
})
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/probe", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusOK, w.Code)
+
+ w2 := httptest.NewRecorder()
+ req2, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w2, req2)
+ assert.Contains(t, w2.Body.String(), "authorizer_client_id_header_missing_total")
}
// TestRecordAuthEventHelpers verifies the helper functions work correctly.
@@ -309,71 +302,69 @@ func TestRecordAuthEventHelpers(t *testing.T) {
// TestAdminLoginMetrics verifies admin login records metrics.
func TestAdminLoginMetrics(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- router := gin.New()
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("records_admin_login_failure", func(t *testing.T) {
- loginReq := &model.AdminLoginRequest{
- AdminSecret: "wrong-secret",
- }
- _, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq)
- assert.Error(t, err)
-
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
-
- body := w.Body.String()
- assert.Contains(t, body, `authorizer_auth_events_total{event="admin_login",status="failure"}`)
- assert.Contains(t, body, `authorizer_security_events_total{event="invalid_admin_secret"`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
- t.Run("records_admin_login_success", func(t *testing.T) {
- loginReq := &model.AdminLoginRequest{
- AdminSecret: cfg.AdminSecret,
- }
- res, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq)
- require.NoError(t, err)
- assert.NotNil(t, res)
+ router := gin.New()
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
+ t.Run("records_admin_login_failure", func(t *testing.T) {
+ loginReq := &model.AdminLoginRequest{
+ AdminSecret: "wrong-secret",
+ }
+ _, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq)
+ assert.Error(t, err)
- assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="admin_login",status="success"}`)
- })
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+
+ body := w.Body.String()
+ assert.Contains(t, body, `authorizer_auth_events_total{event="admin_login",status="failure"}`)
+ assert.Contains(t, body, `authorizer_security_events_total{event="invalid_admin_secret"`)
+ })
+
+ t.Run("records_admin_login_success", func(t *testing.T) {
+ loginReq := &model.AdminLoginRequest{
+ AdminSecret: cfg.AdminSecret,
+ }
+ res, err := ts.GraphQLProvider.AdminLogin(ctx, loginReq)
+ require.NoError(t, err)
+ assert.NotNil(t, res)
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+
+ assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="admin_login",status="success"}`)
})
}
// TestForgotPasswordMetrics verifies forgot password records metrics.
func TestForgotPasswordMetrics(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- router := gin.New()
- router.GET("/metrics", ts.HttpProvider.MetricsHandler())
-
- t.Run("records_forgot_password_failure_for_nonexistent_user", func(t *testing.T) {
- nonExistentEmail := "nonexistent_metrics@authorizer.dev"
- forgotReq := &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(nonExistentEmail),
- }
- _, err := ts.GraphQLProvider.ForgotPassword(ctx, forgotReq)
- assert.Error(t, err)
-
- w := httptest.NewRecorder()
- req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
- require.NoError(t, err)
- router.ServeHTTP(w, req)
-
- assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="forgot_password",status="failure"}`)
- })
+ cfg := getTestConfig()
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ router := gin.New()
+ router.GET("/metrics", ts.HttpProvider.MetricsHandler())
+
+ t.Run("records_forgot_password_failure_for_nonexistent_user", func(t *testing.T) {
+ nonExistentEmail := "nonexistent_metrics@authorizer.dev"
+ forgotReq := &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(nonExistentEmail),
+ }
+ _, err := ts.GraphQLProvider.ForgotPassword(ctx, forgotReq)
+ assert.Error(t, err)
+
+ w := httptest.NewRecorder()
+ req, err := http.NewRequest(http.MethodGet, "/metrics", nil)
+ require.NoError(t, err)
+ router.ServeHTTP(w, req)
+
+ assert.Contains(t, w.Body.String(), `authorizer_auth_events_total{event="forgot_password",status="failure"}`)
})
}
diff --git a/internal/integration_tests/rate_limit_test.go b/internal/integration_tests/rate_limit_test.go
index 968e1ff15..ad2612143 100644
--- a/internal/integration_tests/rate_limit_test.go
+++ b/internal/integration_tests/rate_limit_test.go
@@ -16,132 +16,131 @@ import (
)
func TestRateLimitMiddleware(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- // Set low rate limit for testing
- cfg.RateLimitRPS = 5
- cfg.RateLimitBurst = 5
-
- ts := initTestSetup(t, cfg)
-
- t.Run("should_allow_requests_within_limit", func(t *testing.T) {
- w := httptest.NewRecorder()
- _, router := gin.CreateTestContext(w)
- router.Use(ts.HttpProvider.RateLimitMiddleware())
- router.POST("/graphql", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"data": "ok"})
- })
-
- // First request should succeed
+ cfg := getTestConfig()
+ // Set low rate limit for testing
+ cfg.RateLimitRPS = 5
+ cfg.RateLimitBurst = 5
+
+ ts := initTestSetup(t, cfg)
+
+ t.Run("should_allow_requests_within_limit", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ _, router := gin.CreateTestContext(w)
+ router.Use(ts.HttpProvider.RateLimitMiddleware())
+ router.POST("/graphql", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"data": "ok"})
+ })
+
+ // First request should succeed
+ req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
+ require.NoError(t, err)
+ req.RemoteAddr = "192.168.1.1:1234"
+
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusOK, w.Code)
+ })
+
+ t.Run("should_reject_requests_over_limit", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ _, router := gin.CreateTestContext(w)
+ router.Use(ts.HttpProvider.RateLimitMiddleware())
+ router.POST("/graphql", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"data": "ok"})
+ })
+
+ // Exhaust the burst
+ for i := 0; i < 5; i++ {
req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
require.NoError(t, err)
- req.RemoteAddr = "192.168.1.1:1234"
-
+ req.RemoteAddr = "10.0.0.1:1234"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
+ }
+
+ // Next request should be rejected
+ req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
+ require.NoError(t, err)
+ req.RemoteAddr = "10.0.0.1:1234"
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusTooManyRequests, w.Code)
+ assert.Contains(t, w.Header().Get("Retry-After"), "1")
+ })
+
+ t.Run("should_not_rate_limit_exempt_paths", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ _, router := gin.CreateTestContext(w)
+ router.Use(ts.HttpProvider.RateLimitMiddleware())
+ router.GET("/health", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
+ })
+ router.GET("/.well-known/openid-configuration", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"issuer": "test"})
})
- t.Run("should_reject_requests_over_limit", func(t *testing.T) {
- w := httptest.NewRecorder()
- _, router := gin.CreateTestContext(w)
- router.Use(ts.HttpProvider.RateLimitMiddleware())
- router.POST("/graphql", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"data": "ok"})
- })
-
- // Exhaust the burst
- for i := 0; i < 5; i++ {
- req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
+ exemptPaths := []string{"/health", "/.well-known/openid-configuration"}
+ for _, path := range exemptPaths {
+ // Make many requests - none should be limited
+ for i := 0; i < 10; i++ {
+ req, err := http.NewRequest(http.MethodGet, path, nil)
require.NoError(t, err)
- req.RemoteAddr = "10.0.0.1:1234"
+ req.RemoteAddr = "10.0.0.2:1234"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
+ assert.Equal(t, http.StatusOK, w.Code, "path %s request %d should not be rate limited", path, i)
}
+ }
+ })
- // Next request should be rejected
+ t.Run("should_isolate_rate_limits_per_ip", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ _, router := gin.CreateTestContext(w)
+ router.Use(ts.HttpProvider.RateLimitMiddleware())
+ router.POST("/graphql", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"data": "ok"})
+ })
+
+ // Exhaust burst for IP A
+ for i := 0; i < 5; i++ {
req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
require.NoError(t, err)
- req.RemoteAddr = "10.0.0.1:1234"
+ req.RemoteAddr = "10.0.0.3:1234"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusTooManyRequests, w.Code)
- assert.Contains(t, w.Header().Get("Retry-After"), "1")
- })
+ }
- t.Run("should_not_rate_limit_exempt_paths", func(t *testing.T) {
- w := httptest.NewRecorder()
- _, router := gin.CreateTestContext(w)
- router.Use(ts.HttpProvider.RateLimitMiddleware())
- router.GET("/health", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"status": "ok"})
- })
- router.GET("/.well-known/openid-configuration", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"issuer": "test"})
- })
-
- exemptPaths := []string{"/health", "/.well-known/openid-configuration"}
- for _, path := range exemptPaths {
- // Make many requests - none should be limited
- for i := 0; i < 10; i++ {
- req, err := http.NewRequest(http.MethodGet, path, nil)
- require.NoError(t, err)
- req.RemoteAddr = "10.0.0.2:1234"
- w = httptest.NewRecorder()
- router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code, "path %s request %d should not be rate limited", path, i)
- }
- }
- })
+ // IP B should still be allowed
+ req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
+ require.NoError(t, err)
+ req.RemoteAddr = "10.0.0.4:1234"
+ w = httptest.NewRecorder()
+ router.ServeHTTP(w, req)
+ assert.Equal(t, http.StatusOK, w.Code)
+ })
- t.Run("should_isolate_rate_limits_per_ip", func(t *testing.T) {
- w := httptest.NewRecorder()
- _, router := gin.CreateTestContext(w)
- router.Use(ts.HttpProvider.RateLimitMiddleware())
- router.POST("/graphql", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"data": "ok"})
- })
-
- // Exhaust burst for IP A
- for i := 0; i < 5; i++ {
- req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
- require.NoError(t, err)
- req.RemoteAddr = "10.0.0.3:1234"
- w = httptest.NewRecorder()
- router.ServeHTTP(w, req)
- }
+ t.Run("should_return_correct_error_format", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ _, router := gin.CreateTestContext(w)
+ router.Use(ts.HttpProvider.RateLimitMiddleware())
+ router.POST("/graphql", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"data": "ok"})
+ })
- // IP B should still be allowed
+ // Exhaust the burst
+ for i := 0; i < 6; i++ {
req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
require.NoError(t, err)
- req.RemoteAddr = "10.0.0.4:1234"
+ req.RemoteAddr = "10.0.0.5:1234"
w = httptest.NewRecorder()
router.ServeHTTP(w, req)
- assert.Equal(t, http.StatusOK, w.Code)
- })
-
- t.Run("should_return_correct_error_format", func(t *testing.T) {
- w := httptest.NewRecorder()
- _, router := gin.CreateTestContext(w)
- router.Use(ts.HttpProvider.RateLimitMiddleware())
- router.POST("/graphql", func(c *gin.Context) {
- c.JSON(http.StatusOK, gin.H{"data": "ok"})
- })
-
- // Exhaust the burst
- for i := 0; i < 6; i++ {
- req, err := http.NewRequest(http.MethodPost, "/graphql", nil)
- require.NoError(t, err)
- req.RemoteAddr = "10.0.0.5:1234"
- w = httptest.NewRecorder()
- router.ServeHTTP(w, req)
- }
+ }
- // Check the 429 response body has OAuth2 error format
- assert.Equal(t, http.StatusTooManyRequests, w.Code)
- assert.Contains(t, w.Body.String(), "rate_limit_exceeded")
- assert.Contains(t, w.Body.String(), "error_description")
- })
+ // Check the 429 response body has OAuth2 error format
+ assert.Equal(t, http.StatusTooManyRequests, w.Code)
+ assert.Contains(t, w.Body.String(), "rate_limit_exceeded")
+ assert.Contains(t, w.Body.String(), "error_description")
})
}
diff --git a/internal/integration_tests/redirect_uri_validation_test.go b/internal/integration_tests/redirect_uri_validation_test.go
index 3dfc830cb..565731387 100644
--- a/internal/integration_tests/redirect_uri_validation_test.go
+++ b/internal/integration_tests/redirect_uri_validation_test.go
@@ -7,7 +7,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authorizerdev/authorizer/internal/config"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/refs"
)
@@ -15,43 +14,42 @@ import (
// TestRedirectURIRejectsAttacker verifies that forgot_password rejects
// attacker-controlled redirect_uri values with explicit AllowedOrigins.
func TestRedirectURIRejectsAttacker(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- cfg.AllowedOrigins = []string{"http://localhost:3000"}
- cfg.EnableBasicAuthentication = true
- cfg.EnableEmailVerification = false
- cfg.EnableSignup = true
-
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "redirect_test_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
- signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
- require.NotNil(t, signupRes)
-
- t.Run("rejects attacker redirect_uri", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef("https://attacker.com/steal"),
- })
- assert.Error(t, err)
- assert.Nil(t, res)
- assert.Contains(t, err.Error(), "invalid redirect URI")
+ cfg := getTestConfig()
+ cfg.AllowedOrigins = []string{"http://localhost:3000"}
+ cfg.EnableBasicAuthentication = true
+ cfg.EnableEmailVerification = false
+ cfg.EnableSignup = true
+
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "redirect_test_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+ signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, signupRes)
+
+ t.Run("rejects attacker redirect_uri", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef("https://attacker.com/steal"),
})
+ assert.Error(t, err)
+ assert.Nil(t, res)
+ assert.Contains(t, err.Error(), "invalid redirect URI")
+ })
- t.Run("accepts valid redirect_uri", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef("http://localhost:3000/reset"),
- })
- assert.NoError(t, err)
- assert.NotNil(t, res)
+ t.Run("accepts valid redirect_uri", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef("http://localhost:3000/reset"),
})
+ assert.NoError(t, err)
+ assert.NotNil(t, res)
})
}
@@ -59,71 +57,70 @@ func TestRedirectURIRejectsAttacker(t *testing.T) {
// vulnerability (issue #540). When allowed_origins=["*"] (the default config),
// attacker-controlled redirect_uri values must still be rejected.
func TestRedirectURIWildcardOrigins(t *testing.T) {
- runForEachDB(t, func(t *testing.T, cfg *config.Config) {
- cfg.AllowedOrigins = []string{"*"}
- cfg.EnableBasicAuthentication = true
- cfg.EnableEmailVerification = false
- cfg.EnableSignup = true
- ts := initTestSetup(t, cfg)
- _, ctx := createContext(ts)
-
- email := "wildcard_redirect_" + uuid.New().String() + "@authorizer.dev"
- password := "Password@123"
-
- signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
- Email: &email,
- Password: password,
- ConfirmPassword: password,
- })
- require.NoError(t, err)
- require.NotNil(t, signupRes)
-
- t.Run("rejects attacker redirect_uri with wildcard origins", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef("https://attacker.com/capture"),
- })
- assert.Error(t, err)
- assert.Nil(t, res)
- assert.Contains(t, err.Error(), "invalid redirect URI")
+ cfg := getTestConfig()
+ cfg.AllowedOrigins = []string{"*"}
+ cfg.EnableBasicAuthentication = true
+ cfg.EnableEmailVerification = false
+ cfg.EnableSignup = true
+ ts := initTestSetup(t, cfg)
+ _, ctx := createContext(ts)
+
+ email := "wildcard_redirect_" + uuid.New().String() + "@authorizer.dev"
+ password := "Password@123"
+
+ signupRes, err := ts.GraphQLProvider.SignUp(ctx, &model.SignUpRequest{
+ Email: &email,
+ Password: password,
+ ConfirmPassword: password,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, signupRes)
+
+ t.Run("rejects attacker redirect_uri with wildcard origins", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef("https://attacker.com/capture"),
})
+ assert.Error(t, err)
+ assert.Nil(t, res)
+ assert.Contains(t, err.Error(), "invalid redirect URI")
+ })
- t.Run("allows self-origin redirect_uri with wildcard origins", func(t *testing.T) {
- selfURI := "http://" + ts.HttpServer.Listener.Addr().String() + "/app/reset-password"
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef(selfURI),
- })
- assert.NoError(t, err)
- assert.NotNil(t, res)
- assert.NotEmpty(t, res.Message)
+ t.Run("allows self-origin redirect_uri with wildcard origins", func(t *testing.T) {
+ selfURI := "http://" + ts.HttpServer.Listener.Addr().String() + "/app/reset-password"
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef(selfURI),
})
+ assert.NoError(t, err)
+ assert.NotNil(t, res)
+ assert.NotEmpty(t, res.Message)
+ })
- t.Run("works without redirect_uri (uses default)", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- })
- assert.NoError(t, err)
- assert.NotNil(t, res)
- assert.NotEmpty(t, res.Message)
+ t.Run("works without redirect_uri (uses default)", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
})
+ assert.NoError(t, err)
+ assert.NotNil(t, res)
+ assert.NotEmpty(t, res.Message)
+ })
- t.Run("rejects javascript scheme", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef("javascript:alert(1)"),
- })
- assert.Error(t, err)
- assert.Nil(t, res)
+ t.Run("rejects javascript scheme", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef("javascript:alert(1)"),
})
+ assert.Error(t, err)
+ assert.Nil(t, res)
+ })
- t.Run("rejects data scheme", func(t *testing.T) {
- res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
- Email: refs.NewStringRef(email),
- RedirectURI: refs.NewStringRef("data:text/html,
evil
"),
- })
- assert.Error(t, err)
- assert.Nil(t, res)
+ t.Run("rejects data scheme", func(t *testing.T) {
+ res, err := ts.GraphQLProvider.ForgotPassword(ctx, &model.ForgotPasswordRequest{
+ Email: refs.NewStringRef(email),
+ RedirectURI: refs.NewStringRef("data:text/html,evil
"),
})
+ assert.Error(t, err)
+ assert.Nil(t, res)
})
}
diff --git a/internal/integration_tests/test_helper.go b/internal/integration_tests/test_helper.go
index 8b2cd28c5..fa6123c22 100644
--- a/internal/integration_tests/test_helper.go
+++ b/internal/integration_tests/test_helper.go
@@ -6,7 +6,6 @@ import (
"net/http/httptest"
"os"
"path/filepath"
- "strings"
"testing"
"github.com/gin-gonic/gin"
@@ -58,76 +57,10 @@ func createContext(s *testSetup) (*http.Request, context.Context) {
return req, ctx
}
-// dbTestConfig holds database-specific test configuration
-type dbTestConfig struct {
- DbType string
- DbURL string
-}
-
-// getTestDBs returns the list of database configurations to test against.
-// It reads the TEST_DBS environment variable (comma-separated list of db types).
-// Defaults to "postgres" if not set.
-func getTestDBs() []dbTestConfig {
- testDBsEnv := os.Getenv("TEST_DBS")
- if testDBsEnv == "" {
- testDBsEnv = "postgres"
- }
-
- dbTypes := strings.Split(testDBsEnv, ",")
- var configs []dbTestConfig
-
- for _, dbType := range dbTypes {
- dbType = strings.TrimSpace(dbType)
- if dbType == "" {
- continue
- }
-
- dbURL := getDBURL(dbType)
- if dbURL != "" {
- configs = append(configs, dbTestConfig{
- DbType: dbType,
- DbURL: dbURL,
- })
- }
- }
-
- return configs
-}
-
-// getDBURL returns the connection URL for a given database type
-func getDBURL(dbType string) string {
- switch dbType {
- case constants.DbTypePostgres:
- return "postgres://postgres:postgres@localhost:5434/postgres"
- case constants.DbTypeSqlite:
- return "test.db"
- case constants.DbTypeLibSQL:
- return "test.db"
- case constants.DbTypeMysql:
- return "root:password@tcp(localhost:3306)/authorizer"
- case constants.DbTypeMariaDB:
- return "root:password@tcp(localhost:3307)/authorizer"
- case constants.DbTypeSqlserver:
- return "sqlserver://sa:Password123@localhost:1433?database=authorizer"
- case constants.DbTypeMongoDB:
- return "mongodb://localhost:27017"
- case constants.DbTypeArangoDB:
- return "http://localhost:8529"
- case constants.DbTypeScyllaDB, constants.DbTypeCassandraDB:
- return "127.0.0.1:9042"
- case constants.DbTypeDynamoDB:
- return "http://localhost:8000"
- case constants.DbTypeCouchbaseDB:
- return "couchbase://localhost"
- default:
- return ""
- }
-}
-
-// getTestConfig returns a test config for the default database (postgres).
-// For multi-DB testing, use runForEachDB instead.
+// getTestConfig returns config for integration tests using SQLite.
+// Integration tests validate business logic, not storage compatibility.
func getTestConfig() *config.Config {
- return getTestConfigForDB(constants.DbTypePostgres, "postgres://postgres:postgres@localhost:5434/postgres")
+ return getTestConfigForDB(constants.DbTypeSqlite, "test.db")
}
// getTestConfigForDB returns a test config for a specific database type and URL
@@ -156,8 +89,9 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config {
IsSMSServiceEnabled: true,
}
- // Set MongoDB-specific config
- if dbType == constants.DbTypeMongoDB {
+ // MongoDB, ArangoDB, Cassandra/Scylla require DatabaseName (keyspace / DB name); see storage New().
+ if dbType == constants.DbTypeMongoDB || dbType == constants.DbTypeArangoDB ||
+ dbType == constants.DbTypeScyllaDB || dbType == constants.DbTypeCassandraDB {
cfg.DatabaseName = "authorizer_test"
}
@@ -168,34 +102,12 @@ func getTestConfigForDB(dbType, dbURL string) *config.Config {
cfg.CouchBaseBucket = "authorizer_test"
}
- return cfg
-}
-
-// runForEachDB runs the given test function against each database specified in TEST_DBS.
-// This is the primary way to run tests across multiple database providers.
-//
-// Usage:
-//
-// func TestFeature(t *testing.T) {
-// runForEachDB(t, func(t *testing.T, cfg *config.Config) {
-// ts := initTestSetup(t, cfg)
-// _, ctx := createContext(ts)
-// // ... test logic
-// })
-// }
-func runForEachDB(t *testing.T, testFn func(t *testing.T, cfg *config.Config)) {
- t.Helper()
- dbConfigs := getTestDBs()
- if len(dbConfigs) == 0 {
- t.Fatal("TEST_DBS produced no runnable database configurations; check TEST_DBS and that each database type resolves to a non-empty URL")
+ // DynamoDB Local (and AWS) expect a region for signing; avoid picking up real AWS keys in tests.
+ if dbType == constants.DbTypeDynamoDB {
+ cfg.AWSRegion = "us-east-1"
}
- for _, dbCfg := range dbConfigs {
- t.Run("db="+dbCfg.DbType, func(t *testing.T) {
- cfg := getTestConfigForDB(dbCfg.DbType, dbCfg.DbURL)
- testFn(t, cfg)
- })
- }
+ return cfg
}
// initTestSetup initializes the test setup
@@ -203,6 +115,12 @@ func initTestSetup(t *testing.T, cfg *config.Config) *testSetup {
// Initialize logger
logger := zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger()
+ if cfg.DatabaseType == constants.DbTypeDynamoDB {
+ // Match storage tests: use static creds from config instead of ambient AWS_* env.
+ os.Unsetenv("AWS_ACCESS_KEY_ID")
+ os.Unsetenv("AWS_SECRET_ACCESS_KEY")
+ }
+
if cfg.DatabaseType == constants.DbTypeSqlite || cfg.DatabaseType == constants.DbTypeLibSQL {
cfg.DatabaseURL = filepath.Join(t.TempDir(), "authorizer_integration.db")
}
diff --git a/internal/memory_store/db/provider_test.go b/internal/memory_store/db/provider_test.go
index 347465c70..a5a431405 100644
--- a/internal/memory_store/db/provider_test.go
+++ b/internal/memory_store/db/provider_test.go
@@ -5,122 +5,117 @@ import (
"testing"
"time"
+ "github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "github.com/authorizerdev/authorizer/internal/config"
"github.com/authorizerdev/authorizer/internal/storage"
)
-// TestDBMemoryStoreProvider tests the database-backed memory store provider
-// This test requires a database to be configured
+// TestDBMemoryStoreProvider tests the database-backed memory store against SQLite.
func TestDBMemoryStoreProvider(t *testing.T) {
- dbFile := filepath.Join(t.TempDir(), "authorizer_test.db")
- cfg := &config.Config{
- DatabaseType: "sqlite",
- DatabaseURL: dbFile,
- Env: "test",
+ entries := storageTestDBEntriesFromEnv()
+ if len(entries) == 0 {
+ t.Fatal("no database configurations for memory store DB tests")
}
- // Create storage provider
- storageProvider, err := storage.New(cfg, &storage.Dependencies{})
- if err != nil {
- t.Skipf("Skipping test: failed to create storage provider: %v", err)
- return
- }
+ for _, e := range entries {
+ t.Run("db="+e.dbType, func(t *testing.T) {
+ tempSQLite := filepath.Join(t.TempDir(), "memory_store_test.db")
+ dbURL := resolveSQLiteTestURL(e.dbType, e.dbURL, tempSQLite)
+ cfg := buildStorageTestConfigForMemoryStore(e.dbType, dbURL)
+
+ log := zerolog.New(zerolog.NewTestWriter(t))
+ storageProvider, err := storage.New(cfg, &storage.Dependencies{Log: &log})
+ if err != nil {
+ t.Skipf("skipping: storage provider for %s: %v", e.dbType, err)
+ return
+ }
- // Create DB memory store provider
- p, err := NewDBProvider(cfg, &Dependencies{
- StorageProvider: storageProvider,
- })
- require.NoError(t, err)
- require.NotNil(t, p)
+ p, err := NewDBProvider(cfg, &Dependencies{
+ Log: &log,
+ StorageProvider: storageProvider,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, p)
- // Test SetUserSession and GetUserSession
- err = p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetUserSession("auth_provider:123", "session_token_key", "test_hash123", time.Now().Add(60*time.Second).Unix())
+ assert.NoError(t, err)
- err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetUserSession("auth_provider:123", "access_token_key", "test_jwt123", time.Now().Add(60*time.Second).Unix())
+ assert.NoError(t, err)
- // Get session
- key, err := p.GetUserSession("auth_provider:123", "session_token_key")
- assert.NoError(t, err)
- assert.Equal(t, "test_hash123", key)
+ key, err := p.GetUserSession("auth_provider:123", "session_token_key")
+ assert.NoError(t, err)
+ assert.Equal(t, "test_hash123", key)
- key, err = p.GetUserSession("auth_provider:123", "access_token_key")
- assert.NoError(t, err)
- assert.Equal(t, "test_jwt123", key)
+ key, err = p.GetUserSession("auth_provider:123", "access_token_key")
+ assert.NoError(t, err)
+ assert.Equal(t, "test_jwt123", key)
- // Test expiration
- err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(1*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetUserSession("auth_provider:124", "session_token_key", "test_hash124", time.Now().Add(1*time.Second).Unix())
+ assert.NoError(t, err)
- time.Sleep(2 * time.Second)
+ time.Sleep(2 * time.Second)
- key, err = p.GetUserSession("auth_provider:124", "session_token_key")
- assert.Empty(t, key)
- assert.Error(t, err)
+ key, err = p.GetUserSession("auth_provider:124", "session_token_key")
+ assert.Empty(t, key)
+ assert.Error(t, err)
- // Test DeleteUserSession: for DB-backed store, the key argument is the suffix
- // (e.g. \"key\"), while the stored keys are \"session_token_key\", \"access_token_key\", etc.
- // This matches the in-memory provider behavior and real usage in the codebase.
- err = p.DeleteUserSession("auth_provider:123", "key")
- assert.NoError(t, err)
+ err = p.DeleteUserSession("auth_provider:123", "key")
+ assert.NoError(t, err)
- key, err = p.GetUserSession("auth_provider:123", "session_token_key")
- assert.Empty(t, key)
- assert.Error(t, err)
+ key, err = p.GetUserSession("auth_provider:123", "session_token_key")
+ assert.Empty(t, key)
+ assert.Error(t, err)
- // Test DeleteAllUserSessions
- err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetUserSession("auth_provider:123", "session_token_key1", "test_hash1123", time.Now().Add(60*time.Second).Unix())
+ assert.NoError(t, err)
- err = p.DeleteAllUserSessions("123")
- assert.NoError(t, err)
+ err = p.DeleteAllUserSessions("123")
+ assert.NoError(t, err)
- key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
- assert.Empty(t, key)
- assert.Error(t, err)
+ key, err = p.GetUserSession("auth_provider:123", "session_token_key1")
+ assert.Empty(t, key)
+ assert.Error(t, err)
- // Test DeleteSessionForNamespace
- err = p.SetUserSession("auth_provider:125", "session_token_key", "test_hash125", time.Now().Add(60*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetUserSession("auth_provider:125", "session_token_key", "test_hash125", time.Now().Add(60*time.Second).Unix())
+ assert.NoError(t, err)
- err = p.DeleteSessionForNamespace("auth_provider")
- assert.NoError(t, err)
+ err = p.DeleteSessionForNamespace("auth_provider")
+ assert.NoError(t, err)
- key, err = p.GetUserSession("auth_provider:125", "session_token_key")
- assert.Empty(t, key)
- assert.Error(t, err)
+ key, err = p.GetUserSession("auth_provider:125", "session_token_key")
+ assert.Empty(t, key)
+ assert.Error(t, err)
- // Test MFA sessions
- err = p.SetMfaSession("auth_provider:123", "session123", time.Now().Add(60*time.Second).Unix())
- assert.NoError(t, err)
+ err = p.SetMfaSession("auth_provider:123", "session123", time.Now().Add(60*time.Second).Unix())
+ assert.NoError(t, err)
- key, err = p.GetMfaSession("auth_provider:123", "session123")
- assert.NoError(t, err)
- assert.Equal(t, "auth_provider:123", key)
+ key, err = p.GetMfaSession("auth_provider:123", "session123")
+ assert.NoError(t, err)
+ assert.Equal(t, "auth_provider:123", key)
- err = p.DeleteMfaSession("auth_provider:123", "session123")
- assert.NoError(t, err)
+ err = p.DeleteMfaSession("auth_provider:123", "session123")
+ assert.NoError(t, err)
- key, err = p.GetMfaSession("auth_provider:123", "session123")
- assert.Error(t, err)
- assert.Empty(t, key)
+ key, err = p.GetMfaSession("auth_provider:123", "session123")
+ assert.Error(t, err)
+ assert.Empty(t, key)
- // Test OAuth state
- err = p.SetState("test_state_key", "test_state_value")
- assert.NoError(t, err)
+ err = p.SetState("test_state_key", "test_state_value")
+ assert.NoError(t, err)
- state, err := p.GetState("test_state_key")
- assert.NoError(t, err)
- assert.Equal(t, "test_state_value", state)
+ state, err := p.GetState("test_state_key")
+ assert.NoError(t, err)
+ assert.Equal(t, "test_state_value", state)
- err = p.RemoveState("test_state_key")
- assert.NoError(t, err)
+ err = p.RemoveState("test_state_key")
+ assert.NoError(t, err)
- state, err = p.GetState("test_state_key")
- assert.Error(t, err)
- assert.Empty(t, state)
+ state, err = p.GetState("test_state_key")
+ assert.Error(t, err)
+ assert.Empty(t, state)
+ })
+ }
}
diff --git a/internal/memory_store/db/test_config_test.go b/internal/memory_store/db/test_config_test.go
new file mode 100644
index 000000000..fdd6982a4
--- /dev/null
+++ b/internal/memory_store/db/test_config_test.go
@@ -0,0 +1,55 @@
+package db
+
+import (
+ "github.com/authorizerdev/authorizer/internal/config"
+ "github.com/authorizerdev/authorizer/internal/constants"
+)
+
+// storageDBEntry matches one entry for memory store DB tests.
+// Memory store DB tests only run against SQLite — storage-layer compatibility
+// is covered by internal/storage tests.
+type storageDBEntry struct {
+ dbType string
+ dbURL string
+}
+
+func storageTestDBEntriesFromEnv() []storageDBEntry {
+ return []storageDBEntry{
+ {dbType: constants.DbTypeSqlite, dbURL: "test.db"},
+ }
+}
+
+func resolveSQLiteTestURL(dbType, mappedURL, tempPath string) string {
+ if dbType == constants.DbTypeSqlite || dbType == constants.DbTypeLibSQL {
+ return tempPath
+ }
+ return mappedURL
+}
+
+func buildStorageTestConfigForMemoryStore(dbType, dbURL string) *config.Config {
+ cfg := &config.Config{
+ Env: constants.TestEnv,
+ SkipTestEndpointSSRFValidation: true,
+ DatabaseType: dbType,
+ DatabaseURL: dbURL,
+ JWTSecret: "test-secret",
+ ClientID: "test-client-id",
+ ClientSecret: "test-client-secret",
+ AllowedOrigins: []string{"http://localhost:3000"},
+ JWTType: "HS256",
+ AdminSecret: "test-admin-secret",
+ TwilioAPISecret: "test-twilio-api-secret",
+ TwilioAPIKey: "test-twilio-api-key",
+ TwilioAccountSID: "test-twilio-account-sid",
+ TwilioSender: "test-twilio-sender",
+ DefaultRoles: []string{"user"},
+ EnableSignup: true,
+ EnableBasicAuthentication: true,
+ EnableMobileBasicAuthentication: true,
+ EnableLoginPage: true,
+ EnableStrongPassword: true,
+ IsSMSServiceEnabled: true,
+ }
+
+ return cfg
+}
diff --git a/internal/memory_store/provider_test.go b/internal/memory_store/provider_test.go
index 244bb3752..bc7c00bf0 100644
--- a/internal/memory_store/provider_test.go
+++ b/internal/memory_store/provider_test.go
@@ -1,6 +1,8 @@
package memory_store
import (
+ "os"
+ "strings"
"testing"
"time"
@@ -17,9 +19,13 @@ const (
memoryStoreTypeDB = "db"
)
-var memoryStoreTypes = []string{
- memoryStoreTypeRedis,
- memoryStoreTypeInMemory,
+func memoryStoreTypesForTest() []string {
+ var types []string
+ if redisMemoryStoreTestsEnabled() {
+ types = append(types, memoryStoreTypeRedis)
+ }
+ types = append(types, memoryStoreTypeInMemory)
+ return types
}
func getTestMemoryStorageConfig(storageType string) *config.Config {
@@ -40,9 +46,10 @@ func getTestMemoryStorageConfig(storageType string) *config.Config {
return cfg
}
-// TestMemoryStoreProvider tests the memory store provider
+// TestMemoryStoreProvider tests the in-memory provider always; Redis only when TEST_ENABLE_REDIS=1.
+// TEST_DBS does not apply (these are not storage-backend tests).
func TestMemoryStoreProvider(t *testing.T) {
- for _, storeType := range memoryStoreTypes {
+ for _, storeType := range memoryStoreTypesForTest() {
t.Run("should test memory store provider for "+storeType, func(t *testing.T) {
cfg := getTestMemoryStorageConfig(storeType)
logger := zerolog.Nop()
@@ -50,7 +57,7 @@ func TestMemoryStoreProvider(t *testing.T) {
Log: &logger,
})
if storeType == memoryStoreTypeRedis && err != nil {
- t.Skipf("skipping redis memory store test: %v", err)
+ t.Skipf("skipping redis memory store test (is Redis running on localhost:6380?): %v", err)
}
require.NoError(t, err)
require.NotNil(t, p)
@@ -170,3 +177,8 @@ func TestMemoryStoreProvider(t *testing.T) {
})
}
}
+
+func redisMemoryStoreTestsEnabled() bool {
+ v := strings.TrimSpace(os.Getenv("TEST_ENABLE_REDIS"))
+ return v == "1" || strings.EqualFold(v, "true")
+}
diff --git a/internal/rate_limit/provider.go b/internal/rate_limit/provider.go
index c4fa73a86..a01df503f 100644
--- a/internal/rate_limit/provider.go
+++ b/internal/rate_limit/provider.go
@@ -29,6 +29,10 @@ type Dependencies struct {
// New creates a new rate limit provider based on available infrastructure.
// Uses Redis when RedisStore is provided, falls back to in-memory.
func New(cfg *config.Config, deps *Dependencies) (Provider, error) {
+ deps.Log.Info().
+ Int("rate_limit_rps", cfg.RateLimitRPS).
+ Int("rate_limit_burst", cfg.RateLimitBurst).
+ Msg("Creating rate limit provider")
if deps.RedisStore != nil {
return newRedisProvider(cfg, deps)
}
diff --git a/internal/rate_limit/redis.go b/internal/rate_limit/redis.go
index 8e666f27f..f1d27121e 100644
--- a/internal/rate_limit/redis.go
+++ b/internal/rate_limit/redis.go
@@ -42,9 +42,9 @@ func newRedisProvider(cfg *config.Config, deps *Dependencies) (*redisProvider, e
// Window = burst / rps, minimum 1 second
window := 1
if cfg.RateLimitRPS > 0 {
- w := float64(cfg.RateLimitBurst) / cfg.RateLimitRPS
+ w := int(cfg.RateLimitBurst / cfg.RateLimitRPS)
if w > 1 {
- window = int(w)
+ window = w
}
}
return &redisProvider{
diff --git a/internal/storage/db/couchbase/health_check.go b/internal/storage/db/couchbase/health_check.go
index 92ae64f83..7a5666ac2 100644
--- a/internal/storage/db/couchbase/health_check.go
+++ b/internal/storage/db/couchbase/health_check.go
@@ -5,11 +5,13 @@ import (
"fmt"
"github.com/couchbase/gocb/v2"
+
+ "github.com/authorizerdev/authorizer/internal/storage/schemas"
)
// HealthCheck verifies that the Couchbase backend is reachable and responsive
func (p *provider) HealthCheck(ctx context.Context) error {
- query := fmt.Sprintf("SELECT 1 FROM %s LIMIT 1", p.scopeName)
+ query := fmt.Sprintf("SELECT 1 FROM %s.%s LIMIT 1", p.scopeName, schemas.Collections.User)
_, err := p.db.Query(query, &gocb.QueryOptions{
Context: ctx,
})
diff --git a/internal/storage/db/dynamodb/audit_log.go b/internal/storage/db/dynamodb/audit_log.go
index 472d7108a..5036aa7c0 100644
--- a/internal/storage/db/dynamodb/audit_log.go
+++ b/internal/storage/db/dynamodb/audit_log.go
@@ -2,10 +2,12 @@ package dynamodb
import (
"context"
+ "sort"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -13,7 +15,6 @@ import (
// AddAuditLog adds an audit log entry
func (p *provider) AddAuditLog(ctx context.Context, auditLog *schemas.AuditLog) error {
- collection := p.db.Table(schemas.Collections.AuditLog)
if auditLog.ID == "" {
auditLog.ID = uuid.New().String()
}
@@ -21,85 +22,138 @@ func (p *provider) AddAuditLog(ctx context.Context, auditLog *schemas.AuditLog)
if auditLog.CreatedAt == 0 {
auditLog.CreatedAt = time.Now().Unix()
}
- err := collection.Put(auditLog).RunWithContext(ctx)
- if err != nil {
- return err
- }
- return nil
+ return p.putItem(ctx, schemas.Collections.AuditLog, auditLog)
}
-// ListAuditLogs queries audit logs with filters and pagination
-func (p *provider) ListAuditLogs(ctx context.Context, pagination *model.Pagination, filter map[string]interface{}) ([]*schemas.AuditLog, *model.Pagination, error) {
- auditLogs := []*schemas.AuditLog{}
- var auditLog *schemas.AuditLog
- var lastEval dynamo.PagingKey
- var iter dynamo.PagingIter
- var iteration int64 = 0
- var err error
-
- collection := p.db.Table(schemas.Collections.AuditLog)
- paginationClone := *pagination
- scanner := collection.Scan()
+func int64FromFilter(v interface{}) (int64, bool) {
+ switch x := v.(type) {
+ case int64:
+ return x, true
+ case int:
+ return int64(x), true
+ case int32:
+ return int64(x), true
+ case float64:
+ return int64(x), true
+ default:
+ return 0, false
+ }
+}
- // Apply filters
- if action, ok := filter["action"]; ok && action != "" {
- scanner = scanner.Filter("'action' = ?", action)
+// auditLogExtraFilter builds a filter for attributes not covered by the primary key condition
+// (omitKey is "action" or "actor_id" when that attribute is the partition key for the Query).
+func auditLogExtraFilter(filter map[string]interface{}, omitKey string) *expression.ConditionBuilder {
+ var conds []expression.ConditionBuilder
+ if omitKey != "actor_id" {
+ if actorID, ok := filter["actor_id"]; ok && actorID != "" {
+ conds = append(conds, expression.Name("actor_id").Equal(expression.Value(actorID)))
+ }
}
- if actorID, ok := filter["actor_id"]; ok && actorID != "" {
- scanner = scanner.Filter("'actor_id' = ?", actorID)
+ if omitKey != "action" {
+ if action, ok := filter["action"]; ok && action != "" {
+ conds = append(conds, expression.Name("action").Equal(expression.Value(action)))
+ }
}
if resourceType, ok := filter["resource_type"]; ok && resourceType != "" {
- scanner = scanner.Filter("'resource_type' = ?", resourceType)
+ conds = append(conds, expression.Name("resource_type").Equal(expression.Value(resourceType)))
}
-
- for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &auditLog) {
- if paginationClone.Offset == iteration {
- auditLogs = append(auditLogs, auditLog)
- }
+ if resourceID, ok := filter["resource_id"]; ok && resourceID != "" {
+ conds = append(conds, expression.Name("resource_id").Equal(expression.Value(resourceID)))
+ }
+ if v, ok := filter["from_timestamp"]; ok {
+ if ts, ok := int64FromFilter(v); ok {
+ conds = append(conds, expression.Name("created_at").GreaterThanEqual(expression.Value(ts)))
}
- err = iter.Err()
- if err != nil {
- return nil, nil, err
+ }
+ if v, ok := filter["to_timestamp"]; ok {
+ if ts, ok := int64FromFilter(v); ok {
+ conds = append(conds, expression.Name("created_at").LessThanEqual(expression.Value(ts)))
}
- lastEval = iter.LastEvaluatedKey()
- iteration += paginationClone.Limit
}
+ if len(conds) == 0 {
+ return nil
+ }
+ merged := conds[0]
+ for i := 1; i < len(conds); i++ {
+ merged = merged.And(conds[i])
+ }
+ return &merged
+}
+
+// ListAuditLogs queries audit logs with filters and pagination
+func (p *provider) ListAuditLogs(ctx context.Context, pagination *model.Pagination, filter map[string]interface{}) ([]*schemas.AuditLog, *model.Pagination, error) {
+ paginationClone := *pagination
- // Count total matching documents
- var total int64
- countScanner := collection.Scan()
- if action, ok := filter["action"]; ok && action != "" {
- countScanner = countScanner.Filter("'action' = ?", action)
+ var actionVal, actorVal string
+ if a, ok := filter["action"]; ok && a != "" {
+ if s, ok := a.(string); ok {
+ actionVal = s
+ }
}
- if actorID, ok := filter["actor_id"]; ok && actorID != "" {
- countScanner = countScanner.Filter("'actor_id' = ?", actorID)
+ if a, ok := filter["actor_id"]; ok && a != "" {
+ if s, ok := a.(string); ok {
+ actorVal = s
+ }
}
- if resourceType, ok := filter["resource_type"]; ok && resourceType != "" {
- countScanner = countScanner.Filter("'resource_type' = ?", resourceType)
+
+ var items []map[string]types.AttributeValue
+ var err error
+ table := schemas.Collections.AuditLog
+
+ switch {
+ case actionVal != "":
+ extra := auditLogExtraFilter(filter, "action")
+ items, err = p.queryEq(ctx, table, "action", "action", actionVal, extra)
+ case actorVal != "":
+ extra := auditLogExtraFilter(filter, "actor_id")
+ items, err = p.queryEq(ctx, table, "actor_id", "actor_id", actorVal, extra)
+ default:
+ extra := auditLogExtraFilter(filter, "")
+ items, err = p.scanAllRaw(ctx, table, nil, extra)
}
- var countItems []*schemas.AuditLog
- if err = countScanner.AllWithContext(ctx, &countItems); err != nil {
+ if err != nil {
return nil, nil, err
}
- total = int64(len(countItems))
+
+ var logs []*schemas.AuditLog
+ for _, it := range items {
+ var a schemas.AuditLog
+ if err := unmarshalItem(it, &a); err != nil {
+ return nil, nil, err
+ }
+ logs = append(logs, &a)
+ }
+
+ sort.Slice(logs, func(i, j int) bool { return logs[i].CreatedAt > logs[j].CreatedAt })
+
+ total := int64(len(logs))
paginationClone.Total = total
- return auditLogs, &paginationClone, nil
+ start := int(pagination.Offset)
+ if start >= len(logs) {
+ return []*schemas.AuditLog{}, &paginationClone, nil
+ }
+ end := start + int(pagination.Limit)
+ if end > len(logs) {
+ end = len(logs)
+ }
+
+ return logs[start:end], &paginationClone, nil
}
// DeleteAuditLogsBefore removes logs older than a timestamp
func (p *provider) DeleteAuditLogsBefore(ctx context.Context, before int64) error {
- collection := p.db.Table(schemas.Collections.AuditLog)
- var auditLogs []*schemas.AuditLog
- err := collection.Scan().Filter("'created_at' < ?", before).AllWithContext(ctx, &auditLogs)
+ f := expression.Name("created_at").LessThan(expression.Value(before))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.AuditLog, nil, &f)
if err != nil {
return err
}
- for _, auditLog := range auditLogs {
- err := collection.Delete("id", auditLog.ID).RunWithContext(ctx)
- if err != nil {
+ for _, it := range items {
+ var a schemas.AuditLog
+ if err := unmarshalItem(it, &a); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.AuditLog, "id", a.ID); err != nil {
return err
}
}
diff --git a/internal/storage/db/dynamodb/authenticator.go b/internal/storage/db/dynamodb/authenticator.go
index 88e349ade..673a6510c 100644
--- a/internal/storage/db/dynamodb/authenticator.go
+++ b/internal/storage/db/dynamodb/authenticator.go
@@ -4,6 +4,7 @@ import (
"context"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
"github.com/google/uuid"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -14,44 +15,39 @@ func (p *provider) AddAuthenticator(ctx context.Context, authenticators *schemas
if exists != nil {
return authenticators, nil
}
-
- collection := p.db.Table(schemas.Collections.Authenticators)
if authenticators.ID == "" {
authenticators.ID = uuid.New().String()
}
-
authenticators.CreatedAt = time.Now().Unix()
authenticators.UpdatedAt = time.Now().Unix()
- err := collection.Put(authenticators).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.Authenticators, authenticators); err != nil {
return nil, err
}
return authenticators, nil
}
func (p *provider) UpdateAuthenticator(ctx context.Context, authenticators *schemas.Authenticator) (*schemas.Authenticator, error) {
- collection := p.db.Table(schemas.Collections.Authenticators)
if authenticators.ID != "" {
authenticators.UpdatedAt = time.Now().Unix()
- err := UpdateByHashKey(collection, "id", authenticators.ID, authenticators)
- if err != nil {
+ if err := p.updateByHashKey(ctx, schemas.Collections.Authenticators, "id", authenticators.ID, authenticators); err != nil {
return nil, err
}
}
return authenticators, nil
-
}
func (p *provider) GetAuthenticatorDetailsByUserId(ctx context.Context, userId string, authenticatorType string) (*schemas.Authenticator, error) {
- var authenticators *schemas.Authenticator
- collection := p.db.Table(schemas.Collections.Authenticators)
- iter := collection.Scan().Filter("'user_id' = ?", userId).Filter("'method' = ?", authenticatorType).Iter()
- for iter.NextWithContext(ctx, &authenticators) {
- return authenticators, nil
- }
- err := iter.Err()
+ f := expression.Name("user_id").Equal(expression.Value(userId)).And(expression.Name("method").Equal(expression.Value(authenticatorType)))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.Authenticators, nil, &f)
if err != nil {
return nil, err
}
- return authenticators, nil
+ if len(items) == 0 {
+ return nil, nil
+ }
+ var a schemas.Authenticator
+ if err := unmarshalItem(items[0], &a); err != nil {
+ return nil, err
+ }
+ return &a, nil
}
diff --git a/internal/storage/db/dynamodb/email_template.go b/internal/storage/db/dynamodb/email_template.go
index 4dfe25e1c..be9363789 100644
--- a/internal/storage/db/dynamodb/email_template.go
+++ b/internal/storage/db/dynamodb/email_template.go
@@ -5,8 +5,8 @@ import (
"errors"
"time"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -14,29 +14,22 @@ import (
// AddEmailTemplate to add EmailTemplate
func (p *provider) AddEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) (*schemas.EmailTemplate, error) {
- collection := p.db.Table(schemas.Collections.EmailTemplate)
if emailTemplate.ID == "" {
emailTemplate.ID = uuid.New().String()
}
-
emailTemplate.Key = emailTemplate.ID
emailTemplate.CreatedAt = time.Now().Unix()
emailTemplate.UpdatedAt = time.Now().Unix()
- err := collection.Put(emailTemplate).RunWithContext(ctx)
-
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.EmailTemplate, emailTemplate); err != nil {
return nil, err
}
-
return emailTemplate, nil
}
// UpdateEmailTemplate to update EmailTemplate
func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) (*schemas.EmailTemplate, error) {
- collection := p.db.Table(schemas.Collections.EmailTemplate)
emailTemplate.UpdatedAt = time.Now().Unix()
- err := UpdateByHashKey(collection, "id", emailTemplate.ID, emailTemplate)
- if err != nil {
+ if err := p.updateByHashKey(ctx, schemas.Collections.EmailTemplate, "id", emailTemplate.ID, emailTemplate); err != nil {
return nil, err
}
return emailTemplate, nil
@@ -44,27 +37,35 @@ func (p *provider) UpdateEmailTemplate(ctx context.Context, emailTemplate *schem
// ListEmailTemplates to list EmailTemplate
func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagination) ([]*schemas.EmailTemplate, *model.Pagination, error) {
- var emailTemplate *schemas.EmailTemplate
- var iter dynamo.PagingIter
- var lastEval dynamo.PagingKey
- var iteration int64 = 0
- collection := p.db.Table(schemas.Collections.EmailTemplate)
- emailTemplates := []*schemas.EmailTemplate{}
+ var lastKey map[string]types.AttributeValue
+ var iteration int64
paginationClone := pagination
- scanner := collection.Scan()
- count, err := scanner.Count()
+ var emailTemplates []*schemas.EmailTemplate
+
+ count, err := p.scanCount(ctx, schemas.Collections.EmailTemplate, nil)
if err != nil {
return nil, nil, err
}
+
for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &emailTemplate) {
+ items, next, err := p.scanPageIter(ctx, schemas.Collections.EmailTemplate, nil, int32(paginationClone.Limit), lastKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ for _, it := range items {
+ var e schemas.EmailTemplate
+ if err := unmarshalItem(it, &e); err != nil {
+ return nil, nil, err
+ }
if paginationClone.Offset == iteration {
- emailTemplates = append(emailTemplates, emailTemplate)
+ emailTemplates = append(emailTemplates, &e)
}
}
- lastEval = iter.LastEvaluatedKey()
+ lastKey = next
iteration += paginationClone.Limit
+ if lastKey == nil {
+ break
+ }
}
paginationClone.Total = count
return emailTemplates, paginationClone, nil
@@ -72,39 +73,34 @@ func (p *provider) ListEmailTemplate(ctx context.Context, pagination *model.Pagi
// GetEmailTemplateByID to get EmailTemplate by id
func (p *provider) GetEmailTemplateByID(ctx context.Context, emailTemplateID string) (*schemas.EmailTemplate, error) {
- collection := p.db.Table(schemas.Collections.EmailTemplate)
- var emailTemplate *schemas.EmailTemplate
- err := collection.Get("id", emailTemplateID).OneWithContext(ctx, &emailTemplate)
- if err != nil {
+ var e schemas.EmailTemplate
+ if err := p.getItemByHash(ctx, schemas.Collections.EmailTemplate, "id", emailTemplateID, &e); err != nil {
return nil, err
}
- return emailTemplate, nil
+ return &e, nil
}
// GetEmailTemplateByEventName to get EmailTemplate by event_name
func (p *provider) GetEmailTemplateByEventName(ctx context.Context, eventName string) (*schemas.EmailTemplate, error) {
- collection := p.db.Table(schemas.Collections.EmailTemplate)
- var emailTemplates []*schemas.EmailTemplate
- var emailTemplate *schemas.EmailTemplate
- err := collection.Scan().Index("event_name").Filter("'event_name' = ?", eventName).Limit(1).AllWithContext(ctx, &emailTemplates)
+ // Query the event_name GSI — Scan+Limit applies before filters and can return zero matching items.
+ items, err := p.queryEq(ctx, schemas.Collections.EmailTemplate, "event_name", "event_name", eventName, nil)
if err != nil {
return nil, err
}
- if len(emailTemplates) == 0 {
+ if len(items) == 0 {
return nil, errors.New("no record found")
-
}
- emailTemplate = emailTemplates[0]
- return emailTemplate, nil
+ var e schemas.EmailTemplate
+ if err := unmarshalItem(items[0], &e); err != nil {
+ return nil, err
+ }
+ return &e, nil
}
// DeleteEmailTemplate to delete EmailTemplate
func (p *provider) DeleteEmailTemplate(ctx context.Context, emailTemplate *schemas.EmailTemplate) error {
- collection := p.db.Table(schemas.Collections.EmailTemplate)
- err := collection.Delete("id", emailTemplate.ID).RunWithContext(ctx)
- if err != nil {
- return err
+ if emailTemplate == nil {
+ return nil
}
-
- return nil
+ return p.deleteItemByHash(ctx, schemas.Collections.EmailTemplate, "id", emailTemplate.ID)
}
diff --git a/internal/storage/db/dynamodb/env.go b/internal/storage/db/dynamodb/env.go
index 7d0318c1b..320dcc82a 100644
--- a/internal/storage/db/dynamodb/env.go
+++ b/internal/storage/db/dynamodb/env.go
@@ -13,15 +13,13 @@ import (
// AddEnv to save environment information in database
func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, error) {
- collection := p.db.Table(schemas.Collections.Env)
if env.ID == "" {
env.ID = uuid.New().String()
}
env.Key = env.ID
env.CreatedAt = time.Now().Unix()
env.UpdatedAt = time.Now().Unix()
- err := collection.Put(env).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.Env, env); err != nil {
return nil, err
}
return env, nil
@@ -29,10 +27,8 @@ func (p *provider) AddEnv(ctx context.Context, env *schemas.Env) (*schemas.Env,
// UpdateEnv to update environment information in database
func (p *provider) UpdateEnv(ctx context.Context, env *schemas.Env) (*schemas.Env, error) {
- collection := p.db.Table(schemas.Collections.Env)
env.UpdatedAt = time.Now().Unix()
- err := UpdateByHashKey(collection, "id", env.ID, env)
- if err != nil {
+ if err := p.updateByHashKey(ctx, schemas.Collections.Env, "id", env.ID, env); err != nil {
return nil, err
}
return env, nil
@@ -40,20 +36,16 @@ func (p *provider) UpdateEnv(ctx context.Context, env *schemas.Env) (*schemas.En
// GetEnv to get environment information from database
func (p *provider) GetEnv(ctx context.Context) (*schemas.Env, error) {
- var env *schemas.Env
- collection := p.db.Table(schemas.Collections.Env)
- // As there is no Find one supported.
- iter := collection.Scan().Limit(1).Iter()
- for iter.NextWithContext(ctx, &env) {
- if env == nil {
- return nil, errors.New("no documets found")
- } else {
- return env, nil
- }
- }
- err := iter.Err()
+ items, err := p.scanFilteredLimit(ctx, schemas.Collections.Env, nil, nil, 1)
if err != nil {
- return env, fmt.Errorf("config not found")
+ return nil, err
}
- return env, nil
+ if len(items) == 0 {
+ return nil, errors.New("no documets found")
+ }
+ var env schemas.Env
+ if err := unmarshalItem(items[0], &env); err != nil {
+ return nil, fmt.Errorf("config not found")
+ }
+ return &env, nil
}
diff --git a/internal/storage/db/dynamodb/health_check.go b/internal/storage/db/dynamodb/health_check.go
index 17fbb1a67..56173e5b6 100644
--- a/internal/storage/db/dynamodb/health_check.go
+++ b/internal/storage/db/dynamodb/health_check.go
@@ -9,5 +9,17 @@ import (
// HealthCheck verifies that the DynamoDB backend is reachable and responsive
func (p *provider) HealthCheck(ctx context.Context) error {
var envs []schemas.Env
- return p.db.Table(schemas.Collections.Env).Scan().Limit(1).AllWithContext(ctx, &envs)
+ items, err := p.scanFilteredLimit(ctx, schemas.Collections.Env, nil, nil, 1)
+ if err != nil {
+ return err
+ }
+ for _, it := range items {
+ var e schemas.Env
+ if err := unmarshalItem(it, &e); err != nil {
+ return err
+ }
+ envs = append(envs, e)
+ }
+ _ = envs
+ return nil
}
diff --git a/internal/storage/db/dynamodb/marshal.go b/internal/storage/db/dynamodb/marshal.go
new file mode 100644
index 000000000..944525a59
--- /dev/null
+++ b/internal/storage/db/dynamodb/marshal.go
@@ -0,0 +1,141 @@
+package dynamodb
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
+)
+
+func dynamoFieldNameFromTag(tag string) string {
+ if tag == "" || tag == "-" {
+ return ""
+ }
+ parts := strings.Split(tag, ",")
+ name := strings.TrimSpace(parts[0])
+ if name == "" {
+ return ""
+ }
+ return name
+}
+
+func dynamoOmitEmpty(tag string) bool {
+ return strings.Contains(tag, "omitempty")
+}
+
+func marshalStruct(v interface{}) (map[string]types.AttributeValue, error) {
+ rv := reflect.ValueOf(v)
+ if rv.Kind() == reflect.Ptr {
+ if rv.IsNil() {
+ return nil, fmt.Errorf("nil value")
+ }
+ rv = rv.Elem()
+ }
+ if rv.Kind() != reflect.Struct {
+ return nil, fmt.Errorf("expected struct or pointer to struct")
+ }
+ rt := rv.Type()
+ out := make(map[string]types.AttributeValue)
+ for i := 0; i < rv.NumField(); i++ {
+ sf := rt.Field(i)
+ if !sf.IsExported() {
+ continue
+ }
+ tag := sf.Tag.Get("dynamo")
+ name := dynamoFieldNameFromTag(tag)
+ if name == "" {
+ continue
+ }
+ fv := rv.Field(i)
+ if dynamoOmitEmpty(tag) && isEmptyValue(fv) {
+ continue
+ }
+ if fv.Kind() == reflect.Ptr && fv.IsNil() {
+ continue
+ }
+ av, err := attributevalue.Marshal(fv.Interface())
+ if err != nil {
+ return nil, err
+ }
+ out[name] = av
+ }
+ return out, nil
+}
+
+func isEmptyValue(v reflect.Value) bool {
+ if !v.IsValid() {
+ return true
+ }
+ switch v.Kind() {
+ case reflect.String:
+ return v.Len() == 0
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return v.Int() == 0
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return v.Uint() == 0
+ case reflect.Bool:
+ return !v.Bool()
+ case reflect.Ptr, reflect.Interface:
+ return v.IsNil()
+ default:
+ return false
+ }
+}
+
+func marshalMapStringInterface(m map[string]interface{}) (map[string]types.AttributeValue, error) {
+ out := make(map[string]types.AttributeValue, len(m))
+ for k, v := range m {
+ av, err := attributevalue.Marshal(v)
+ if err != nil {
+ return nil, err
+ }
+ out[k] = av
+ }
+ return out, nil
+}
+
+func unmarshalItem(av map[string]types.AttributeValue, out interface{}) error {
+ rv := reflect.ValueOf(out)
+ if rv.Kind() != reflect.Ptr || rv.IsNil() {
+ return fmt.Errorf("out must be non-nil pointer")
+ }
+ ev := rv.Elem()
+ if ev.Kind() != reflect.Struct {
+ return fmt.Errorf("out must be pointer to struct")
+ }
+ et := ev.Type()
+ for i := 0; i < ev.NumField(); i++ {
+ sf := et.Field(i)
+ if !sf.IsExported() {
+ continue
+ }
+ tag := sf.Tag.Get("dynamo")
+ name := dynamoFieldNameFromTag(tag)
+ if name == "" {
+ continue
+ }
+ dv, ok := av[name]
+ if !ok {
+ continue
+ }
+ field := ev.Field(i)
+ if !field.CanSet() {
+ continue
+ }
+ if field.Kind() == reflect.Ptr {
+ if field.IsNil() {
+ field.Set(reflect.New(field.Type().Elem()))
+ }
+ if err := attributevalue.Unmarshal(dv, field.Interface()); err != nil {
+ return err
+ }
+ continue
+ }
+ if err := attributevalue.Unmarshal(dv, field.Addr().Interface()); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/internal/storage/db/dynamodb/ops.go b/internal/storage/db/dynamodb/ops.go
new file mode 100644
index 000000000..53f30c1ed
--- /dev/null
+++ b/internal/storage/db/dynamodb/ops.go
@@ -0,0 +1,316 @@
+package dynamodb
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
+)
+
+func (p *provider) putItem(ctx context.Context, table string, v interface{}) error {
+ item, err := marshalStruct(v)
+ if err != nil {
+ return err
+ }
+ _, err = p.client.PutItem(ctx, &dynamodb.PutItemInput{
+ TableName: aws.String(table),
+ Item: item,
+ })
+ return err
+}
+
+func (p *provider) getItemByHash(ctx context.Context, table, hashKey, hashValue string, out interface{}) error {
+ res, err := p.client.GetItem(ctx, &dynamodb.GetItemInput{
+ TableName: aws.String(table),
+ Key: map[string]types.AttributeValue{
+ hashKey: &types.AttributeValueMemberS{Value: hashValue},
+ },
+ })
+ if err != nil {
+ return err
+ }
+ if len(res.Item) == 0 {
+ return fmt.Errorf("no record found")
+ }
+ return unmarshalItem(res.Item, out)
+}
+
+func (p *provider) deleteItemByHash(ctx context.Context, table, hashKey, hashValue string) error {
+ _, err := p.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{
+ TableName: aws.String(table),
+ Key: map[string]types.AttributeValue{
+ hashKey: &types.AttributeValueMemberS{Value: hashValue},
+ },
+ })
+ return err
+}
+
+func (p *provider) scanAllRaw(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) {
+ var built expression.Expression
+ var hasFilter bool
+ if filter != nil {
+ e, err := expression.NewBuilder().WithFilter(*filter).Build()
+ if err != nil {
+ return nil, err
+ }
+ built = e
+ hasFilter = true
+ }
+ var out []map[string]types.AttributeValue
+ var start map[string]types.AttributeValue
+ for {
+ in := &dynamodb.ScanInput{
+ TableName: aws.String(table),
+ ExclusiveStartKey: start,
+ }
+ if index != nil {
+ in.IndexName = index
+ }
+ if hasFilter {
+ in.FilterExpression = built.Filter()
+ in.ExpressionAttributeNames = built.Names()
+ in.ExpressionAttributeValues = built.Values()
+ }
+ res, err := p.client.Scan(ctx, in)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, res.Items...)
+ if res.LastEvaluatedKey == nil {
+ break
+ }
+ start = res.LastEvaluatedKey
+ }
+ return out, nil
+}
+
+func (p *provider) scanCount(ctx context.Context, table string, filter *expression.ConditionBuilder) (int64, error) {
+ var built expression.Expression
+ var hasFilter bool
+ if filter != nil {
+ e, err := expression.NewBuilder().WithFilter(*filter).Build()
+ if err != nil {
+ return 0, err
+ }
+ built = e
+ hasFilter = true
+ }
+ var total int64
+ var start map[string]types.AttributeValue
+ for {
+ in := &dynamodb.ScanInput{
+ TableName: aws.String(table),
+ Select: types.SelectCount,
+ ExclusiveStartKey: start,
+ }
+ if hasFilter {
+ in.FilterExpression = built.Filter()
+ in.ExpressionAttributeNames = built.Names()
+ in.ExpressionAttributeValues = built.Values()
+ }
+ res, err := p.client.Scan(ctx, in)
+ if err != nil {
+ return 0, err
+ }
+ total += int64(res.Count)
+ if res.LastEvaluatedKey == nil {
+ break
+ }
+ start = res.LastEvaluatedKey
+ }
+ return total, nil
+}
+
+func (p *provider) queryEq(ctx context.Context, table, indexName, pkAttr, pkVal string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) {
+ kc := expression.Key(pkAttr).Equal(expression.Value(pkVal))
+ var eb expression.Builder
+ if filter != nil {
+ eb = expression.NewBuilder().WithKeyCondition(kc).WithFilter(*filter)
+ } else {
+ eb = expression.NewBuilder().WithKeyCondition(kc)
+ }
+ expr, err := eb.Build()
+ if err != nil {
+ return nil, err
+ }
+ var out []map[string]types.AttributeValue
+ var start map[string]types.AttributeValue
+ for {
+ in := &dynamodb.QueryInput{
+ TableName: aws.String(table),
+ IndexName: aws.String(indexName),
+ KeyConditionExpression: expr.KeyCondition(),
+ ExpressionAttributeNames: expr.Names(),
+ ExpressionAttributeValues: expr.Values(),
+ ExclusiveStartKey: start,
+ }
+ if filter != nil {
+ in.FilterExpression = expr.Filter()
+ }
+ res, err := p.client.Query(ctx, in)
+ if err != nil {
+ return nil, err
+ }
+ out = append(out, res.Items...)
+ if res.LastEvaluatedKey == nil {
+ break
+ }
+ start = res.LastEvaluatedKey
+ }
+ return out, nil
+}
+
+func (p *provider) queryEqLimit(ctx context.Context, table, indexName, pkAttr, pkVal string, filter *expression.ConditionBuilder, limit int32) ([]map[string]types.AttributeValue, error) {
+ kc := expression.Key(pkAttr).Equal(expression.Value(pkVal))
+ var eb expression.Builder
+ if filter != nil {
+ eb = expression.NewBuilder().WithKeyCondition(kc).WithFilter(*filter)
+ } else {
+ eb = expression.NewBuilder().WithKeyCondition(kc)
+ }
+ expr, err := eb.Build()
+ if err != nil {
+ return nil, err
+ }
+ in := &dynamodb.QueryInput{
+ TableName: aws.String(table),
+ IndexName: aws.String(indexName),
+ KeyConditionExpression: expr.KeyCondition(),
+ ExpressionAttributeNames: expr.Names(),
+ ExpressionAttributeValues: expr.Values(),
+ Limit: aws.Int32(limit),
+ }
+ if filter != nil {
+ in.FilterExpression = expr.Filter()
+ }
+ res, err := p.client.Query(ctx, in)
+ if err != nil {
+ return nil, err
+ }
+ return res.Items, nil
+}
+
+func (p *provider) scanFilteredLimit(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder, limit int32) ([]map[string]types.AttributeValue, error) {
+ in := &dynamodb.ScanInput{
+ TableName: aws.String(table),
+ Limit: aws.Int32(limit),
+ }
+ if index != nil {
+ in.IndexName = index
+ }
+ if filter != nil {
+ expr, err := expression.NewBuilder().WithFilter(*filter).Build()
+ if err != nil {
+ return nil, err
+ }
+ in.FilterExpression = expr.Filter()
+ in.ExpressionAttributeNames = expr.Names()
+ in.ExpressionAttributeValues = expr.Values()
+ }
+ res, err := p.client.Scan(ctx, in)
+ if err != nil {
+ return nil, err
+ }
+ return res.Items, nil
+}
+
+func (p *provider) scanFilteredAll(ctx context.Context, table string, index *string, filter *expression.ConditionBuilder) ([]map[string]types.AttributeValue, error) {
+ return p.scanAllRaw(ctx, table, index, filter)
+}
+
+func (p *provider) updateByHashKey(ctx context.Context, tableName, hashKeyName, hashValue string, item interface{}) error {
+ return p.updateByHashKeyWithRemoves(ctx, tableName, hashKeyName, hashValue, item, nil)
+}
+
+// updateByHashKeyWithRemoves runs UpdateItem with SET from marshalled fields and optional REMOVE
+// of attribute names (e.g. when mapping SQL NULL for optional pointer fields — nil is omitted from
+// SET but the old DynamoDB attribute must be explicitly removed).
+func (p *provider) updateByHashKeyWithRemoves(ctx context.Context, tableName, hashKeyName, hashValue string, item interface{}, removeAttrs []string) error {
+ var attrs map[string]types.AttributeValue
+ var err error
+ switch m := item.(type) {
+ case map[string]interface{}:
+ attrs, err = marshalMapStringInterface(m)
+ default:
+ attrs, err = marshalStruct(item)
+ }
+ if err != nil {
+ return err
+ }
+ delete(attrs, hashKeyName)
+
+ names := map[string]string{}
+ vals := map[string]types.AttributeValue{}
+ var sets []string
+ i := 0
+ for k, v := range attrs {
+ nk := "#n" + fmt.Sprint(i)
+ vk := ":v" + fmt.Sprint(i)
+ names[nk] = k
+ vals[vk] = v
+ sets = append(sets, nk+" = "+vk)
+ i++
+ }
+
+ var removeParts []string
+ for j, attr := range removeAttrs {
+ rk := "#r" + fmt.Sprint(j)
+ names[rk] = attr
+ removeParts = append(removeParts, rk)
+ }
+
+ if len(sets) == 0 && len(removeParts) == 0 {
+ return nil
+ }
+
+ var exprParts []string
+ if len(sets) > 0 {
+ exprParts = append(exprParts, "SET "+strings.Join(sets, ", "))
+ }
+ if len(removeParts) > 0 {
+ exprParts = append(exprParts, "REMOVE "+strings.Join(removeParts, ", "))
+ }
+ updateExpr := strings.Join(exprParts, " ")
+
+ in := &dynamodb.UpdateItemInput{
+ TableName: aws.String(tableName),
+ Key: map[string]types.AttributeValue{
+ hashKeyName: &types.AttributeValueMemberS{Value: hashValue},
+ },
+ UpdateExpression: aws.String(updateExpr),
+ ExpressionAttributeNames: names,
+ }
+ if len(vals) > 0 {
+ in.ExpressionAttributeValues = vals
+ }
+ _, err = p.client.UpdateItem(ctx, in)
+ return err
+}
+
+func (p *provider) scanPageIter(ctx context.Context, table string, filter *expression.ConditionBuilder, pageLimit int32, startKey map[string]types.AttributeValue) ([]map[string]types.AttributeValue, map[string]types.AttributeValue, error) {
+ in := &dynamodb.ScanInput{
+ TableName: aws.String(table),
+ Limit: aws.Int32(pageLimit),
+ ExclusiveStartKey: startKey,
+ }
+ if filter != nil {
+ expr, err := expression.NewBuilder().WithFilter(*filter).Build()
+ if err != nil {
+ return nil, nil, err
+ }
+ in.FilterExpression = expr.Filter()
+ in.ExpressionAttributeNames = expr.Names()
+ in.ExpressionAttributeValues = expr.Values()
+ }
+ res, err := p.client.Scan(ctx, in)
+ if err != nil {
+ return nil, nil, err
+ }
+ return res.Items, res.LastEvaluatedKey, nil
+}
+
+func strPtr(s string) *string { return &s }
diff --git a/internal/storage/db/dynamodb/otp.go b/internal/storage/db/dynamodb/otp.go
index 5e82eaf02..e305feadb 100644
--- a/internal/storage/db/dynamodb/otp.go
+++ b/internal/storage/db/dynamodb/otp.go
@@ -5,6 +5,7 @@ import (
"errors"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
"github.com/google/uuid"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -12,7 +13,6 @@ import (
// UpsertOTP to add or update otp
func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schemas.OTP, error) {
- // check if email or phone number is present
if otpParam.Email == "" && otpParam.PhoneNumber == "" {
return nil, errors.New("email or phone_number is required")
}
@@ -43,13 +43,12 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem
otp.Otp = otpParam.Otp
otp.ExpiresAt = otpParam.ExpiresAt
}
- collection := p.db.Table(schemas.Collections.OTP)
otp.UpdatedAt = time.Now().Unix()
var err error
if shouldCreate {
- err = collection.Put(otp).RunWithContext(ctx)
+ err = p.putItem(ctx, schemas.Collections.OTP, otp)
} else {
- err = UpdateByHashKey(collection, "id", otp.ID, otp)
+ err = p.updateByHashKey(ctx, schemas.Collections.OTP, "id", otp.ID, otp)
}
if err != nil {
return nil, err
@@ -59,44 +58,41 @@ func (p *provider) UpsertOTP(ctx context.Context, otpParam *schemas.OTP) (*schem
// GetOTPByEmail to get otp for a given email address
func (p *provider) GetOTPByEmail(ctx context.Context, emailAddress string) (*schemas.OTP, error) {
- var otps []schemas.OTP
- var otp schemas.OTP
- collection := p.db.Table(schemas.Collections.OTP)
- err := collection.Scan().Index("email").Filter("'email' = ?", emailAddress).Limit(1).AllWithContext(ctx, &otps)
+ items, err := p.queryEqLimit(ctx, schemas.Collections.OTP, "email", "email", emailAddress, nil, 1)
if err != nil {
return nil, err
}
- if len(otps) > 0 {
- otp = otps[0]
- return &otp, nil
+ if len(items) == 0 {
+ return nil, errors.New("no docuemnt found")
}
- return nil, errors.New("no docuemnt found")
+ var otp schemas.OTP
+ if err := unmarshalItem(items[0], &otp); err != nil {
+ return nil, err
+ }
+ return &otp, nil
}
// GetOTPByPhoneNumber to get otp for a given phone number
func (p *provider) GetOTPByPhoneNumber(ctx context.Context, phoneNumber string) (*schemas.OTP, error) {
- var otps []schemas.OTP
- var otp schemas.OTP
- collection := p.db.Table(schemas.Collections.OTP)
- err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).Limit(1).AllWithContext(ctx, &otps)
+ f := expression.Name("phone_number").Equal(expression.Value(phoneNumber))
+ items, err := p.scanAllRaw(ctx, schemas.Collections.OTP, nil, &f)
if err != nil {
return nil, err
}
- if len(otps) > 0 {
- otp = otps[0]
- return &otp, nil
+ if len(items) == 0 {
+ return nil, errors.New("no docuemnt found")
}
- return nil, errors.New("no docuemnt found")
+ var otp schemas.OTP
+ if err := unmarshalItem(items[0], &otp); err != nil {
+ return nil, err
+ }
+ return &otp, nil
}
// DeleteOTP to delete otp
func (p *provider) DeleteOTP(ctx context.Context, otp *schemas.OTP) error {
- collection := p.db.Table(schemas.Collections.OTP)
- if otp.ID != "" {
- err := collection.Delete("id", otp.ID).RunWithContext(ctx)
- if err != nil {
- return err
- }
+ if otp == nil || otp.ID == "" {
+ return nil
}
- return nil
+ return p.deleteItemByHash(ctx, schemas.Collections.OTP, "id", otp.ID)
}
diff --git a/internal/storage/db/dynamodb/provider.go b/internal/storage/db/dynamodb/provider.go
index 7a305309a..798f0a7b2 100644
--- a/internal/storage/db/dynamodb/provider.go
+++ b/internal/storage/db/dynamodb/provider.go
@@ -5,14 +5,13 @@ import (
"fmt"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/credentials"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/guregu/dynamo"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ awsconfig "github.com/aws/aws-sdk-go-v2/config"
+ "github.com/aws/aws-sdk-go-v2/credentials"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/rs/zerolog"
"github.com/authorizerdev/authorizer/internal/config"
- "github.com/authorizerdev/authorizer/internal/storage/schemas"
)
// Dependencies struct the dynamodb data store provider
@@ -23,75 +22,75 @@ type Dependencies struct {
type provider struct {
config *config.Config
dependencies *Dependencies
- db *dynamo.DB
+ client *dynamodb.Client
}
-// NewProvider returns a new Dynamo provider
+// NewProvider returns a new Dynamo provider using AWS SDK for Go v2.
func NewProvider(cfg *config.Config, deps *Dependencies) (*provider, error) {
dbURL := cfg.DatabaseURL
awsRegion := cfg.AWSRegion
awsAccessKeyID := cfg.AWSAccessKeyID
awsSecretAccessKey := cfg.AWSSecretAccessKey
- awsCfg := aws.Config{
- MaxRetries: aws.Int(3),
- CredentialsChainVerboseErrors: aws.Bool(true), // for full error logs
+ region := awsRegion
+ if region == "" {
+ region = "us-east-1"
}
- if awsRegion != "" {
- awsCfg.Region = aws.String(awsRegion)
+ loadOpts := []func(*awsconfig.LoadOptions) error{
+ awsconfig.WithRegion(region),
}
- // custom awsAccessKeyID, awsSecretAccessKey took first priority, if not then fetch config from aws credentials
+
if awsAccessKeyID != "" && awsSecretAccessKey != "" {
- awsCfg.Credentials = credentials.NewStaticCredentials(awsAccessKeyID, awsSecretAccessKey, "")
+ loadOpts = append(loadOpts, awsconfig.WithCredentialsProvider(
+ credentials.NewStaticCredentialsProvider(awsAccessKeyID, awsSecretAccessKey, "")))
} else if dbURL != "" {
deps.Log.Info().Msg("Using DB URL for dynamodb")
- // static config in case of testing or local-setup
- awsCfg.Credentials = credentials.NewStaticCredentials("key", "key", "")
- awsCfg.Endpoint = aws.String(dbURL)
+ loadOpts = append(loadOpts, awsconfig.WithCredentialsProvider(
+ credentials.NewStaticCredentialsProvider("key", "key", "")))
} else {
deps.Log.Info().Msg("Using default AWS credentials config from system for dynamodb")
}
- sess, err := session.NewSession(&awsCfg)
+
+ if dbURL != "" {
+ resolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
+ if service == dynamodb.ServiceID {
+ return aws.Endpoint{
+ URL: dbURL,
+ HostnameImmutable: true,
+ }, nil
+ }
+ return aws.Endpoint{}, &aws.EndpointNotFoundError{}
+ })
+ loadOpts = append(loadOpts, awsconfig.WithEndpointResolverWithOptions(resolver))
+ }
+
+ awsCfg, err := awsconfig.LoadDefaultConfig(context.Background(), loadOpts...)
if err != nil {
- return nil, fmt.Errorf("dynamodb session: %w", err)
+ return nil, fmt.Errorf("aws config: %w", err)
+ }
+
+ client := dynamodb.NewFromConfig(awsCfg, func(o *dynamodb.Options) {
+ o.RetryMaxAttempts = 3
+ })
+
+ p := &provider{
+ client: client,
+ config: cfg,
+ dependencies: deps,
}
- db := dynamo.New(sess)
createCtx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
- tables := []struct {
- name string
- model interface{}
- }{
- {schemas.Collections.User, schemas.User{}},
- {schemas.Collections.Session, schemas.Session{}},
- {schemas.Collections.EmailTemplate, schemas.EmailTemplate{}},
- {schemas.Collections.Env, schemas.Env{}},
- {schemas.Collections.OTP, schemas.OTP{}},
- {schemas.Collections.VerificationRequest, schemas.VerificationRequest{}},
- {schemas.Collections.Webhook, schemas.Webhook{}},
- {schemas.Collections.WebhookLog, schemas.WebhookLog{}},
- {schemas.Collections.Authenticators, schemas.Authenticator{}},
- {schemas.Collections.SessionToken, schemas.SessionToken{}},
- {schemas.Collections.MFASession, schemas.MFASession{}},
- {schemas.Collections.OAuthState, schemas.OAuthState{}},
- {schemas.Collections.AuditLog, schemas.AuditLog{}},
- }
- for _, tbl := range tables {
- if werr := db.CreateTable(tbl.name, tbl.model).WaitWithContext(createCtx); werr != nil {
- return nil, fmt.Errorf("dynamodb create/wait table %q: %w", tbl.name, werr)
- }
+ if err := p.ensureTables(createCtx); err != nil {
+ return nil, err
}
- return &provider{
- db: db,
- config: cfg,
- dependencies: deps,
- }, nil
+
+ return p, nil
}
-// Close is a no-op; the AWS SDK session needs no explicit shutdown for typical use.
+// Close is a no-op; the AWS SDK v2 client needs no explicit shutdown for typical use.
func (p *provider) Close() error {
return nil
}
diff --git a/internal/storage/db/dynamodb/session.go b/internal/storage/db/dynamodb/session.go
index 3a06ae406..4d008747b 100644
--- a/internal/storage/db/dynamodb/session.go
+++ b/internal/storage/db/dynamodb/session.go
@@ -11,14 +11,12 @@ import (
// AddSession to save session information in database
func (p *provider) AddSession(ctx context.Context, session *schemas.Session) error {
- collection := p.db.Table(schemas.Collections.Session)
if session.ID == "" {
session.ID = uuid.New().String()
}
session.CreatedAt = time.Now().Unix()
session.UpdatedAt = time.Now().Unix()
- err := collection.Put(session).RunWithContext(ctx)
- return err
+ return p.putItem(ctx, schemas.Collections.Session, session)
}
// DeleteSession to delete session information from database
diff --git a/internal/storage/db/dynamodb/session_token.go b/internal/storage/db/dynamodb/session_token.go
index f6c2be3f3..b004a74f6 100644
--- a/internal/storage/db/dynamodb/session_token.go
+++ b/internal/storage/db/dynamodb/session_token.go
@@ -6,6 +6,7 @@ import (
"strings"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
"github.com/google/uuid"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -23,47 +24,44 @@ func (p *provider) AddSessionToken(ctx context.Context, token *schemas.SessionTo
if token.UpdatedAt == 0 {
token.UpdatedAt = time.Now().Unix()
}
- collection := p.db.Table(schemas.Collections.SessionToken)
- return collection.Put(token).RunWithContext(ctx)
+ return p.putItem(ctx, schemas.Collections.SessionToken, token)
}
// GetSessionTokenByUserIDAndKey retrieves a session token by user ID and key
func (p *provider) GetSessionTokenByUserIDAndKey(ctx context.Context, userId, key string) (*schemas.SessionToken, error) {
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().
- Index("user_id").
- Filter("'user_id' = ? AND 'key_name' = ?", userId, key).
- Limit(1).
- AllWithContext(ctx, &tokens)
+ f := expression.Name("key_name").Equal(expression.Value(key))
+ items, err := p.queryEqLimit(ctx, schemas.Collections.SessionToken, "user_id", "user_id", userId, &f, 1)
if err != nil {
return nil, err
}
- if len(tokens) == 0 {
+ if len(items) == 0 {
return nil, errors.New("session token not found")
}
- return &tokens[0], nil
+ var t schemas.SessionToken
+ if err := unmarshalItem(items[0], &t); err != nil {
+ return nil, err
+ }
+ return &t, nil
}
// DeleteSessionToken deletes a session token by ID
func (p *provider) DeleteSessionToken(ctx context.Context, id string) error {
- collection := p.db.Table(schemas.Collections.SessionToken)
- return collection.Delete("id", id).RunWithContext(ctx)
+ return p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", id)
}
// DeleteSessionTokenByUserIDAndKey deletes a session token by user ID and key
func (p *provider) DeleteSessionTokenByUserIDAndKey(ctx context.Context, userId, key string) error {
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().
- Index("user_id").
- Filter("'user_id' = ? AND 'key_name' = ?", userId, key).
- AllWithContext(ctx, &tokens)
+ f := expression.Name("key_name").Equal(expression.Value(key))
+ items, err := p.queryEq(ctx, schemas.Collections.SessionToken, "user_id", "user_id", userId, &f)
if err != nil {
return err
}
- for _, token := range tokens {
- if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var t schemas.SessionToken
+ if err := unmarshalItem(it, &t); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil {
return err
}
}
@@ -72,15 +70,17 @@ func (p *provider) DeleteSessionTokenByUserIDAndKey(ctx context.Context, userId,
// DeleteAllSessionTokensByUserID deletes all session tokens for a user ID
func (p *provider) DeleteAllSessionTokensByUserID(ctx context.Context, userId string) error {
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().AllWithContext(ctx, &tokens)
+ items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil)
if err != nil {
return err
}
- for _, token := range tokens {
- if strings.Contains(token.UserID, userId) {
- if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var t schemas.SessionToken
+ if err := unmarshalItem(it, &t); err != nil {
+ return err
+ }
+ if strings.Contains(t.UserID, userId) {
+ if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil {
return err
}
}
@@ -91,15 +91,17 @@ func (p *provider) DeleteAllSessionTokensByUserID(ctx context.Context, userId st
// DeleteSessionTokensByNamespace deletes all session tokens for a namespace
func (p *provider) DeleteSessionTokensByNamespace(ctx context.Context, namespace string) error {
prefix := namespace + ":"
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().AllWithContext(ctx, &tokens)
+ items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil)
if err != nil {
return err
}
- for _, token := range tokens {
- if strings.HasPrefix(token.UserID, prefix) {
- if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var t schemas.SessionToken
+ if err := unmarshalItem(it, &t); err != nil {
+ return err
+ }
+ if strings.HasPrefix(t.UserID, prefix) {
+ if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil {
return err
}
}
@@ -110,14 +112,17 @@ func (p *provider) DeleteSessionTokensByNamespace(ctx context.Context, namespace
// CleanExpiredSessionTokens removes expired session tokens from the database
func (p *provider) CleanExpiredSessionTokens(ctx context.Context) error {
currentTime := time.Now().Unix()
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().Filter("'expires_at' < ?", currentTime).AllWithContext(ctx, &tokens)
+ f := expression.Name("expires_at").LessThan(expression.Value(currentTime))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.SessionToken, nil, &f)
if err != nil {
return err
}
- for _, token := range tokens {
- if err := collection.Delete("id", token.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var t schemas.SessionToken
+ if err := unmarshalItem(it, &t); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.SessionToken, "id", t.ID); err != nil {
return err
}
}
@@ -126,15 +131,17 @@ func (p *provider) CleanExpiredSessionTokens(ctx context.Context) error {
// GetAllSessionTokens retrieves all session tokens (for testing)
func (p *provider) GetAllSessionTokens(ctx context.Context) ([]*schemas.SessionToken, error) {
- var tokens []schemas.SessionToken
- collection := p.db.Table(schemas.Collections.SessionToken)
- err := collection.Scan().AllWithContext(ctx, &tokens)
+ items, err := p.scanAllRaw(ctx, schemas.Collections.SessionToken, nil, nil)
if err != nil {
return nil, err
}
var result []*schemas.SessionToken
- for i := range tokens {
- result = append(result, &tokens[i])
+ for _, it := range items {
+ var t schemas.SessionToken
+ if err := unmarshalItem(it, &t); err != nil {
+ return nil, err
+ }
+ result = append(result, &t)
}
return result, nil
}
@@ -151,47 +158,44 @@ func (p *provider) AddMFASession(ctx context.Context, session *schemas.MFASessio
if session.UpdatedAt == 0 {
session.UpdatedAt = time.Now().Unix()
}
- collection := p.db.Table(schemas.Collections.MFASession)
- return collection.Put(session).RunWithContext(ctx)
+ return p.putItem(ctx, schemas.Collections.MFASession, session)
}
// GetMFASessionByUserIDAndKey retrieves an MFA session by user ID and key
func (p *provider) GetMFASessionByUserIDAndKey(ctx context.Context, userId, key string) (*schemas.MFASession, error) {
- var sessions []schemas.MFASession
- collection := p.db.Table(schemas.Collections.MFASession)
- err := collection.Scan().
- Index("user_id").
- Filter("'user_id' = ? AND 'key_name' = ?", userId, key).
- Limit(1).
- AllWithContext(ctx, &sessions)
+ f := expression.Name("key_name").Equal(expression.Value(key))
+ items, err := p.queryEqLimit(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, &f, 1)
if err != nil {
return nil, err
}
- if len(sessions) == 0 {
+ if len(items) == 0 {
return nil, errors.New("MFA session not found")
}
- return &sessions[0], nil
+ var s schemas.MFASession
+ if err := unmarshalItem(items[0], &s); err != nil {
+ return nil, err
+ }
+ return &s, nil
}
// DeleteMFASession deletes an MFA session by ID
func (p *provider) DeleteMFASession(ctx context.Context, id string) error {
- collection := p.db.Table(schemas.Collections.MFASession)
- return collection.Delete("id", id).RunWithContext(ctx)
+ return p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", id)
}
// DeleteMFASessionByUserIDAndKey deletes an MFA session by user ID and key
func (p *provider) DeleteMFASessionByUserIDAndKey(ctx context.Context, userId, key string) error {
- var sessions []schemas.MFASession
- collection := p.db.Table(schemas.Collections.MFASession)
- err := collection.Scan().
- Index("user_id").
- Filter("'user_id' = ? AND 'key_name' = ?", userId, key).
- AllWithContext(ctx, &sessions)
+ f := expression.Name("key_name").Equal(expression.Value(key))
+ items, err := p.queryEq(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, &f)
if err != nil {
return err
}
- for _, session := range sessions {
- if err := collection.Delete("id", session.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var s schemas.MFASession
+ if err := unmarshalItem(it, &s); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", s.ID); err != nil {
return err
}
}
@@ -200,18 +204,17 @@ func (p *provider) DeleteMFASessionByUserIDAndKey(ctx context.Context, userId, k
// GetAllMFASessionsByUserID retrieves all MFA sessions for a user ID
func (p *provider) GetAllMFASessionsByUserID(ctx context.Context, userId string) ([]*schemas.MFASession, error) {
- var sessions []schemas.MFASession
- collection := p.db.Table(schemas.Collections.MFASession)
- err := collection.Scan().
- Index("user_id").
- Filter("'user_id' = ?", userId).
- AllWithContext(ctx, &sessions)
+ items, err := p.queryEq(ctx, schemas.Collections.MFASession, "user_id", "user_id", userId, nil)
if err != nil {
return nil, err
}
var result []*schemas.MFASession
- for i := range sessions {
- result = append(result, &sessions[i])
+ for _, it := range items {
+ var s schemas.MFASession
+ if err := unmarshalItem(it, &s); err != nil {
+ return nil, err
+ }
+ result = append(result, &s)
}
return result, nil
}
@@ -219,14 +222,17 @@ func (p *provider) GetAllMFASessionsByUserID(ctx context.Context, userId string)
// CleanExpiredMFASessions removes expired MFA sessions from the database
func (p *provider) CleanExpiredMFASessions(ctx context.Context) error {
currentTime := time.Now().Unix()
- var sessions []schemas.MFASession
- collection := p.db.Table(schemas.Collections.MFASession)
- err := collection.Scan().Filter("'expires_at' < ?", currentTime).AllWithContext(ctx, &sessions)
+ f := expression.Name("expires_at").LessThan(expression.Value(currentTime))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.MFASession, nil, &f)
if err != nil {
return err
}
- for _, session := range sessions {
- if err := collection.Delete("id", session.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var s schemas.MFASession
+ if err := unmarshalItem(it, &s); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.MFASession, "id", s.ID); err != nil {
return err
}
}
@@ -235,15 +241,17 @@ func (p *provider) CleanExpiredMFASessions(ctx context.Context) error {
// GetAllMFASessions retrieves all MFA sessions (for testing)
func (p *provider) GetAllMFASessions(ctx context.Context) ([]*schemas.MFASession, error) {
- var sessions []schemas.MFASession
- collection := p.db.Table(schemas.Collections.MFASession)
- err := collection.Scan().AllWithContext(ctx, &sessions)
+ items, err := p.scanAllRaw(ctx, schemas.Collections.MFASession, nil, nil)
if err != nil {
return nil, err
}
var result []*schemas.MFASession
- for i := range sessions {
- result = append(result, &sessions[i])
+ for _, it := range items {
+ var s schemas.MFASession
+ if err := unmarshalItem(it, &s); err != nil {
+ return nil, err
+ }
+ result = append(result, &s)
}
return result, nil
}
@@ -260,50 +268,45 @@ func (p *provider) AddOAuthState(ctx context.Context, state *schemas.OAuthState)
if state.UpdatedAt == 0 {
state.UpdatedAt = time.Now().Unix()
}
- // Delete existing state with same state_key first (upsert behavior)
- var existing []schemas.OAuthState
- collection := p.db.Table(schemas.Collections.OAuthState)
- collection.Scan().
- Index("state_key").
- Filter("'state_key' = ?", state.StateKey).
- AllWithContext(ctx, &existing)
- for _, e := range existing {
- collection.Delete("id", e.ID).RunWithContext(ctx)
- }
- return collection.Put(state).RunWithContext(ctx)
+ existing, _ := p.queryEq(ctx, schemas.Collections.OAuthState, "state_key", "state_key", state.StateKey, nil)
+ for _, it := range existing {
+ var e schemas.OAuthState
+ if err := unmarshalItem(it, &e); err != nil {
+ continue
+ }
+ _ = p.deleteItemByHash(ctx, schemas.Collections.OAuthState, "id", e.ID)
+ }
+ return p.putItem(ctx, schemas.Collections.OAuthState, state)
}
// GetOAuthStateByKey retrieves an OAuth state by key
func (p *provider) GetOAuthStateByKey(ctx context.Context, key string) (*schemas.OAuthState, error) {
- var states []schemas.OAuthState
- collection := p.db.Table(schemas.Collections.OAuthState)
- err := collection.Scan().
- Index("state_key").
- Filter("'state_key' = ?", key).
- Limit(1).
- AllWithContext(ctx, &states)
+ items, err := p.queryEqLimit(ctx, schemas.Collections.OAuthState, "state_key", "state_key", key, nil, 1)
if err != nil {
return nil, err
}
- if len(states) == 0 {
+ if len(items) == 0 {
return nil, errors.New("OAuth state not found")
}
- return &states[0], nil
+ var s schemas.OAuthState
+ if err := unmarshalItem(items[0], &s); err != nil {
+ return nil, err
+ }
+ return &s, nil
}
// DeleteOAuthStateByKey deletes an OAuth state by key
func (p *provider) DeleteOAuthStateByKey(ctx context.Context, key string) error {
- var states []schemas.OAuthState
- collection := p.db.Table(schemas.Collections.OAuthState)
- err := collection.Scan().
- Index("state_key").
- Filter("'state_key' = ?", key).
- AllWithContext(ctx, &states)
+ items, err := p.queryEq(ctx, schemas.Collections.OAuthState, "state_key", "state_key", key, nil)
if err != nil {
return err
}
- for _, state := range states {
- if err := collection.Delete("id", state.ID).RunWithContext(ctx); err != nil {
+ for _, it := range items {
+ var s schemas.OAuthState
+ if err := unmarshalItem(it, &s); err != nil {
+ return err
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.OAuthState, "id", s.ID); err != nil {
return err
}
}
@@ -312,15 +315,17 @@ func (p *provider) DeleteOAuthStateByKey(ctx context.Context, key string) error
// GetAllOAuthStates retrieves all OAuth states (for testing)
func (p *provider) GetAllOAuthStates(ctx context.Context) ([]*schemas.OAuthState, error) {
- var states []schemas.OAuthState
- collection := p.db.Table(schemas.Collections.OAuthState)
- err := collection.Scan().AllWithContext(ctx, &states)
+ items, err := p.scanAllRaw(ctx, schemas.Collections.OAuthState, nil, nil)
if err != nil {
return nil, err
}
var result []*schemas.OAuthState
- for i := range states {
- result = append(result, &states[i])
+ for _, it := range items {
+ var s schemas.OAuthState
+ if err := unmarshalItem(it, &s); err != nil {
+ return nil, err
+ }
+ result = append(result, &s)
}
return result, nil
}
diff --git a/internal/storage/db/dynamodb/shared.go b/internal/storage/db/dynamodb/shared.go
deleted file mode 100644
index 5597c0ad7..000000000
--- a/internal/storage/db/dynamodb/shared.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package dynamodb
-
-import (
- "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
- "github.com/guregu/dynamo"
-)
-
-// As updpate all item not supported so set manually via Set and SetNullable for empty field
-func UpdateByHashKey(table dynamo.Table, hashKey string, hashValue string, item interface{}) error {
- existingValue, err := dynamo.MarshalItem(item)
- var i interface{}
- if err != nil {
- return err
- }
- nullableValue, err := dynamodbattribute.MarshalMap(item)
- if err != nil {
- return err
- }
- u := table.Update(hashKey, hashValue)
- for k, v := range existingValue {
- if k == hashKey {
- continue
- }
- u = u.Set(k, v)
- }
- for k, v := range nullableValue {
- if k == hashKey {
- continue
- }
- dynamodbattribute.Unmarshal(v, &i)
- if i == nil {
- u = u.SetNullable(k, v)
- }
- }
- err = u.Run()
- if err != nil {
- return err
- }
- return nil
-}
diff --git a/internal/storage/db/dynamodb/tables.go b/internal/storage/db/dynamodb/tables.go
new file mode 100644
index 000000000..81b8b7054
--- /dev/null
+++ b/internal/storage/db/dynamodb/tables.go
@@ -0,0 +1,184 @@
+package dynamodb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "github.com/authorizerdev/authorizer/internal/storage/schemas"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
+ "github.com/aws/smithy-go"
+)
+
+func gsi(name, hashAttr string) types.GlobalSecondaryIndex {
+ return types.GlobalSecondaryIndex{
+ IndexName: aws.String(name),
+ KeySchema: []types.KeySchemaElement{
+ {AttributeName: aws.String(hashAttr), KeyType: types.KeyTypeHash},
+ },
+ Projection: &types.Projection{ProjectionType: types.ProjectionTypeAll},
+ }
+}
+
+func createTable(ctx context.Context, client *dynamodb.Client, name string, hashAttr string, attrs []types.AttributeDefinition, gsis []types.GlobalSecondaryIndex) error {
+ in := &dynamodb.CreateTableInput{
+ TableName: aws.String(name),
+ BillingMode: types.BillingModePayPerRequest,
+ AttributeDefinitions: attrs,
+ KeySchema: []types.KeySchemaElement{
+ {AttributeName: aws.String(hashAttr), KeyType: types.KeyTypeHash},
+ },
+ }
+ if len(gsis) > 0 {
+ in.GlobalSecondaryIndexes = gsis
+ }
+ _, err := client.CreateTable(ctx, in)
+ if err != nil {
+ var apiErr smithy.APIError
+ if errors.As(err, &apiErr) && apiErr.ErrorCode() == "ResourceInUseException" {
+ return nil
+ }
+ return err
+ }
+ w := dynamodb.NewTableExistsWaiter(client)
+ return w.Wait(ctx, &dynamodb.DescribeTableInput{TableName: aws.String(name)}, 5*time.Minute)
+}
+
+func (p *provider) ensureTables(ctx context.Context) error {
+ tables := []struct {
+ name string
+ hash string
+ attr []types.AttributeDefinition
+ gsi []types.GlobalSecondaryIndex
+ }{
+ {
+ name: schemas.Collections.User,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("email"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("email", "email")},
+ },
+ {
+ name: schemas.Collections.Session,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")},
+ },
+ {
+ name: schemas.Collections.Env,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ },
+ {
+ name: schemas.Collections.Webhook,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("event_name"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("event_name", "event_name")},
+ },
+ {
+ name: schemas.Collections.WebhookLog,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("webhook_id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("webhook_id", "webhook_id")},
+ },
+ {
+ name: schemas.Collections.EmailTemplate,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("event_name"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("event_name", "event_name")},
+ },
+ {
+ name: schemas.Collections.OTP,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("email"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("email", "email")},
+ },
+ {
+ name: schemas.Collections.VerificationRequest,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("token"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("token", "token")},
+ },
+ {
+ name: schemas.Collections.Authenticators,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")},
+ },
+ {
+ name: schemas.Collections.SessionToken,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")},
+ },
+ {
+ name: schemas.Collections.MFASession,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("user_id"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("user_id", "user_id")},
+ },
+ {
+ name: schemas.Collections.OAuthState,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("state_key"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{gsi("state_key", "state_key")},
+ },
+ {
+ name: schemas.Collections.AuditLog,
+ hash: "id",
+ attr: []types.AttributeDefinition{
+ {AttributeName: aws.String("id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("actor_id"), AttributeType: types.ScalarAttributeTypeS},
+ {AttributeName: aws.String("action"), AttributeType: types.ScalarAttributeTypeS},
+ },
+ gsi: []types.GlobalSecondaryIndex{
+ gsi("actor_id", "actor_id"),
+ gsi("action", "action"),
+ },
+ },
+ }
+
+ for _, t := range tables {
+ if err := createTable(ctx, p.client, t.name, t.hash, t.attr, t.gsi); err != nil {
+ return fmt.Errorf("create table %s: %w", t.name, err)
+ }
+ }
+ return nil
+}
diff --git a/internal/storage/db/dynamodb/user.go b/internal/storage/db/dynamodb/user.go
index da517e4bf..6f2bad702 100644
--- a/internal/storage/db/dynamodb/user.go
+++ b/internal/storage/db/dynamodb/user.go
@@ -7,17 +7,34 @@ import (
"strings"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/refs"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
)
+// normalizeUserOptionalPtrs clears *int64 fields that round-trip as 0 from DynamoDB where other
+// providers use SQL NULL — tests and handlers treat nil as "unset" for these fields.
+func normalizeUserOptionalPtrs(u *schemas.User) {
+ if u == nil {
+ return
+ }
+ if u.EmailVerifiedAt != nil && *u.EmailVerifiedAt == 0 {
+ u.EmailVerifiedAt = nil
+ }
+ if u.PhoneNumberVerifiedAt != nil && *u.PhoneNumberVerifiedAt == 0 {
+ u.PhoneNumberVerifiedAt = nil
+ }
+ if u.RevokedTimestamp != nil && *u.RevokedTimestamp == 0 {
+ u.RevokedTimestamp = nil
+ }
+}
+
// AddUser to save user information in database
func (p *provider) AddUser(ctx context.Context, user *schemas.User) (*schemas.User, error) {
- collection := p.db.Table(schemas.Collections.User)
if user.ID == "" {
user.ID = uuid.New().String()
}
@@ -35,20 +52,37 @@ func (p *provider) AddUser(ctx context.Context, user *schemas.User) (*schemas.Us
}
user.CreatedAt = time.Now().Unix()
user.UpdatedAt = time.Now().Unix()
- err := collection.Put(user).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.User, user); err != nil {
return nil, err
}
return user, nil
}
+// userDynamoRemoveAttrsIfNil lists attribute names to REMOVE so that optional nil-pointer fields
+// match SQL NULL semantics (omitting them from SET in DynamoDB would otherwise leave old values).
+func userDynamoRemoveAttrsIfNil(u *schemas.User) []string {
+ if u == nil {
+ return nil
+ }
+ var remove []string
+ if u.EmailVerifiedAt == nil {
+ remove = append(remove, "email_verified_at")
+ }
+ if u.PhoneNumberVerifiedAt == nil {
+ remove = append(remove, "phone_number_verified_at")
+ }
+ if u.RevokedTimestamp == nil {
+ remove = append(remove, "revoked_timestamp")
+ }
+ return remove
+}
+
// UpdateUser to update user information in database
func (p *provider) UpdateUser(ctx context.Context, user *schemas.User) (*schemas.User, error) {
- collection := p.db.Table(schemas.Collections.User)
if user.ID != "" {
user.UpdatedAt = time.Now().Unix()
- err := UpdateByHashKey(collection, "id", user.ID, user)
- if err != nil {
+ remove := userDynamoRemoveAttrsIfNil(user)
+ if err := p.updateByHashKeyWithRemoves(ctx, schemas.Collections.User, "id", user.ID, user, remove); err != nil {
return nil, err
}
}
@@ -57,15 +91,22 @@ func (p *provider) UpdateUser(ctx context.Context, user *schemas.User) (*schemas
// DeleteUser to delete user information from database
func (p *provider) DeleteUser(ctx context.Context, user *schemas.User) error {
- collection := p.db.Table(schemas.Collections.User)
- sessionCollection := p.db.Table(schemas.Collections.Session)
- if user.ID != "" {
- err := collection.Delete("id", user.ID).Run()
- if err != nil {
+ if user.ID == "" {
+ return nil
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.User, "id", user.ID); err != nil {
+ return err
+ }
+ items, err := p.queryEq(ctx, schemas.Collections.Session, "user_id", "user_id", user.ID, nil)
+ if err != nil {
+ return err
+ }
+ for _, it := range items {
+ var s schemas.Session
+ if err := unmarshalItem(it, &s); err != nil {
return err
}
- _, err = sessionCollection.Batch("id").Write().Delete(dynamo.Keys{"user_id", user.ID}).RunWithContext(ctx)
- if err != nil {
+ if err := p.deleteItemByHash(ctx, schemas.Collections.Session, "id", s.ID); err != nil {
return err
}
}
@@ -74,31 +115,36 @@ func (p *provider) DeleteUser(ctx context.Context, user *schemas.User) error {
// ListUsers to get list of users from database
func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination) ([]*schemas.User, *model.Pagination, error) {
- var user *schemas.User
- var lastEval dynamo.PagingKey
- var iter dynamo.PagingIter
- var iteration int64 = 0
- collection := p.db.Table(schemas.Collections.User)
- var users []*schemas.User
+ var lastKey map[string]types.AttributeValue
+ var iteration int64
paginationClone := pagination
- scanner := collection.Scan()
- count, err := scanner.Count()
+ var users []*schemas.User
+
+ count, err := p.scanCount(ctx, schemas.Collections.User, nil)
if err != nil {
return nil, nil, err
}
+
for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &user) {
+ items, next, err := p.scanPageIter(ctx, schemas.Collections.User, nil, int32(paginationClone.Limit), lastKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ for _, it := range items {
+ var u schemas.User
+ if err := unmarshalItem(it, &u); err != nil {
+ return nil, nil, err
+ }
+ normalizeUserOptionalPtrs(&u)
if paginationClone.Offset == iteration {
- users = append(users, user)
+ users = append(users, &u)
}
}
- lastEval = iter.LastEvaluatedKey()
+ lastKey = next
iteration += paginationClone.Limit
- }
- err = iter.Err()
- if err != nil {
- return nil, nil, err
+ if lastKey == nil {
+ break
+ }
}
paginationClone.Total = count
return users, paginationClone, nil
@@ -106,79 +152,77 @@ func (p *provider) ListUsers(ctx context.Context, pagination *model.Pagination)
// GetUserByEmail to get user information from database using email address
func (p *provider) GetUserByEmail(ctx context.Context, email string) (*schemas.User, error) {
- var users []*schemas.User
- var user *schemas.User
- collection := p.db.Table(schemas.Collections.User)
- err := collection.Scan().Index("email").Filter("'email' = ?", email).AllWithContext(ctx, &users)
+ items, err := p.queryEq(ctx, schemas.Collections.User, "email", "email", email, nil)
if err != nil {
- return user, nil
+ return nil, err
}
- if len(users) > 0 {
- user = users[0]
- return user, nil
- } else {
+ if len(items) == 0 {
return nil, errors.New("no record found")
}
+ var u schemas.User
+ if err := unmarshalItem(items[0], &u); err != nil {
+ return nil, err
+ }
+ normalizeUserOptionalPtrs(&u)
+ return &u, nil
}
// GetUserByID to get user information from database using user ID
func (p *provider) GetUserByID(ctx context.Context, id string) (*schemas.User, error) {
- collection := p.db.Table(schemas.Collections.User)
- var user *schemas.User
- err := collection.Get("id", id).OneWithContext(ctx, &user)
+ var user schemas.User
+ err := p.getItemByHash(ctx, schemas.Collections.User, "id", id, &user)
if err != nil {
- if refs.StringValue(user.Email) == "" {
- return nil, errors.New("no documets found")
- } else {
- return user, nil
- }
+ return nil, errors.New("no documets found")
}
- return user, nil
+ normalizeUserOptionalPtrs(&user)
+ return &user, nil
}
// UpdateUsers to update multiple users, with parameters of user IDs slice
-// If ids set to nil / empty all the users will be updated
func (p *provider) UpdateUsers(ctx context.Context, data map[string]interface{}, ids []string) error {
- // set updated_at time for all users
- userCollection := p.db.Table(schemas.Collections.User)
- var allUsers []schemas.User
- var res int64 = 0
+ var res int64
var err error
if len(ids) > 0 {
for _, v := range ids {
- err = UpdateByHashKey(userCollection, "id", v, data)
+ err = p.updateByHashKey(ctx, schemas.Collections.User, "id", v, data)
}
} else {
- // as there is no facility to update all doc - https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/SQLtoNoSQL.UpdateData.html
- userCollection.Scan().All(&allUsers)
- for _, user := range allUsers {
- err = UpdateByHashKey(userCollection, "id", user.ID, data)
+ items, errScan := p.scanAllRaw(ctx, schemas.Collections.User, nil, nil)
+ if errScan != nil {
+ return errScan
+ }
+ for _, it := range items {
+ var user schemas.User
+ if err := unmarshalItem(it, &user); err != nil {
+ return err
+ }
+ err = p.updateByHashKey(ctx, schemas.Collections.User, "id", user.ID, data)
if err == nil {
- res = res + 1
+ res++
}
}
}
if err != nil {
return err
- } else {
- p.dependencies.Log.Info().Int64("modified_count", res).Msg("users updated")
}
+ p.dependencies.Log.Info().Int64("modified_count", res).Msg("users updated")
return nil
}
// GetUserByPhoneNumber to get user information from database using phone number
func (p *provider) GetUserByPhoneNumber(ctx context.Context, phoneNumber string) (*schemas.User, error) {
- var users []*schemas.User
- var user *schemas.User
- collection := p.db.Table(schemas.Collections.User)
- err := collection.Scan().Filter("'phone_number' = ?", phoneNumber).AllWithContext(ctx, &users)
+ f := expression.Name("phone_number").Equal(expression.Value(phoneNumber))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.User, nil, &f)
if err != nil {
return nil, err
}
- if len(users) > 0 {
- user = users[0]
- return user, nil
- } else {
+ if len(items) == 0 {
return nil, errors.New("no record found")
}
+ var u schemas.User
+ if err := unmarshalItem(items[0], &u); err != nil {
+ return nil, err
+ }
+ normalizeUserOptionalPtrs(&u)
+ return &u, nil
}
diff --git a/internal/storage/db/dynamodb/verification_requests.go b/internal/storage/db/dynamodb/verification_requests.go
index c44dd52e6..0d9591901 100644
--- a/internal/storage/db/dynamodb/verification_requests.go
+++ b/internal/storage/db/dynamodb/verification_requests.go
@@ -4,8 +4,9 @@ import (
"context"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -13,13 +14,11 @@ import (
// AddVerification to save verification request in database
func (p *provider) AddVerificationRequest(ctx context.Context, verificationRequest *schemas.VerificationRequest) (*schemas.VerificationRequest, error) {
- collection := p.db.Table(schemas.Collections.VerificationRequest)
if verificationRequest.ID == "" {
verificationRequest.ID = uuid.New().String()
verificationRequest.CreatedAt = time.Now().Unix()
verificationRequest.UpdatedAt = time.Now().Unix()
- err := collection.Put(verificationRequest).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.VerificationRequest, verificationRequest); err != nil {
return nil, err
}
}
@@ -28,61 +27,68 @@ func (p *provider) AddVerificationRequest(ctx context.Context, verificationReque
// GetVerificationRequestByToken to get verification request from database using token
func (p *provider) GetVerificationRequestByToken(ctx context.Context, token string) (*schemas.VerificationRequest, error) {
- collection := p.db.Table(schemas.Collections.VerificationRequest)
- var verificationRequest *schemas.VerificationRequest
- iter := collection.Scan().Filter("'token' = ?", token).Iter()
- for iter.NextWithContext(ctx, &verificationRequest) {
- return verificationRequest, nil
- }
- err := iter.Err()
+ items, err := p.queryEq(ctx, schemas.Collections.VerificationRequest, "token", "token", token, nil)
if err != nil {
return nil, err
}
- return verificationRequest, nil
+ if len(items) == 0 {
+ return nil, nil
+ }
+ var v schemas.VerificationRequest
+ if err := unmarshalItem(items[0], &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
}
// GetVerificationRequestByEmail to get verification request by email from database
func (p *provider) GetVerificationRequestByEmail(ctx context.Context, email string, identifier string) (*schemas.VerificationRequest, error) {
- var verificationRequest *schemas.VerificationRequest
- collection := p.db.Table(schemas.Collections.VerificationRequest)
- iter := collection.Scan().Filter("'email' = ?", email).Filter("'identifier' = ?", identifier).Iter()
- for iter.NextWithContext(ctx, &verificationRequest) {
- return verificationRequest, nil
- }
- err := iter.Err()
+ f := expression.Name("email").Equal(expression.Value(email)).And(expression.Name("identifier").Equal(expression.Value(identifier)))
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.VerificationRequest, nil, &f)
if err != nil {
return nil, err
}
- return verificationRequest, nil
+ if len(items) == 0 {
+ return nil, nil
+ }
+ var v schemas.VerificationRequest
+ if err := unmarshalItem(items[0], &v); err != nil {
+ return nil, err
+ }
+ return &v, nil
}
// ListVerificationRequests to get list of verification requests from database
func (p *provider) ListVerificationRequests(ctx context.Context, pagination *model.Pagination) ([]*schemas.VerificationRequest, *model.Pagination, error) {
- var verificationRequests []*schemas.VerificationRequest
- var verificationRequest *schemas.VerificationRequest
- var lastEval dynamo.PagingKey
- var iter dynamo.PagingIter
- var iteration int64 = 0
- collection := p.db.Table(schemas.Collections.VerificationRequest)
+ var lastKey map[string]types.AttributeValue
+ var iteration int64
paginationClone := pagination
- scanner := collection.Scan()
- count, err := scanner.Count()
+ var verificationRequests []*schemas.VerificationRequest
+
+ count, err := p.scanCount(ctx, schemas.Collections.VerificationRequest, nil)
if err != nil {
return nil, nil, err
}
+
for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &verificationRequest) {
- if paginationClone.Offset == iteration {
- verificationRequests = append(verificationRequests, verificationRequest)
- }
- }
- err = iter.Err()
+ items, next, err := p.scanPageIter(ctx, schemas.Collections.VerificationRequest, nil, int32(paginationClone.Limit), lastKey)
if err != nil {
return nil, nil, err
}
- lastEval = iter.LastEvaluatedKey()
+ for _, it := range items {
+ var v schemas.VerificationRequest
+ if err := unmarshalItem(it, &v); err != nil {
+ return nil, nil, err
+ }
+ if paginationClone.Offset == iteration {
+ verificationRequests = append(verificationRequests, &v)
+ }
+ }
+ lastKey = next
iteration += paginationClone.Limit
+ if lastKey == nil {
+ break
+ }
}
paginationClone.Total = count
return verificationRequests, paginationClone, nil
@@ -90,13 +96,8 @@ func (p *provider) ListVerificationRequests(ctx context.Context, pagination *mod
// DeleteVerificationRequest to delete verification request from database
func (p *provider) DeleteVerificationRequest(ctx context.Context, verificationRequest *schemas.VerificationRequest) error {
- collection := p.db.Table(schemas.Collections.VerificationRequest)
- if verificationRequest != nil {
- err := collection.Delete("id", verificationRequest.ID).RunWithContext(ctx)
-
- if err != nil {
- return err
- }
+ if verificationRequest == nil {
+ return nil
}
- return nil
+ return p.deleteItemByHash(ctx, schemas.Collections.VerificationRequest, "id", verificationRequest.ID)
}
diff --git a/internal/storage/db/dynamodb/webhook.go b/internal/storage/db/dynamodb/webhook.go
index e85ef9640..53b957150 100644
--- a/internal/storage/db/dynamodb/webhook.go
+++ b/internal/storage/db/dynamodb/webhook.go
@@ -7,8 +7,9 @@ import (
"strings"
"time"
+ "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -16,17 +17,14 @@ import (
// AddWebhook to add webhook
func (p *provider) AddWebhook(ctx context.Context, webhook *schemas.Webhook) (*schemas.Webhook, error) {
- collection := p.db.Table(schemas.Collections.Webhook)
if webhook.ID == "" {
webhook.ID = uuid.New().String()
}
webhook.Key = webhook.ID
webhook.CreatedAt = time.Now().Unix()
webhook.UpdatedAt = time.Now().Unix()
- // Add timestamp to make event name unique for legacy version
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
- err := collection.Put(webhook).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.Webhook, webhook); err != nil {
return nil, err
}
return webhook, nil
@@ -35,13 +33,10 @@ func (p *provider) AddWebhook(ctx context.Context, webhook *schemas.Webhook) (*s
// UpdateWebhook to update webhook
func (p *provider) UpdateWebhook(ctx context.Context, webhook *schemas.Webhook) (*schemas.Webhook, error) {
webhook.UpdatedAt = time.Now().Unix()
- // Event is changed
if !strings.Contains(webhook.EventName, "-") {
webhook.EventName = fmt.Sprintf("%s-%d", webhook.EventName, time.Now().Unix())
}
- collection := p.db.Table(schemas.Collections.Webhook)
- err := UpdateByHashKey(collection, "id", webhook.ID, webhook)
- if err != nil {
+ if err := p.updateByHashKey(ctx, schemas.Collections.Webhook, "id", webhook.ID, webhook); err != nil {
return nil, err
}
return webhook, nil
@@ -49,31 +44,35 @@ func (p *provider) UpdateWebhook(ctx context.Context, webhook *schemas.Webhook)
// ListWebhooks to list webhook
func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination) ([]*schemas.Webhook, *model.Pagination, error) {
- webhooks := []*schemas.Webhook{}
- var webhook *schemas.Webhook
- var lastEval dynamo.PagingKey
- var iter dynamo.PagingIter
- var iteration int64 = 0
- collection := p.db.Table(schemas.Collections.Webhook)
+ var lastKey map[string]types.AttributeValue
+ var iteration int64
paginationClone := pagination
- scanner := collection.Scan()
- count, err := scanner.Count()
+ var webhooks []*schemas.Webhook
+
+ count, err := p.scanCount(ctx, schemas.Collections.Webhook, nil)
if err != nil {
return nil, nil, err
}
+
for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &webhook) {
- if paginationClone.Offset == iteration {
- webhooks = append(webhooks, webhook)
- }
- }
- err = iter.Err()
+ items, next, err := p.scanPageIter(ctx, schemas.Collections.Webhook, nil, int32(paginationClone.Limit), lastKey)
if err != nil {
return nil, nil, err
}
- lastEval = iter.LastEvaluatedKey()
+ for _, it := range items {
+ var w schemas.Webhook
+ if err := unmarshalItem(it, &w); err != nil {
+ return nil, nil, err
+ }
+ if paginationClone.Offset == iteration {
+ webhooks = append(webhooks, &w)
+ }
+ }
+ lastKey = next
iteration += paginationClone.Limit
+ if lastKey == nil {
+ break
+ }
}
paginationClone.Total = count
return webhooks, paginationClone, nil
@@ -81,51 +80,52 @@ func (p *provider) ListWebhook(ctx context.Context, pagination *model.Pagination
// GetWebhookByID to get webhook by id
func (p *provider) GetWebhookByID(ctx context.Context, webhookID string) (*schemas.Webhook, error) {
- collection := p.db.Table(schemas.Collections.Webhook)
- var webhook *schemas.Webhook
- err := collection.Get("id", webhookID).OneWithContext(ctx, &webhook)
+ var webhook schemas.Webhook
+ err := p.getItemByHash(ctx, schemas.Collections.Webhook, "id", webhookID, &webhook)
if err != nil {
return nil, err
}
if webhook.ID == "" {
return nil, errors.New("no document found")
}
- return webhook, nil
+ return &webhook, nil
}
// GetWebhookByEventName to get webhook by event_name
func (p *provider) GetWebhookByEventName(ctx context.Context, eventName string) ([]*schemas.Webhook, error) {
- webhooks := []*schemas.Webhook{}
- collection := p.db.Table(schemas.Collections.Webhook)
- err := collection.Scan().Index("event_name").Filter("contains(event_name, ?)", eventName).AllWithContext(ctx, &webhooks)
+ // Match SQL LIKE 'eventName%' (see sql/webhook.go); do not use Contains (substring match).
+ f := expression.Name("event_name").BeginsWith(eventName)
+ items, err := p.scanFilteredAll(ctx, schemas.Collections.Webhook, strPtr("event_name"), &f)
if err != nil {
return nil, err
}
- return webhooks, nil
+ var out []*schemas.Webhook
+ for _, it := range items {
+ var w schemas.Webhook
+ if err := unmarshalItem(it, &w); err != nil {
+ return nil, err
+ }
+ out = append(out, &w)
+ }
+ return out, nil
}
// DeleteWebhook to delete webhook
func (p *provider) DeleteWebhook(ctx context.Context, webhook *schemas.Webhook) error {
- // Also delete webhook logs for given webhook id
- if webhook != nil {
- webhookCollection := p.db.Table(schemas.Collections.Webhook)
- webhookLogCollection := p.db.Table(schemas.Collections.WebhookLog)
- err := webhookCollection.Delete("id", webhook.ID).RunWithContext(ctx)
- if err != nil {
- return err
- }
- pagination := &model.Pagination{}
- webhookLogs, _, err := p.ListWebhookLogs(ctx, pagination, webhook.ID)
- if err != nil {
- p.dependencies.Log.Debug().Err(err).Msg("failed to list webhook logs")
- } else {
- for _, webhookLog := range webhookLogs {
- err = webhookLogCollection.Delete("id", webhookLog.ID).RunWithContext(ctx)
- if err != nil {
- p.dependencies.Log.Debug().Err(err).Msg("failed to delete webhook log")
- // continue
- }
- }
+ if webhook == nil {
+ return nil
+ }
+ if err := p.deleteItemByHash(ctx, schemas.Collections.Webhook, "id", webhook.ID); err != nil {
+ return err
+ }
+ logs, _, err := p.ListWebhookLogs(ctx, &model.Pagination{}, webhook.ID)
+ if err != nil {
+ p.dependencies.Log.Debug().Err(err).Msg("failed to list webhook logs")
+ return nil
+ }
+ for _, wl := range logs {
+ if err := p.deleteItemByHash(ctx, schemas.Collections.WebhookLog, "id", wl.ID); err != nil {
+ p.dependencies.Log.Debug().Err(err).Msg("failed to delete webhook log")
}
}
return nil
diff --git a/internal/storage/db/dynamodb/webhook_log.go b/internal/storage/db/dynamodb/webhook_log.go
index 2092b1d18..9b794829b 100644
--- a/internal/storage/db/dynamodb/webhook_log.go
+++ b/internal/storage/db/dynamodb/webhook_log.go
@@ -4,8 +4,8 @@ import (
"context"
"time"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/google/uuid"
- "github.com/guregu/dynamo"
"github.com/authorizerdev/authorizer/internal/graph/model"
"github.com/authorizerdev/authorizer/internal/storage/schemas"
@@ -13,15 +13,13 @@ import (
// AddWebhookLog to add webhook log
func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *schemas.WebhookLog) (*schemas.WebhookLog, error) {
- collection := p.db.Table(schemas.Collections.WebhookLog)
if webhookLog.ID == "" {
webhookLog.ID = uuid.New().String()
}
webhookLog.Key = webhookLog.ID
webhookLog.CreatedAt = time.Now().Unix()
webhookLog.UpdatedAt = time.Now().Unix()
- err := collection.Put(webhookLog).RunWithContext(ctx)
- if err != nil {
+ if err := p.putItem(ctx, schemas.Collections.WebhookLog, webhookLog); err != nil {
return nil, err
}
return webhookLog, nil
@@ -29,43 +27,48 @@ func (p *provider) AddWebhookLog(ctx context.Context, webhookLog *schemas.Webhoo
// ListWebhookLogs to list webhook logs
func (p *provider) ListWebhookLogs(ctx context.Context, pagination *model.Pagination, webhookID string) ([]*schemas.WebhookLog, *model.Pagination, error) {
+ paginationClone := pagination
+ // Non-nil empty slice: callers/tests expect a slice value even when there are no rows.
webhookLogs := []*schemas.WebhookLog{}
- var webhookLog *schemas.WebhookLog
- var lastEval dynamo.PagingKey
- var iter dynamo.PagingIter
- var iteration int64 = 0
- var err error
- var count int64
- collection := p.db.Table(schemas.Collections.WebhookLog)
- paginationClone := pagination
- scanner := collection.Scan()
if webhookID != "" {
- iter = scanner.Index("webhook_id").Filter("'webhook_id' = ?", webhookID).Iter()
- for iter.NextWithContext(ctx, &webhookLog) {
- webhookLogs = append(webhookLogs, webhookLog)
- }
- err = iter.Err()
+ items, err := p.queryEq(ctx, schemas.Collections.WebhookLog, "webhook_id", "webhook_id", webhookID, nil)
if err != nil {
return nil, nil, err
}
- } else {
- for (paginationClone.Offset + paginationClone.Limit) > iteration {
- iter = scanner.StartFrom(lastEval).Limit(paginationClone.Limit).Iter()
- for iter.NextWithContext(ctx, &webhookLog) {
- if paginationClone.Offset == iteration {
- webhookLogs = append(webhookLogs, webhookLog)
- }
+ for _, it := range items {
+ var wl schemas.WebhookLog
+ if err := unmarshalItem(it, &wl); err != nil {
+ return nil, nil, err
}
- err = iter.Err()
- if err != nil {
+ webhookLogs = append(webhookLogs, &wl)
+ }
+ paginationClone.Total = 0
+ return webhookLogs, paginationClone, nil
+ }
+
+ var lastKey map[string]types.AttributeValue
+ var iteration int64
+ for (paginationClone.Offset + paginationClone.Limit) > iteration {
+ items, next, err := p.scanPageIter(ctx, schemas.Collections.WebhookLog, nil, int32(paginationClone.Limit), lastKey)
+ if err != nil {
+ return nil, nil, err
+ }
+ for _, it := range items {
+ var wl schemas.WebhookLog
+ if err := unmarshalItem(it, &wl); err != nil {
return nil, nil, err
}
- lastEval = iter.LastEvaluatedKey()
- iteration += paginationClone.Limit
+ if paginationClone.Offset == iteration {
+ webhookLogs = append(webhookLogs, &wl)
+ }
+ }
+ lastKey = next
+ iteration += paginationClone.Limit
+ if lastKey == nil {
+ break
}
}
- paginationClone.Total = count
- // paginationClone.Cursor = iter.LastEvaluatedKey()
+ paginationClone.Total = 0
return webhookLogs, paginationClone, nil
}
diff --git a/internal/storage/provider_test.go b/internal/storage/provider_test.go
index cab6f6e67..721618e22 100644
--- a/internal/storage/provider_test.go
+++ b/internal/storage/provider_test.go
@@ -74,7 +74,8 @@ func getTestDBConfig(dbType string) *config.Config {
// Allow extra time for Couchbase container to become ready in tests
cfg.CouchBaseWaitTimeout = 120
case constants.DbTypeDynamoDB:
- cfg.DatabaseURL = "http://0.0.0.0:8000"
+ // Must be a client-routable host (not bind address 0.0.0.0); matches integration_tests getDBURL.
+ cfg.DatabaseURL = "http://127.0.0.1:8000"
}
return cfg
diff --git a/internal/storage/schemas/audit_log.go b/internal/storage/schemas/audit_log.go
index 116f49f05..bd004c10f 100644
--- a/internal/storage/schemas/audit_log.go
+++ b/internal/storage/schemas/audit_log.go
@@ -13,10 +13,10 @@ import (
type AuditLog struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"`
- ActorID string `gorm:"type:char(36)" json:"actor_id" bson:"actor_id" cql:"actor_id" dynamo:"actor_id" index:"actor_id,hash"`
+ ActorID string `gorm:"type:char(36)" json:"actor_id" bson:"actor_id" cql:"actor_id" dynamo:"actor_id,omitempty" index:"actor_id,hash"`
ActorType string `gorm:"type:varchar(30)" json:"actor_type" bson:"actor_type" cql:"actor_type" dynamo:"actor_type"`
ActorEmail string `gorm:"type:varchar(256)" json:"actor_email" bson:"actor_email" cql:"actor_email" dynamo:"actor_email"`
- Action string `gorm:"type:varchar(100)" json:"action" bson:"action" cql:"action" dynamo:"action" index:"action,hash"`
+ Action string `gorm:"type:varchar(100)" json:"action" bson:"action" cql:"action" dynamo:"action,omitempty" index:"action,hash"`
ResourceType string `gorm:"type:varchar(50)" json:"resource_type" bson:"resource_type" cql:"resource_type" dynamo:"resource_type"`
ResourceID string `gorm:"type:char(36)" json:"resource_id" bson:"resource_id" cql:"resource_id" dynamo:"resource_id"`
IPAddress string `gorm:"type:varchar(45)" json:"ip_address" bson:"ip_address" cql:"ip_address" dynamo:"ip_address"`
diff --git a/internal/storage/schemas/otp.go b/internal/storage/schemas/otp.go
index 4fb94ead6..1985f2852 100644
--- a/internal/storage/schemas/otp.go
+++ b/internal/storage/schemas/otp.go
@@ -11,7 +11,7 @@ const (
type OTP struct {
Key string `json:"_key,omitempty" bson:"_key,omitempty" cql:"_key,omitempty" dynamo:"key,omitempty"` // for arangodb
ID string `gorm:"primaryKey;type:char(36)" json:"_id" bson:"_id" cql:"id" dynamo:"id,hash"`
- Email string `gorm:"index" json:"email" bson:"email" cql:"email" dynamo:"email" index:"email,hash"`
+ Email string `gorm:"index" json:"email" bson:"email" cql:"email" dynamo:"email,omitempty" index:"email,hash"`
PhoneNumber string `gorm:"index" json:"phone_number" bson:"phone_number" cql:"phone_number" dynamo:"phone_number"`
Otp string `json:"otp" bson:"otp" cql:"otp" dynamo:"otp"`
ExpiresAt int64 `json:"expires_at" bson:"expires_at" cql:"expires_at" dynamo:"expires_at"`
diff --git a/internal/storage/schemas/user.go b/internal/storage/schemas/user.go
index 8c0244585..674c500de 100644
--- a/internal/storage/schemas/user.go
+++ b/internal/storage/schemas/user.go
@@ -9,6 +9,9 @@ import (
)
// Note: any change here should be reflected in providers/casandra/provider.go as it does not have model support in collection creation
+//
+// Nullable pointers (*int64, *string, etc.): do not add json/bson omitempty to fields that must
+// clear stored values when nil — see docs/storage-optional-null-fields.md.
// User model for db
type User struct {